sw/vec_iree: Fix ukernel breakage due to latest IREE ukernel refactoring
There was a new major refactoring on upstream IREE ukernel.
This is the minimum change needed for fixing the breakage due to it.
Change-Id: If322e74256a0f48436cd401e91104cfaa1519280
diff --git a/ukernel/importer.c b/ukernel/importer.c
index cf0215d..facd654 100644
--- a/ukernel/importer.c
+++ b/ukernel/importer.c
@@ -14,14 +14,14 @@
* limitations under the License.
*/
-#include "iree/builtins/ukernel/api.h"
+#include "iree/builtins/ukernel/mmt4d_internal.h"
#include "iree/hal/local/executable_loader.h"
#include "iree/hal/local/executable_plugin.h"
// Importer entry point wrapping the actual ukernel.
static int iree_uk_importer_mmt4d(void* params_ptr, void* context,
void* reserved) {
- iree_uk_mmt4d((const iree_uk_mmt4d_params_t*)params_ptr);
+ iree_uk_mmt4d_p((const iree_uk_mmt4d_params_t*)params_ptr);
return 0;
}
diff --git a/ukernel/mmt4d_tile.c b/ukernel/mmt4d_tile.c
index 93c5ee6..a8f830c 100644
--- a/ukernel/mmt4d_tile.c
+++ b/ukernel/mmt4d_tile.c
@@ -490,32 +490,24 @@
iree_uk_int16_t M0 = params->M0;
iree_uk_int16_t N0 = params->N0;
iree_uk_int16_t K0 = params->K0;
- // Initialize the local accumulator tile.
- float acc[iree_uk_mmt4d_tile_generic_max_bytes / sizeof(*out_tile)];
- if (params->flags & IREE_UK_FLAG_MMT4D_ACCUMULATE) {
- for (int i = 0; i < M0 * N0; ++i) acc[i] = out_tile[i];
- } else {
- for (int i = 0; i < M0 * N0; ++i) acc[i] = 0;
- }
- // Accumulation loop.
- for (iree_uk_index_t k = 0; k < params->K; ++k) {
- for (iree_uk_index_t i0 = 0; i0 < M0; ++i0) {
- for (iree_uk_index_t j0 = 0; j0 < N0; ++j0) {
+ for (iree_uk_index_t i0 = 0; i0 < M0; ++i0) {
+ for (iree_uk_index_t j0 = 0; j0 < N0; ++j0) {
+ float acc = (params->flags & IREE_UK_FLAG_MMT4D_ACCUMULATE)
+ ? out_tile[i0 * N0 + j0]
+ : 0.f;
+ for (iree_uk_index_t k = 0; k < params->K; ++k) {
for (iree_uk_index_t k0 = 0; k0 < K0; ++k0) {
- float lhs_f32 = lhs_panel[i0 * K0 + k0];
- float rhs_f32 = rhs_panel[j0 * K0 + k0];
- acc[i0 * N0 + j0] += lhs_f32 * rhs_f32;
+ float lhs_f32 = lhs_panel[k * M0 * K0 + i0 * K0 + k0];
+ float rhs_f32 = rhs_panel[k * N0 * K0 + j0 * K0 + k0];
+ acc += lhs_f32 * rhs_f32;
}
}
+ out_tile[i0 * N0 + j0] = acc;
}
- lhs_panel += M0 * K0;
- rhs_panel += N0 * K0;
}
- // Store the local accumulator tile to the destination.
- for (int i = 0; i < M0 * N0; ++i) out_tile[i] = acc[i];
}
-iree_uk_mmt4d_tile_func_t iree_uk_mmt4d_select_tile_func(
+iree_uk_mmt4d_tile_func_t iree_uk_mmt4d_select_tile_func_generic(
const iree_uk_mmt4d_params_t* params) {
switch (iree_uk_mmt4d_type(params->flags)) {
case iree_uk_mmt4d_type_s8s8s32: