Add s8s4s32 dotprod microkernel (#16473)

Adds `i8 * i4 -> i32` microkernel that uses the `+dotprod` ARM CPU
feature. Supports tile sizes `M0xN0xK0`: `1x8x8`, `2x8x8`, `4x8x8`,
`8x8x8`. We use `K0=8` since `+dotprod` requires 4 contiguous elements
and 2 `i4s` are contained in a single byte (`2x4`).

This ukernel significantly departs from the XNNPack
[qd8-f32-qc4w-gemm-1x16c4-minmax-neondot.c](https://github.com/google/XNNPACK/blob/master/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x16c4-minmax-neondot.c)
kernel. The XNNPack version packs the LHS so that no de-interleaving is
required on load. It also uses different shapes (`1x8 * 8x16 --> 1x16`
vs `4x2 * 2x8 --> 4x8`). The int4 shift trick is re-used.

Microbenchmark results on Pixel 8 Pro (last four rows):
```
-----------------------------------------------------------------------------------------------------------
Benchmark                                                 Time             CPU   Iterations UserCounters...
-----------------------------------------------------------------------------------------------------------
BM_mmt4d_s8s4s32_tile_1x16x2/real_time                0.690 us        0.689 us      1048575 items_per_second=23.735G/s
BM_mmt4d_s8s4s32_tile_2x16x2/real_time                0.803 us        0.801 us      1048575 items_per_second=40.8248G/s
BM_mmt4d_s8s4s32_tile_4x16x2/real_time                 1.68 us         1.68 us       524287 items_per_second=38.9081G/s
BM_mmt4d_s8s4s32_tile_1x8x8_dotprod/real_time      0.640 us        0.638 us      1048575 items_per_second=51.185G/s
BM_mmt4d_s8s4s32_tile_2x8x8_dotprod/real_time      0.369 us        0.368 us      2097151 items_per_second=177.561G/s
BM_mmt4d_s8s4s32_tile_4x8x8_dotprod/real_time      0.540 us        0.539 us      1048575 items_per_second=242.718G/s
BM_mmt4d_s8s4s32_tile_8x8x8_dotprod/real_time      0.934 us        0.932 us      1048575 items_per_second=280.543G/s
```
diff --git a/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_dotprod.c b/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_dotprod.c
index 8bf713c..e4c7b00 100644
--- a/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_dotprod.c
+++ b/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_dotprod.c
@@ -75,3 +75,136 @@
     iree_uk_mmt4d_tile_s8s8s32_2x8x4_arm_64_dotprod,
     iree_uk_mmt4d_tile_s8s8s32_4x8x4_arm_64_dotprod,
     iree_uk_mmt4d_tile_s8s8s32_8x8x4_arm_64_dotprod)
+
+static inline void iree_uk_mmt4d_tile_s8s4s32_1x8x8_to_8x8x8_arm_64_dotprod(
+    void* IREE_UK_RESTRICT out_tile, const void* IREE_UK_RESTRICT lhs_panel,
+    const void* IREE_UK_RESTRICT rhs_panel,
+    const iree_uk_mmt4d_params_t* params, int M0) {
+  IREE_UK_ASSERT(M0 >= 1 && M0 <= 8 && iree_uk_is_po2_u32(M0));
+  const iree_uk_int8_t* IREE_UK_RESTRICT lhs_ptr = lhs_panel;
+  const iree_uk_int8_t* IREE_UK_RESTRICT rhs_ptr = rhs_panel;
+  iree_uk_int32_t* IREE_UK_RESTRICT out_ptr = out_tile;
+
+  const int8x16_t vmask = vmovq_n_s8(0xF0);
+  const int8x8_t vzero = vmov_n_s8(0);
+
+  int32x4_t acc[16];
+  for (int i = 0; i < 16; i++) {
+    // We start with zero accumulators and add the value of *out_ptr later.
+    // This is required for the int4 left shift described later.
+    acc[i] = vdupq_n_s32(0);
+  }
+
+  for (int k = 0; k < params->K; ++k) {
+    int8x16_t rhs[4];
+    for (int i = 0; i < 2; i++) {
+      int8x16_t r = vld1q_s8(rhs_ptr);
+      rhs_ptr += 16;
+      // We unpack int4s into individual int8s. To preserve signedness,
+      // int4s are moved to the upper 4-bits of each byte. This has the effect
+      // of multiplying each int4 by 2^4 = 16. To compensate, we divide the
+      // accumulator values by 16 before storing to memory.
+      // This int4 conversion trick is borrowed from the `qd8-f32-qc4w-gemm*`
+      // kernels in https://github.com/google/XNNPACK.
+      rhs[i + 0] = vshlq_n_s8(r, 4);
+      rhs[i + 2] = vandq_s8(r, vmask);
+    }
+
+    if (M0 >= 4) {
+      int8x8_t lhs[8];
+      if (M0 == 8) {
+        int8x16x2_t lhs_uzp_0 = vld2q_s8(lhs_ptr);
+        lhs_ptr += 32;
+        int8x16x2_t lhs_uzp_1 = vld2q_s8(lhs_ptr);
+        lhs_ptr += 32;
+        lhs[0] = vget_low_s8(lhs_uzp_0.val[0]);
+        lhs[1] = vget_low_s8(lhs_uzp_0.val[1]);
+        lhs[2] = vget_high_s8(lhs_uzp_0.val[0]);
+        lhs[3] = vget_high_s8(lhs_uzp_0.val[1]);
+        lhs[4] = vget_low_s8(lhs_uzp_1.val[0]);
+        lhs[5] = vget_high_s8(lhs_uzp_1.val[0]);
+        lhs[6] = vget_low_s8(lhs_uzp_1.val[1]);
+        lhs[7] = vget_high_s8(lhs_uzp_1.val[1]);
+      } else {  // M0 = 4.
+        int8x16x2_t lhs_uzp = vld2q_s8(lhs_ptr);
+        lhs_ptr += 32;
+        lhs[0] = vget_low_s8(lhs_uzp.val[0]);
+        lhs[1] = vget_low_s8(lhs_uzp.val[1]);
+        lhs[2] = vget_high_s8(lhs_uzp.val[0]);
+        lhs[3] = vget_high_s8(lhs_uzp.val[1]);
+      }
+      // 4x8 * 8x8 -> 4x8.
+      for (int i = 0; i < 2; i++) {
+        for (int j = 0; j < 2; j++) {
+          for (int k = 0; k < 2; k++) {
+            acc[4 * k + j] = vdotq_lane_s32(acc[4 * k + j], rhs[2 * i + j],
+                                            lhs[2 * k + i], 0);
+            acc[4 * k + j + 2] = vdotq_lane_s32(
+                acc[4 * k + j + 2], rhs[2 * i + j], lhs[2 * k + i], 1);
+          }
+        }
+      }
+      if (M0 == 4) continue;
+      // 8x8 * 8x8 -> 8x8.
+      for (int i = 0; i < 2; i++) {
+        for (int j = 0; j < 2; j++) {
+          for (int k = 0; k < 2; k++) {
+            acc[8 + 4 * k + j] = vdotq_lane_s32(
+                acc[8 + 4 * k + j], rhs[2 * i + j], lhs[4 + 2 * i + k], 0);
+            acc[10 + 4 * k + j] = vdotq_lane_s32(
+                acc[10 + 4 * k + j], rhs[2 * i + j], lhs[4 + 2 * i + k], 1);
+          }
+        }
+      }
+    } else {
+      int8x8_t lhs[2];
+      if (M0 == 2) {
+        int8x8x2_t lhs_uzp = vld2_s8(lhs_ptr);
+        lhs_ptr += 16;
+        lhs[0] = lhs_uzp.val[0];
+        lhs[1] = lhs_uzp.val[1];
+      } else {  // M0 == 1.
+        int8x8_t r = vld1_s8(lhs_ptr);
+        lhs_ptr += 8;
+        int8x8x2_t lhs_uzp = vuzp_s8(r, vzero);
+        lhs[0] = lhs_uzp.val[0];
+        lhs[1] = lhs_uzp.val[1];
+      }
+      // 1x8 * 8x8 -> 1x8.
+      for (int i = 0; i < 2; ++i) {
+        for (int j = 0; j < 2; ++j) {
+          acc[j] = vdotq_lane_s32(acc[j], rhs[i * 2 + j], lhs[i], 0);
+        }
+      }
+      if (M0 == 1) continue;
+      // 2x8 * 8x8 -> 2x8.
+      for (int i = 0; i < 2; ++i) {
+        for (int j = 0; j < 2; ++j) {
+          acc[2 + j] = vdotq_lane_s32(acc[2 + j], rhs[i * 2 + j], lhs[i], 1);
+        }
+      }
+    }
+  }
+
+  if (params->flags & IREE_UK_FLAG_MMT4D_ACCUMULATE) {
+    for (int i = 0; i < 2 * M0; i++) {
+      int32x4_t existing_acc = vld1q_s32(out_ptr);
+      acc[i] = vsraq_n_s32(existing_acc, acc[i], 4);
+      vst1q_s32(out_ptr, acc[i]);
+      out_ptr += 4;
+    }
+  } else {
+    for (int i = 0; i < 2 * M0; i++) {
+      acc[i] = vshrq_n_s32(acc[i], 4);
+      vst1q_s32(out_ptr, acc[i]);
+      out_ptr += 4;
+    }
+  }
+}
+
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0_1_2_4_8(
+    iree_uk_mmt4d_tile_s8s4s32_1x8x8_to_8x8x8_arm_64_dotprod,
+    iree_uk_mmt4d_tile_s8s4s32_1x8x8_arm_64_dotprod,
+    iree_uk_mmt4d_tile_s8s4s32_2x8x8_arm_64_dotprod,
+    iree_uk_mmt4d_tile_s8s4s32_4x8x8_arm_64_dotprod,
+    iree_uk_mmt4d_tile_s8s4s32_8x8x8_arm_64_dotprod)
diff --git a/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_entry_point.c b/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_entry_point.c
index f7d9384..61ad8f9 100644
--- a/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_entry_point.c
+++ b/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_entry_point.c
@@ -254,11 +254,34 @@
   return 0;
 }
 
+static iree_uk_mmt4d_tile_func_t
+iree_uk_mmt4d_select_tile_func_arm_64_i8i4i32_M0x8x8(
+    const iree_uk_mmt4d_params_t* params) {
+#ifdef IREE_UK_BUILD_ARM_64_DOTPROD
+  if (iree_uk_cpu_supports_dotprod(params->cpu_data)) {
+    switch (params->M0) {
+      case 1:
+        return iree_uk_mmt4d_tile_s8s4s32_1x8x8_arm_64_dotprod;
+      case 2:
+        return iree_uk_mmt4d_tile_s8s4s32_2x8x8_arm_64_dotprod;
+      case 4:
+        return iree_uk_mmt4d_tile_s8s4s32_4x8x8_arm_64_dotprod;
+      case 8:
+        return iree_uk_mmt4d_tile_s8s4s32_8x8x8_arm_64_dotprod;
+    }
+  }
+#endif
+  return 0;
+}
+
 static iree_uk_mmt4d_tile_func_t iree_uk_mmt4d_select_tile_func_arm_64_i8i4i32(
     const iree_uk_mmt4d_params_t* params) {
   if (params->N0 == 16 && params->K0 == 2) {
     return iree_uk_mmt4d_select_tile_func_arm_64_i8i4i32_M0x16x2(params);
   }
+  if (params->N0 == 8 && params->K0 == 8) {
+    return iree_uk_mmt4d_select_tile_func_arm_64_i8i4i32_M0x8x8(params);
+  }
   return 0;
 }
 
diff --git a/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_internal.h b/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_internal.h
index 9c33bd2..3d3e19d 100644
--- a/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_internal.h
+++ b/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_internal.h
@@ -52,5 +52,9 @@
 IREE_UK_MMT4D_TILE_FUNC_DECL(iree_uk_mmt4d_tile_s8s4s32_1x16x2_arm_64)
 IREE_UK_MMT4D_TILE_FUNC_DECL(iree_uk_mmt4d_tile_s8s4s32_2x16x2_arm_64)
 IREE_UK_MMT4D_TILE_FUNC_DECL(iree_uk_mmt4d_tile_s8s4s32_4x16x2_arm_64)
+IREE_UK_MMT4D_TILE_FUNC_DECL(iree_uk_mmt4d_tile_s8s4s32_1x8x8_arm_64_dotprod)
+IREE_UK_MMT4D_TILE_FUNC_DECL(iree_uk_mmt4d_tile_s8s4s32_2x8x8_arm_64_dotprod)
+IREE_UK_MMT4D_TILE_FUNC_DECL(iree_uk_mmt4d_tile_s8s4s32_4x8x8_arm_64_dotprod)
+IREE_UK_MMT4D_TILE_FUNC_DECL(iree_uk_mmt4d_tile_s8s4s32_8x8x8_arm_64_dotprod)
 
 #endif  // IREE_BUILTINS_UKERNEL_ARCH_ARM_64_MMT4D_ARM_64_INTERNAL_H_
diff --git a/runtime/src/iree/builtins/ukernel/tools/mmt4d_benchmark.c b/runtime/src/iree/builtins/ukernel/tools/mmt4d_benchmark.c
index cf790cb..4e302f7 100644
--- a/runtime/src/iree/builtins/ukernel/tools/mmt4d_benchmark.c
+++ b/runtime/src/iree/builtins/ukernel/tools/mmt4d_benchmark.c
@@ -144,6 +144,8 @@
                                    "i8mm");
   iree_uk_benchmark_register_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S8S4S32, 4, 16, 2,
                                    "");
+  iree_uk_benchmark_register_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S8S4S32, 8, 8, 8,
+                                   "dotprod");
 #elif defined(IREE_ARCH_X86_64)
   iree_uk_benchmark_register_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_F32F32F32, 8, 8, 1,
                                    "avx2_fma");
diff --git a/runtime/src/iree/builtins/ukernel/tools/mmt4d_test.c b/runtime/src/iree/builtins/ukernel/tools/mmt4d_test.c
index f567844..995582f 100644
--- a/runtime/src/iree/builtins/ukernel/tools/mmt4d_test.c
+++ b/runtime/src/iree/builtins/ukernel/tools/mmt4d_test.c
@@ -557,6 +557,7 @@
   iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S8S8S32, 8, 8, 4, "dotprod");
   iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S8S8S32, 8, 8, 8, "i8mm");
   iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S8S4S32, 4, 16, 2, "");
+  iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S8S4S32, 8, 8, 8, "dotprod");
 #elif defined(IREE_ARCH_X86_64)
   iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_F32F32F32, 8, 4, 1, "");  // SSE
   iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_F32F32F32, 8, 8, 1, "avx2_fma");