/*
 * Copyright 2023 Google LLC
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "iree/builtins/ukernel/elementwise.h"

#include <math.h>
#include <riscv_vector.h>

//===----------------------------------------------------------------------===//
// Helpers for defining generic implementations of elementwise functions.
// Since it affords the best code size tradeoff options, the entrypoint
// is dispatched based on an opcode.
//===----------------------------------------------------------------------===//

// Opcodes for generic functions operating on 32-bit operands and result.
// Since the outer dispatcher only differentiates based on width, all other
// type specificity is carried by the opcode.
// Binary opcodes are named "X32B" and unary opcodes "X32U".
// The initial list was sorted, and it is encouraged to sort extensions, but
// each opcode must be numerically stable, so the list is not expected to
// be sorted over time.
typedef enum {
  IREE_UK_X32B_ADDF = 0,
  IREE_UK_X32B_ADDI = 1,
  IREE_UK_X32B_ANDI = 2,
  IREE_UK_X32B_DIVF = 3,
  IREE_UK_X32B_DIVSI = 4,
  IREE_UK_X32B_DIVUI = 5,
  IREE_UK_X32B_MULF = 6,
  IREE_UK_X32B_MULI = 7,
  IREE_UK_X32B_ORI = 8,
  IREE_UK_X32B_SHLI = 9,
  IREE_UK_X32B_SHRSI = 10,
  IREE_UK_X32B_SHRUI = 11,
  IREE_UK_X32B_SUBF = 12,
  IREE_UK_X32B_SUBI = 13,
  IREE_UKENREL_X32B_XORI = 14,
} iree_uk_x32b_opcode_t;

typedef enum {
  IREE_UK_X32B_UI = 0,  // unsigned integer
  IREE_UK_X32B_SI = 1,  // signed integer
  IREE_UK_X32B_NA = 2,  // not available in RVV
} iree_uk_x32b_opcode_type_t;

typedef enum {
  IREE_UK_X32U_ABSF,
  IREE_UK_X32U_CEILF,
  IREE_UK_X32U_CTLZ,
  IREE_UK_X32U_EXPF,
  IREE_UK_X32U_FLOORF,
  IREE_UK_X32U_LOGF,
  IREE_UK_X32U_NEGF,
  IREE_UK_X32U_RSQRTF,
} iree_uk_x32u_opcode_t;

// Macros to access various typed, dereferenced pointers.
#define ASF32(ptr) *((float*)ptr)
#define ASUI32(ptr) *((iree_uk_uint32_t*)ptr)
#define ASSI32(ptr) *((iree_uk_int32_t*)ptr)

//===----------------------------------------------------------------------===//
// Implementation macros.
//===----------------------------------------------------------------------===//

// Defines a generic "dispatched" implementation via opcode_t by invoking
// the function iree_uk_generic_{category}_2d.
// Corresponds to the header macro DECLARE_UKERNEL_BINARY_2D.
#define DISPATCH_UKERNEL_BINARY_2D(opcode, opcode_t, dtype, category)         \
  IREE_UK_EXPORT int iree_uk_##category##_##opcode##_2d(                      \
      const dtype* lhs, iree_uk_ssize_t lhs_offset,                           \
      iree_uk_ssize_t lhs_stride0, iree_uk_ssize_t lhs_stride1,               \
      const dtype* rhs, iree_uk_ssize_t rhs_offset,                           \
      iree_uk_ssize_t rhs_stride0, iree_uk_ssize_t rhs_stride1,               \
      dtype* IREE_UK_RESTRICT out, iree_uk_ssize_t out_offset,                \
      iree_uk_ssize_t out_stride0, iree_uk_ssize_t out_stride1,               \
      iree_uk_ssize_t size0, iree_uk_ssize_t size1) {                         \
    return iree_uk_##category##_2d(opcode_t, lhs, lhs_offset, lhs_stride0,    \
                                   lhs_stride1, rhs, rhs_offset, rhs_stride0, \
                                   rhs_stride1, out, out_offset, out_stride0, \
                                   out_stride1, size0, size1);                \
  }

// Defines a generic "dispatched" implementation via opcode_t by invoking
// the function iree_uk_generic_{category}_2d.
// Corresponds to the header macro DECLARE_UKERNEL_BINARY_2D.
#define DISPATCH_UKERNEL_UNARY_2D(opcode, opcode_t, dtype, category)          \
  IREE_UK_EXPORT int iree_uk_##category##_##opcode##_2d(                      \
      const dtype* in, iree_uk_ssize_t in_offset, iree_uk_ssize_t in_stride0, \
      iree_uk_ssize_t in_stride1, dtype* IREE_UK_RESTRICT out,                \
      iree_uk_ssize_t out_offset, iree_uk_ssize_t out_stride0,                \
      iree_uk_ssize_t out_stride1, iree_uk_ssize_t size0,                     \
      iree_uk_ssize_t size1) {                                                \
    return iree_uk_generic_##category##_2d(                                   \
        opcode_t, in, in_offset, in_stride0, in_stride1, out, out_offset,     \
        out_stride0, out_stride1, size0, size1);                              \
  }

//===----------------------------------------------------------------------===//
// Internal helpers.
//===----------------------------------------------------------------------===//

static iree_uk_x32b_opcode_type_t get_iree_uk_x32b_op_type(
    iree_uk_x32b_opcode_t opcode) {
  switch (opcode) {
    case IREE_UK_X32B_ADDI:
    case IREE_UK_X32B_ANDI:
    case IREE_UK_X32B_DIVUI:
    case IREE_UK_X32B_MULI:
    case IREE_UK_X32B_ORI:
    case IREE_UK_X32B_SHLI:
    case IREE_UK_X32B_SHRUI:
    case IREE_UKENREL_X32B_XORI:
    case IREE_UK_X32B_SUBI:
      return IREE_UK_X32B_UI;
    case IREE_UK_X32B_DIVSI:
      return IREE_UK_X32B_SI;
    default:
      return IREE_UK_X32B_NA;
  }
}

// Computes a single element of an x32b opcode usinbg RVV.
static void iree_uk_rvv_x32b_op(iree_uk_x32b_opcode_t opcode, int* result_code,
                                const iree_uk_uint32_t* lhs,
                                iree_uk_ssize_t lhs_stride,
                                const iree_uk_uint32_t* rhs,
                                iree_uk_ssize_t rhs_stride,
                                iree_uk_uint32_t* out,
                                iree_uk_ssize_t out_stride, size_t vl) {
  iree_uk_x32b_opcode_type_t op_type = get_iree_uk_x32b_op_type(opcode);
  if (op_type == IREE_UK_X32B_UI) {
    vuint32m8_t vx = vlse32_v_u32m8(lhs, lhs_stride, vl);  // load
    vuint32m8_t vy = vlse32_v_u32m8(rhs, rhs_stride, vl);  // load
    switch (opcode) {
      case IREE_UK_X32B_ADDI:
        vx = vadd(vx, vy, vl);
        break;
      case IREE_UK_X32B_ANDI:
        vx = vand(vx, vy, vl);
        break;
      case IREE_UK_X32B_DIVUI:
        vx = vdivu(vx, vy, vl);
        break;
      case IREE_UK_X32B_MULI:
        vx = vmul(vx, vy, vl);
        break;
      case IREE_UK_X32B_ORI:
        vx = vor(vx, vy, vl);
        break;
      case IREE_UK_X32B_SHLI:
        vx = vsll(vx, vy, vl);
        break;
      case IREE_UK_X32B_SHRUI:
        vx = vsrl(vx, vy, vl);
        break;
      case IREE_UKENREL_X32B_XORI:
        vx = vor(vx, vy, vl);
        break;
      case IREE_UK_X32B_SUBI:
        vx = vsub(vx, vy, vl);
        break;
      default:
        *result_code = 1;
    }
    vsse32(out, out_stride, vx, vl);  // save
  } else if (op_type == IREE_UK_X32B_SI) {
    vint32m8_t vx =
        vlse32_v_i32m8((iree_uk_int32_t*)lhs, lhs_stride, vl);  // load
    vint32m8_t vy =
        vlse32_v_i32m8((iree_uk_int32_t*)rhs, rhs_stride, vl);  // load
    switch (opcode) {
      case IREE_UK_X32B_DIVSI:
        vx = vdiv(vx, vy, vl);
        break;
      default:
        *result_code = 1;
    }
    vsse32((iree_uk_int32_t*)out, out_stride, vx, vl);  // save
  } else {
    *result_code = 1;
  }
}

// Computes a single element of an x32b opcode. On error, should set
// |*result_code| to a non-zero value (but should not touch it otherwise).
static void iree_uk_generic_x32b_op(iree_uk_x32b_opcode_t opcode,
                                    int* result_code,
                                    const iree_uk_uint32_t* lhs,
                                    const iree_uk_uint32_t* rhs,
                                    iree_uk_uint32_t* out) {
  switch (opcode) {
    case IREE_UK_X32B_ADDF:
      ASF32(out) = ASF32(lhs) + ASF32(rhs);
      return;
    case IREE_UK_X32B_ADDI:
      ASUI32(out) = ASUI32(lhs) + ASUI32(rhs);
      return;
    case IREE_UK_X32B_ANDI:
      ASUI32(out) = ASUI32(lhs) & ASUI32(rhs);
      return;
    case IREE_UK_X32B_DIVF:
      ASF32(out) = ASF32(lhs) / ASF32(rhs);
      return;
    case IREE_UK_X32B_DIVSI:
      ASSI32(out) = ASSI32(lhs) / ASSI32(rhs);
      return;
    case IREE_UK_X32B_DIVUI:
      ASUI32(out) = ASUI32(lhs) / ASUI32(rhs);
      return;
    case IREE_UK_X32B_MULF:
      ASF32(out) = ASF32(lhs) * ASF32(rhs);
      return;
    case IREE_UK_X32B_MULI:
      ASUI32(out) = ASUI32(lhs) * ASUI32(rhs);
      return;
    case IREE_UK_X32B_ORI:
      ASUI32(out) = ASUI32(lhs) | ASUI32(rhs);
      return;
    case IREE_UK_X32B_SHLI:
      ASUI32(out) = ASUI32(lhs) << ASUI32(rhs);
      return;
    case IREE_UK_X32B_SHRSI:
      ASSI32(out) = ASSI32(lhs) >> ASSI32(rhs);
      return;
    case IREE_UK_X32B_SHRUI:
      ASUI32(out) = ASUI32(lhs) >> ASUI32(rhs);
      return;
    case IREE_UKENREL_X32B_XORI:
      ASUI32(out) = ASUI32(lhs) ^ ASUI32(rhs);
      return;
    case IREE_UK_X32B_SUBF:
      ASF32(out) = ASF32(lhs) - ASF32(rhs);
      return;
    case IREE_UK_X32B_SUBI:
      ASSI32(out) = ASUI32(lhs) - ASUI32(rhs);
      return;
    default:
      *result_code = 1;
  }
}

// Computes a single element of an x32u opcode. Most are float ops. On error,
// should set |*result_code| to a non-zero value (but should not touch it
// otherwise).
static void iree_uk_generic_x32u_op(iree_uk_x32u_opcode_t opcode,
                                    int* result_code,
                                    const iree_uk_uint32_t* in,
                                    iree_uk_uint32_t* out) {
  switch (opcode) {
    case IREE_UK_X32U_ABSF:
      ASF32(out) = fabsf(ASF32(in));
      return;
    case IREE_UK_X32U_CEILF:
      ASF32(out) = ceilf(ASF32(in));
      return;
    case IREE_UK_X32U_CTLZ:
      ASUI32(out) = iree_uk_count_leading_zeros_u32(ASUI32(in));
      return;
    case IREE_UK_X32U_EXPF:
      ASF32(out) = expf(ASF32(in));
      return;
    case IREE_UK_X32U_FLOORF:
      ASF32(out) = floorf(ASF32(in));
      return;
    case IREE_UK_X32U_LOGF:
      ASF32(out) = logf(ASF32(in));
      return;
    case IREE_UK_X32U_NEGF:
      ASF32(out) = -ASF32(in);
      return;
    case IREE_UK_X32U_RSQRTF:
      ASF32(out) = 1.0f / sqrtf(ASF32(in));
      return;
    default:
      *result_code = 1;
  }
}

//===----------------------------------------------------------------------===//
// Opcode dispatch entry points.
//===----------------------------------------------------------------------===//

// 32bit binary kernels.
IREE_UK_ATTRIBUTE_NOINLINE static int iree_uk_x32b_2d(
    iree_uk_x32b_opcode_t opcode,
    // LHS.
    const iree_uk_uint32_t* lhs, iree_uk_ssize_t lhs_offset,
    iree_uk_ssize_t lhs_stride0, iree_uk_ssize_t lhs_stride1,
    // RHS
    const iree_uk_uint32_t* rhs, iree_uk_ssize_t rhs_offset,
    iree_uk_ssize_t rhs_stride0, iree_uk_ssize_t rhs_stride1,
    // OUT.
    iree_uk_uint32_t* IREE_UK_RESTRICT out, iree_uk_ssize_t out_offset,
    iree_uk_ssize_t out_stride0, iree_uk_ssize_t out_stride1,
    // Sizes.
    iree_uk_ssize_t size0, iree_uk_ssize_t size1) {
  int result_code = 0;

  if (get_iree_uk_x32b_op_type(opcode) != IREE_UK_X32B_NA) {
    size_t vl;
    // make most use of vectorization by swiching dimension
    if (size0 < size1) {
      for (iree_uk_ssize_t i = 0; i < size0; ++i) {
        for (iree_uk_ssize_t j = 0; j < size1; j += vl) {
          vl = vsetvl_e32m8(size1 - j);
          iree_uk_rvv_x32b_op(opcode, &result_code,
                              &lhs[i * lhs_stride0 + j * lhs_stride1],
                              lhs_stride1 * sizeof(uint32_t),
                              &rhs[i * rhs_stride0 + j * rhs_stride1],
                              rhs_stride1 * sizeof(uint32_t),
                              &out[i * out_stride0 + j * out_stride1],
                              out_stride1 * sizeof(uint32_t), vl);
        }
      }
    } else {
      for (iree_uk_ssize_t j = 0; j < size1; ++j) {
        for (iree_uk_ssize_t i = 0; i < size0; i += vl) {
          vl = vsetvl_e32m8(size0 - i);
          iree_uk_rvv_x32b_op(opcode, &result_code,
                              &lhs[i * lhs_stride0 + j * lhs_stride1],
                              lhs_stride0 * sizeof(uint32_t),
                              &rhs[i * rhs_stride0 + j * rhs_stride1],
                              rhs_stride0 * sizeof(uint32_t),
                              &out[i * out_stride0 + j * out_stride1],
                              out_stride0 * sizeof(uint32_t), vl);
        }
      }
    }
  } else {
    for (iree_uk_ssize_t i = 0; i < size0; ++i) {
      for (iree_uk_ssize_t j = 0; j < size1; ++j) {
        iree_uk_generic_x32b_op(opcode, &result_code,
                                &lhs[i * lhs_stride0 + j * lhs_stride1],
                                &rhs[i * rhs_stride0 + j * rhs_stride1],
                                &out[i * out_stride0 + j * out_stride1]);
      }
    }
  }
  return result_code;
}

// Generic 32bit unary kernels.
IREE_UK_ATTRIBUTE_NOINLINE static int iree_uk_generic_x32u_2d(
    iree_uk_x32u_opcode_t opcode,
    // IN.
    const iree_uk_uint32_t* in, iree_uk_ssize_t in_offset,
    iree_uk_ssize_t in_stride0, iree_uk_ssize_t in_stride1,
    // OUT.
    iree_uk_uint32_t* IREE_UK_RESTRICT out, iree_uk_ssize_t out_offset,
    iree_uk_ssize_t out_stride0, iree_uk_ssize_t out_stride1,
    // Sizes.
    iree_uk_ssize_t size0, iree_uk_ssize_t size1) {
  int result_code = 0;
  // TODO: Manually unroll to x4 to trigger vectorization.
  for (iree_uk_ssize_t i = 0; i < size0; ++i) {
    for (iree_uk_ssize_t j = 0; j < size1; ++j) {
      iree_uk_generic_x32u_op(opcode, &result_code,
                              &in[i * in_stride0 + j * in_stride1],
                              &out[i * out_stride0 + j * out_stride1]);
    }
  }
  return result_code;
}

DISPATCH_UKERNEL_BINARY_2D(addf, IREE_UK_X32B_ADDF, iree_uk_uint32_t, x32b);
DISPATCH_UKERNEL_BINARY_2D(addi, IREE_UK_X32B_ADDI, iree_uk_uint32_t, x32b);
DISPATCH_UKERNEL_BINARY_2D(andi, IREE_UK_X32B_ANDI, iree_uk_uint32_t, x32b);
DISPATCH_UKERNEL_BINARY_2D(divf, IREE_UK_X32B_DIVF, iree_uk_uint32_t, x32b);
DISPATCH_UKERNEL_BINARY_2D(divsi, IREE_UK_X32B_DIVSI, iree_uk_uint32_t, x32b);
DISPATCH_UKERNEL_BINARY_2D(divui, IREE_UK_X32B_DIVUI, iree_uk_uint32_t, x32b);
DISPATCH_UKERNEL_BINARY_2D(mulf, IREE_UK_X32B_MULF, iree_uk_uint32_t, x32b);
DISPATCH_UKERNEL_BINARY_2D(muli, IREE_UK_X32B_MULI, iree_uk_uint32_t, x32b);
DISPATCH_UKERNEL_BINARY_2D(ori, IREE_UK_X32B_ORI, iree_uk_uint32_t, x32b);
DISPATCH_UKERNEL_BINARY_2D(shli, IREE_UK_X32B_SHLI, iree_uk_uint32_t, x32b);
DISPATCH_UKERNEL_BINARY_2D(shrsi, IREE_UK_X32B_SHRSI, iree_uk_uint32_t, x32b);
DISPATCH_UKERNEL_BINARY_2D(shrui, IREE_UK_X32B_SHRUI, iree_uk_uint32_t, x32b);
DISPATCH_UKERNEL_BINARY_2D(subf, IREE_UK_X32B_SUBF, iree_uk_uint32_t, x32b);
DISPATCH_UKERNEL_BINARY_2D(subi, IREE_UK_X32B_SUBI, iree_uk_uint32_t, x32b);
DISPATCH_UKERNEL_BINARY_2D(xori, IREE_UKENREL_X32B_XORI, iree_uk_uint32_t,
                           x32b);

DISPATCH_UKERNEL_UNARY_2D(absf, IREE_UK_X32U_ABSF, iree_uk_uint32_t, x32u);
DISPATCH_UKERNEL_UNARY_2D(ceilf, IREE_UK_X32U_CEILF, iree_uk_uint32_t, x32u);
DISPATCH_UKERNEL_UNARY_2D(ctlz, IREE_UK_X32U_CTLZ, iree_uk_uint32_t, x32u);
DISPATCH_UKERNEL_UNARY_2D(expf, IREE_UK_X32U_EXPF, iree_uk_uint32_t, x32u);
DISPATCH_UKERNEL_UNARY_2D(floorf, IREE_UK_X32U_FLOORF, iree_uk_uint32_t, x32u);
DISPATCH_UKERNEL_UNARY_2D(logf, IREE_UK_X32U_LOGF, iree_uk_uint32_t, x32u);
DISPATCH_UKERNEL_UNARY_2D(negf, IREE_UK_X32U_NEGF, iree_uk_uint32_t, x32u);
DISPATCH_UKERNEL_UNARY_2D(rsqrtf, IREE_UK_X32U_RSQRTF, iree_uk_uint32_t, x32u);
