Lun Dong | fc805b3 | 2023-02-01 09:52:49 -0800 | [diff] [blame] | 1 | /* |
| 2 | * Copyright 2023 Google LLC |
| 3 | * |
| 4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | * you may not use this file except in compliance with the License. |
| 6 | * You may obtain a copy of the License at |
| 7 | * |
| 8 | * http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | * |
| 10 | * Unless required by applicable law or agreed to in writing, software |
| 11 | * distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | * See the License for the specific language governing permissions and |
| 14 | * limitations under the License. |
| 15 | */ |
| 16 | |
Lun Dong | fc805b3 | 2023-02-01 09:52:49 -0800 | [diff] [blame] | 17 | #include <riscv_vector.h> |
| 18 | #include <string.h> |
| 19 | |
Lun Dong | f20445c | 2023-04-24 12:01:48 -0700 | [diff] [blame] | 20 | #include "iree/builtins/ukernel/mmt4d_internal.h" |
| 21 | |
Lun Dong | fc805b3 | 2023-02-01 09:52:49 -0800 | [diff] [blame] | 22 | // Calculate the dot product of two int8 vectors using RVV |
| 23 | static iree_uk_int32_t dot_product_rvv(const iree_uk_int8_t* u, |
| 24 | const iree_uk_int8_t* w, int n) { |
| 25 | size_t vl; |
| 26 | // auxiliary variables |
| 27 | vint8m4_t vu, vw; |
| 28 | vint16m8_t vx; |
| 29 | vint32m1_t v_sum; |
| 30 | iree_uk_int32_t sum = 0; |
| 31 | for (size_t i = 0; i < n; i += vl) { |
| 32 | vl = vsetvl_e8m4(n - i); |
| 33 | vu = vle8_v_i8m4(u + i, vl); // load |
| 34 | vw = vle8_v_i8m4(w + i, vl); // load |
| 35 | vx = vwmul(vu, vw, vl); // multiply |
| 36 | v_sum = vmv_s(v_sum, 0, vl); // init |
| 37 | v_sum = vwredsum(v_sum, vx, v_sum, vl); // sum |
| 38 | sum += vmv_x(v_sum); |
| 39 | } |
| 40 | return sum; |
| 41 | } |
| 42 | |
| 43 | // RVV implementation of matmul tile, i8*i8->i32 case. |
| 44 | static void iree_uk_mmt4d_tile_i8i8i32_rvv( |
| 45 | void* out_tile_untyped, const void* lhs_panel_untyped, |
| 46 | const void* rhs_panel_untyped, iree_uk_int32_t K, iree_uk_uint32_t flags, |
| 47 | const iree_uk_mmt4d_params_t* params) { |
| 48 | iree_uk_int32_t* out_tile = out_tile_untyped; |
| 49 | const iree_uk_int8_t* lhs_panel = lhs_panel_untyped; |
| 50 | const iree_uk_int8_t* rhs_panel = rhs_panel_untyped; |
| 51 | iree_uk_int16_t M0 = params->M0; |
| 52 | iree_uk_int16_t N0 = params->N0; |
| 53 | iree_uk_int16_t K0 = params->K0; |
| 54 | // Initialize the accumulator tile. |
Lun Dong | 4e239a0 | 2023-04-25 09:32:17 -0700 | [diff] [blame] | 55 | if (!(flags & IREE_UK_FLAG_MMT4D_ACCUMULATE)) { |
Lun Dong | fc805b3 | 2023-02-01 09:52:49 -0800 | [diff] [blame] | 56 | memset(out_tile, 0, M0 * N0 * sizeof(iree_uk_int32_t)); |
| 57 | } |
| 58 | // Accumulation loop. |
Lun Dong | 8914dd6 | 2023-06-02 12:58:22 -0700 | [diff] [blame] | 59 | for (iree_uk_index_t k = 0; k < K; ++k) { |
| 60 | for (iree_uk_index_t i0 = 0; i0 < M0; ++i0) { |
| 61 | for (iree_uk_index_t j0 = 0; j0 < N0; ++j0) { |
Lun Dong | fc805b3 | 2023-02-01 09:52:49 -0800 | [diff] [blame] | 62 | out_tile[i0 * N0 + j0] += |
| 63 | dot_product_rvv(lhs_panel + i0 * K0, rhs_panel + j0 * K0, K0); |
| 64 | } |
| 65 | } |
| 66 | lhs_panel += M0 * K0; |
| 67 | rhs_panel += N0 * K0; |
| 68 | } |
| 69 | } |
| 70 | |
| 71 | // Generic implementation of matmul tile, f32*f32->f32 case. |
| 72 | static void iree_uk_mmt4d_tile_f32f32f32_generic( |
| 73 | void* out_tile_untyped, const void* lhs_panel_untyped, |
| 74 | const void* rhs_panel_untyped, iree_uk_int32_t K, iree_uk_uint32_t flags, |
| 75 | const iree_uk_mmt4d_params_t* params) { |
| 76 | float* out_tile = out_tile_untyped; |
| 77 | const float* lhs_panel = lhs_panel_untyped; |
| 78 | const float* rhs_panel = rhs_panel_untyped; |
| 79 | iree_uk_int16_t M0 = params->M0; |
| 80 | iree_uk_int16_t N0 = params->N0; |
| 81 | iree_uk_int16_t K0 = params->K0; |
| 82 | // Initialize the local accumulator tile. |
| 83 | float acc[iree_uk_mmt4d_tile_generic_max_bytes / sizeof(*out_tile)]; |
Lun Dong | 4e239a0 | 2023-04-25 09:32:17 -0700 | [diff] [blame] | 84 | if (flags & IREE_UK_FLAG_MMT4D_ACCUMULATE) { |
Lun Dong | fc805b3 | 2023-02-01 09:52:49 -0800 | [diff] [blame] | 85 | for (int i = 0; i < M0 * N0; ++i) acc[i] = out_tile[i]; |
| 86 | } else { |
| 87 | for (int i = 0; i < M0 * N0; ++i) acc[i] = 0; |
| 88 | } |
| 89 | // Accumulation loop. |
Lun Dong | 8914dd6 | 2023-06-02 12:58:22 -0700 | [diff] [blame] | 90 | for (iree_uk_index_t k = 0; k < K; ++k) { |
| 91 | for (iree_uk_index_t i0 = 0; i0 < M0; ++i0) { |
| 92 | for (iree_uk_index_t j0 = 0; j0 < N0; ++j0) { |
| 93 | for (iree_uk_index_t k0 = 0; k0 < K0; ++k0) { |
Lun Dong | fc805b3 | 2023-02-01 09:52:49 -0800 | [diff] [blame] | 94 | float lhs_val = lhs_panel[i0 * K0 + k0]; |
| 95 | float rhs_val = rhs_panel[j0 * K0 + k0]; |
| 96 | acc[i0 * N0 + j0] += lhs_val * rhs_val; |
| 97 | } |
| 98 | } |
| 99 | } |
| 100 | lhs_panel += M0 * K0; |
| 101 | rhs_panel += N0 * K0; |
| 102 | } |
| 103 | // Store the local accumulator tile to the destination. |
| 104 | for (int i = 0; i < M0 * N0; ++i) out_tile[i] = acc[i]; |
| 105 | } |
| 106 | |
| 107 | iree_uk_mmt4d_tile_func_t iree_uk_mmt4d_select_tile_func( |
| 108 | const iree_uk_mmt4d_params_t* params) { |
| 109 | // TODO(lundong): to be replaced with Kelvin |
Lun Dong | 4e239a0 | 2023-04-25 09:32:17 -0700 | [diff] [blame] | 110 | switch (iree_uk_mmt4d_type(params->flags)) { |
Lun Dong | fc805b3 | 2023-02-01 09:52:49 -0800 | [diff] [blame] | 111 | case iree_uk_mmt4d_type_f32f32f32: |
| 112 | return iree_uk_mmt4d_tile_f32f32f32_generic; |
| 113 | case iree_uk_mmt4d_type_i8i8i32: |
| 114 | return iree_uk_mmt4d_tile_i8i8i32_rvv; |
| 115 | default: |
| 116 | // shouldn't happen, validated earlier. |
| 117 | IREE_UK_ASSUME_UNREACHABLE; |
| 118 | return 0; |
| 119 | } |
| 120 | } |