[CPU] Remove 8x8x16 i8mm microkernel
diff --git a/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_entry_point.c b/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_entry_point.c
index 6c71c4d..76d140d 100644
--- a/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_entry_point.c
+++ b/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_entry_point.c
@@ -286,8 +286,6 @@
return iree_uk_mmt4d_tile_s8s4s32_2x8x16_arm_64_i8mm;
case 4:
return iree_uk_mmt4d_tile_s8s4s32_4x8x16_arm_64_i8mm;
- case 8:
- return iree_uk_mmt4d_tile_s8s4s32_8x8x16_arm_64_i8mm;
}
}
#endif
diff --git a/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_i8mm.c b/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_i8mm.c
index 218612a..8582e07 100644
--- a/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_i8mm.c
+++ b/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_i8mm.c
@@ -171,11 +171,12 @@
}
IREE_UK_ATTRIBUTE_ALWAYS_INLINE static inline void
-iree_uk_mmt4d_tile_s8s4s32_2x8x16_to_8x8x16_arm_64_i8mm(
+iree_uk_mmt4d_tile_s8s4s32_2x8x16_to_4x8x16_arm_64_i8mm(
void* IREE_UK_RESTRICT out_tile, const void* IREE_UK_RESTRICT lhs_panel,
const void* IREE_UK_RESTRICT rhs_panel,
const iree_uk_mmt4d_params_t* params, int M0) {
- IREE_UK_ASSERT(M0 >= 2 && M0 <= 8 && iree_uk_is_po2_u32(M0));
+ // We support M0 up to 4 in order to fit within the register budget.
+ IREE_UK_ASSERT(M0 >= 2 && M0 <= 4 && iree_uk_is_po2_u32(M0));
IREE_UK_ASSERT(!(params->K0 % 16));
const iree_uk_int8_t* IREE_UK_RESTRICT lhs_ptr = lhs_panel;
const iree_uk_int8_t* IREE_UK_RESTRICT rhs_ptr = rhs_panel;
@@ -184,7 +185,7 @@
const int8x16_t vmask = vmovq_n_s8(0xF0);
const int mtiles = M0 / 2;
- int32x4_t acc[4][4];
+ int32x4_t acc[2][4];
IREE_UK_UNROLL for (int i = 0; i < mtiles; i++) {
IREE_UK_UNROLL for (int j = 0; j < 4; j++) {
// We start with zero accumulators and add the value of *out_ptr later.
@@ -202,7 +203,7 @@
}
rhs_ptr += 64;
- int8x16_t lhs[2][4];
+ int8x16_t lhs[2][2];
if (M0 == 2) {
int8x8x2_t lhs_uzp[2];
IREE_UK_UNROLL for (int i = 0; i < 2; i++) {
@@ -210,15 +211,16 @@
}
lhs[0][0] = vcombine_s8(lhs_uzp[0].val[0], lhs_uzp[1].val[0]);
lhs[1][0] = vcombine_s8(lhs_uzp[0].val[1], lhs_uzp[1].val[1]);
+ lhs_ptr += 32;
} else {
- IREE_UK_UNROLL for (int i = 0; i < mtiles; i++) {
+ IREE_UK_UNROLL for (int i = 0; i < 2; i++) {
int8x8x2_t lhs_0 = vld2_s8(lhs_ptr + 16 * 2 * i);
int8x8x2_t lhs_1 = vld2_s8(lhs_ptr + 16 * (2 * i + 1));
lhs[0][i] = vcombine_s8(lhs_0.val[0], lhs_1.val[0]);
lhs[1][i] = vcombine_s8(lhs_0.val[1], lhs_1.val[1]);
}
+ lhs_ptr += 64;
}
- lhs_ptr += 32 * mtiles;
IREE_UK_UNROLL for (int i = 0; i < mtiles; i++) {
IREE_UK_UNROLL for (int j = 0; j < 4; j++) {
@@ -255,11 +257,8 @@
}
IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
- iree_uk_mmt4d_tile_s8s4s32_2x8x16_to_8x8x16_arm_64_i8mm,
+ iree_uk_mmt4d_tile_s8s4s32_2x8x16_to_4x8x16_arm_64_i8mm,
iree_uk_mmt4d_tile_s8s4s32_2x8x16_arm_64_i8mm, 2)
IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
- iree_uk_mmt4d_tile_s8s4s32_2x8x16_to_8x8x16_arm_64_i8mm,
+ iree_uk_mmt4d_tile_s8s4s32_2x8x16_to_4x8x16_arm_64_i8mm,
iree_uk_mmt4d_tile_s8s4s32_4x8x16_arm_64_i8mm, 4)
-IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
- iree_uk_mmt4d_tile_s8s4s32_2x8x16_to_8x8x16_arm_64_i8mm,
- iree_uk_mmt4d_tile_s8s4s32_8x8x16_arm_64_i8mm, 8)
diff --git a/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_internal.h b/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_internal.h
index 8ecc43b..f21a0f1 100644
--- a/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_internal.h
+++ b/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_internal.h
@@ -59,6 +59,5 @@
IREE_UK_MMT4D_TILE_FUNC_DECL(iree_uk_mmt4d_tile_s8s4s32_1x8x16_arm_64_i8mm)
IREE_UK_MMT4D_TILE_FUNC_DECL(iree_uk_mmt4d_tile_s8s4s32_2x8x16_arm_64_i8mm)
IREE_UK_MMT4D_TILE_FUNC_DECL(iree_uk_mmt4d_tile_s8s4s32_4x8x16_arm_64_i8mm)
-IREE_UK_MMT4D_TILE_FUNC_DECL(iree_uk_mmt4d_tile_s8s4s32_8x8x16_arm_64_i8mm)
#endif // IREE_BUILTINS_UKERNEL_ARCH_ARM_64_MMT4D_ARM_64_INTERNAL_H_
diff --git a/runtime/src/iree/builtins/ukernel/tools/mmt4d_benchmark.c b/runtime/src/iree/builtins/ukernel/tools/mmt4d_benchmark.c
index a3c0e9c..bafde05 100644
--- a/runtime/src/iree/builtins/ukernel/tools/mmt4d_benchmark.c
+++ b/runtime/src/iree/builtins/ukernel/tools/mmt4d_benchmark.c
@@ -149,7 +149,7 @@
"");
iree_uk_benchmark_register_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S8S4S32, 8, 8, 8,
"dotprod");
- iree_uk_benchmark_register_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S8S4S32, 8, 8, 16,
+ iree_uk_benchmark_register_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S8S4S32, 4, 8, 16,
"i8mm");
#elif defined(IREE_ARCH_X86_64)
iree_uk_benchmark_register_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_F32F32F32, 8, 8, 1,
diff --git a/runtime/src/iree/builtins/ukernel/tools/mmt4d_test.c b/runtime/src/iree/builtins/ukernel/tools/mmt4d_test.c
index 517f0ae..eb17204 100644
--- a/runtime/src/iree/builtins/ukernel/tools/mmt4d_test.c
+++ b/runtime/src/iree/builtins/ukernel/tools/mmt4d_test.c
@@ -571,7 +571,7 @@
iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S8S8S32, 8, 8, 4, "dotprod");
iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S8S8S32, 8, 8, 8, "i8mm");
iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S8S4S32, 8, 8, 8, "dotprod");
- iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S8S4S32, 8, 8, 16, "i8mm");
+ iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S8S4S32, 4, 8, 16, "i8mm");
#elif defined(IREE_ARCH_X86_64)