sw:vec_iree: Fix microkernel breakage Fix the ukernel breakage due to https://github.com/iree-org/iree/pull/12015. Change-Id: Id34c67b06f25ebd13445a78e29b277833d105493
diff --git a/vmvx_ukernel/CMakeLists.txt b/vmvx_ukernel/CMakeLists.txt index 3b9c0e8..ccf6791 100644 --- a/vmvx_ukernel/CMakeLists.txt +++ b/vmvx_ukernel/CMakeLists.txt
@@ -2,38 +2,22 @@ iree_cc_library( NAME - arch - HDRS - "${IREE_RUNTIME_SOURCE_DIR}/builtins/ukernel/arch/mmt4d_arch.h" - "${IREE_RUNTIME_SOURCE_DIR}/builtins/ukernel/arch/pack_arch.h" - "${IREE_RUNTIME_SOURCE_DIR}/builtins/ukernel/arch/query_tile_sizes_arch.h" - SRCS - "mmt4d_arch.c" - "pack_arch.c" - "query_tile_sizes_arch.c" - DEPS - iree::builtins::ukernel::headers - PUBLIC -) - -iree_cc_library( - NAME ukernel HDRS "${IREE_RUNTIME_SOURCE_DIR}/builtins/ukernel/api.h" - "${IREE_RUNTIME_SOURCE_DIR}/builtins/ukernel/mmt4d_generic.h" - "${IREE_RUNTIME_SOURCE_DIR}/builtins/ukernel/pack_generic.h" + "${IREE_RUNTIME_SOURCE_DIR}/builtins/ukernel/mmt4d_tile.h" + "${IREE_RUNTIME_SOURCE_DIR}/builtins/ukernel/pack_tile.h" SRCS "elementwise.c" "elementwise_impl.c.inc" + "mmt4d_tile.c" "${IREE_RUNTIME_SOURCE_DIR}/builtins/ukernel/mmt4d.c" - "${IREE_RUNTIME_SOURCE_DIR}/builtins/ukernel/mmt4d_generic.c" "${IREE_RUNTIME_SOURCE_DIR}/builtins/ukernel/pack.c" - "${IREE_RUNTIME_SOURCE_DIR}/builtins/ukernel/pack_generic.c" + "${IREE_RUNTIME_SOURCE_DIR}/builtins/ukernel/pack_tile.c" "${IREE_RUNTIME_SOURCE_DIR}/builtins/ukernel/query_tile_sizes.c" "${IREE_RUNTIME_SOURCE_DIR}/builtins/ukernel/unpack.c" DEPS - ::arch + iree::builtins::ukernel::headers PUBLIC )
diff --git a/vmvx_ukernel/mmt4d_arch.c b/vmvx_ukernel/mmt4d_arch.c deleted file mode 100644 index 92e19ce..0000000 --- a/vmvx_ukernel/mmt4d_arch.c +++ /dev/null
@@ -1,78 +0,0 @@ -/* - * Copyright 2022 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "iree/builtins/ukernel/arch/mmt4d_arch.h" - -#include <riscv_vector.h> -#include <string.h> - -// Calculate the dot product of two int8 vectors using RVV -static iree_uk_int32_t dot_product_rvv(const iree_uk_int8_t* u, - const iree_uk_int8_t* w, int n) { - size_t vl; - // auxiliary variables - vint8m4_t vu, vw; - vint16m8_t vx; - vint32m1_t v_sum; - iree_uk_int32_t sum = 0; - for (size_t i = 0; i < n; i += vl) { - vl = vsetvl_e8m4(n - i); - vu = vle8_v_i8m4(u + i, vl); // load - vw = vle8_v_i8m4(w + i, vl); // load - vx = vwmul(vu, vw, vl); // multiply - v_sum = vmv_s(v_sum, 0, vl); // init - v_sum = vwredsum(v_sum, vx, v_sum, vl); // sum - sum += vmv_x(v_sum); - } - return sum; -} - -// RVV implementation of matmul tile, i8*i8->i32 case. -static void iree_uk_mmt4d_tile_i8i8i32_rvv( - void* out_tile_untyped, const void* lhs_panel_untyped, - const void* rhs_panel_untyped, iree_uk_int32_t K, iree_uk_uint32_t flags, - const iree_uk_mmt4d_params_t* params) { - iree_uk_int32_t* out_tile = out_tile_untyped; - const iree_uk_int8_t* lhs_panel = lhs_panel_untyped; - const iree_uk_int8_t* rhs_panel = rhs_panel_untyped; - iree_uk_int16_t M0 = params->M0; - iree_uk_int16_t N0 = params->N0; - iree_uk_int16_t K0 = params->K0; - // Initialize the accumulator tile. - if (!(flags & IREE_UK_FLAG_ACCUMULATE)) { - memset(out_tile, 0, M0 * N0 * sizeof(iree_uk_int32_t)); - } - // Accumulation loop. - for (iree_uk_ssize_t k = 0; k < K; ++k) { - for (iree_uk_ssize_t i0 = 0; i0 < M0; ++i0) { - for (iree_uk_ssize_t j0 = 0; j0 < N0; ++j0) { - out_tile[i0 * N0 + j0] += - dot_product_rvv(lhs_panel + i0 * K0, rhs_panel + j0 * K0, K0); - } - } - lhs_panel += M0 * K0; - rhs_panel += N0 * K0; - } -} - -iree_uk_mmt4d_tile_func_t iree_uk_mmt4d_select_tile_func_arch( - const iree_uk_mmt4d_params_t* params) { - // TODO(lundong): to be replaced with Kelvin - if (params->type == iree_uk_mmt4d_type_i8i8i32) { - return iree_uk_mmt4d_tile_i8i8i32_rvv; - } - return 0; -}
diff --git a/vmvx_ukernel/mmt4d_tile.c b/vmvx_ukernel/mmt4d_tile.c new file mode 100644 index 0000000..a591c66 --- /dev/null +++ b/vmvx_ukernel/mmt4d_tile.c
@@ -0,0 +1,120 @@ +/* + * Copyright 2023 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "iree/builtins/ukernel/mmt4d_tile.h" + +#include <riscv_vector.h> +#include <string.h> + +// Calculate the dot product of two int8 vectors using RVV +static iree_uk_int32_t dot_product_rvv(const iree_uk_int8_t* u, + const iree_uk_int8_t* w, int n) { + size_t vl; + // auxiliary variables + vint8m4_t vu, vw; + vint16m8_t vx; + vint32m1_t v_sum; + iree_uk_int32_t sum = 0; + for (size_t i = 0; i < n; i += vl) { + vl = vsetvl_e8m4(n - i); + vu = vle8_v_i8m4(u + i, vl); // load + vw = vle8_v_i8m4(w + i, vl); // load + vx = vwmul(vu, vw, vl); // multiply + v_sum = vmv_s(v_sum, 0, vl); // init + v_sum = vwredsum(v_sum, vx, v_sum, vl); // sum + sum += vmv_x(v_sum); + } + return sum; +} + +// RVV implementation of matmul tile, i8*i8->i32 case. +static void iree_uk_mmt4d_tile_i8i8i32_rvv( + void* out_tile_untyped, const void* lhs_panel_untyped, + const void* rhs_panel_untyped, iree_uk_int32_t K, iree_uk_uint32_t flags, + const iree_uk_mmt4d_params_t* params) { + iree_uk_int32_t* out_tile = out_tile_untyped; + const iree_uk_int8_t* lhs_panel = lhs_panel_untyped; + const iree_uk_int8_t* rhs_panel = rhs_panel_untyped; + iree_uk_int16_t M0 = params->M0; + iree_uk_int16_t N0 = params->N0; + iree_uk_int16_t K0 = params->K0; + // Initialize the accumulator tile. + if (!(flags & IREE_UK_FLAG_ACCUMULATE)) { + memset(out_tile, 0, M0 * N0 * sizeof(iree_uk_int32_t)); + } + // Accumulation loop. + for (iree_uk_ssize_t k = 0; k < K; ++k) { + for (iree_uk_ssize_t i0 = 0; i0 < M0; ++i0) { + for (iree_uk_ssize_t j0 = 0; j0 < N0; ++j0) { + out_tile[i0 * N0 + j0] += + dot_product_rvv(lhs_panel + i0 * K0, rhs_panel + j0 * K0, K0); + } + } + lhs_panel += M0 * K0; + rhs_panel += N0 * K0; + } +} + +// Generic implementation of matmul tile, f32*f32->f32 case. +static void iree_uk_mmt4d_tile_f32f32f32_generic( + void* out_tile_untyped, const void* lhs_panel_untyped, + const void* rhs_panel_untyped, iree_uk_int32_t K, iree_uk_uint32_t flags, + const iree_uk_mmt4d_params_t* params) { + float* out_tile = out_tile_untyped; + const float* lhs_panel = lhs_panel_untyped; + const float* rhs_panel = rhs_panel_untyped; + iree_uk_int16_t M0 = params->M0; + iree_uk_int16_t N0 = params->N0; + iree_uk_int16_t K0 = params->K0; + // Initialize the local accumulator tile. + float acc[iree_uk_mmt4d_tile_generic_max_bytes / sizeof(*out_tile)]; + if (flags & IREE_UK_FLAG_ACCUMULATE) { + for (int i = 0; i < M0 * N0; ++i) acc[i] = out_tile[i]; + } else { + for (int i = 0; i < M0 * N0; ++i) acc[i] = 0; + } + // Accumulation loop. + for (iree_uk_ssize_t k = 0; k < K; ++k) { + for (iree_uk_ssize_t i0 = 0; i0 < M0; ++i0) { + for (iree_uk_ssize_t j0 = 0; j0 < N0; ++j0) { + for (iree_uk_ssize_t k0 = 0; k0 < K0; ++k0) { + float lhs_val = lhs_panel[i0 * K0 + k0]; + float rhs_val = rhs_panel[j0 * K0 + k0]; + acc[i0 * N0 + j0] += lhs_val * rhs_val; + } + } + } + lhs_panel += M0 * K0; + rhs_panel += N0 * K0; + } + // Store the local accumulator tile to the destination. + for (int i = 0; i < M0 * N0; ++i) out_tile[i] = acc[i]; +} + +iree_uk_mmt4d_tile_func_t iree_uk_mmt4d_select_tile_func( + const iree_uk_mmt4d_params_t* params) { + // TODO(lundong): to be replaced with Kelvin + switch (params->type) { + case iree_uk_mmt4d_type_f32f32f32: + return iree_uk_mmt4d_tile_f32f32f32_generic; + case iree_uk_mmt4d_type_i8i8i32: + return iree_uk_mmt4d_tile_i8i8i32_rvv; + default: + // shouldn't happen, validated earlier. + IREE_UK_ASSUME_UNREACHABLE; + return 0; + } +}
diff --git a/vmvx_ukernel/pack_arch.c b/vmvx_ukernel/pack_arch.c deleted file mode 100644 index beb234e..0000000 --- a/vmvx_ukernel/pack_arch.c +++ /dev/null
@@ -1,23 +0,0 @@ -/* - * Copyright 2022 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "iree/builtins/ukernel/arch/pack_arch.h" - -iree_uk_pack_tile_func_t iree_uk_pack_select_tile_func_arch( - const iree_uk_pack_params_t* params) { - // TODO(lundong): to be replaced with Kelvin - return 0; -}
diff --git a/vmvx_ukernel/query_tile_sizes_arch.c b/vmvx_ukernel/query_tile_sizes_arch.c deleted file mode 100644 index a8d2db7..0000000 --- a/vmvx_ukernel/query_tile_sizes_arch.c +++ /dev/null
@@ -1,24 +0,0 @@ -/* - * Copyright 2023 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "iree/builtins/ukernel/arch/query_tile_sizes_arch.h" - -bool iree_uk_query_matmul_tile_sizes_arch( - const iree_uk_query_tile_sizes_2d_params_t* params, - iree_uk_matmul_tile_sizes_t* out_matmul_tile_sizes) { - // TODO(lundong): to be replaced with Kelvin logic. - return false; -}