| /* |
| * Copyright 2023 Google LLC |
| * |
| * Licensed under the Apache License, Version 2.0 (the "License"); |
| * you may not use this file except in compliance with the License. |
| * You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| #include "iree/builtins/ukernel/mmt4d_internal.h" |
| |
| #ifdef BUILD_KELVIN |
| |
| #include <kelvin.h> |
| |
| static_assert(sizeof(struct vconv_u8_t) == 4); |
| union { |
| struct vconv_u8_t conv; |
| uint32_t raw; |
| } cmds; |
| |
| static iree_uk_int32_t buffer[8] __attribute__((aligned)); |
| |
| // dot product s8xx8->s32 |
| static iree_uk_int32_t dot_product_s8s8(const iree_uk_int8_t* u, |
| const iree_uk_int8_t* w, int n) { |
| int vl; |
| getmaxvl_w(vl); |
| vdup_w_x_m(v0, 0); |
| |
| while (n) { |
| int count = n < vl ? n : vl; |
| |
| vld_b_lp_xx(v8, u, count); |
| vaddw_h_vx(v8, v8, 0); |
| vzip_h_vv(v10, v8, v9); |
| |
| vld_b_lp_xx(v16, w, count); |
| vaddw_h_vx(v16, v16, 0); |
| vzip_h_vv(v18, v16, v17); |
| |
| vmulw_w_vv(v4, v10, v18); |
| vzip_w_vv(v1, v4, v5); |
| |
| vadd_w_vv(v0, v0, v1); |
| n -= count; |
| } |
| |
| vst_w_l_xx(v0, buffer, vl); |
| |
| iree_uk_int32_t sum = 0; |
| for (int i = 0; i < vl; ++i) sum += buffer[i]; |
| return sum; |
| } |
| |
| // dot product s16xs8->s32 |
| static iree_uk_int32_t dot_product_s16s8(const iree_uk_int16_t* u, |
| const iree_uk_int8_t* w, int n) { |
| int vl; |
| getmaxvl_w(vl); |
| vdup_w_x_m(v0, 0); |
| |
| while (n) { |
| int count = n < vl ? n : vl; |
| |
| vld_h_lp_xx(v8, u, count); |
| |
| vld_b_lp_xx(v16, w, count); |
| vaddw_h_vx(v16, v16, 0); |
| vzip_h_vv(v18, v16, v17); |
| |
| vmulw_w_vv(v4, v8, v18); |
| vzip_w_vv(v1, v4, v5); |
| |
| vadd_w_vv(v0, v0, v1); |
| n -= count; |
| } |
| |
| vst_w_l_xx(v0, buffer, vl); |
| |
| iree_uk_int32_t sum = 0; |
| for (int i = 0; i < vl; ++i) sum += buffer[i]; |
| return sum; |
| } |
| |
| // Matrix multiplication s8xs8->s32. |
| // lhs (row-major): 8/4/2/1x4, rhs (column-major): 4x8 |
| // out (row-major): 8/4/2/1x8 |
| static void matmul_s8s8(const iree_uk_int8_t* lhs, const iree_uk_int8_t* rhs, |
| iree_uk_int32_t* out, |
| const iree_uk_mmt4d_params_t* params) { |
| iree_uk_int16_t M0 = params->M0; |
| iree_uk_int16_t K = params->K; |
| IREE_UK_ASSERT(params->N0 == 8 && params->K0 == 4); |
| IREE_UK_ASSERT(M0 == 8 || M0 == 4 || M0 == 2 || M0 == 1); |
| IREE_UK_ASSERT(K == 1 || K == 2 || K == 4 || (K % 8 == 0)); |
| |
| cmds.conv.mode = 0; |
| cmds.conv.start = 0; |
| cmds.conv.stop = K > 8 ? 7 : K - 1; |
| cmds.conv.sbias1 = 0; |
| cmds.conv.sdata1 = true; |
| cmds.conv.sbias2 = 0; |
| cmds.conv.sdata2 = true; |
| |
| const iree_uk_int8_t* p_in = lhs; |
| const iree_uk_int8_t* p_flt = rhs; |
| iree_uk_int32_t* p_out = out; |
| |
| // load initial output tile values |
| if (params->flags & IREE_UK_FLAG_MMT4D_ACCUMULATE) { |
| if (M0 <= 2) { |
| vld_w_p_x(v56, p_out); |
| if (M0 == 2) { |
| vld_w_p_x(v57, p_out); |
| } |
| } else { |
| vld_w_p_x_m(v56, p_out); |
| } |
| if (M0 == 8) { |
| vld_w_p_x_m(v60, p_out); |
| } |
| p_out = out; |
| } else { |
| vdup_w_x_m(v56, 0); |
| vdup_w_x_m(v60, 0); |
| } |
| |
| iree_uk_int16_t k = K; |
| while (k > 0) { |
| k -= 8; |
| // load LHS and RHS |
| if (K <= 2) { |
| vld_b_lp_xx(v0, p_in, 4 * M0); |
| vld_b_p_x(v8, p_flt); |
| if (K == 2) { |
| vld_b_lp_xx(v1, p_in, 4 * M0); |
| vld_b_p_x(v9, p_flt); |
| } |
| } else { |
| vld_b_sp_xx_m(v0, p_in, 4 * M0); |
| vld_b_p_x_m(v8, p_flt); |
| } |
| if (K >= 8) { |
| vld_b_sp_xx_m(v4, p_in, 4 * M0); |
| vld_b_p_x_m(v12, p_flt); |
| } |
| |
| // re-arrange LHS |
| vzip_w_vv_m(v20, v0, v4); |
| vzip_w_vv_m(v28, v20, v24); |
| vzip_w_vv_m(v0, v28, v32); |
| |
| // matrix multiplication |
| aconv_vxv(v48, v0, cmds, v8); |
| } |
| |
| vcget(v48); |
| |
| // de-interleaving |
| vzip_w_vv(v20, v48, v49); |
| vzip_w_vv(v22, v50, v51); |
| vzip_w_vv(v24, v20, v22); |
| vzip_w_vv(v26, v21, v23); |
| |
| if (M0 <= 2) { |
| vadd_w_vv(v56, v56, v24); |
| vst_w_p_x(v56, p_out); |
| if (M0 == 2) { |
| vadd_w_vv(v57, v57, v25); |
| vst_w_p_x(v57, p_out); |
| } |
| } else { |
| vadd_w_vv_m(v56, v56, v24); |
| vst_w_p_x_m(v56, p_out); |
| } |
| |
| if (M0 == 8) { |
| // de-interleaving |
| vzip_w_vv(v28, v52, v53); |
| vzip_w_vv(v30, v54, v55); |
| vzip_w_vv(v32, v28, v30); |
| vzip_w_vv(v34, v29, v31); |
| |
| vadd_w_vv_m(v60, v60, v32); |
| vst_w_p_x_m(v60, p_out); |
| } |
| } |
| |
| // Matrix multiplication s16xs8->s32. |
| // lhs (row-major): 8/4/2/1x4, rhs (column-major): 4x8 |
| // out (row-major): 8/4/2/1x8 |
| // s16 * s8 = (s8_hi * s8) << 8 + (u8_lo * s8) |
| // where s8_hi = s16[15:8]; u8_lo = s16[7:0] |
| static void matmul_s16s8(const iree_uk_int16_t* lhs, const iree_uk_int8_t* rhs, |
| iree_uk_int32_t* out, |
| const iree_uk_mmt4d_params_t* params) { |
| iree_uk_int16_t M0 = params->M0; |
| iree_uk_int16_t K = params->K; |
| IREE_UK_ASSERT(params->N0 == 8 && params->K0 == 4); |
| IREE_UK_ASSERT(M0 == 8 || M0 == 4 || M0 == 2 || M0 == 1); |
| IREE_UK_ASSERT(K == 1 || K == 2 || K == 4 || (K % 8 == 0)); |
| |
| cmds.conv.mode = 0; |
| cmds.conv.start = 0; |
| cmds.conv.stop = K > 8 ? 7 : K - 1; |
| cmds.conv.sbias1 = 0; |
| cmds.conv.sdata1 = true; |
| cmds.conv.sbias2 = 0; |
| cmds.conv.sdata2 = true; |
| |
| const iree_uk_int16_t* p_in = lhs; |
| const iree_uk_int8_t* p_flt = rhs; |
| iree_uk_int32_t* p_out = out; |
| |
| // load initial output tile values |
| if (params->flags & IREE_UK_FLAG_MMT4D_ACCUMULATE) { |
| if (M0 <= 2) { |
| vld_w_p_x(v56, p_out); |
| if (M0 == 2) { |
| vld_w_p_x(v57, p_out); |
| } |
| } else { |
| vld_w_p_x_m(v56, p_out); |
| } |
| if (M0 == 8) { |
| vld_w_p_x_m(v60, p_out); |
| } |
| } else { |
| vdup_w_x_m(v56, 0); |
| vdup_w_x_m(v60, 0); |
| } |
| |
| // first pass - high |
| iree_uk_int16_t k = K; |
| while (k > 0) { |
| k -= 8; |
| // load LHS high bits and RHS |
| if (K == 1) { |
| vld_h_lp_xx(v12, p_in, 2 * M0); |
| vld_h_lp_xx(v13, p_in, 2 * M0); |
| vodd_b_vv(v0, v12, v13); |
| vld_b_p_x(v8, p_flt); |
| } else if (K == 2) { |
| vld_h_sp_xx_m(v12, p_in, 2 * M0); |
| vodd_b_vv(v0, v12, v13); |
| vodd_b_vv(v1, v14, v15); |
| vld_b_p_x(v8, p_flt); |
| vld_b_p_x(v9, p_flt); |
| } |
| if (K >= 4) { |
| vld_h_sp_xx_m(v12, p_in, 2 * M0); |
| vld_h_sp_xx_m(v16, p_in, 2 * M0); |
| vodd_b_vv_m(v0, v12, v16); |
| vld_b_p_x_m(v8, p_flt); |
| } |
| if (K >= 8) { |
| vld_h_sp_xx_m(v20, p_in, 2 * M0); |
| vld_h_sp_xx_m(v24, p_in, 2 * M0); |
| vodd_b_vv_m(v4, v20, v24); |
| vld_b_p_x_m(v12, p_flt); |
| } |
| |
| // re-arrange LHS |
| vzip_w_vv_m(v20, v0, v4); |
| vzip_w_vv_m(v28, v20, v24); |
| vzip_w_vv_m(v0, v28, v32); |
| |
| // matrix multiplication |
| aconv_vxv(v48, v0, cmds, v8); |
| } |
| |
| vcget(v48); |
| |
| // left shift 8 bits and store to accumulator |
| vsll_w_vx_m(v48, v48, 8); |
| vsll_w_vx_m(v52, v52, 8); |
| acset_v(v48, v48); |
| |
| // second pass - low |
| cmds.conv.sdata1 = false; |
| p_in = lhs; |
| p_flt = rhs; |
| p_out = out; |
| k = K; |
| while (k > 0) { |
| k -= 8; |
| // load LHS low bits and RHS |
| if (K == 1) { |
| vld_h_lp_xx(v12, p_in, 2 * M0); |
| vld_h_lp_xx(v13, p_in, 2 * M0); |
| vevn_b_vv(v0, v12, v13); |
| vld_b_p_x(v8, p_flt); |
| } else if (K == 2) { |
| vld_h_sp_xx_m(v12, p_in, 2 * M0); |
| vevn_b_vv(v0, v12, v13); |
| vevn_b_vv(v1, v14, v15); |
| vld_b_p_x(v8, p_flt); |
| vld_b_p_x(v9, p_flt); |
| } |
| if (K >= 4) { |
| vld_h_sp_xx_m(v12, p_in, 2 * M0); |
| vld_h_sp_xx_m(v16, p_in, 2 * M0); |
| vevn_b_vv_m(v0, v12, v16); |
| vld_b_p_x_m(v8, p_flt); |
| } |
| if (K >= 8) { |
| vld_h_sp_xx_m(v20, p_in, 2 * M0); |
| vld_h_sp_xx_m(v24, p_in, 2 * M0); |
| vevn_b_vv_m(v4, v20, v24); |
| vld_b_p_x_m(v12, p_flt); |
| } |
| |
| // re-arrange LHS |
| vzip_w_vv_m(v20, v0, v4); |
| vzip_w_vv_m(v28, v20, v24); |
| vzip_w_vv_m(v0, v28, v32); |
| |
| // matrix multiplication |
| aconv_vxv(v48, v0, cmds, v8); |
| } |
| |
| vcget(v48); |
| |
| // de-interleaving |
| vzip_w_vv(v20, v48, v49); |
| vzip_w_vv(v22, v50, v51); |
| vzip_w_vv(v24, v20, v22); |
| vzip_w_vv(v26, v21, v23); |
| |
| if (M0 <= 2) { |
| vadd_w_vv(v56, v56, v24); |
| vst_w_p_x(v56, p_out); |
| if (M0 == 2) { |
| vadd_w_vv(v57, v57, v25); |
| vst_w_p_x(v57, p_out); |
| } |
| } else { |
| vadd_w_vv_m(v56, v56, v24); |
| vst_w_p_x_m(v56, p_out); |
| } |
| |
| if (M0 == 8) { |
| // de-interleaving |
| vzip_w_vv(v28, v52, v53); |
| vzip_w_vv(v30, v54, v55); |
| vzip_w_vv(v32, v28, v30); |
| vzip_w_vv(v34, v29, v31); |
| |
| vadd_w_vv_m(v60, v60, v32); |
| vst_w_p_x_m(v60, p_out); |
| } |
| } |
| |
| #else // RVV implementation in Springbok |
| |
| #include <riscv_vector.h> |
| // Calculate the dot product of two int8 vectors using RVV |
| static iree_uk_int32_t dot_product_s8s8(const iree_uk_int8_t* u, |
| const iree_uk_int8_t* w, int n) { |
| size_t vl; |
| // auxiliary variables |
| vint8m4_t vu, vw; |
| vint16m8_t vx; |
| vint32m1_t v_sum; |
| iree_uk_int32_t sum = 0; |
| for (size_t i = 0; i < n; i += vl) { |
| vl = __riscv_vsetvl_e8m4(n - i); |
| vu = __riscv_vle8_v_i8m4(u + i, vl); // load |
| vw = __riscv_vle8_v_i8m4(w + i, vl); // load |
| vx = __riscv_vwmul(vu, vw, vl); // multiply |
| v_sum = __riscv_vmv_v_x_i32m1(0, vl); // init |
| v_sum = __riscv_vwredsum(vx, v_sum, vl); // sum |
| sum += __riscv_vmv_x(v_sum); |
| } |
| return sum; |
| } |
| |
| // Calculate the dot product of int16 x int8 vectors using RVV |
| static iree_uk_int32_t dot_product_s16s8(const iree_uk_int16_t* u, |
| const iree_uk_int8_t* w, int n) { |
| size_t vl; |
| // auxiliary variables |
| vint16m4_t vu, vw; |
| vint8m2_t vy; |
| vint32m8_t vx; |
| vint32m1_t v_sum; |
| iree_uk_int32_t sum = 0; |
| for (size_t i = 0; i < n; i += vl) { |
| vl = __riscv_vsetvl_e16m4(n - i); |
| vu = __riscv_vle16_v_i16m4(u + i, vl); // load |
| vy = __riscv_vle8_v_i8m2(w + i, vl); // load |
| vw = __riscv_vwadd_vx(vy, 0, vl); // widen |
| vx = __riscv_vwmul(vu, vw, vl); // multiply |
| v_sum = __riscv_vmv_v_x_i32m1(0, vl); // init |
| v_sum = __riscv_vredsum(vx, v_sum, vl); // sum |
| sum += __riscv_vmv_x(v_sum); |
| } |
| return sum; |
| } |
| |
| #endif // # ifdef BUILD_KELVIN |
| |
| // RVV implementation of matmul tile, s8*s8->s32 case. |
| static void iree_uk_mmt4d_tile_s8s8s32_custom( |
| void* out_tile_untyped, const void* lhs_panel_untyped, |
| const void* rhs_panel_untyped, const iree_uk_mmt4d_params_t* params) { |
| iree_uk_int32_t* out_tile = out_tile_untyped; |
| const iree_uk_int8_t* lhs_panel = lhs_panel_untyped; |
| const iree_uk_int8_t* rhs_panel = rhs_panel_untyped; |
| iree_uk_int16_t M0 = params->M0; |
| iree_uk_int16_t N0 = params->N0; |
| iree_uk_int16_t K0 = params->K0; |
| |
| #ifdef BUILD_KELVIN |
| if (N0 == 8 && K0 == 4 && (M0 == 8 || M0 == 4 || M0 == 2 || M0 == 1)) { |
| matmul_s8s8(lhs_panel, rhs_panel, out_tile, params); |
| return; |
| } |
| #endif // # ifdef BUILD_KELVIN |
| |
| // Initialize the accumulator tile. |
| if (!(params->flags & IREE_UK_FLAG_MMT4D_ACCUMULATE)) { |
| for (iree_uk_int32_t i = 0; i < M0 * N0; i++) out_tile[i] = 0; |
| } |
| |
| // Accumulation loop. |
| for (iree_uk_index_t k = 0; k < params->K; ++k) { |
| for (iree_uk_index_t i0 = 0; i0 < M0; ++i0) { |
| for (iree_uk_index_t j0 = 0; j0 < N0; ++j0) { |
| out_tile[i0 * N0 + j0] += |
| dot_product_s8s8(lhs_panel + i0 * K0, rhs_panel + j0 * K0, K0); |
| } |
| } |
| lhs_panel += M0 * K0; |
| rhs_panel += N0 * K0; |
| } |
| } |
| |
| // implementation of matmul tile, s16*s8->s32 case. |
| static void iree_uk_mmt4d_tile_s16s8s32_custom( |
| void* out_tile_untyped, const void* lhs_panel_untyped, |
| const void* rhs_panel_untyped, const iree_uk_mmt4d_params_t* params) { |
| iree_uk_int32_t* out_tile = out_tile_untyped; |
| const iree_uk_int16_t* lhs_panel = lhs_panel_untyped; |
| const iree_uk_int8_t* rhs_panel = rhs_panel_untyped; |
| iree_uk_int16_t M0 = params->M0; |
| iree_uk_int16_t N0 = params->N0; |
| iree_uk_int16_t K0 = params->K0; |
| |
| #ifdef BUILD_KELVIN |
| if (N0 == 8 && K0 == 4 && (M0 == 8 || M0 == 4 || M0 == 2 || M0 == 1)) { |
| matmul_s16s8(lhs_panel, rhs_panel, out_tile, params); |
| return; |
| } |
| #endif // # ifdef BUILD_KELVIN |
| |
| // Initialize the accumulator tile. |
| if (!(params->flags & IREE_UK_FLAG_MMT4D_ACCUMULATE)) { |
| for (iree_uk_int32_t i = 0; i < M0 * N0; i++) out_tile[i] = 0; |
| } |
| |
| // Accumulation loop. |
| for (iree_uk_index_t k = 0; k < params->K; ++k) { |
| for (iree_uk_index_t i0 = 0; i0 < M0; ++i0) { |
| for (iree_uk_index_t j0 = 0; j0 < N0; ++j0) { |
| out_tile[i0 * N0 + j0] += |
| dot_product_s16s8(lhs_panel + i0 * K0, rhs_panel + j0 * K0, K0); |
| } |
| } |
| lhs_panel += M0 * K0; |
| rhs_panel += N0 * K0; |
| } |
| } |
| |
| // Generic implementation of matmul tile, f32*f32->f32 case. |
| static void iree_uk_mmt4d_tile_f32f32f32_generic( |
| void* out_tile_untyped, const void* lhs_panel_untyped, |
| const void* rhs_panel_untyped, const iree_uk_mmt4d_params_t* params) { |
| float* out_tile = out_tile_untyped; |
| const float* lhs_panel = lhs_panel_untyped; |
| const float* rhs_panel = rhs_panel_untyped; |
| iree_uk_int16_t M0 = params->M0; |
| iree_uk_int16_t N0 = params->N0; |
| iree_uk_int16_t K0 = params->K0; |
| for (iree_uk_index_t i0 = 0; i0 < M0; ++i0) { |
| for (iree_uk_index_t j0 = 0; j0 < N0; ++j0) { |
| float acc = (params->flags & IREE_UK_FLAG_MMT4D_ACCUMULATE) |
| ? out_tile[i0 * N0 + j0] |
| : 0.f; |
| for (iree_uk_index_t k = 0; k < params->K; ++k) { |
| for (iree_uk_index_t k0 = 0; k0 < K0; ++k0) { |
| float lhs_f32 = lhs_panel[k * M0 * K0 + i0 * K0 + k0]; |
| float rhs_f32 = rhs_panel[k * N0 * K0 + j0 * K0 + k0]; |
| acc += lhs_f32 * rhs_f32; |
| } |
| } |
| out_tile[i0 * N0 + j0] = acc; |
| } |
| } |
| } |
| |
| iree_uk_mmt4d_tile_func_t iree_uk_mmt4d_select_tile_func_generic( |
| const iree_uk_mmt4d_params_t* params) { |
| switch (iree_uk_mmt4d_type(params->flags)) { |
| case iree_uk_mmt4d_type_s8s8s32: |
| return iree_uk_mmt4d_tile_s8s8s32_custom; |
| case iree_uk_mmt4d_type_s16s8s32: |
| return iree_uk_mmt4d_tile_s16s8s32_custom; |
| case iree_uk_mmt4d_type_f32f32f32: |
| return iree_uk_mmt4d_tile_f32f32f32_generic; |
| default: |
| // Shouldn't happen, validated earlier. |
| return 0; |
| } |
| } |