ukernel: unroll the s16u4 VNNI ukernel, and drop the unused N0=16 variant (#16047)

diff --git a/runtime/src/iree/builtins/ukernel/arch/x86_64/mmt4d_x86_64_avx512_vnni.c b/runtime/src/iree/builtins/ukernel/arch/x86_64/mmt4d_x86_64_avx512_vnni.c
index 2912022..5389c82 100644
--- a/runtime/src/iree/builtins/ukernel/arch/x86_64/mmt4d_x86_64_avx512_vnni.c
+++ b/runtime/src/iree/builtins/ukernel/arch/x86_64/mmt4d_x86_64_avx512_vnni.c
@@ -255,14 +255,6 @@
     iree_uk_mmt4d_tile_s16s16s32_8x16x2_x86_64_avx512_vnni,
     iree_uk_mmt4d_tile_s16s16s32_16x16x2_x86_64_avx512_vnni)
 
-// This kernel is parametrized in N0, allowing N0==16 and N0==32.
-// Performance on AMD Ryzen 9 7950X3D:
-//   - with N0=16:  180 Gop/s
-//   - with N0=32:  240 Gop/s
-// So there's a nice reward for going extra large, but that's also a liability
-// for vecmat shapes whose N dimension isn't a multiple of 32. Maybe we can
-// keep both for now.
-//
 // The idea of this kernel is to split the LHS s16 values into high and low
 // 8-bit components to be able to use _mm512_dpbusd_epi32.
 //
@@ -285,33 +277,31 @@
 // of the combinations of operands that we have to feed _mm512_dpbusd_epi32,
 // we manage to find an operand order that accomodates the instruction's
 // requirements on signednesses.
-static inline void
-iree_uk_mmt4d_tile_s16u4s32_1x16x8_to_1x32x8_x86_64_avx512_vnni(
+void iree_uk_mmt4d_tile_s16u4s32_1x32x8_x86_64_avx512_vnni(
     void* IREE_UK_RESTRICT out_tile, const void* IREE_UK_RESTRICT lhs_panel,
     const void* IREE_UK_RESTRICT rhs_panel,
-    const iree_uk_mmt4d_params_t* params, int N0) {
-  IREE_UK_ASSERT(N0 >= 16 && N0 <= 32 && iree_uk_is_po2_u32(N0));
+    const iree_uk_mmt4d_params_t* params) {
   iree_uk_int32_t* IREE_UK_RESTRICT out_ptr = out_tile;
   const iree_uk_int16_t* IREE_UK_RESTRICT lhs_ptr = lhs_panel;
   const iree_uk_uint8_t* IREE_UK_RESTRICT rhs_ptr = rhs_panel;
-  // acc[4 * i] is the actual accumulator.
-  // The other acc[4 * i + j] are only used internally in the accumulation loop.
-  __m512i acc[8];
+  // Accumulator shape: 1x32xs32, in 2 registers, each 1x16xs32.
+  __m512i acc0, acc1;
   if (params->flags & IREE_UK_FLAG_MMT4D_ACCUMULATE) {
-    for (int i = 0; i < N0 / 16; ++i) {
-      acc[4 * i] = _mm512_loadu_si512((const __m512i*)(out_ptr + 16 * i));
-    }
+    acc0 = _mm512_loadu_si512((const __m512i*)(out_ptr + 16 * 0));
+    acc1 = _mm512_loadu_si512((const __m512i*)(out_ptr + 16 * 1));
   } else {
-    for (int i = 0; i < N0 / 16; ++i) {
-      acc[4 * i] = _mm512_setzero_si512();
-    }
+    acc0 = _mm512_setzero_si512();
+    acc1 = _mm512_setzero_si512();
   }
-  for (int i = 0; i < N0 / 16; ++i) {
-    for (int j = 1; j < 4; ++j) {
-      acc[4 * i + j] = _mm512_setzero_si512();
-    }
-  }
-
+  // Additional internal accumulators - acc{i}{j} will be folded into acc{i} at
+  // the end of the loop.
+  __m512i acc01 = _mm512_setzero_si512();
+  __m512i acc02 = _mm512_setzero_si512();
+  __m512i acc03 = _mm512_setzero_si512();
+  __m512i acc11 = _mm512_setzero_si512();
+  __m512i acc12 = _mm512_setzero_si512();
+  __m512i acc13 = _mm512_setzero_si512();
+  // Shuffle indices.
   const __m128i idx_0_mod_4 = _mm_set1_epi32(0x0c080400);
   const __m128i idx_1_mod_4 = _mm_set1_epi32(0x0d090501);
   const __m128i idx_2_mod_4 = _mm_set1_epi32(0x0e0a0602);
@@ -332,65 +322,44 @@
         _mm512_broadcastq_epi64(_mm_shuffle_epi8(lhs, idx_2_mod_4));
     __m512i lhs_odd_s16_high_s8 =
         _mm512_broadcastq_epi64(_mm_shuffle_epi8(lhs, idx_3_mod_4));
-    // Load 8x16xu4 RHS data.
-    __m512i rhs[2];
-    for (int i = 0; i < N0 / 16; ++i) {
-      rhs[i] = _mm512_loadu_si512((const __m512i*)(rhs_ptr + 64 * i));
-    }
-    rhs_ptr += N0 * 4;
+    // Load 8x32xu4 RHS data, in 2 registers, each 8x16xu4.
+    __m512i rhs0 = _mm512_loadu_si512((const __m512i*)(rhs_ptr + 64 * 0));
+    __m512i rhs1 = _mm512_loadu_si512((const __m512i*)(rhs_ptr + 64 * 1));
+    rhs_ptr += 128;
     // Extract the even/odd u4 lanes.
-    __m512i rhs_even_u4[2];
-    __m512i rhs_odd_u4[2];
-    for (int i = 0; i < N0 / 16; ++i) {
-      rhs_even_u4[i] = _mm512_and_si512(mask_0f, rhs[i]);
-      rhs_odd_u4[i] = _mm512_and_si512(mask_0f, _mm512_srli_epi16(rhs[i], 4));
-    }
+    __m512i rhs0_even_u4 = _mm512_and_si512(mask_0f, rhs0);
+    __m512i rhs1_even_u4 = _mm512_and_si512(mask_0f, rhs1);
+    __m512i rhs0_odd_u4 = _mm512_and_si512(mask_0f, _mm512_srli_epi16(rhs0, 4));
+    __m512i rhs1_odd_u4 = _mm512_and_si512(mask_0f, _mm512_srli_epi16(rhs1, 4));
     // Arithmetic. See the comment at the top of this kernel for an explanation.
     // _mm512_dpbusd_epi32 takes an unsigned LHS and a signed RHS. The parameter
     // order in each call is adapted to that constraint.
-    for (int i = 0; i < N0 / 16; ++i) {
-      acc[4 * i + 0] = _mm512_dpbusd_epi32(acc[4 * i + 0], lhs_even_s16_low_u8,
-                                           rhs_even_u4[i]);
-      acc[4 * i + 1] = _mm512_dpbusd_epi32(acc[4 * i + 1], rhs_even_u4[i],
-                                           lhs_even_s16_high_s8);
-      acc[4 * i + 2] = _mm512_dpbusd_epi32(acc[4 * i + 2], lhs_odd_s16_low_u8,
-                                           rhs_odd_u4[i]);
-      acc[4 * i + 3] = _mm512_dpbusd_epi32(acc[4 * i + 3], rhs_odd_u4[i],
-                                           lhs_odd_s16_high_s8);
-    }
+    acc0 = _mm512_dpbusd_epi32(acc0, lhs_even_s16_low_u8, rhs0_even_u4);
+    acc01 = _mm512_dpbusd_epi32(acc01, rhs0_even_u4, lhs_even_s16_high_s8);
+    acc02 = _mm512_dpbusd_epi32(acc02, lhs_odd_s16_low_u8, rhs0_odd_u4);
+    acc03 = _mm512_dpbusd_epi32(acc03, rhs0_odd_u4, lhs_odd_s16_high_s8);
+    acc1 = _mm512_dpbusd_epi32(acc1, lhs_even_s16_low_u8, rhs1_even_u4);
+    acc11 = _mm512_dpbusd_epi32(acc11, rhs1_even_u4, lhs_even_s16_high_s8);
+    acc12 = _mm512_dpbusd_epi32(acc12, lhs_odd_s16_low_u8, rhs1_odd_u4);
+    acc13 = _mm512_dpbusd_epi32(acc13, rhs1_odd_u4, lhs_odd_s16_high_s8);
   }
 
   // The accumulators that contain products against high 8bit parts of s16 LHS
   // values need to be left-shifted by 8 bits to account for that.
-  for (int i = 0; i < N0 / 16; ++i) {
-    acc[4 * i + 1] = _mm512_slli_epi32(acc[4 * i + 1], 8);
-    acc[4 * i + 3] = _mm512_slli_epi32(acc[4 * i + 3], 8);
-  }
+  acc01 = _mm512_slli_epi32(acc01, 8);
+  acc03 = _mm512_slli_epi32(acc03, 8);
+  acc11 = _mm512_slli_epi32(acc11, 8);
+  acc13 = _mm512_slli_epi32(acc13, 8);
 
   // Add accumulators together.
-  for (int i = 0; i < N0 / 16; ++i) {
-    for (int j = 1; j <= 3; ++j) {
-      acc[4 * i + 0] = _mm512_add_epi32(acc[4 * i + 0], acc[4 * i + j]);
-    }
-  }
+  acc0 = _mm512_add_epi32(acc0, acc01);
+  acc1 = _mm512_add_epi32(acc1, acc11);
+  acc0 = _mm512_add_epi32(acc0, acc02);
+  acc1 = _mm512_add_epi32(acc1, acc12);
+  acc0 = _mm512_add_epi32(acc0, acc03);
+  acc1 = _mm512_add_epi32(acc1, acc13);
 
-  for (int i = 0; i < N0 / 16; ++i) {
-    _mm512_storeu_si512((__m512i*)(out_ptr + 16 * i), acc[4 * i]);
-  }
-}
-
-void iree_uk_mmt4d_tile_s16u4s32_1x16x8_x86_64_avx512_vnni(
-    void* IREE_UK_RESTRICT out_tile, const void* IREE_UK_RESTRICT lhs_panel,
-    const void* IREE_UK_RESTRICT rhs_panel,
-    const iree_uk_mmt4d_params_t* params) {
-  iree_uk_mmt4d_tile_s16u4s32_1x16x8_to_1x32x8_x86_64_avx512_vnni(
-      out_tile, lhs_panel, rhs_panel, params, 16);
-}
-
-void iree_uk_mmt4d_tile_s16u4s32_1x32x8_x86_64_avx512_vnni(
-    void* IREE_UK_RESTRICT out_tile, const void* IREE_UK_RESTRICT lhs_panel,
-    const void* IREE_UK_RESTRICT rhs_panel,
-    const iree_uk_mmt4d_params_t* params) {
-  iree_uk_mmt4d_tile_s16u4s32_1x16x8_to_1x32x8_x86_64_avx512_vnni(
-      out_tile, lhs_panel, rhs_panel, params, 32);
+  // Store.
+  _mm512_storeu_si512((__m512i*)(out_ptr + 16 * 0), acc0);
+  _mm512_storeu_si512((__m512i*)(out_ptr + 16 * 1), acc1);
 }
diff --git a/runtime/src/iree/builtins/ukernel/arch/x86_64/mmt4d_x86_64_entry_point.c b/runtime/src/iree/builtins/ukernel/arch/x86_64/mmt4d_x86_64_entry_point.c
index 9c4c17c..6e367d7 100644
--- a/runtime/src/iree/builtins/ukernel/arch/x86_64/mmt4d_x86_64_entry_point.c
+++ b/runtime/src/iree/builtins/ukernel/arch/x86_64/mmt4d_x86_64_entry_point.c
@@ -374,16 +374,11 @@
 }
 
 static iree_uk_mmt4d_tile_func_t
-iree_uk_mmt4d_select_tile_func_x86_64_s16u4s32_1xN0x8(
+iree_uk_mmt4d_select_tile_func_x86_64_s16u4s32_1x32x8(
     const iree_uk_mmt4d_params_t* params) {
 #if defined(IREE_UK_BUILD_X86_64_AVX512_VNNI)
   if (iree_uk_cpu_supports_avx512_vnni(params->cpu_data)) {
-    switch (params->N0) {
-      case 16:
-        return iree_uk_mmt4d_tile_s16u4s32_1x16x8_x86_64_avx512_vnni;
-      case 32:
-        return iree_uk_mmt4d_tile_s16u4s32_1x32x8_x86_64_avx512_vnni;
-    }
+    return iree_uk_mmt4d_tile_s16u4s32_1x32x8_x86_64_avx512_vnni;
   }
 #endif
   return 0;
@@ -391,8 +386,8 @@
 
 static iree_uk_mmt4d_tile_func_t iree_uk_mmt4d_select_tile_func_x86_64_s16u4s32(
     const iree_uk_mmt4d_params_t* params) {
-  if (params->M0 == 1 && params->K0 == 8) {
-    return iree_uk_mmt4d_select_tile_func_x86_64_s16u4s32_1xN0x8(params);
+  if (params->M0 == 1 && params->N0 == 32 && params->K0 == 8) {
+    return iree_uk_mmt4d_select_tile_func_x86_64_s16u4s32_1x32x8(params);
   }
   return 0;
 }
diff --git a/runtime/src/iree/builtins/ukernel/arch/x86_64/mmt4d_x86_64_internal.h b/runtime/src/iree/builtins/ukernel/arch/x86_64/mmt4d_x86_64_internal.h
index b2eb9ef..419cfb6 100644
--- a/runtime/src/iree/builtins/ukernel/arch/x86_64/mmt4d_x86_64_internal.h
+++ b/runtime/src/iree/builtins/ukernel/arch/x86_64/mmt4d_x86_64_internal.h
@@ -120,8 +120,6 @@
 IREE_UK_MMT4D_TILE_FUNC_DECL(
     iree_uk_mmt4d_tile_s16s16s32_16x16x2_x86_64_avx512_vnni)
 IREE_UK_MMT4D_TILE_FUNC_DECL(
-    iree_uk_mmt4d_tile_s16u4s32_1x16x8_x86_64_avx512_vnni)
-IREE_UK_MMT4D_TILE_FUNC_DECL(
     iree_uk_mmt4d_tile_s16u4s32_1x32x8_x86_64_avx512_vnni)
 
 #endif  // foIREE_BUILTINS_UKERNEL_ARCH_X86_64_MMT4D_X86_64_INTERNAL_H_
diff --git a/runtime/src/iree/builtins/ukernel/tools/mmt4d_benchmark.c b/runtime/src/iree/builtins/ukernel/tools/mmt4d_benchmark.c
index 315eb0c..b1b8663 100644
--- a/runtime/src/iree/builtins/ukernel/tools/mmt4d_benchmark.c
+++ b/runtime/src/iree/builtins/ukernel/tools/mmt4d_benchmark.c
@@ -171,8 +171,6 @@
                                    "avx512_base");
   iree_uk_benchmark_register_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S16S16S32, 16, 16, 2,
                                    "avx512_vnni");
-  iree_uk_benchmark_register_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S16U4S32, 1, 16, 8,
-                                   "avx512_vnni");
   iree_uk_benchmark_register_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S16U4S32, 1, 32, 8,
                                    "avx512_vnni");
 #else   // defined(IREE_ARCH_ARM_64)
diff --git a/runtime/src/iree/builtins/ukernel/tools/mmt4d_test.c b/runtime/src/iree/builtins/ukernel/tools/mmt4d_test.c
index 07cb4c9..e81aa42 100644
--- a/runtime/src/iree/builtins/ukernel/tools/mmt4d_test.c
+++ b/runtime/src/iree/builtins/ukernel/tools/mmt4d_test.c
@@ -544,7 +544,6 @@
                      "avx512_base");
   iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S16S16S32, 16, 16, 2,
                      "avx512_vnni");
-  iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S16U4S32, 1, 16, 8, "avx512_vnni");
   iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S16U4S32, 1, 32, 8, "avx512_vnni");
 #endif  // defined(IREE_ARCH_ARM_64)