[CPU] Add s8s4s32 i8mm ukernel (#16678)
Adds ukernel for `i8 * i4 -> i32` using the `ssmla` instruction
available in +i8mm. Uses 2 separate functions: one for M0=1 and another
for M0=2, M0=4, M0=8.
```
-----------------------------------------------------------------------------------------------------------
Benchmark Time CPU Iterations UserCounters...
-----------------------------------------------------------------------------------------------------------
BM_mmt4d_s8s4s32_tile_1x16x2/real_time 0.690 us 0.689 us 1048575 items_per_second=23.735G/s
BM_mmt4d_s8s4s32_tile_2x16x2/real_time 0.803 us 0.801 us 1048575 items_per_second=40.8248G/s
BM_mmt4d_s8s4s32_tile_4x16x2/real_time 1.68 us 1.68 us 524287 items_per_second=38.9081G/s
BM_mmt4d_s8s4s32_tile_1x8x8_dotprod/real_time 0.640 us 0.638 us 1048575 items_per_second=51.185G/s
BM_mmt4d_s8s4s32_tile_2x8x8_dotprod/real_time 0.369 us 0.368 us 2097151 items_per_second=177.561G/s
BM_mmt4d_s8s4s32_tile_4x8x8_dotprod/real_time 0.540 us 0.539 us 1048575 items_per_second=242.718G/s
BM_mmt4d_s8s4s32_tile_8x8x8_dotprod/real_time 0.934 us 0.932 us 1048575 items_per_second=280.543G/s
BM_mmt4d_s8s4s32_tile_1x8x16_i8mm/real_time 0.451 us 0.450 us 2097151 items_per_second=145.412G/s
BM_mmt4d_s8s4s32_tile_2x8x16_i8mm/real_time 0.509 us 0.509 us 1048575 items_per_second=257.577G/s
BM_mmt4d_s8s4s32_tile_4x8x16_i8mm/real_time 0.806 us 0.805 us 1048575 items_per_second=325.353G/s
BM_mmt4d_s8s4s32_tile_8x8x16_i8mm/real_time 1.76 us 1.76 us 524287 items_per_second=297.277G/s
```
diff --git a/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_entry_point.c b/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_entry_point.c
index 61ad8f9..6c71c4d 100644
--- a/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_entry_point.c
+++ b/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_entry_point.c
@@ -274,6 +274,26 @@
return 0;
}
+static iree_uk_mmt4d_tile_func_t
+iree_uk_mmt4d_select_tile_func_arm_64_i8i4i32_M0x8x16(
+ const iree_uk_mmt4d_params_t* params) {
+#ifdef IREE_UK_BUILD_ARM_64_I8MM
+ if (iree_uk_cpu_supports_i8mm(params->cpu_data)) {
+ switch (params->M0) {
+ case 1:
+ return iree_uk_mmt4d_tile_s8s4s32_1x8x16_arm_64_i8mm;
+ case 2:
+ return iree_uk_mmt4d_tile_s8s4s32_2x8x16_arm_64_i8mm;
+ case 4:
+ return iree_uk_mmt4d_tile_s8s4s32_4x8x16_arm_64_i8mm;
+ case 8:
+ return iree_uk_mmt4d_tile_s8s4s32_8x8x16_arm_64_i8mm;
+ }
+ }
+#endif
+ return 0;
+}
+
static iree_uk_mmt4d_tile_func_t iree_uk_mmt4d_select_tile_func_arm_64_i8i4i32(
const iree_uk_mmt4d_params_t* params) {
if (params->N0 == 16 && params->K0 == 2) {
@@ -282,6 +302,9 @@
if (params->N0 == 8 && params->K0 == 8) {
return iree_uk_mmt4d_select_tile_func_arm_64_i8i4i32_M0x8x8(params);
}
+ if (params->N0 == 8 && params->K0 == 16) {
+ return iree_uk_mmt4d_select_tile_func_arm_64_i8i4i32_M0x8x16(params);
+ }
return 0;
}
diff --git a/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_i8mm.c b/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_i8mm.c
index a62e80c..218612a 100644
--- a/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_i8mm.c
+++ b/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_i8mm.c
@@ -107,3 +107,159 @@
IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
iree_uk_mmt4d_tile_s8s8s32_1x8x8_to_8x8x8_arm_64_i8mm,
iree_uk_mmt4d_tile_s8s8s32_8x8x8_arm_64_i8mm, 8)
+
+// In the s8s4s32 kernels below, we unpack int4s into individual int8s.
+// To preserve signedness, int4s are moved to the upper 4-bits of each byte.
+// This has the effect of multiplying each int4 by 2^4 = 16. To compensate,
+// we divide the accumulator values by 16 before storing to memory.
+// This int4 conversion trick is borrowed from the `qd8-f32-qc4w-gemm*`
+// kernels in https://github.com/google/XNNPACK.
+
+IREE_UK_ATTRIBUTE_ALWAYS_INLINE inline void
+iree_uk_mmt4d_tile_s8s4s32_1x8x16_arm_64_i8mm(
+ void* IREE_UK_RESTRICT out_tile, const void* IREE_UK_RESTRICT lhs_panel,
+ const void* IREE_UK_RESTRICT rhs_panel,
+ const iree_uk_mmt4d_params_t* params) {
+ IREE_UK_ASSERT(!(params->K0 % 16));
+ const iree_uk_int8_t* IREE_UK_RESTRICT lhs_ptr = lhs_panel;
+ const iree_uk_int8_t* IREE_UK_RESTRICT rhs_ptr = rhs_panel;
+ iree_uk_int32_t* IREE_UK_RESTRICT out_ptr = out_tile;
+
+ const int8x16_t vmask = vmovq_n_s8(0xF0);
+ const int8x8_t vzero = vmov_n_s8(0);
+
+ int32x4_t acc[4];
+ IREE_UK_UNROLL for (int i = 0; i < 4; i++) {
+ // We start with zero accumulators and add the value of *out_ptr later.
+ // This is required for the int4 left shift described above.
+ acc[i] = vdupq_n_s32(0);
+ }
+
+ for (int k = 0; k < params->K; ++k) {
+ int8x16_t rhs[2][4];
+ IREE_UK_UNROLL for (int i = 0; i < 4; i++) {
+ int8x16_t r = vld1q_s8(rhs_ptr + 16 * i);
+ rhs[0][i] = vshlq_n_s8(r, 4);
+ rhs[1][i] = vandq_s8(r, vmask);
+ }
+ rhs_ptr += 64;
+
+ int8x16_t lhs[2];
+ int8x8x2_t lhs_uzp = vld2_s8(lhs_ptr);
+ lhs_ptr += 16;
+ lhs[0] = vcombine_s8(lhs_uzp.val[0], vzero);
+ lhs[1] = vcombine_s8(lhs_uzp.val[1], vzero);
+
+ IREE_UK_UNROLL for (int i = 0; i < 4; i++) {
+ acc[i] = vmmlaq_s32(acc[i], lhs[0], rhs[0][i]);
+ acc[i] = vmmlaq_s32(acc[i], lhs[1], rhs[1][i]);
+ }
+ }
+
+ IREE_UK_UNROLL for (int j = 0; j < 2; j++) {
+ acc[2 * j + 0] = vshrq_n_s32(acc[2 * j + 0], 4);
+ acc[2 * j + 1] = vshrq_n_s32(acc[2 * j + 1], 4);
+
+ int32x4_t acc_1x4_0 =
+ iree_uk_neon_uzp1_s32_as_s64(acc[2 * j + 0], acc[2 * j + 1]);
+ if (params->flags & IREE_UK_FLAG_MMT4D_ACCUMULATE) {
+ int32x4_t existing_acc = vld1q_s32(out_ptr + 4 * j);
+ acc_1x4_0 = vaddq_s32(acc_1x4_0, existing_acc);
+ }
+ vst1q_s32(out_ptr + 4 * j, acc_1x4_0);
+ }
+}
+
+IREE_UK_ATTRIBUTE_ALWAYS_INLINE static inline void
+iree_uk_mmt4d_tile_s8s4s32_2x8x16_to_8x8x16_arm_64_i8mm(
+ void* IREE_UK_RESTRICT out_tile, const void* IREE_UK_RESTRICT lhs_panel,
+ const void* IREE_UK_RESTRICT rhs_panel,
+ const iree_uk_mmt4d_params_t* params, int M0) {
+ IREE_UK_ASSERT(M0 >= 2 && M0 <= 8 && iree_uk_is_po2_u32(M0));
+ IREE_UK_ASSERT(!(params->K0 % 16));
+ const iree_uk_int8_t* IREE_UK_RESTRICT lhs_ptr = lhs_panel;
+ const iree_uk_int8_t* IREE_UK_RESTRICT rhs_ptr = rhs_panel;
+ iree_uk_int32_t* IREE_UK_RESTRICT out_ptr = out_tile;
+
+ const int8x16_t vmask = vmovq_n_s8(0xF0);
+ const int mtiles = M0 / 2;
+
+ int32x4_t acc[4][4];
+ IREE_UK_UNROLL for (int i = 0; i < mtiles; i++) {
+ IREE_UK_UNROLL for (int j = 0; j < 4; j++) {
+ // We start with zero accumulators and add the value of *out_ptr later.
+ // This is required for the int4 left shift described above.
+ acc[i][j] = vdupq_n_s32(0);
+ }
+ }
+
+ for (int k = 0; k < params->K; ++k) {
+ int8x16_t rhs[2][4];
+ IREE_UK_UNROLL for (int i = 0; i < 4; i++) {
+ int8x16_t r = vld1q_s8(rhs_ptr + 16 * i);
+ rhs[0][i] = vshlq_n_s8(r, 4);
+ rhs[1][i] = vandq_s8(r, vmask);
+ }
+ rhs_ptr += 64;
+
+ int8x16_t lhs[2][4];
+ if (M0 == 2) {
+ int8x8x2_t lhs_uzp[2];
+ IREE_UK_UNROLL for (int i = 0; i < 2; i++) {
+ lhs_uzp[i] = vld2_s8(lhs_ptr + 16 * i);
+ }
+ lhs[0][0] = vcombine_s8(lhs_uzp[0].val[0], lhs_uzp[1].val[0]);
+ lhs[1][0] = vcombine_s8(lhs_uzp[0].val[1], lhs_uzp[1].val[1]);
+ } else {
+ IREE_UK_UNROLL for (int i = 0; i < mtiles; i++) {
+ int8x8x2_t lhs_0 = vld2_s8(lhs_ptr + 16 * 2 * i);
+ int8x8x2_t lhs_1 = vld2_s8(lhs_ptr + 16 * (2 * i + 1));
+ lhs[0][i] = vcombine_s8(lhs_0.val[0], lhs_1.val[0]);
+ lhs[1][i] = vcombine_s8(lhs_0.val[1], lhs_1.val[1]);
+ }
+ }
+ lhs_ptr += 32 * mtiles;
+
+ IREE_UK_UNROLL for (int i = 0; i < mtiles; i++) {
+ IREE_UK_UNROLL for (int j = 0; j < 4; j++) {
+ IREE_UK_UNROLL for (int m = 0; m < 2; m++) {
+ acc[i][j] = vmmlaq_s32(acc[i][j], lhs[m][i], rhs[m][j]);
+ }
+ }
+ }
+ }
+
+ // Swizzle accumulator 2x2 register tiles back to row-major and store.
+ IREE_UK_UNROLL for (int i = 0; i < mtiles; i++) {
+ IREE_UK_UNROLL for (int j = 0; j < 2; j++) {
+ acc[i][2 * j + 0] = vshrq_n_s32(acc[i][2 * j + 0], 4);
+ acc[i][2 * j + 1] = vshrq_n_s32(acc[i][2 * j + 1], 4);
+
+ int32x4_t acc_1x4_0 =
+ iree_uk_neon_uzp1_s32_as_s64(acc[i][2 * j + 0], acc[i][2 * j + 1]);
+ if (params->flags & IREE_UK_FLAG_MMT4D_ACCUMULATE) {
+ int32x4_t existing_acc = vld1q_s32(out_ptr + 8 * (2 * i + 0) + 4 * j);
+ acc_1x4_0 = vaddq_s32(acc_1x4_0, existing_acc);
+ }
+ vst1q_s32(out_ptr + 8 * (2 * i + 0) + 4 * j, acc_1x4_0);
+
+ int32x4_t acc_1x4_1 =
+ iree_uk_neon_uzp2_s32_as_s64(acc[i][2 * j + 0], acc[i][2 * j + 1]);
+ if (params->flags & IREE_UK_FLAG_MMT4D_ACCUMULATE) {
+ int32x4_t existing_acc = vld1q_s32(out_ptr + 8 * (2 * i + 1) + 4 * j);
+ acc_1x4_1 = vaddq_s32(acc_1x4_1, existing_acc);
+ }
+ vst1q_s32(out_ptr + 8 * (2 * i + 1) + 4 * j, acc_1x4_1);
+ }
+ }
+}
+
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
+ iree_uk_mmt4d_tile_s8s4s32_2x8x16_to_8x8x16_arm_64_i8mm,
+ iree_uk_mmt4d_tile_s8s4s32_2x8x16_arm_64_i8mm, 2)
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
+ iree_uk_mmt4d_tile_s8s4s32_2x8x16_to_8x8x16_arm_64_i8mm,
+ iree_uk_mmt4d_tile_s8s4s32_4x8x16_arm_64_i8mm, 4)
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
+ iree_uk_mmt4d_tile_s8s4s32_2x8x16_to_8x8x16_arm_64_i8mm,
+ iree_uk_mmt4d_tile_s8s4s32_8x8x16_arm_64_i8mm, 8)
diff --git a/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_internal.h b/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_internal.h
index 3d3e19d..8ecc43b 100644
--- a/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_internal.h
+++ b/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_internal.h
@@ -56,5 +56,9 @@
IREE_UK_MMT4D_TILE_FUNC_DECL(iree_uk_mmt4d_tile_s8s4s32_2x8x8_arm_64_dotprod)
IREE_UK_MMT4D_TILE_FUNC_DECL(iree_uk_mmt4d_tile_s8s4s32_4x8x8_arm_64_dotprod)
IREE_UK_MMT4D_TILE_FUNC_DECL(iree_uk_mmt4d_tile_s8s4s32_8x8x8_arm_64_dotprod)
+IREE_UK_MMT4D_TILE_FUNC_DECL(iree_uk_mmt4d_tile_s8s4s32_1x8x16_arm_64_i8mm)
+IREE_UK_MMT4D_TILE_FUNC_DECL(iree_uk_mmt4d_tile_s8s4s32_2x8x16_arm_64_i8mm)
+IREE_UK_MMT4D_TILE_FUNC_DECL(iree_uk_mmt4d_tile_s8s4s32_4x8x16_arm_64_i8mm)
+IREE_UK_MMT4D_TILE_FUNC_DECL(iree_uk_mmt4d_tile_s8s4s32_8x8x16_arm_64_i8mm)
#endif // IREE_BUILTINS_UKERNEL_ARCH_ARM_64_MMT4D_ARM_64_INTERNAL_H_
diff --git a/runtime/src/iree/builtins/ukernel/tools/mmt4d_benchmark.c b/runtime/src/iree/builtins/ukernel/tools/mmt4d_benchmark.c
index bfb932e..a3c0e9c 100644
--- a/runtime/src/iree/builtins/ukernel/tools/mmt4d_benchmark.c
+++ b/runtime/src/iree/builtins/ukernel/tools/mmt4d_benchmark.c
@@ -149,6 +149,8 @@
"");
iree_uk_benchmark_register_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S8S4S32, 8, 8, 8,
"dotprod");
+ iree_uk_benchmark_register_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S8S4S32, 8, 8, 16,
+ "i8mm");
#elif defined(IREE_ARCH_X86_64)
iree_uk_benchmark_register_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_F32F32F32, 8, 8, 1,
"avx2_fma");
diff --git a/runtime/src/iree/builtins/ukernel/tools/mmt4d_test.c b/runtime/src/iree/builtins/ukernel/tools/mmt4d_test.c
index b7e9593..517f0ae 100644
--- a/runtime/src/iree/builtins/ukernel/tools/mmt4d_test.c
+++ b/runtime/src/iree/builtins/ukernel/tools/mmt4d_test.c
@@ -569,8 +569,9 @@
iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_BF16BF16BF16, 8, 8, 4, "bf16");
iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_BF16BF16F32, 8, 8, 4, "bf16");
iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S8S8S32, 8, 8, 4, "dotprod");
- iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S8S4S32, 8, 8, 8, "dotprod");
iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S8S8S32, 8, 8, 8, "i8mm");
+ iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S8S4S32, 8, 8, 8, "dotprod");
+ iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S8S4S32, 8, 8, 16, "i8mm");
#elif defined(IREE_ARCH_X86_64)