| /* |
| * Copyright 2023 Google LLC |
| * |
| * Licensed under the Apache License, Version 2.0 (the "License"); |
| * you may not use this file except in compliance with the License. |
| * You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| #include <riscv_vector.h> |
| #include <string.h> |
| |
| #include "iree/builtins/ukernel/mmt4d_internal.h" |
| |
| // Calculate the dot product of two int8 vectors using RVV |
| static iree_uk_int32_t dot_product_rvv(const iree_uk_int8_t* u, |
| const iree_uk_int8_t* w, int n) { |
| size_t vl; |
| // auxiliary variables |
| vint8m4_t vu, vw; |
| vint16m8_t vx; |
| vint32m1_t v_sum; |
| iree_uk_int32_t sum = 0; |
| for (size_t i = 0; i < n; i += vl) { |
| vl = __riscv_vsetvl_e8m4(n - i); |
| vu = __riscv_vle8_v_i8m4(u + i, vl); // load |
| vw = __riscv_vle8_v_i8m4(w + i, vl); // load |
| vx = __riscv_vwmul(vu, vw, vl); // multiply |
| v_sum = __riscv_vmv_v_x_i32m1(0, vl); // init |
| v_sum = __riscv_vwredsum(vx, v_sum, vl); // sum |
| sum += __riscv_vmv_x(v_sum); |
| } |
| return sum; |
| } |
| |
| // RVV implementation of matmul tile, i8*i8->i32 case. |
| static void iree_uk_mmt4d_tile_s8s8s32_rvv( |
| void* out_tile_untyped, const void* lhs_panel_untyped, |
| const void* rhs_panel_untyped, const iree_uk_mmt4d_params_t* params) { |
| iree_uk_int32_t* out_tile = out_tile_untyped; |
| const iree_uk_int8_t* lhs_panel = lhs_panel_untyped; |
| const iree_uk_int8_t* rhs_panel = rhs_panel_untyped; |
| iree_uk_int16_t M0 = params->M0; |
| iree_uk_int16_t N0 = params->N0; |
| iree_uk_int16_t K0 = params->K0; |
| // Initialize the accumulator tile. |
| if (!(params->flags & IREE_UK_FLAG_MMT4D_ACCUMULATE)) { |
| memset(out_tile, 0, M0 * N0 * sizeof(iree_uk_int32_t)); |
| } |
| // Accumulation loop. |
| for (iree_uk_index_t k = 0; k < params->K; ++k) { |
| for (iree_uk_index_t i0 = 0; i0 < M0; ++i0) { |
| for (iree_uk_index_t j0 = 0; j0 < N0; ++j0) { |
| out_tile[i0 * N0 + j0] += |
| dot_product_rvv(lhs_panel + i0 * K0, rhs_panel + j0 * K0, K0); |
| } |
| } |
| lhs_panel += M0 * K0; |
| rhs_panel += N0 * K0; |
| } |
| } |
| |
| // Generic implementation of matmul tile, f32*f32->f32 case. |
| static void iree_uk_mmt4d_tile_f32f32f32_generic( |
| void* out_tile_untyped, const void* lhs_panel_untyped, |
| const void* rhs_panel_untyped, const iree_uk_mmt4d_params_t* params) { |
| float* out_tile = out_tile_untyped; |
| const float* lhs_panel = lhs_panel_untyped; |
| const float* rhs_panel = rhs_panel_untyped; |
| iree_uk_int16_t M0 = params->M0; |
| iree_uk_int16_t N0 = params->N0; |
| iree_uk_int16_t K0 = params->K0; |
| // Initialize the local accumulator tile. |
| float acc[iree_uk_mmt4d_tile_generic_max_bytes / sizeof(*out_tile)]; |
| if (params->flags & IREE_UK_FLAG_MMT4D_ACCUMULATE) { |
| for (int i = 0; i < M0 * N0; ++i) acc[i] = out_tile[i]; |
| } else { |
| for (int i = 0; i < M0 * N0; ++i) acc[i] = 0; |
| } |
| // Accumulation loop. |
| for (iree_uk_index_t k = 0; k < params->K; ++k) { |
| for (iree_uk_index_t i0 = 0; i0 < M0; ++i0) { |
| for (iree_uk_index_t j0 = 0; j0 < N0; ++j0) { |
| for (iree_uk_index_t k0 = 0; k0 < K0; ++k0) { |
| float lhs_val = lhs_panel[i0 * K0 + k0]; |
| float rhs_val = rhs_panel[j0 * K0 + k0]; |
| acc[i0 * N0 + j0] += lhs_val * rhs_val; |
| } |
| } |
| } |
| lhs_panel += M0 * K0; |
| rhs_panel += N0 * K0; |
| } |
| // Store the local accumulator tile to the destination. |
| for (int i = 0; i < M0 * N0; ++i) out_tile[i] = acc[i]; |
| } |
| |
| iree_uk_mmt4d_tile_func_t iree_uk_mmt4d_select_tile_func( |
| const iree_uk_mmt4d_params_t* params) { |
| // TODO(lundong): to be replaced with Kelvin |
| switch (iree_uk_mmt4d_type(params->flags)) { |
| case iree_uk_mmt4d_type_f32f32f32: |
| return iree_uk_mmt4d_tile_f32f32f32_generic; |
| case iree_uk_mmt4d_type_s8s8s32: |
| return iree_uk_mmt4d_tile_s8s8s32_rvv; |
| default: |
| // shouldn't happen, validated earlier. |
| IREE_UK_ASSUME_UNREACHABLE; |
| return 0; |
| } |
| } |