ukernel: unroll the s16u4 VNNI ukernel, and drop the unused N0=16 variant (#16047)
diff --git a/runtime/src/iree/builtins/ukernel/arch/x86_64/mmt4d_x86_64_avx512_vnni.c b/runtime/src/iree/builtins/ukernel/arch/x86_64/mmt4d_x86_64_avx512_vnni.c
index 2912022..5389c82 100644
--- a/runtime/src/iree/builtins/ukernel/arch/x86_64/mmt4d_x86_64_avx512_vnni.c
+++ b/runtime/src/iree/builtins/ukernel/arch/x86_64/mmt4d_x86_64_avx512_vnni.c
@@ -255,14 +255,6 @@
iree_uk_mmt4d_tile_s16s16s32_8x16x2_x86_64_avx512_vnni,
iree_uk_mmt4d_tile_s16s16s32_16x16x2_x86_64_avx512_vnni)
-// This kernel is parametrized in N0, allowing N0==16 and N0==32.
-// Performance on AMD Ryzen 9 7950X3D:
-// - with N0=16: 180 Gop/s
-// - with N0=32: 240 Gop/s
-// So there's a nice reward for going extra large, but that's also a liability
-// for vecmat shapes whose N dimension isn't a multiple of 32. Maybe we can
-// keep both for now.
-//
// The idea of this kernel is to split the LHS s16 values into high and low
// 8-bit components to be able to use _mm512_dpbusd_epi32.
//
@@ -285,33 +277,31 @@
// of the combinations of operands that we have to feed _mm512_dpbusd_epi32,
// we manage to find an operand order that accomodates the instruction's
// requirements on signednesses.
-static inline void
-iree_uk_mmt4d_tile_s16u4s32_1x16x8_to_1x32x8_x86_64_avx512_vnni(
+void iree_uk_mmt4d_tile_s16u4s32_1x32x8_x86_64_avx512_vnni(
void* IREE_UK_RESTRICT out_tile, const void* IREE_UK_RESTRICT lhs_panel,
const void* IREE_UK_RESTRICT rhs_panel,
- const iree_uk_mmt4d_params_t* params, int N0) {
- IREE_UK_ASSERT(N0 >= 16 && N0 <= 32 && iree_uk_is_po2_u32(N0));
+ const iree_uk_mmt4d_params_t* params) {
iree_uk_int32_t* IREE_UK_RESTRICT out_ptr = out_tile;
const iree_uk_int16_t* IREE_UK_RESTRICT lhs_ptr = lhs_panel;
const iree_uk_uint8_t* IREE_UK_RESTRICT rhs_ptr = rhs_panel;
- // acc[4 * i] is the actual accumulator.
- // The other acc[4 * i + j] are only used internally in the accumulation loop.
- __m512i acc[8];
+ // Accumulator shape: 1x32xs32, in 2 registers, each 1x16xs32.
+ __m512i acc0, acc1;
if (params->flags & IREE_UK_FLAG_MMT4D_ACCUMULATE) {
- for (int i = 0; i < N0 / 16; ++i) {
- acc[4 * i] = _mm512_loadu_si512((const __m512i*)(out_ptr + 16 * i));
- }
+ acc0 = _mm512_loadu_si512((const __m512i*)(out_ptr + 16 * 0));
+ acc1 = _mm512_loadu_si512((const __m512i*)(out_ptr + 16 * 1));
} else {
- for (int i = 0; i < N0 / 16; ++i) {
- acc[4 * i] = _mm512_setzero_si512();
- }
+ acc0 = _mm512_setzero_si512();
+ acc1 = _mm512_setzero_si512();
}
- for (int i = 0; i < N0 / 16; ++i) {
- for (int j = 1; j < 4; ++j) {
- acc[4 * i + j] = _mm512_setzero_si512();
- }
- }
-
+ // Additional internal accumulators - acc{i}{j} will be folded into acc{i} at
+ // the end of the loop.
+ __m512i acc01 = _mm512_setzero_si512();
+ __m512i acc02 = _mm512_setzero_si512();
+ __m512i acc03 = _mm512_setzero_si512();
+ __m512i acc11 = _mm512_setzero_si512();
+ __m512i acc12 = _mm512_setzero_si512();
+ __m512i acc13 = _mm512_setzero_si512();
+ // Shuffle indices.
const __m128i idx_0_mod_4 = _mm_set1_epi32(0x0c080400);
const __m128i idx_1_mod_4 = _mm_set1_epi32(0x0d090501);
const __m128i idx_2_mod_4 = _mm_set1_epi32(0x0e0a0602);
@@ -332,65 +322,44 @@
_mm512_broadcastq_epi64(_mm_shuffle_epi8(lhs, idx_2_mod_4));
__m512i lhs_odd_s16_high_s8 =
_mm512_broadcastq_epi64(_mm_shuffle_epi8(lhs, idx_3_mod_4));
- // Load 8x16xu4 RHS data.
- __m512i rhs[2];
- for (int i = 0; i < N0 / 16; ++i) {
- rhs[i] = _mm512_loadu_si512((const __m512i*)(rhs_ptr + 64 * i));
- }
- rhs_ptr += N0 * 4;
+ // Load 8x32xu4 RHS data, in 2 registers, each 8x16xu4.
+ __m512i rhs0 = _mm512_loadu_si512((const __m512i*)(rhs_ptr + 64 * 0));
+ __m512i rhs1 = _mm512_loadu_si512((const __m512i*)(rhs_ptr + 64 * 1));
+ rhs_ptr += 128;
// Extract the even/odd u4 lanes.
- __m512i rhs_even_u4[2];
- __m512i rhs_odd_u4[2];
- for (int i = 0; i < N0 / 16; ++i) {
- rhs_even_u4[i] = _mm512_and_si512(mask_0f, rhs[i]);
- rhs_odd_u4[i] = _mm512_and_si512(mask_0f, _mm512_srli_epi16(rhs[i], 4));
- }
+ __m512i rhs0_even_u4 = _mm512_and_si512(mask_0f, rhs0);
+ __m512i rhs1_even_u4 = _mm512_and_si512(mask_0f, rhs1);
+ __m512i rhs0_odd_u4 = _mm512_and_si512(mask_0f, _mm512_srli_epi16(rhs0, 4));
+ __m512i rhs1_odd_u4 = _mm512_and_si512(mask_0f, _mm512_srli_epi16(rhs1, 4));
// Arithmetic. See the comment at the top of this kernel for an explanation.
// _mm512_dpbusd_epi32 takes an unsigned LHS and a signed RHS. The parameter
// order in each call is adapted to that constraint.
- for (int i = 0; i < N0 / 16; ++i) {
- acc[4 * i + 0] = _mm512_dpbusd_epi32(acc[4 * i + 0], lhs_even_s16_low_u8,
- rhs_even_u4[i]);
- acc[4 * i + 1] = _mm512_dpbusd_epi32(acc[4 * i + 1], rhs_even_u4[i],
- lhs_even_s16_high_s8);
- acc[4 * i + 2] = _mm512_dpbusd_epi32(acc[4 * i + 2], lhs_odd_s16_low_u8,
- rhs_odd_u4[i]);
- acc[4 * i + 3] = _mm512_dpbusd_epi32(acc[4 * i + 3], rhs_odd_u4[i],
- lhs_odd_s16_high_s8);
- }
+ acc0 = _mm512_dpbusd_epi32(acc0, lhs_even_s16_low_u8, rhs0_even_u4);
+ acc01 = _mm512_dpbusd_epi32(acc01, rhs0_even_u4, lhs_even_s16_high_s8);
+ acc02 = _mm512_dpbusd_epi32(acc02, lhs_odd_s16_low_u8, rhs0_odd_u4);
+ acc03 = _mm512_dpbusd_epi32(acc03, rhs0_odd_u4, lhs_odd_s16_high_s8);
+ acc1 = _mm512_dpbusd_epi32(acc1, lhs_even_s16_low_u8, rhs1_even_u4);
+ acc11 = _mm512_dpbusd_epi32(acc11, rhs1_even_u4, lhs_even_s16_high_s8);
+ acc12 = _mm512_dpbusd_epi32(acc12, lhs_odd_s16_low_u8, rhs1_odd_u4);
+ acc13 = _mm512_dpbusd_epi32(acc13, rhs1_odd_u4, lhs_odd_s16_high_s8);
}
// The accumulators that contain products against high 8bit parts of s16 LHS
// values need to be left-shifted by 8 bits to account for that.
- for (int i = 0; i < N0 / 16; ++i) {
- acc[4 * i + 1] = _mm512_slli_epi32(acc[4 * i + 1], 8);
- acc[4 * i + 3] = _mm512_slli_epi32(acc[4 * i + 3], 8);
- }
+ acc01 = _mm512_slli_epi32(acc01, 8);
+ acc03 = _mm512_slli_epi32(acc03, 8);
+ acc11 = _mm512_slli_epi32(acc11, 8);
+ acc13 = _mm512_slli_epi32(acc13, 8);
// Add accumulators together.
- for (int i = 0; i < N0 / 16; ++i) {
- for (int j = 1; j <= 3; ++j) {
- acc[4 * i + 0] = _mm512_add_epi32(acc[4 * i + 0], acc[4 * i + j]);
- }
- }
+ acc0 = _mm512_add_epi32(acc0, acc01);
+ acc1 = _mm512_add_epi32(acc1, acc11);
+ acc0 = _mm512_add_epi32(acc0, acc02);
+ acc1 = _mm512_add_epi32(acc1, acc12);
+ acc0 = _mm512_add_epi32(acc0, acc03);
+ acc1 = _mm512_add_epi32(acc1, acc13);
- for (int i = 0; i < N0 / 16; ++i) {
- _mm512_storeu_si512((__m512i*)(out_ptr + 16 * i), acc[4 * i]);
- }
-}
-
-void iree_uk_mmt4d_tile_s16u4s32_1x16x8_x86_64_avx512_vnni(
- void* IREE_UK_RESTRICT out_tile, const void* IREE_UK_RESTRICT lhs_panel,
- const void* IREE_UK_RESTRICT rhs_panel,
- const iree_uk_mmt4d_params_t* params) {
- iree_uk_mmt4d_tile_s16u4s32_1x16x8_to_1x32x8_x86_64_avx512_vnni(
- out_tile, lhs_panel, rhs_panel, params, 16);
-}
-
-void iree_uk_mmt4d_tile_s16u4s32_1x32x8_x86_64_avx512_vnni(
- void* IREE_UK_RESTRICT out_tile, const void* IREE_UK_RESTRICT lhs_panel,
- const void* IREE_UK_RESTRICT rhs_panel,
- const iree_uk_mmt4d_params_t* params) {
- iree_uk_mmt4d_tile_s16u4s32_1x16x8_to_1x32x8_x86_64_avx512_vnni(
- out_tile, lhs_panel, rhs_panel, params, 32);
+ // Store.
+ _mm512_storeu_si512((__m512i*)(out_ptr + 16 * 0), acc0);
+ _mm512_storeu_si512((__m512i*)(out_ptr + 16 * 1), acc1);
}
diff --git a/runtime/src/iree/builtins/ukernel/arch/x86_64/mmt4d_x86_64_entry_point.c b/runtime/src/iree/builtins/ukernel/arch/x86_64/mmt4d_x86_64_entry_point.c
index 9c4c17c..6e367d7 100644
--- a/runtime/src/iree/builtins/ukernel/arch/x86_64/mmt4d_x86_64_entry_point.c
+++ b/runtime/src/iree/builtins/ukernel/arch/x86_64/mmt4d_x86_64_entry_point.c
@@ -374,16 +374,11 @@
}
static iree_uk_mmt4d_tile_func_t
-iree_uk_mmt4d_select_tile_func_x86_64_s16u4s32_1xN0x8(
+iree_uk_mmt4d_select_tile_func_x86_64_s16u4s32_1x32x8(
const iree_uk_mmt4d_params_t* params) {
#if defined(IREE_UK_BUILD_X86_64_AVX512_VNNI)
if (iree_uk_cpu_supports_avx512_vnni(params->cpu_data)) {
- switch (params->N0) {
- case 16:
- return iree_uk_mmt4d_tile_s16u4s32_1x16x8_x86_64_avx512_vnni;
- case 32:
- return iree_uk_mmt4d_tile_s16u4s32_1x32x8_x86_64_avx512_vnni;
- }
+ return iree_uk_mmt4d_tile_s16u4s32_1x32x8_x86_64_avx512_vnni;
}
#endif
return 0;
@@ -391,8 +386,8 @@
static iree_uk_mmt4d_tile_func_t iree_uk_mmt4d_select_tile_func_x86_64_s16u4s32(
const iree_uk_mmt4d_params_t* params) {
- if (params->M0 == 1 && params->K0 == 8) {
- return iree_uk_mmt4d_select_tile_func_x86_64_s16u4s32_1xN0x8(params);
+ if (params->M0 == 1 && params->N0 == 32 && params->K0 == 8) {
+ return iree_uk_mmt4d_select_tile_func_x86_64_s16u4s32_1x32x8(params);
}
return 0;
}
diff --git a/runtime/src/iree/builtins/ukernel/arch/x86_64/mmt4d_x86_64_internal.h b/runtime/src/iree/builtins/ukernel/arch/x86_64/mmt4d_x86_64_internal.h
index b2eb9ef..419cfb6 100644
--- a/runtime/src/iree/builtins/ukernel/arch/x86_64/mmt4d_x86_64_internal.h
+++ b/runtime/src/iree/builtins/ukernel/arch/x86_64/mmt4d_x86_64_internal.h
@@ -120,8 +120,6 @@
IREE_UK_MMT4D_TILE_FUNC_DECL(
iree_uk_mmt4d_tile_s16s16s32_16x16x2_x86_64_avx512_vnni)
IREE_UK_MMT4D_TILE_FUNC_DECL(
- iree_uk_mmt4d_tile_s16u4s32_1x16x8_x86_64_avx512_vnni)
-IREE_UK_MMT4D_TILE_FUNC_DECL(
iree_uk_mmt4d_tile_s16u4s32_1x32x8_x86_64_avx512_vnni)
#endif // foIREE_BUILTINS_UKERNEL_ARCH_X86_64_MMT4D_X86_64_INTERNAL_H_
diff --git a/runtime/src/iree/builtins/ukernel/tools/mmt4d_benchmark.c b/runtime/src/iree/builtins/ukernel/tools/mmt4d_benchmark.c
index 315eb0c..b1b8663 100644
--- a/runtime/src/iree/builtins/ukernel/tools/mmt4d_benchmark.c
+++ b/runtime/src/iree/builtins/ukernel/tools/mmt4d_benchmark.c
@@ -171,8 +171,6 @@
"avx512_base");
iree_uk_benchmark_register_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S16S16S32, 16, 16, 2,
"avx512_vnni");
- iree_uk_benchmark_register_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S16U4S32, 1, 16, 8,
- "avx512_vnni");
iree_uk_benchmark_register_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S16U4S32, 1, 32, 8,
"avx512_vnni");
#else // defined(IREE_ARCH_ARM_64)
diff --git a/runtime/src/iree/builtins/ukernel/tools/mmt4d_test.c b/runtime/src/iree/builtins/ukernel/tools/mmt4d_test.c
index 07cb4c9..e81aa42 100644
--- a/runtime/src/iree/builtins/ukernel/tools/mmt4d_test.c
+++ b/runtime/src/iree/builtins/ukernel/tools/mmt4d_test.c
@@ -544,7 +544,6 @@
"avx512_base");
iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S16S16S32, 16, 16, 2,
"avx512_vnni");
- iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S16U4S32, 1, 16, 8, "avx512_vnni");
iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S16U4S32, 1, 32, 8, "avx512_vnni");
#endif // defined(IREE_ARCH_ARM_64)