| /* |
| * Copyright 2023 Google LLC |
| * |
| * Licensed under the Apache License, Version 2.0 (the "License"); |
| * you may not use this file except in compliance with the License. |
| * You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| #include "iree/modules/vmvx/elementwise.h" |
| |
| #include <math.h> |
| #include <riscv_vector.h> |
| |
| //===----------------------------------------------------------------------===// |
| // Helpers for defining generic implementations of elementwise functions. |
| // Since it affords the best code size tradeoff options, the entrypoint |
| // is dispatched based on an opcode. |
| //===----------------------------------------------------------------------===// |
| |
| // Opcodes for generic functions operating on 32-bit operands and result. |
| // Since the outer dispatcher only differentiates based on width, all other |
| // type specificity is carried by the opcode. |
| // Binary opcodes are named "X32B" and unary opcodes "X32U". |
| // The initial list was sorted, and it is encouraged to sort extensions, but |
| // each opcode must be numerically stable, so the list is not expected to |
| // be sorted over time. |
| typedef enum { |
| IREE_UK_X32B_ADDF = 0, |
| IREE_UK_X32B_ADDI = 1, |
| IREE_UK_X32B_ANDI = 2, |
| IREE_UK_X32B_DIVF = 3, |
| IREE_UK_X32B_DIVSI = 4, |
| IREE_UK_X32B_DIVUI = 5, |
| IREE_UK_X32B_MULF = 6, |
| IREE_UK_X32B_MULI = 7, |
| IREE_UK_X32B_ORI = 8, |
| IREE_UK_X32B_SHLI = 9, |
| IREE_UK_X32B_SHRSI = 10, |
| IREE_UK_X32B_SHRUI = 11, |
| IREE_UK_X32B_SUBF = 12, |
| IREE_UK_X32B_SUBI = 13, |
| IREE_UKENREL_X32B_XORI = 14, |
| } iree_uk_x32b_opcode_t; |
| |
| typedef enum { |
| IREE_UK_X32B_UI = 0, // unsigned integer |
| IREE_UK_X32B_SI = 1, // signed integer |
| IREE_UK_X32B_NA = 2, // not available in RVV |
| } iree_uk_x32b_opcode_type_t; |
| |
| typedef enum { |
| IREE_UK_X32U_ABSF, |
| IREE_UK_X32U_CEILF, |
| IREE_UK_X32U_CTLZ, |
| IREE_UK_X32U_EXPF, |
| IREE_UK_X32U_FLOORF, |
| IREE_UK_X32U_LOGF, |
| IREE_UK_X32U_NEGF, |
| IREE_UK_X32U_RSQRTF, |
| } iree_uk_x32u_opcode_t; |
| |
| // Macros to access various typed, dereferenced pointers. |
| #define ASF32(ptr) *((float*)ptr) |
| #define ASUI32(ptr) *((iree_uk_uint32_t*)ptr) |
| #define ASSI32(ptr) *((iree_uk_int32_t*)ptr) |
| |
| //===----------------------------------------------------------------------===// |
| // Implementation macros. |
| //===----------------------------------------------------------------------===// |
| |
| // Defines a generic "dispatched" implementation via opcode_t by invoking |
| // the function iree_uk_generic_{category}_2d. |
| // Corresponds to the header macro DECLARE_UKERNEL_BINARY_2D. |
| #define DISPATCH_UKERNEL_BINARY_2D(opcode, opcode_t, dtype, category) \ |
| IREE_UK_EXPORT int iree_uk_##category##_##opcode##_2d( \ |
| const dtype* lhs, iree_uk_index_t lhs_offset, \ |
| iree_uk_index_t lhs_stride0, iree_uk_index_t lhs_stride1, \ |
| const dtype* rhs, iree_uk_index_t rhs_offset, \ |
| iree_uk_index_t rhs_stride0, iree_uk_index_t rhs_stride1, \ |
| dtype* IREE_UK_RESTRICT out, iree_uk_index_t out_offset, \ |
| iree_uk_index_t out_stride0, iree_uk_index_t out_stride1, \ |
| iree_uk_index_t size0, iree_uk_index_t size1) { \ |
| return iree_uk_##category##_2d(opcode_t, lhs, lhs_offset, lhs_stride0, \ |
| lhs_stride1, rhs, rhs_offset, rhs_stride0, \ |
| rhs_stride1, out, out_offset, out_stride0, \ |
| out_stride1, size0, size1); \ |
| } |
| |
| // Defines a generic "dispatched" implementation via opcode_t by invoking |
| // the function iree_uk_generic_{category}_2d. |
| // Corresponds to the header macro DECLARE_UKERNEL_BINARY_2D. |
| #define DISPATCH_UKERNEL_UNARY_2D(opcode, opcode_t, dtype, category) \ |
| IREE_UK_EXPORT int iree_uk_##category##_##opcode##_2d( \ |
| const dtype* in, iree_uk_index_t in_offset, iree_uk_index_t in_stride0, \ |
| iree_uk_index_t in_stride1, dtype* IREE_UK_RESTRICT out, \ |
| iree_uk_index_t out_offset, iree_uk_index_t out_stride0, \ |
| iree_uk_index_t out_stride1, iree_uk_index_t size0, \ |
| iree_uk_index_t size1) { \ |
| return iree_uk_generic_##category##_2d( \ |
| opcode_t, in, in_offset, in_stride0, in_stride1, out, out_offset, \ |
| out_stride0, out_stride1, size0, size1); \ |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Internal helpers. |
| //===----------------------------------------------------------------------===// |
| |
| static iree_uk_x32b_opcode_type_t get_iree_uk_x32b_op_type( |
| iree_uk_x32b_opcode_t opcode) { |
| switch (opcode) { |
| case IREE_UK_X32B_ADDI: |
| case IREE_UK_X32B_ANDI: |
| case IREE_UK_X32B_DIVUI: |
| case IREE_UK_X32B_MULI: |
| case IREE_UK_X32B_ORI: |
| case IREE_UK_X32B_SHLI: |
| case IREE_UK_X32B_SHRUI: |
| case IREE_UKENREL_X32B_XORI: |
| case IREE_UK_X32B_SUBI: |
| return IREE_UK_X32B_UI; |
| case IREE_UK_X32B_DIVSI: |
| return IREE_UK_X32B_SI; |
| default: |
| return IREE_UK_X32B_NA; |
| } |
| } |
| |
| // Computes a single element of an x32b opcode usinbg RVV. |
| static void iree_uk_rvv_x32b_op(iree_uk_x32b_opcode_t opcode, int* result_code, |
| const iree_uk_uint32_t* lhs, |
| iree_uk_index_t lhs_stride, |
| const iree_uk_uint32_t* rhs, |
| iree_uk_index_t rhs_stride, |
| iree_uk_uint32_t* out, |
| iree_uk_index_t out_stride, size_t vl) { |
| iree_uk_x32b_opcode_type_t op_type = get_iree_uk_x32b_op_type(opcode); |
| if (op_type == IREE_UK_X32B_UI) { |
| vuint32m8_t vx = vlse32_v_u32m8(lhs, lhs_stride, vl); // load |
| vuint32m8_t vy = vlse32_v_u32m8(rhs, rhs_stride, vl); // load |
| switch (opcode) { |
| case IREE_UK_X32B_ADDI: |
| vx = vadd(vx, vy, vl); |
| break; |
| case IREE_UK_X32B_ANDI: |
| vx = vand(vx, vy, vl); |
| break; |
| case IREE_UK_X32B_DIVUI: |
| vx = vdivu(vx, vy, vl); |
| break; |
| case IREE_UK_X32B_MULI: |
| vx = vmul(vx, vy, vl); |
| break; |
| case IREE_UK_X32B_ORI: |
| vx = vor(vx, vy, vl); |
| break; |
| case IREE_UK_X32B_SHLI: |
| vx = vsll(vx, vy, vl); |
| break; |
| case IREE_UK_X32B_SHRUI: |
| vx = vsrl(vx, vy, vl); |
| break; |
| case IREE_UKENREL_X32B_XORI: |
| vx = vor(vx, vy, vl); |
| break; |
| case IREE_UK_X32B_SUBI: |
| vx = vsub(vx, vy, vl); |
| break; |
| default: |
| *result_code = 1; |
| } |
| vsse32(out, out_stride, vx, vl); // save |
| } else if (op_type == IREE_UK_X32B_SI) { |
| vint32m8_t vx = |
| vlse32_v_i32m8((iree_uk_int32_t*)lhs, lhs_stride, vl); // load |
| vint32m8_t vy = |
| vlse32_v_i32m8((iree_uk_int32_t*)rhs, rhs_stride, vl); // load |
| switch (opcode) { |
| case IREE_UK_X32B_DIVSI: |
| vx = vdiv(vx, vy, vl); |
| break; |
| default: |
| *result_code = 1; |
| } |
| vsse32((iree_uk_int32_t*)out, out_stride, vx, vl); // save |
| } else { |
| *result_code = 1; |
| } |
| } |
| |
| // Computes a single element of an x32b opcode. On error, should set |
| // |*result_code| to a non-zero value (but should not touch it otherwise). |
| static void iree_uk_generic_x32b_op(iree_uk_x32b_opcode_t opcode, |
| int* result_code, |
| const iree_uk_uint32_t* lhs, |
| const iree_uk_uint32_t* rhs, |
| iree_uk_uint32_t* out) { |
| switch (opcode) { |
| case IREE_UK_X32B_ADDF: |
| ASF32(out) = ASF32(lhs) + ASF32(rhs); |
| return; |
| case IREE_UK_X32B_ADDI: |
| ASUI32(out) = ASUI32(lhs) + ASUI32(rhs); |
| return; |
| case IREE_UK_X32B_ANDI: |
| ASUI32(out) = ASUI32(lhs) & ASUI32(rhs); |
| return; |
| case IREE_UK_X32B_DIVF: |
| ASF32(out) = ASF32(lhs) / ASF32(rhs); |
| return; |
| case IREE_UK_X32B_DIVSI: |
| ASSI32(out) = ASSI32(lhs) / ASSI32(rhs); |
| return; |
| case IREE_UK_X32B_DIVUI: |
| ASUI32(out) = ASUI32(lhs) / ASUI32(rhs); |
| return; |
| case IREE_UK_X32B_MULF: |
| ASF32(out) = ASF32(lhs) * ASF32(rhs); |
| return; |
| case IREE_UK_X32B_MULI: |
| ASUI32(out) = ASUI32(lhs) * ASUI32(rhs); |
| return; |
| case IREE_UK_X32B_ORI: |
| ASUI32(out) = ASUI32(lhs) | ASUI32(rhs); |
| return; |
| case IREE_UK_X32B_SHLI: |
| ASUI32(out) = ASUI32(lhs) << ASUI32(rhs); |
| return; |
| case IREE_UK_X32B_SHRSI: |
| ASSI32(out) = ASSI32(lhs) >> ASSI32(rhs); |
| return; |
| case IREE_UK_X32B_SHRUI: |
| ASUI32(out) = ASUI32(lhs) >> ASUI32(rhs); |
| return; |
| case IREE_UKENREL_X32B_XORI: |
| ASUI32(out) = ASUI32(lhs) ^ ASUI32(rhs); |
| return; |
| case IREE_UK_X32B_SUBF: |
| ASF32(out) = ASF32(lhs) - ASF32(rhs); |
| return; |
| case IREE_UK_X32B_SUBI: |
| ASSI32(out) = ASUI32(lhs) - ASUI32(rhs); |
| return; |
| default: |
| *result_code = 1; |
| } |
| } |
| |
| // Computes a single element of an x32u opcode. Most are float ops. On error, |
| // should set |*result_code| to a non-zero value (but should not touch it |
| // otherwise). |
| static void iree_uk_generic_x32u_op(iree_uk_x32u_opcode_t opcode, |
| int* result_code, |
| const iree_uk_uint32_t* in, |
| iree_uk_uint32_t* out) { |
| switch (opcode) { |
| case IREE_UK_X32U_ABSF: |
| ASF32(out) = fabsf(ASF32(in)); |
| return; |
| case IREE_UK_X32U_CEILF: |
| ASF32(out) = ceilf(ASF32(in)); |
| return; |
| case IREE_UK_X32U_CTLZ: |
| ASUI32(out) = iree_uk_count_leading_zeros_u32(ASUI32(in)); |
| return; |
| case IREE_UK_X32U_EXPF: |
| ASF32(out) = expf(ASF32(in)); |
| return; |
| case IREE_UK_X32U_FLOORF: |
| ASF32(out) = floorf(ASF32(in)); |
| return; |
| case IREE_UK_X32U_LOGF: |
| ASF32(out) = logf(ASF32(in)); |
| return; |
| case IREE_UK_X32U_NEGF: |
| ASF32(out) = -ASF32(in); |
| return; |
| case IREE_UK_X32U_RSQRTF: |
| ASF32(out) = 1.0f / sqrtf(ASF32(in)); |
| return; |
| default: |
| *result_code = 1; |
| } |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Opcode dispatch entry points. |
| //===----------------------------------------------------------------------===// |
| |
| // 32bit binary kernels. |
| IREE_UK_ATTRIBUTE_NOINLINE static int iree_uk_x32b_2d( |
| iree_uk_x32b_opcode_t opcode, |
| // LHS. |
| const iree_uk_uint32_t* lhs, iree_uk_index_t lhs_offset, |
| iree_uk_index_t lhs_stride0, iree_uk_index_t lhs_stride1, |
| // RHS |
| const iree_uk_uint32_t* rhs, iree_uk_index_t rhs_offset, |
| iree_uk_index_t rhs_stride0, iree_uk_index_t rhs_stride1, |
| // OUT. |
| iree_uk_uint32_t* IREE_UK_RESTRICT out, iree_uk_index_t out_offset, |
| iree_uk_index_t out_stride0, iree_uk_index_t out_stride1, |
| // Sizes. |
| iree_uk_index_t size0, iree_uk_index_t size1) { |
| int result_code = 0; |
| |
| if (get_iree_uk_x32b_op_type(opcode) != IREE_UK_X32B_NA) { |
| size_t vl; |
| // make most use of vectorization by swiching dimension |
| if (size0 < size1) { |
| for (iree_uk_index_t i = 0; i < size0; ++i) { |
| for (iree_uk_index_t j = 0; j < size1; j += vl) { |
| vl = vsetvl_e32m8(size1 - j); |
| iree_uk_rvv_x32b_op(opcode, &result_code, |
| &lhs[i * lhs_stride0 + j * lhs_stride1], |
| lhs_stride1 * sizeof(uint32_t), |
| &rhs[i * rhs_stride0 + j * rhs_stride1], |
| rhs_stride1 * sizeof(uint32_t), |
| &out[i * out_stride0 + j * out_stride1], |
| out_stride1 * sizeof(uint32_t), vl); |
| } |
| } |
| } else { |
| for (iree_uk_index_t j = 0; j < size1; ++j) { |
| for (iree_uk_index_t i = 0; i < size0; i += vl) { |
| vl = vsetvl_e32m8(size0 - i); |
| iree_uk_rvv_x32b_op(opcode, &result_code, |
| &lhs[i * lhs_stride0 + j * lhs_stride1], |
| lhs_stride0 * sizeof(uint32_t), |
| &rhs[i * rhs_stride0 + j * rhs_stride1], |
| rhs_stride0 * sizeof(uint32_t), |
| &out[i * out_stride0 + j * out_stride1], |
| out_stride0 * sizeof(uint32_t), vl); |
| } |
| } |
| } |
| } else { |
| for (iree_uk_index_t i = 0; i < size0; ++i) { |
| for (iree_uk_index_t j = 0; j < size1; ++j) { |
| iree_uk_generic_x32b_op(opcode, &result_code, |
| &lhs[i * lhs_stride0 + j * lhs_stride1], |
| &rhs[i * rhs_stride0 + j * rhs_stride1], |
| &out[i * out_stride0 + j * out_stride1]); |
| } |
| } |
| } |
| return result_code; |
| } |
| |
| // Generic 32bit unary kernels. |
| IREE_UK_ATTRIBUTE_NOINLINE static int iree_uk_generic_x32u_2d( |
| iree_uk_x32u_opcode_t opcode, |
| // IN. |
| const iree_uk_uint32_t* in, iree_uk_index_t in_offset, |
| iree_uk_index_t in_stride0, iree_uk_index_t in_stride1, |
| // OUT. |
| iree_uk_uint32_t* IREE_UK_RESTRICT out, iree_uk_index_t out_offset, |
| iree_uk_index_t out_stride0, iree_uk_index_t out_stride1, |
| // Sizes. |
| iree_uk_index_t size0, iree_uk_index_t size1) { |
| int result_code = 0; |
| // TODO: Manually unroll to x4 to trigger vectorization. |
| for (iree_uk_index_t i = 0; i < size0; ++i) { |
| for (iree_uk_index_t j = 0; j < size1; ++j) { |
| iree_uk_generic_x32u_op(opcode, &result_code, |
| &in[i * in_stride0 + j * in_stride1], |
| &out[i * out_stride0 + j * out_stride1]); |
| } |
| } |
| return result_code; |
| } |
| |
| DISPATCH_UKERNEL_BINARY_2D(addf, IREE_UK_X32B_ADDF, iree_uk_uint32_t, x32b); |
| DISPATCH_UKERNEL_BINARY_2D(addi, IREE_UK_X32B_ADDI, iree_uk_uint32_t, x32b); |
| DISPATCH_UKERNEL_BINARY_2D(andi, IREE_UK_X32B_ANDI, iree_uk_uint32_t, x32b); |
| DISPATCH_UKERNEL_BINARY_2D(divf, IREE_UK_X32B_DIVF, iree_uk_uint32_t, x32b); |
| DISPATCH_UKERNEL_BINARY_2D(divsi, IREE_UK_X32B_DIVSI, iree_uk_uint32_t, x32b); |
| DISPATCH_UKERNEL_BINARY_2D(divui, IREE_UK_X32B_DIVUI, iree_uk_uint32_t, x32b); |
| DISPATCH_UKERNEL_BINARY_2D(mulf, IREE_UK_X32B_MULF, iree_uk_uint32_t, x32b); |
| DISPATCH_UKERNEL_BINARY_2D(muli, IREE_UK_X32B_MULI, iree_uk_uint32_t, x32b); |
| DISPATCH_UKERNEL_BINARY_2D(ori, IREE_UK_X32B_ORI, iree_uk_uint32_t, x32b); |
| DISPATCH_UKERNEL_BINARY_2D(shli, IREE_UK_X32B_SHLI, iree_uk_uint32_t, x32b); |
| DISPATCH_UKERNEL_BINARY_2D(shrsi, IREE_UK_X32B_SHRSI, iree_uk_uint32_t, x32b); |
| DISPATCH_UKERNEL_BINARY_2D(shrui, IREE_UK_X32B_SHRUI, iree_uk_uint32_t, x32b); |
| DISPATCH_UKERNEL_BINARY_2D(subf, IREE_UK_X32B_SUBF, iree_uk_uint32_t, x32b); |
| DISPATCH_UKERNEL_BINARY_2D(subi, IREE_UK_X32B_SUBI, iree_uk_uint32_t, x32b); |
| DISPATCH_UKERNEL_BINARY_2D(xori, IREE_UKENREL_X32B_XORI, iree_uk_uint32_t, |
| x32b); |
| |
| DISPATCH_UKERNEL_UNARY_2D(absf, IREE_UK_X32U_ABSF, iree_uk_uint32_t, x32u); |
| DISPATCH_UKERNEL_UNARY_2D(ceilf, IREE_UK_X32U_CEILF, iree_uk_uint32_t, x32u); |
| DISPATCH_UKERNEL_UNARY_2D(ctlz, IREE_UK_X32U_CTLZ, iree_uk_uint32_t, x32u); |
| DISPATCH_UKERNEL_UNARY_2D(expf, IREE_UK_X32U_EXPF, iree_uk_uint32_t, x32u); |
| DISPATCH_UKERNEL_UNARY_2D(floorf, IREE_UK_X32U_FLOORF, iree_uk_uint32_t, x32u); |
| DISPATCH_UKERNEL_UNARY_2D(logf, IREE_UK_X32U_LOGF, iree_uk_uint32_t, x32u); |
| DISPATCH_UKERNEL_UNARY_2D(negf, IREE_UK_X32U_NEGF, iree_uk_uint32_t, x32u); |
| DISPATCH_UKERNEL_UNARY_2D(rsqrtf, IREE_UK_X32U_RSQRTF, iree_uk_uint32_t, x32u); |