Add s8s4s32 dotprod microkernel (#16473)
Adds `i8 * i4 -> i32` microkernel that uses the `+dotprod` ARM CPU
feature. Supports tile sizes `M0xN0xK0`: `1x8x8`, `2x8x8`, `4x8x8`,
`8x8x8`. We use `K0=8` since `+dotprod` requires 4 contiguous elements
and 2 `i4s` are contained in a single byte (`2x4`).
This ukernel significantly departs from the XNNPack
[qd8-f32-qc4w-gemm-1x16c4-minmax-neondot.c](https://github.com/google/XNNPACK/blob/master/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x16c4-minmax-neondot.c)
kernel. The XNNPack version packs the LHS so that no de-interleaving is
required on load. It also uses different shapes (`1x8 * 8x16 --> 1x16`
vs `4x2 * 2x8 --> 4x8`). The int4 shift trick is re-used.
Microbenchmark results on Pixel 8 Pro (last four rows):
```
-----------------------------------------------------------------------------------------------------------
Benchmark Time CPU Iterations UserCounters...
-----------------------------------------------------------------------------------------------------------
BM_mmt4d_s8s4s32_tile_1x16x2/real_time 0.690 us 0.689 us 1048575 items_per_second=23.735G/s
BM_mmt4d_s8s4s32_tile_2x16x2/real_time 0.803 us 0.801 us 1048575 items_per_second=40.8248G/s
BM_mmt4d_s8s4s32_tile_4x16x2/real_time 1.68 us 1.68 us 524287 items_per_second=38.9081G/s
BM_mmt4d_s8s4s32_tile_1x8x8_dotprod/real_time 0.640 us 0.638 us 1048575 items_per_second=51.185G/s
BM_mmt4d_s8s4s32_tile_2x8x8_dotprod/real_time 0.369 us 0.368 us 2097151 items_per_second=177.561G/s
BM_mmt4d_s8s4s32_tile_4x8x8_dotprod/real_time 0.540 us 0.539 us 1048575 items_per_second=242.718G/s
BM_mmt4d_s8s4s32_tile_8x8x8_dotprod/real_time 0.934 us 0.932 us 1048575 items_per_second=280.543G/s
```
diff --git a/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_dotprod.c b/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_dotprod.c
index 8bf713c..e4c7b00 100644
--- a/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_dotprod.c
+++ b/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_dotprod.c
@@ -75,3 +75,136 @@
iree_uk_mmt4d_tile_s8s8s32_2x8x4_arm_64_dotprod,
iree_uk_mmt4d_tile_s8s8s32_4x8x4_arm_64_dotprod,
iree_uk_mmt4d_tile_s8s8s32_8x8x4_arm_64_dotprod)
+
+static inline void iree_uk_mmt4d_tile_s8s4s32_1x8x8_to_8x8x8_arm_64_dotprod(
+ void* IREE_UK_RESTRICT out_tile, const void* IREE_UK_RESTRICT lhs_panel,
+ const void* IREE_UK_RESTRICT rhs_panel,
+ const iree_uk_mmt4d_params_t* params, int M0) {
+ IREE_UK_ASSERT(M0 >= 1 && M0 <= 8 && iree_uk_is_po2_u32(M0));
+ const iree_uk_int8_t* IREE_UK_RESTRICT lhs_ptr = lhs_panel;
+ const iree_uk_int8_t* IREE_UK_RESTRICT rhs_ptr = rhs_panel;
+ iree_uk_int32_t* IREE_UK_RESTRICT out_ptr = out_tile;
+
+ const int8x16_t vmask = vmovq_n_s8(0xF0);
+ const int8x8_t vzero = vmov_n_s8(0);
+
+ int32x4_t acc[16];
+ for (int i = 0; i < 16; i++) {
+ // We start with zero accumulators and add the value of *out_ptr later.
+ // This is required for the int4 left shift described later.
+ acc[i] = vdupq_n_s32(0);
+ }
+
+ for (int k = 0; k < params->K; ++k) {
+ int8x16_t rhs[4];
+ for (int i = 0; i < 2; i++) {
+ int8x16_t r = vld1q_s8(rhs_ptr);
+ rhs_ptr += 16;
+ // We unpack int4s into individual int8s. To preserve signedness,
+ // int4s are moved to the upper 4-bits of each byte. This has the effect
+ // of multiplying each int4 by 2^4 = 16. To compensate, we divide the
+ // accumulator values by 16 before storing to memory.
+ // This int4 conversion trick is borrowed from the `qd8-f32-qc4w-gemm*`
+ // kernels in https://github.com/google/XNNPACK.
+ rhs[i + 0] = vshlq_n_s8(r, 4);
+ rhs[i + 2] = vandq_s8(r, vmask);
+ }
+
+ if (M0 >= 4) {
+ int8x8_t lhs[8];
+ if (M0 == 8) {
+ int8x16x2_t lhs_uzp_0 = vld2q_s8(lhs_ptr);
+ lhs_ptr += 32;
+ int8x16x2_t lhs_uzp_1 = vld2q_s8(lhs_ptr);
+ lhs_ptr += 32;
+ lhs[0] = vget_low_s8(lhs_uzp_0.val[0]);
+ lhs[1] = vget_low_s8(lhs_uzp_0.val[1]);
+ lhs[2] = vget_high_s8(lhs_uzp_0.val[0]);
+ lhs[3] = vget_high_s8(lhs_uzp_0.val[1]);
+ lhs[4] = vget_low_s8(lhs_uzp_1.val[0]);
+ lhs[5] = vget_high_s8(lhs_uzp_1.val[0]);
+ lhs[6] = vget_low_s8(lhs_uzp_1.val[1]);
+ lhs[7] = vget_high_s8(lhs_uzp_1.val[1]);
+ } else { // M0 = 4.
+ int8x16x2_t lhs_uzp = vld2q_s8(lhs_ptr);
+ lhs_ptr += 32;
+ lhs[0] = vget_low_s8(lhs_uzp.val[0]);
+ lhs[1] = vget_low_s8(lhs_uzp.val[1]);
+ lhs[2] = vget_high_s8(lhs_uzp.val[0]);
+ lhs[3] = vget_high_s8(lhs_uzp.val[1]);
+ }
+ // 4x8 * 8x8 -> 4x8.
+ for (int i = 0; i < 2; i++) {
+ for (int j = 0; j < 2; j++) {
+ for (int k = 0; k < 2; k++) {
+ acc[4 * k + j] = vdotq_lane_s32(acc[4 * k + j], rhs[2 * i + j],
+ lhs[2 * k + i], 0);
+ acc[4 * k + j + 2] = vdotq_lane_s32(
+ acc[4 * k + j + 2], rhs[2 * i + j], lhs[2 * k + i], 1);
+ }
+ }
+ }
+ if (M0 == 4) continue;
+ // 8x8 * 8x8 -> 8x8.
+ for (int i = 0; i < 2; i++) {
+ for (int j = 0; j < 2; j++) {
+ for (int k = 0; k < 2; k++) {
+ acc[8 + 4 * k + j] = vdotq_lane_s32(
+ acc[8 + 4 * k + j], rhs[2 * i + j], lhs[4 + 2 * i + k], 0);
+ acc[10 + 4 * k + j] = vdotq_lane_s32(
+ acc[10 + 4 * k + j], rhs[2 * i + j], lhs[4 + 2 * i + k], 1);
+ }
+ }
+ }
+ } else {
+ int8x8_t lhs[2];
+ if (M0 == 2) {
+ int8x8x2_t lhs_uzp = vld2_s8(lhs_ptr);
+ lhs_ptr += 16;
+ lhs[0] = lhs_uzp.val[0];
+ lhs[1] = lhs_uzp.val[1];
+ } else { // M0 == 1.
+ int8x8_t r = vld1_s8(lhs_ptr);
+ lhs_ptr += 8;
+ int8x8x2_t lhs_uzp = vuzp_s8(r, vzero);
+ lhs[0] = lhs_uzp.val[0];
+ lhs[1] = lhs_uzp.val[1];
+ }
+ // 1x8 * 8x8 -> 1x8.
+ for (int i = 0; i < 2; ++i) {
+ for (int j = 0; j < 2; ++j) {
+ acc[j] = vdotq_lane_s32(acc[j], rhs[i * 2 + j], lhs[i], 0);
+ }
+ }
+ if (M0 == 1) continue;
+ // 2x8 * 8x8 -> 2x8.
+ for (int i = 0; i < 2; ++i) {
+ for (int j = 0; j < 2; ++j) {
+ acc[2 + j] = vdotq_lane_s32(acc[2 + j], rhs[i * 2 + j], lhs[i], 1);
+ }
+ }
+ }
+ }
+
+ if (params->flags & IREE_UK_FLAG_MMT4D_ACCUMULATE) {
+ for (int i = 0; i < 2 * M0; i++) {
+ int32x4_t existing_acc = vld1q_s32(out_ptr);
+ acc[i] = vsraq_n_s32(existing_acc, acc[i], 4);
+ vst1q_s32(out_ptr, acc[i]);
+ out_ptr += 4;
+ }
+ } else {
+ for (int i = 0; i < 2 * M0; i++) {
+ acc[i] = vshrq_n_s32(acc[i], 4);
+ vst1q_s32(out_ptr, acc[i]);
+ out_ptr += 4;
+ }
+ }
+}
+
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0_1_2_4_8(
+ iree_uk_mmt4d_tile_s8s4s32_1x8x8_to_8x8x8_arm_64_dotprod,
+ iree_uk_mmt4d_tile_s8s4s32_1x8x8_arm_64_dotprod,
+ iree_uk_mmt4d_tile_s8s4s32_2x8x8_arm_64_dotprod,
+ iree_uk_mmt4d_tile_s8s4s32_4x8x8_arm_64_dotprod,
+ iree_uk_mmt4d_tile_s8s4s32_8x8x8_arm_64_dotprod)
diff --git a/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_entry_point.c b/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_entry_point.c
index f7d9384..61ad8f9 100644
--- a/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_entry_point.c
+++ b/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_entry_point.c
@@ -254,11 +254,34 @@
return 0;
}
+static iree_uk_mmt4d_tile_func_t
+iree_uk_mmt4d_select_tile_func_arm_64_i8i4i32_M0x8x8(
+ const iree_uk_mmt4d_params_t* params) {
+#ifdef IREE_UK_BUILD_ARM_64_DOTPROD
+ if (iree_uk_cpu_supports_dotprod(params->cpu_data)) {
+ switch (params->M0) {
+ case 1:
+ return iree_uk_mmt4d_tile_s8s4s32_1x8x8_arm_64_dotprod;
+ case 2:
+ return iree_uk_mmt4d_tile_s8s4s32_2x8x8_arm_64_dotprod;
+ case 4:
+ return iree_uk_mmt4d_tile_s8s4s32_4x8x8_arm_64_dotprod;
+ case 8:
+ return iree_uk_mmt4d_tile_s8s4s32_8x8x8_arm_64_dotprod;
+ }
+ }
+#endif
+ return 0;
+}
+
static iree_uk_mmt4d_tile_func_t iree_uk_mmt4d_select_tile_func_arm_64_i8i4i32(
const iree_uk_mmt4d_params_t* params) {
if (params->N0 == 16 && params->K0 == 2) {
return iree_uk_mmt4d_select_tile_func_arm_64_i8i4i32_M0x16x2(params);
}
+ if (params->N0 == 8 && params->K0 == 8) {
+ return iree_uk_mmt4d_select_tile_func_arm_64_i8i4i32_M0x8x8(params);
+ }
return 0;
}
diff --git a/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_internal.h b/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_internal.h
index 9c33bd2..3d3e19d 100644
--- a/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_internal.h
+++ b/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_internal.h
@@ -52,5 +52,9 @@
IREE_UK_MMT4D_TILE_FUNC_DECL(iree_uk_mmt4d_tile_s8s4s32_1x16x2_arm_64)
IREE_UK_MMT4D_TILE_FUNC_DECL(iree_uk_mmt4d_tile_s8s4s32_2x16x2_arm_64)
IREE_UK_MMT4D_TILE_FUNC_DECL(iree_uk_mmt4d_tile_s8s4s32_4x16x2_arm_64)
+IREE_UK_MMT4D_TILE_FUNC_DECL(iree_uk_mmt4d_tile_s8s4s32_1x8x8_arm_64_dotprod)
+IREE_UK_MMT4D_TILE_FUNC_DECL(iree_uk_mmt4d_tile_s8s4s32_2x8x8_arm_64_dotprod)
+IREE_UK_MMT4D_TILE_FUNC_DECL(iree_uk_mmt4d_tile_s8s4s32_4x8x8_arm_64_dotprod)
+IREE_UK_MMT4D_TILE_FUNC_DECL(iree_uk_mmt4d_tile_s8s4s32_8x8x8_arm_64_dotprod)
#endif // IREE_BUILTINS_UKERNEL_ARCH_ARM_64_MMT4D_ARM_64_INTERNAL_H_
diff --git a/runtime/src/iree/builtins/ukernel/tools/mmt4d_benchmark.c b/runtime/src/iree/builtins/ukernel/tools/mmt4d_benchmark.c
index cf790cb..4e302f7 100644
--- a/runtime/src/iree/builtins/ukernel/tools/mmt4d_benchmark.c
+++ b/runtime/src/iree/builtins/ukernel/tools/mmt4d_benchmark.c
@@ -144,6 +144,8 @@
"i8mm");
iree_uk_benchmark_register_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S8S4S32, 4, 16, 2,
"");
+ iree_uk_benchmark_register_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S8S4S32, 8, 8, 8,
+ "dotprod");
#elif defined(IREE_ARCH_X86_64)
iree_uk_benchmark_register_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_F32F32F32, 8, 8, 1,
"avx2_fma");
diff --git a/runtime/src/iree/builtins/ukernel/tools/mmt4d_test.c b/runtime/src/iree/builtins/ukernel/tools/mmt4d_test.c
index f567844..995582f 100644
--- a/runtime/src/iree/builtins/ukernel/tools/mmt4d_test.c
+++ b/runtime/src/iree/builtins/ukernel/tools/mmt4d_test.c
@@ -557,6 +557,7 @@
iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S8S8S32, 8, 8, 4, "dotprod");
iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S8S8S32, 8, 8, 8, "i8mm");
iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S8S4S32, 4, 16, 2, "");
+ iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S8S4S32, 8, 8, 8, "dotprod");
#elif defined(IREE_ARCH_X86_64)
iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_F32F32F32, 8, 4, 1, ""); // SSE
iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_F32F32F32, 8, 8, 1, "avx2_fma");