blob: a8f830cc41f0c79f8af74d055de802668096ef6c [file] [log] [blame]
/*
* Copyright 2023 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "iree/builtins/ukernel/mmt4d_internal.h"
#ifdef BUILD_KELVIN
#include <kelvin.h>
static_assert(sizeof(struct vconv_u8_t) == 4);
union {
struct vconv_u8_t conv;
uint32_t raw;
} cmds;
static iree_uk_int32_t buffer[8] __attribute__((aligned));
// dot product s8xx8->s32
static iree_uk_int32_t dot_product_s8s8(const iree_uk_int8_t* u,
const iree_uk_int8_t* w, int n) {
int vl;
getmaxvl_w(vl);
vdup_w_x_m(v0, 0);
while (n) {
int count = n < vl ? n : vl;
vld_b_lp_xx(v8, u, count);
vaddw_h_vx(v8, v8, 0);
vzip_h_vv(v10, v8, v9);
vld_b_lp_xx(v16, w, count);
vaddw_h_vx(v16, v16, 0);
vzip_h_vv(v18, v16, v17);
vmulw_w_vv(v4, v10, v18);
vzip_w_vv(v1, v4, v5);
vadd_w_vv(v0, v0, v1);
n -= count;
}
vst_w_l_xx(v0, buffer, vl);
iree_uk_int32_t sum = 0;
for (int i = 0; i < vl; ++i) sum += buffer[i];
return sum;
}
// dot product s16xs8->s32
static iree_uk_int32_t dot_product_s16s8(const iree_uk_int16_t* u,
const iree_uk_int8_t* w, int n) {
int vl;
getmaxvl_w(vl);
vdup_w_x_m(v0, 0);
while (n) {
int count = n < vl ? n : vl;
vld_h_lp_xx(v8, u, count);
vld_b_lp_xx(v16, w, count);
vaddw_h_vx(v16, v16, 0);
vzip_h_vv(v18, v16, v17);
vmulw_w_vv(v4, v8, v18);
vzip_w_vv(v1, v4, v5);
vadd_w_vv(v0, v0, v1);
n -= count;
}
vst_w_l_xx(v0, buffer, vl);
iree_uk_int32_t sum = 0;
for (int i = 0; i < vl; ++i) sum += buffer[i];
return sum;
}
// Matrix multiplication s8xs8->s32.
// lhs (row-major): 8/4/2/1x4, rhs (column-major): 4x8
// out (row-major): 8/4/2/1x8
static void matmul_s8s8(const iree_uk_int8_t* lhs, const iree_uk_int8_t* rhs,
iree_uk_int32_t* out,
const iree_uk_mmt4d_params_t* params) {
iree_uk_int16_t M0 = params->M0;
iree_uk_int16_t K = params->K;
IREE_UK_ASSERT(params->N0 == 8 && params->K0 == 4);
IREE_UK_ASSERT(M0 == 8 || M0 == 4 || M0 == 2 || M0 == 1);
IREE_UK_ASSERT(K == 1 || K == 2 || K == 4 || (K % 8 == 0));
cmds.conv.mode = 0;
cmds.conv.start = 0;
cmds.conv.stop = K > 8 ? 7 : K - 1;
cmds.conv.sbias1 = 0;
cmds.conv.sdata1 = true;
cmds.conv.sbias2 = 0;
cmds.conv.sdata2 = true;
const iree_uk_int8_t* p_in = lhs;
const iree_uk_int8_t* p_flt = rhs;
iree_uk_int32_t* p_out = out;
// load initial output tile values
if (params->flags & IREE_UK_FLAG_MMT4D_ACCUMULATE) {
if (M0 <= 2) {
vld_w_p_x(v56, p_out);
if (M0 == 2) {
vld_w_p_x(v57, p_out);
}
} else {
vld_w_p_x_m(v56, p_out);
}
if (M0 == 8) {
vld_w_p_x_m(v60, p_out);
}
p_out = out;
} else {
vdup_w_x_m(v56, 0);
vdup_w_x_m(v60, 0);
}
iree_uk_int16_t k = K;
while (k > 0) {
k -= 8;
// load LHS and RHS
if (K <= 2) {
vld_b_lp_xx(v0, p_in, 4 * M0);
vld_b_p_x(v8, p_flt);
if (K == 2) {
vld_b_lp_xx(v1, p_in, 4 * M0);
vld_b_p_x(v9, p_flt);
}
} else {
vld_b_sp_xx_m(v0, p_in, 4 * M0);
vld_b_p_x_m(v8, p_flt);
}
if (K >= 8) {
vld_b_sp_xx_m(v4, p_in, 4 * M0);
vld_b_p_x_m(v12, p_flt);
}
// re-arrange LHS
vzip_w_vv_m(v20, v0, v4);
vzip_w_vv_m(v28, v20, v24);
vzip_w_vv_m(v0, v28, v32);
// matrix multiplication
aconv_vxv(v48, v0, cmds, v8);
}
vcget(v48);
// de-interleaving
vzip_w_vv(v20, v48, v49);
vzip_w_vv(v22, v50, v51);
vzip_w_vv(v24, v20, v22);
vzip_w_vv(v26, v21, v23);
if (M0 <= 2) {
vadd_w_vv(v56, v56, v24);
vst_w_p_x(v56, p_out);
if (M0 == 2) {
vadd_w_vv(v57, v57, v25);
vst_w_p_x(v57, p_out);
}
} else {
vadd_w_vv_m(v56, v56, v24);
vst_w_p_x_m(v56, p_out);
}
if (M0 == 8) {
// de-interleaving
vzip_w_vv(v28, v52, v53);
vzip_w_vv(v30, v54, v55);
vzip_w_vv(v32, v28, v30);
vzip_w_vv(v34, v29, v31);
vadd_w_vv_m(v60, v60, v32);
vst_w_p_x_m(v60, p_out);
}
}
// Matrix multiplication s16xs8->s32.
// lhs (row-major): 8/4/2/1x4, rhs (column-major): 4x8
// out (row-major): 8/4/2/1x8
// s16 * s8 = (s8_hi * s8) << 8 + (u8_lo * s8)
// where s8_hi = s16[15:8]; u8_lo = s16[7:0]
static void matmul_s16s8(const iree_uk_int16_t* lhs, const iree_uk_int8_t* rhs,
iree_uk_int32_t* out,
const iree_uk_mmt4d_params_t* params) {
iree_uk_int16_t M0 = params->M0;
iree_uk_int16_t K = params->K;
IREE_UK_ASSERT(params->N0 == 8 && params->K0 == 4);
IREE_UK_ASSERT(M0 == 8 || M0 == 4 || M0 == 2 || M0 == 1);
IREE_UK_ASSERT(K == 1 || K == 2 || K == 4 || (K % 8 == 0));
cmds.conv.mode = 0;
cmds.conv.start = 0;
cmds.conv.stop = K > 8 ? 7 : K - 1;
cmds.conv.sbias1 = 0;
cmds.conv.sdata1 = true;
cmds.conv.sbias2 = 0;
cmds.conv.sdata2 = true;
const iree_uk_int16_t* p_in = lhs;
const iree_uk_int8_t* p_flt = rhs;
iree_uk_int32_t* p_out = out;
// load initial output tile values
if (params->flags & IREE_UK_FLAG_MMT4D_ACCUMULATE) {
if (M0 <= 2) {
vld_w_p_x(v56, p_out);
if (M0 == 2) {
vld_w_p_x(v57, p_out);
}
} else {
vld_w_p_x_m(v56, p_out);
}
if (M0 == 8) {
vld_w_p_x_m(v60, p_out);
}
} else {
vdup_w_x_m(v56, 0);
vdup_w_x_m(v60, 0);
}
// first pass - high
iree_uk_int16_t k = K;
while (k > 0) {
k -= 8;
// load LHS high bits and RHS
if (K == 1) {
vld_h_lp_xx(v12, p_in, 2 * M0);
vld_h_lp_xx(v13, p_in, 2 * M0);
vodd_b_vv(v0, v12, v13);
vld_b_p_x(v8, p_flt);
} else if (K == 2) {
vld_h_sp_xx_m(v12, p_in, 2 * M0);
vodd_b_vv(v0, v12, v13);
vodd_b_vv(v1, v14, v15);
vld_b_p_x(v8, p_flt);
vld_b_p_x(v9, p_flt);
}
if (K >= 4) {
vld_h_sp_xx_m(v12, p_in, 2 * M0);
vld_h_sp_xx_m(v16, p_in, 2 * M0);
vodd_b_vv_m(v0, v12, v16);
vld_b_p_x_m(v8, p_flt);
}
if (K >= 8) {
vld_h_sp_xx_m(v20, p_in, 2 * M0);
vld_h_sp_xx_m(v24, p_in, 2 * M0);
vodd_b_vv_m(v4, v20, v24);
vld_b_p_x_m(v12, p_flt);
}
// re-arrange LHS
vzip_w_vv_m(v20, v0, v4);
vzip_w_vv_m(v28, v20, v24);
vzip_w_vv_m(v0, v28, v32);
// matrix multiplication
aconv_vxv(v48, v0, cmds, v8);
}
vcget(v48);
// left shift 8 bits and store to accumulator
vsll_w_vx_m(v48, v48, 8);
vsll_w_vx_m(v52, v52, 8);
acset_v(v48, v48);
// second pass - low
cmds.conv.sdata1 = false;
p_in = lhs;
p_flt = rhs;
p_out = out;
k = K;
while (k > 0) {
k -= 8;
// load LHS low bits and RHS
if (K == 1) {
vld_h_lp_xx(v12, p_in, 2 * M0);
vld_h_lp_xx(v13, p_in, 2 * M0);
vevn_b_vv(v0, v12, v13);
vld_b_p_x(v8, p_flt);
} else if (K == 2) {
vld_h_sp_xx_m(v12, p_in, 2 * M0);
vevn_b_vv(v0, v12, v13);
vevn_b_vv(v1, v14, v15);
vld_b_p_x(v8, p_flt);
vld_b_p_x(v9, p_flt);
}
if (K >= 4) {
vld_h_sp_xx_m(v12, p_in, 2 * M0);
vld_h_sp_xx_m(v16, p_in, 2 * M0);
vevn_b_vv_m(v0, v12, v16);
vld_b_p_x_m(v8, p_flt);
}
if (K >= 8) {
vld_h_sp_xx_m(v20, p_in, 2 * M0);
vld_h_sp_xx_m(v24, p_in, 2 * M0);
vevn_b_vv_m(v4, v20, v24);
vld_b_p_x_m(v12, p_flt);
}
// re-arrange LHS
vzip_w_vv_m(v20, v0, v4);
vzip_w_vv_m(v28, v20, v24);
vzip_w_vv_m(v0, v28, v32);
// matrix multiplication
aconv_vxv(v48, v0, cmds, v8);
}
vcget(v48);
// de-interleaving
vzip_w_vv(v20, v48, v49);
vzip_w_vv(v22, v50, v51);
vzip_w_vv(v24, v20, v22);
vzip_w_vv(v26, v21, v23);
if (M0 <= 2) {
vadd_w_vv(v56, v56, v24);
vst_w_p_x(v56, p_out);
if (M0 == 2) {
vadd_w_vv(v57, v57, v25);
vst_w_p_x(v57, p_out);
}
} else {
vadd_w_vv_m(v56, v56, v24);
vst_w_p_x_m(v56, p_out);
}
if (M0 == 8) {
// de-interleaving
vzip_w_vv(v28, v52, v53);
vzip_w_vv(v30, v54, v55);
vzip_w_vv(v32, v28, v30);
vzip_w_vv(v34, v29, v31);
vadd_w_vv_m(v60, v60, v32);
vst_w_p_x_m(v60, p_out);
}
}
#else // RVV implementation in Springbok
#include <riscv_vector.h>
// Calculate the dot product of two int8 vectors using RVV
static iree_uk_int32_t dot_product_s8s8(const iree_uk_int8_t* u,
const iree_uk_int8_t* w, int n) {
size_t vl;
// auxiliary variables
vint8m4_t vu, vw;
vint16m8_t vx;
vint32m1_t v_sum;
iree_uk_int32_t sum = 0;
for (size_t i = 0; i < n; i += vl) {
vl = __riscv_vsetvl_e8m4(n - i);
vu = __riscv_vle8_v_i8m4(u + i, vl); // load
vw = __riscv_vle8_v_i8m4(w + i, vl); // load
vx = __riscv_vwmul(vu, vw, vl); // multiply
v_sum = __riscv_vmv_v_x_i32m1(0, vl); // init
v_sum = __riscv_vwredsum(vx, v_sum, vl); // sum
sum += __riscv_vmv_x(v_sum);
}
return sum;
}
// Calculate the dot product of int16 x int8 vectors using RVV
static iree_uk_int32_t dot_product_s16s8(const iree_uk_int16_t* u,
const iree_uk_int8_t* w, int n) {
size_t vl;
// auxiliary variables
vint16m4_t vu, vw;
vint8m2_t vy;
vint32m8_t vx;
vint32m1_t v_sum;
iree_uk_int32_t sum = 0;
for (size_t i = 0; i < n; i += vl) {
vl = __riscv_vsetvl_e16m4(n - i);
vu = __riscv_vle16_v_i16m4(u + i, vl); // load
vy = __riscv_vle8_v_i8m2(w + i, vl); // load
vw = __riscv_vwadd_vx(vy, 0, vl); // widen
vx = __riscv_vwmul(vu, vw, vl); // multiply
v_sum = __riscv_vmv_v_x_i32m1(0, vl); // init
v_sum = __riscv_vredsum(vx, v_sum, vl); // sum
sum += __riscv_vmv_x(v_sum);
}
return sum;
}
#endif // # ifdef BUILD_KELVIN
// RVV implementation of matmul tile, s8*s8->s32 case.
static void iree_uk_mmt4d_tile_s8s8s32_custom(
void* out_tile_untyped, const void* lhs_panel_untyped,
const void* rhs_panel_untyped, const iree_uk_mmt4d_params_t* params) {
iree_uk_int32_t* out_tile = out_tile_untyped;
const iree_uk_int8_t* lhs_panel = lhs_panel_untyped;
const iree_uk_int8_t* rhs_panel = rhs_panel_untyped;
iree_uk_int16_t M0 = params->M0;
iree_uk_int16_t N0 = params->N0;
iree_uk_int16_t K0 = params->K0;
#ifdef BUILD_KELVIN
if (N0 == 8 && K0 == 4 && (M0 == 8 || M0 == 4 || M0 == 2 || M0 == 1)) {
matmul_s8s8(lhs_panel, rhs_panel, out_tile, params);
return;
}
#endif // # ifdef BUILD_KELVIN
// Initialize the accumulator tile.
if (!(params->flags & IREE_UK_FLAG_MMT4D_ACCUMULATE)) {
for (iree_uk_int32_t i = 0; i < M0 * N0; i++) out_tile[i] = 0;
}
// Accumulation loop.
for (iree_uk_index_t k = 0; k < params->K; ++k) {
for (iree_uk_index_t i0 = 0; i0 < M0; ++i0) {
for (iree_uk_index_t j0 = 0; j0 < N0; ++j0) {
out_tile[i0 * N0 + j0] +=
dot_product_s8s8(lhs_panel + i0 * K0, rhs_panel + j0 * K0, K0);
}
}
lhs_panel += M0 * K0;
rhs_panel += N0 * K0;
}
}
// implementation of matmul tile, s16*s8->s32 case.
static void iree_uk_mmt4d_tile_s16s8s32_custom(
void* out_tile_untyped, const void* lhs_panel_untyped,
const void* rhs_panel_untyped, const iree_uk_mmt4d_params_t* params) {
iree_uk_int32_t* out_tile = out_tile_untyped;
const iree_uk_int16_t* lhs_panel = lhs_panel_untyped;
const iree_uk_int8_t* rhs_panel = rhs_panel_untyped;
iree_uk_int16_t M0 = params->M0;
iree_uk_int16_t N0 = params->N0;
iree_uk_int16_t K0 = params->K0;
#ifdef BUILD_KELVIN
if (N0 == 8 && K0 == 4 && (M0 == 8 || M0 == 4 || M0 == 2 || M0 == 1)) {
matmul_s16s8(lhs_panel, rhs_panel, out_tile, params);
return;
}
#endif // # ifdef BUILD_KELVIN
// Initialize the accumulator tile.
if (!(params->flags & IREE_UK_FLAG_MMT4D_ACCUMULATE)) {
for (iree_uk_int32_t i = 0; i < M0 * N0; i++) out_tile[i] = 0;
}
// Accumulation loop.
for (iree_uk_index_t k = 0; k < params->K; ++k) {
for (iree_uk_index_t i0 = 0; i0 < M0; ++i0) {
for (iree_uk_index_t j0 = 0; j0 < N0; ++j0) {
out_tile[i0 * N0 + j0] +=
dot_product_s16s8(lhs_panel + i0 * K0, rhs_panel + j0 * K0, K0);
}
}
lhs_panel += M0 * K0;
rhs_panel += N0 * K0;
}
}
// Generic implementation of matmul tile, f32*f32->f32 case.
static void iree_uk_mmt4d_tile_f32f32f32_generic(
void* out_tile_untyped, const void* lhs_panel_untyped,
const void* rhs_panel_untyped, const iree_uk_mmt4d_params_t* params) {
float* out_tile = out_tile_untyped;
const float* lhs_panel = lhs_panel_untyped;
const float* rhs_panel = rhs_panel_untyped;
iree_uk_int16_t M0 = params->M0;
iree_uk_int16_t N0 = params->N0;
iree_uk_int16_t K0 = params->K0;
for (iree_uk_index_t i0 = 0; i0 < M0; ++i0) {
for (iree_uk_index_t j0 = 0; j0 < N0; ++j0) {
float acc = (params->flags & IREE_UK_FLAG_MMT4D_ACCUMULATE)
? out_tile[i0 * N0 + j0]
: 0.f;
for (iree_uk_index_t k = 0; k < params->K; ++k) {
for (iree_uk_index_t k0 = 0; k0 < K0; ++k0) {
float lhs_f32 = lhs_panel[k * M0 * K0 + i0 * K0 + k0];
float rhs_f32 = rhs_panel[k * N0 * K0 + j0 * K0 + k0];
acc += lhs_f32 * rhs_f32;
}
}
out_tile[i0 * N0 + j0] = acc;
}
}
}
iree_uk_mmt4d_tile_func_t iree_uk_mmt4d_select_tile_func_generic(
const iree_uk_mmt4d_params_t* params) {
switch (iree_uk_mmt4d_type(params->flags)) {
case iree_uk_mmt4d_type_s8s8s32:
return iree_uk_mmt4d_tile_s8s8s32_custom;
case iree_uk_mmt4d_type_s16s8s32:
return iree_uk_mmt4d_tile_s16s8s32_custom;
case iree_uk_mmt4d_type_f32f32f32:
return iree_uk_mmt4d_tile_f32f32f32_generic;
default:
// Shouldn't happen, validated earlier.
return 0;
}
}