blob: 92e19ce478feeec6dd1b1c5a046cfc2df04cb877 [file] [log] [blame]
/*
* Copyright 2022 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "iree/builtins/ukernel/arch/mmt4d_arch.h"
#include <riscv_vector.h>
#include <string.h>
// Calculate the dot product of two int8 vectors using RVV
static iree_uk_int32_t dot_product_rvv(const iree_uk_int8_t* u,
const iree_uk_int8_t* w, int n) {
size_t vl;
// auxiliary variables
vint8m4_t vu, vw;
vint16m8_t vx;
vint32m1_t v_sum;
iree_uk_int32_t sum = 0;
for (size_t i = 0; i < n; i += vl) {
vl = vsetvl_e8m4(n - i);
vu = vle8_v_i8m4(u + i, vl); // load
vw = vle8_v_i8m4(w + i, vl); // load
vx = vwmul(vu, vw, vl); // multiply
v_sum = vmv_s(v_sum, 0, vl); // init
v_sum = vwredsum(v_sum, vx, v_sum, vl); // sum
sum += vmv_x(v_sum);
}
return sum;
}
// RVV implementation of matmul tile, i8*i8->i32 case.
static void iree_uk_mmt4d_tile_i8i8i32_rvv(
void* out_tile_untyped, const void* lhs_panel_untyped,
const void* rhs_panel_untyped, iree_uk_int32_t K, iree_uk_uint32_t flags,
const iree_uk_mmt4d_params_t* params) {
iree_uk_int32_t* out_tile = out_tile_untyped;
const iree_uk_int8_t* lhs_panel = lhs_panel_untyped;
const iree_uk_int8_t* rhs_panel = rhs_panel_untyped;
iree_uk_int16_t M0 = params->M0;
iree_uk_int16_t N0 = params->N0;
iree_uk_int16_t K0 = params->K0;
// Initialize the accumulator tile.
if (!(flags & IREE_UK_FLAG_ACCUMULATE)) {
memset(out_tile, 0, M0 * N0 * sizeof(iree_uk_int32_t));
}
// Accumulation loop.
for (iree_uk_ssize_t k = 0; k < K; ++k) {
for (iree_uk_ssize_t i0 = 0; i0 < M0; ++i0) {
for (iree_uk_ssize_t j0 = 0; j0 < N0; ++j0) {
out_tile[i0 * N0 + j0] +=
dot_product_rvv(lhs_panel + i0 * K0, rhs_panel + j0 * K0, K0);
}
}
lhs_panel += M0 * K0;
rhs_panel += N0 * K0;
}
}
iree_uk_mmt4d_tile_func_t iree_uk_mmt4d_select_tile_func_arch(
const iree_uk_mmt4d_params_t* params) {
// TODO(lundong): to be replaced with Kelvin
if (params->type == iree_uk_mmt4d_type_i8i8i32) {
return iree_uk_mmt4d_tile_i8i8i32_rvv;
}
return 0;
}