[CPU] Add s8s4s32 i8mm ukernel (#16678)

Adds ukernel for `i8 * i4 -> i32` using the `ssmla` instruction
available in +i8mm. Uses 2 separate functions: one for M0=1 and another
for M0=2, M0=4, M0=8.

```
-----------------------------------------------------------------------------------------------------------
Benchmark                                                 Time             CPU   Iterations UserCounters...
-----------------------------------------------------------------------------------------------------------
BM_mmt4d_s8s4s32_tile_1x16x2/real_time                0.690 us        0.689 us      1048575 items_per_second=23.735G/s
BM_mmt4d_s8s4s32_tile_2x16x2/real_time                0.803 us        0.801 us      1048575 items_per_second=40.8248G/s
BM_mmt4d_s8s4s32_tile_4x16x2/real_time                 1.68 us         1.68 us       524287 items_per_second=38.9081G/s
BM_mmt4d_s8s4s32_tile_1x8x8_dotprod/real_time      0.640 us        0.638 us      1048575 items_per_second=51.185G/s
BM_mmt4d_s8s4s32_tile_2x8x8_dotprod/real_time      0.369 us        0.368 us      2097151 items_per_second=177.561G/s
BM_mmt4d_s8s4s32_tile_4x8x8_dotprod/real_time      0.540 us        0.539 us      1048575 items_per_second=242.718G/s
BM_mmt4d_s8s4s32_tile_8x8x8_dotprod/real_time      0.934 us        0.932 us      1048575 items_per_second=280.543G/s
BM_mmt4d_s8s4s32_tile_1x8x16_i8mm/real_time      0.451 us        0.450 us      2097151 items_per_second=145.412G/s
BM_mmt4d_s8s4s32_tile_2x8x16_i8mm/real_time      0.509 us        0.509 us      1048575 items_per_second=257.577G/s
BM_mmt4d_s8s4s32_tile_4x8x16_i8mm/real_time      0.806 us        0.805 us      1048575 items_per_second=325.353G/s
BM_mmt4d_s8s4s32_tile_8x8x16_i8mm/real_time       1.76 us         1.76 us       524287 items_per_second=297.277G/s
```
diff --git a/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_entry_point.c b/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_entry_point.c
index 61ad8f9..6c71c4d 100644
--- a/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_entry_point.c
+++ b/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_entry_point.c
@@ -274,6 +274,26 @@
   return 0;
 }
 
+static iree_uk_mmt4d_tile_func_t
+iree_uk_mmt4d_select_tile_func_arm_64_i8i4i32_M0x8x16(
+    const iree_uk_mmt4d_params_t* params) {
+#ifdef IREE_UK_BUILD_ARM_64_I8MM
+  if (iree_uk_cpu_supports_i8mm(params->cpu_data)) {
+    switch (params->M0) {
+      case 1:
+        return iree_uk_mmt4d_tile_s8s4s32_1x8x16_arm_64_i8mm;
+      case 2:
+        return iree_uk_mmt4d_tile_s8s4s32_2x8x16_arm_64_i8mm;
+      case 4:
+        return iree_uk_mmt4d_tile_s8s4s32_4x8x16_arm_64_i8mm;
+      case 8:
+        return iree_uk_mmt4d_tile_s8s4s32_8x8x16_arm_64_i8mm;
+    }
+  }
+#endif
+  return 0;
+}
+
 static iree_uk_mmt4d_tile_func_t iree_uk_mmt4d_select_tile_func_arm_64_i8i4i32(
     const iree_uk_mmt4d_params_t* params) {
   if (params->N0 == 16 && params->K0 == 2) {
@@ -282,6 +302,9 @@
   if (params->N0 == 8 && params->K0 == 8) {
     return iree_uk_mmt4d_select_tile_func_arm_64_i8i4i32_M0x8x8(params);
   }
+  if (params->N0 == 8 && params->K0 == 16) {
+    return iree_uk_mmt4d_select_tile_func_arm_64_i8i4i32_M0x8x16(params);
+  }
   return 0;
 }
 
diff --git a/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_i8mm.c b/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_i8mm.c
index a62e80c..218612a 100644
--- a/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_i8mm.c
+++ b/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_i8mm.c
@@ -107,3 +107,159 @@
 IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
     iree_uk_mmt4d_tile_s8s8s32_1x8x8_to_8x8x8_arm_64_i8mm,
     iree_uk_mmt4d_tile_s8s8s32_8x8x8_arm_64_i8mm, 8)
+
+// In the s8s4s32 kernels below, we unpack int4s into individual int8s.
+// To preserve signedness, int4s are moved to the upper 4-bits of each byte.
+// This has the effect of multiplying each int4 by 2^4 = 16. To compensate,
+// we divide the accumulator values by 16 before storing to memory.
+// This int4 conversion trick is borrowed from the `qd8-f32-qc4w-gemm*`
+// kernels in https://github.com/google/XNNPACK.
+
+IREE_UK_ATTRIBUTE_ALWAYS_INLINE inline void
+iree_uk_mmt4d_tile_s8s4s32_1x8x16_arm_64_i8mm(
+    void* IREE_UK_RESTRICT out_tile, const void* IREE_UK_RESTRICT lhs_panel,
+    const void* IREE_UK_RESTRICT rhs_panel,
+    const iree_uk_mmt4d_params_t* params) {
+  IREE_UK_ASSERT(!(params->K0 % 16));
+  const iree_uk_int8_t* IREE_UK_RESTRICT lhs_ptr = lhs_panel;
+  const iree_uk_int8_t* IREE_UK_RESTRICT rhs_ptr = rhs_panel;
+  iree_uk_int32_t* IREE_UK_RESTRICT out_ptr = out_tile;
+
+  const int8x16_t vmask = vmovq_n_s8(0xF0);
+  const int8x8_t vzero = vmov_n_s8(0);
+
+  int32x4_t acc[4];
+  IREE_UK_UNROLL for (int i = 0; i < 4; i++) {
+    // We start with zero accumulators and add the value of *out_ptr later.
+    // This is required for the int4 left shift described above.
+    acc[i] = vdupq_n_s32(0);
+  }
+
+  for (int k = 0; k < params->K; ++k) {
+    int8x16_t rhs[2][4];
+    IREE_UK_UNROLL for (int i = 0; i < 4; i++) {
+      int8x16_t r = vld1q_s8(rhs_ptr + 16 * i);
+      rhs[0][i] = vshlq_n_s8(r, 4);
+      rhs[1][i] = vandq_s8(r, vmask);
+    }
+    rhs_ptr += 64;
+
+    int8x16_t lhs[2];
+    int8x8x2_t lhs_uzp = vld2_s8(lhs_ptr);
+    lhs_ptr += 16;
+    lhs[0] = vcombine_s8(lhs_uzp.val[0], vzero);
+    lhs[1] = vcombine_s8(lhs_uzp.val[1], vzero);
+
+    IREE_UK_UNROLL for (int i = 0; i < 4; i++) {
+      acc[i] = vmmlaq_s32(acc[i], lhs[0], rhs[0][i]);
+      acc[i] = vmmlaq_s32(acc[i], lhs[1], rhs[1][i]);
+    }
+  }
+
+  IREE_UK_UNROLL for (int j = 0; j < 2; j++) {
+    acc[2 * j + 0] = vshrq_n_s32(acc[2 * j + 0], 4);
+    acc[2 * j + 1] = vshrq_n_s32(acc[2 * j + 1], 4);
+
+    int32x4_t acc_1x4_0 =
+        iree_uk_neon_uzp1_s32_as_s64(acc[2 * j + 0], acc[2 * j + 1]);
+    if (params->flags & IREE_UK_FLAG_MMT4D_ACCUMULATE) {
+      int32x4_t existing_acc = vld1q_s32(out_ptr + 4 * j);
+      acc_1x4_0 = vaddq_s32(acc_1x4_0, existing_acc);
+    }
+    vst1q_s32(out_ptr + 4 * j, acc_1x4_0);
+  }
+}
+
+IREE_UK_ATTRIBUTE_ALWAYS_INLINE static inline void
+iree_uk_mmt4d_tile_s8s4s32_2x8x16_to_8x8x16_arm_64_i8mm(
+    void* IREE_UK_RESTRICT out_tile, const void* IREE_UK_RESTRICT lhs_panel,
+    const void* IREE_UK_RESTRICT rhs_panel,
+    const iree_uk_mmt4d_params_t* params, int M0) {
+  IREE_UK_ASSERT(M0 >= 2 && M0 <= 8 && iree_uk_is_po2_u32(M0));
+  IREE_UK_ASSERT(!(params->K0 % 16));
+  const iree_uk_int8_t* IREE_UK_RESTRICT lhs_ptr = lhs_panel;
+  const iree_uk_int8_t* IREE_UK_RESTRICT rhs_ptr = rhs_panel;
+  iree_uk_int32_t* IREE_UK_RESTRICT out_ptr = out_tile;
+
+  const int8x16_t vmask = vmovq_n_s8(0xF0);
+  const int mtiles = M0 / 2;
+
+  int32x4_t acc[4][4];
+  IREE_UK_UNROLL for (int i = 0; i < mtiles; i++) {
+    IREE_UK_UNROLL for (int j = 0; j < 4; j++) {
+      // We start with zero accumulators and add the value of *out_ptr later.
+      // This is required for the int4 left shift described above.
+      acc[i][j] = vdupq_n_s32(0);
+    }
+  }
+
+  for (int k = 0; k < params->K; ++k) {
+    int8x16_t rhs[2][4];
+    IREE_UK_UNROLL for (int i = 0; i < 4; i++) {
+      int8x16_t r = vld1q_s8(rhs_ptr + 16 * i);
+      rhs[0][i] = vshlq_n_s8(r, 4);
+      rhs[1][i] = vandq_s8(r, vmask);
+    }
+    rhs_ptr += 64;
+
+    int8x16_t lhs[2][4];
+    if (M0 == 2) {
+      int8x8x2_t lhs_uzp[2];
+      IREE_UK_UNROLL for (int i = 0; i < 2; i++) {
+        lhs_uzp[i] = vld2_s8(lhs_ptr + 16 * i);
+      }
+      lhs[0][0] = vcombine_s8(lhs_uzp[0].val[0], lhs_uzp[1].val[0]);
+      lhs[1][0] = vcombine_s8(lhs_uzp[0].val[1], lhs_uzp[1].val[1]);
+    } else {
+      IREE_UK_UNROLL for (int i = 0; i < mtiles; i++) {
+        int8x8x2_t lhs_0 = vld2_s8(lhs_ptr + 16 * 2 * i);
+        int8x8x2_t lhs_1 = vld2_s8(lhs_ptr + 16 * (2 * i + 1));
+        lhs[0][i] = vcombine_s8(lhs_0.val[0], lhs_1.val[0]);
+        lhs[1][i] = vcombine_s8(lhs_0.val[1], lhs_1.val[1]);
+      }
+    }
+    lhs_ptr += 32 * mtiles;
+
+    IREE_UK_UNROLL for (int i = 0; i < mtiles; i++) {
+      IREE_UK_UNROLL for (int j = 0; j < 4; j++) {
+        IREE_UK_UNROLL for (int m = 0; m < 2; m++) {
+          acc[i][j] = vmmlaq_s32(acc[i][j], lhs[m][i], rhs[m][j]);
+        }
+      }
+    }
+  }
+
+  // Swizzle accumulator 2x2 register tiles back to row-major and store.
+  IREE_UK_UNROLL for (int i = 0; i < mtiles; i++) {
+    IREE_UK_UNROLL for (int j = 0; j < 2; j++) {
+      acc[i][2 * j + 0] = vshrq_n_s32(acc[i][2 * j + 0], 4);
+      acc[i][2 * j + 1] = vshrq_n_s32(acc[i][2 * j + 1], 4);
+
+      int32x4_t acc_1x4_0 =
+          iree_uk_neon_uzp1_s32_as_s64(acc[i][2 * j + 0], acc[i][2 * j + 1]);
+      if (params->flags & IREE_UK_FLAG_MMT4D_ACCUMULATE) {
+        int32x4_t existing_acc = vld1q_s32(out_ptr + 8 * (2 * i + 0) + 4 * j);
+        acc_1x4_0 = vaddq_s32(acc_1x4_0, existing_acc);
+      }
+      vst1q_s32(out_ptr + 8 * (2 * i + 0) + 4 * j, acc_1x4_0);
+
+      int32x4_t acc_1x4_1 =
+          iree_uk_neon_uzp2_s32_as_s64(acc[i][2 * j + 0], acc[i][2 * j + 1]);
+      if (params->flags & IREE_UK_FLAG_MMT4D_ACCUMULATE) {
+        int32x4_t existing_acc = vld1q_s32(out_ptr + 8 * (2 * i + 1) + 4 * j);
+        acc_1x4_1 = vaddq_s32(acc_1x4_1, existing_acc);
+      }
+      vst1q_s32(out_ptr + 8 * (2 * i + 1) + 4 * j, acc_1x4_1);
+    }
+  }
+}
+
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
+    iree_uk_mmt4d_tile_s8s4s32_2x8x16_to_8x8x16_arm_64_i8mm,
+    iree_uk_mmt4d_tile_s8s4s32_2x8x16_arm_64_i8mm, 2)
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
+    iree_uk_mmt4d_tile_s8s4s32_2x8x16_to_8x8x16_arm_64_i8mm,
+    iree_uk_mmt4d_tile_s8s4s32_4x8x16_arm_64_i8mm, 4)
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
+    iree_uk_mmt4d_tile_s8s4s32_2x8x16_to_8x8x16_arm_64_i8mm,
+    iree_uk_mmt4d_tile_s8s4s32_8x8x16_arm_64_i8mm, 8)
diff --git a/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_internal.h b/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_internal.h
index 3d3e19d..8ecc43b 100644
--- a/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_internal.h
+++ b/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_internal.h
@@ -56,5 +56,9 @@
 IREE_UK_MMT4D_TILE_FUNC_DECL(iree_uk_mmt4d_tile_s8s4s32_2x8x8_arm_64_dotprod)
 IREE_UK_MMT4D_TILE_FUNC_DECL(iree_uk_mmt4d_tile_s8s4s32_4x8x8_arm_64_dotprod)
 IREE_UK_MMT4D_TILE_FUNC_DECL(iree_uk_mmt4d_tile_s8s4s32_8x8x8_arm_64_dotprod)
+IREE_UK_MMT4D_TILE_FUNC_DECL(iree_uk_mmt4d_tile_s8s4s32_1x8x16_arm_64_i8mm)
+IREE_UK_MMT4D_TILE_FUNC_DECL(iree_uk_mmt4d_tile_s8s4s32_2x8x16_arm_64_i8mm)
+IREE_UK_MMT4D_TILE_FUNC_DECL(iree_uk_mmt4d_tile_s8s4s32_4x8x16_arm_64_i8mm)
+IREE_UK_MMT4D_TILE_FUNC_DECL(iree_uk_mmt4d_tile_s8s4s32_8x8x16_arm_64_i8mm)
 
 #endif  // IREE_BUILTINS_UKERNEL_ARCH_ARM_64_MMT4D_ARM_64_INTERNAL_H_
diff --git a/runtime/src/iree/builtins/ukernel/tools/mmt4d_benchmark.c b/runtime/src/iree/builtins/ukernel/tools/mmt4d_benchmark.c
index bfb932e..a3c0e9c 100644
--- a/runtime/src/iree/builtins/ukernel/tools/mmt4d_benchmark.c
+++ b/runtime/src/iree/builtins/ukernel/tools/mmt4d_benchmark.c
@@ -149,6 +149,8 @@
                                    "");
   iree_uk_benchmark_register_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S8S4S32, 8, 8, 8,
                                    "dotprod");
+  iree_uk_benchmark_register_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S8S4S32, 8, 8, 16,
+                                   "i8mm");
 #elif defined(IREE_ARCH_X86_64)
   iree_uk_benchmark_register_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_F32F32F32, 8, 8, 1,
                                    "avx2_fma");
diff --git a/runtime/src/iree/builtins/ukernel/tools/mmt4d_test.c b/runtime/src/iree/builtins/ukernel/tools/mmt4d_test.c
index b7e9593..517f0ae 100644
--- a/runtime/src/iree/builtins/ukernel/tools/mmt4d_test.c
+++ b/runtime/src/iree/builtins/ukernel/tools/mmt4d_test.c
@@ -569,8 +569,9 @@
   iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_BF16BF16BF16, 8, 8, 4, "bf16");
   iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_BF16BF16F32, 8, 8, 4, "bf16");
   iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S8S8S32, 8, 8, 4, "dotprod");
-  iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S8S4S32, 8, 8, 8, "dotprod");
   iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S8S8S32, 8, 8, 8, "i8mm");
+  iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S8S4S32, 8, 8, 8, "dotprod");
+  iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S8S4S32, 8, 8, 16, "i8mm");
 
 #elif defined(IREE_ARCH_X86_64)