blob: 10105b66b72e192ff57329404fe8fe64a5a191b7 [file] [log] [blame]
Lun Dongfc805b32023-02-01 09:52:49 -08001/*
2 * Copyright 2023 Google LLC
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
Lun Dongfc805b32023-02-01 09:52:49 -080017#include <riscv_vector.h>
18#include <string.h>
19
Lun Dongf20445c2023-04-24 12:01:48 -070020#include "iree/builtins/ukernel/mmt4d_internal.h"
21
Lun Dongfc805b32023-02-01 09:52:49 -080022// Calculate the dot product of two int8 vectors using RVV
23static iree_uk_int32_t dot_product_rvv(const iree_uk_int8_t* u,
24 const iree_uk_int8_t* w, int n) {
25 size_t vl;
26 // auxiliary variables
27 vint8m4_t vu, vw;
28 vint16m8_t vx;
29 vint32m1_t v_sum;
30 iree_uk_int32_t sum = 0;
31 for (size_t i = 0; i < n; i += vl) {
32 vl = vsetvl_e8m4(n - i);
33 vu = vle8_v_i8m4(u + i, vl); // load
34 vw = vle8_v_i8m4(w + i, vl); // load
35 vx = vwmul(vu, vw, vl); // multiply
36 v_sum = vmv_s(v_sum, 0, vl); // init
37 v_sum = vwredsum(v_sum, vx, v_sum, vl); // sum
38 sum += vmv_x(v_sum);
39 }
40 return sum;
41}
42
43// RVV implementation of matmul tile, i8*i8->i32 case.
44static void iree_uk_mmt4d_tile_i8i8i32_rvv(
45 void* out_tile_untyped, const void* lhs_panel_untyped,
46 const void* rhs_panel_untyped, iree_uk_int32_t K, iree_uk_uint32_t flags,
47 const iree_uk_mmt4d_params_t* params) {
48 iree_uk_int32_t* out_tile = out_tile_untyped;
49 const iree_uk_int8_t* lhs_panel = lhs_panel_untyped;
50 const iree_uk_int8_t* rhs_panel = rhs_panel_untyped;
51 iree_uk_int16_t M0 = params->M0;
52 iree_uk_int16_t N0 = params->N0;
53 iree_uk_int16_t K0 = params->K0;
54 // Initialize the accumulator tile.
Lun Dong4e239a02023-04-25 09:32:17 -070055 if (!(flags & IREE_UK_FLAG_MMT4D_ACCUMULATE)) {
Lun Dongfc805b32023-02-01 09:52:49 -080056 memset(out_tile, 0, M0 * N0 * sizeof(iree_uk_int32_t));
57 }
58 // Accumulation loop.
Lun Dong8914dd62023-06-02 12:58:22 -070059 for (iree_uk_index_t k = 0; k < K; ++k) {
60 for (iree_uk_index_t i0 = 0; i0 < M0; ++i0) {
61 for (iree_uk_index_t j0 = 0; j0 < N0; ++j0) {
Lun Dongfc805b32023-02-01 09:52:49 -080062 out_tile[i0 * N0 + j0] +=
63 dot_product_rvv(lhs_panel + i0 * K0, rhs_panel + j0 * K0, K0);
64 }
65 }
66 lhs_panel += M0 * K0;
67 rhs_panel += N0 * K0;
68 }
69}
70
71// Generic implementation of matmul tile, f32*f32->f32 case.
72static void iree_uk_mmt4d_tile_f32f32f32_generic(
73 void* out_tile_untyped, const void* lhs_panel_untyped,
74 const void* rhs_panel_untyped, iree_uk_int32_t K, iree_uk_uint32_t flags,
75 const iree_uk_mmt4d_params_t* params) {
76 float* out_tile = out_tile_untyped;
77 const float* lhs_panel = lhs_panel_untyped;
78 const float* rhs_panel = rhs_panel_untyped;
79 iree_uk_int16_t M0 = params->M0;
80 iree_uk_int16_t N0 = params->N0;
81 iree_uk_int16_t K0 = params->K0;
82 // Initialize the local accumulator tile.
83 float acc[iree_uk_mmt4d_tile_generic_max_bytes / sizeof(*out_tile)];
Lun Dong4e239a02023-04-25 09:32:17 -070084 if (flags & IREE_UK_FLAG_MMT4D_ACCUMULATE) {
Lun Dongfc805b32023-02-01 09:52:49 -080085 for (int i = 0; i < M0 * N0; ++i) acc[i] = out_tile[i];
86 } else {
87 for (int i = 0; i < M0 * N0; ++i) acc[i] = 0;
88 }
89 // Accumulation loop.
Lun Dong8914dd62023-06-02 12:58:22 -070090 for (iree_uk_index_t k = 0; k < K; ++k) {
91 for (iree_uk_index_t i0 = 0; i0 < M0; ++i0) {
92 for (iree_uk_index_t j0 = 0; j0 < N0; ++j0) {
93 for (iree_uk_index_t k0 = 0; k0 < K0; ++k0) {
Lun Dongfc805b32023-02-01 09:52:49 -080094 float lhs_val = lhs_panel[i0 * K0 + k0];
95 float rhs_val = rhs_panel[j0 * K0 + k0];
96 acc[i0 * N0 + j0] += lhs_val * rhs_val;
97 }
98 }
99 }
100 lhs_panel += M0 * K0;
101 rhs_panel += N0 * K0;
102 }
103 // Store the local accumulator tile to the destination.
104 for (int i = 0; i < M0 * N0; ++i) out_tile[i] = acc[i];
105}
106
107iree_uk_mmt4d_tile_func_t iree_uk_mmt4d_select_tile_func(
108 const iree_uk_mmt4d_params_t* params) {
109 // TODO(lundong): to be replaced with Kelvin
Lun Dong4e239a02023-04-25 09:32:17 -0700110 switch (iree_uk_mmt4d_type(params->flags)) {
Lun Dongfc805b32023-02-01 09:52:49 -0800111 case iree_uk_mmt4d_type_f32f32f32:
112 return iree_uk_mmt4d_tile_f32f32f32_generic;
113 case iree_uk_mmt4d_type_i8i8i32:
114 return iree_uk_mmt4d_tile_i8i8i32_rvv;
115 default:
116 // shouldn't happen, validated earlier.
117 IREE_UK_ASSUME_UNREACHABLE;
118 return 0;
119 }
120}