mmt4d ukernel: use fewer magic macros to generate tile-functions M0-variants (#16645)
The motivation for this is that some of the M0==1 variants need more
special-casing anyway to be truly efficient, so we are headed towards a
place where we don't necessarily use the same generic implementations
for all M0 values, so just decoupling them is a first step.
diff --git a/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_base.c b/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_base.c
index c04ff3e..5789b9c 100644
--- a/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_base.c
+++ b/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_base.c
@@ -73,12 +73,18 @@
}
}
-IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0_1_2_4_8(
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
iree_uk_mmt4d_tile_f32f32f32_1x8x1_to_8x8x1_arm_64,
- iree_uk_mmt4d_tile_f32f32f32_1x8x1_arm_64,
- iree_uk_mmt4d_tile_f32f32f32_2x8x1_arm_64,
- iree_uk_mmt4d_tile_f32f32f32_4x8x1_arm_64,
- iree_uk_mmt4d_tile_f32f32f32_8x8x1_arm_64)
+ iree_uk_mmt4d_tile_f32f32f32_1x8x1_arm_64, 1)
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
+ iree_uk_mmt4d_tile_f32f32f32_1x8x1_to_8x8x1_arm_64,
+ iree_uk_mmt4d_tile_f32f32f32_2x8x1_arm_64, 2)
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
+ iree_uk_mmt4d_tile_f32f32f32_1x8x1_to_8x8x1_arm_64,
+ iree_uk_mmt4d_tile_f32f32f32_4x8x1_arm_64, 4)
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
+ iree_uk_mmt4d_tile_f32f32f32_1x8x1_to_8x8x1_arm_64,
+ iree_uk_mmt4d_tile_f32f32f32_8x8x1_arm_64, 8)
// Shared implementation for f16f16f16 and f16f16f32.
// In the f16f16f16 case, intermediate roundings are skipped. This function
@@ -184,19 +190,31 @@
out_tile, lhs_panel, rhs_panel, params, IREE_UK_TYPE_FLOAT_32, M0);
}
-IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0_1_2_4_8(
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
iree_uk_mmt4d_tile_f16f16f32_1x8x1_to_8x8x1_arm_64,
- iree_uk_mmt4d_tile_f16f16f32_1x8x1_arm_64,
- iree_uk_mmt4d_tile_f16f16f32_2x8x1_arm_64,
- iree_uk_mmt4d_tile_f16f16f32_4x8x1_arm_64,
- iree_uk_mmt4d_tile_f16f16f32_8x8x1_arm_64)
+ iree_uk_mmt4d_tile_f16f16f32_1x8x1_arm_64, 1)
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
+ iree_uk_mmt4d_tile_f16f16f32_1x8x1_to_8x8x1_arm_64,
+ iree_uk_mmt4d_tile_f16f16f32_2x8x1_arm_64, 2)
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
+ iree_uk_mmt4d_tile_f16f16f32_1x8x1_to_8x8x1_arm_64,
+ iree_uk_mmt4d_tile_f16f16f32_4x8x1_arm_64, 4)
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
+ iree_uk_mmt4d_tile_f16f16f32_1x8x1_to_8x8x1_arm_64,
+ iree_uk_mmt4d_tile_f16f16f32_8x8x1_arm_64, 8)
-IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0_1_2_4_8(
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
iree_uk_mmt4d_tile_f16f16f16_1x8x1_to_8x8x1_arm_64,
- iree_uk_mmt4d_tile_f16f16f16_1x8x1_arm_64,
- iree_uk_mmt4d_tile_f16f16f16_2x8x1_arm_64,
- iree_uk_mmt4d_tile_f16f16f16_4x8x1_arm_64,
- iree_uk_mmt4d_tile_f16f16f16_8x8x1_arm_64)
+ iree_uk_mmt4d_tile_f16f16f16_1x8x1_arm_64, 1)
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
+ iree_uk_mmt4d_tile_f16f16f16_1x8x1_to_8x8x1_arm_64,
+ iree_uk_mmt4d_tile_f16f16f16_2x8x1_arm_64, 2)
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
+ iree_uk_mmt4d_tile_f16f16f16_1x8x1_to_8x8x1_arm_64,
+ iree_uk_mmt4d_tile_f16f16f16_4x8x1_arm_64, 4)
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
+ iree_uk_mmt4d_tile_f16f16f16_1x8x1_to_8x8x1_arm_64,
+ iree_uk_mmt4d_tile_f16f16f16_8x8x1_arm_64, 8)
IREE_UK_ATTRIBUTE_ALWAYS_INLINE static inline void
iree_uk_mmt4d_tile_s8s8s32_1x8x1_to_8x8x1_arm_64(
@@ -257,12 +275,18 @@
}
}
-IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0_1_2_4_8(
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
iree_uk_mmt4d_tile_s8s8s32_1x8x1_to_8x8x1_arm_64,
- iree_uk_mmt4d_tile_s8s8s32_1x8x1_arm_64,
- iree_uk_mmt4d_tile_s8s8s32_2x8x1_arm_64,
- iree_uk_mmt4d_tile_s8s8s32_4x8x1_arm_64,
- iree_uk_mmt4d_tile_s8s8s32_8x8x1_arm_64)
+ iree_uk_mmt4d_tile_s8s8s32_1x8x1_arm_64, 1)
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
+ iree_uk_mmt4d_tile_s8s8s32_1x8x1_to_8x8x1_arm_64,
+ iree_uk_mmt4d_tile_s8s8s32_2x8x1_arm_64, 2)
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
+ iree_uk_mmt4d_tile_s8s8s32_1x8x1_to_8x8x1_arm_64,
+ iree_uk_mmt4d_tile_s8s8s32_4x8x1_arm_64, 4)
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
+ iree_uk_mmt4d_tile_s8s8s32_1x8x1_to_8x8x1_arm_64,
+ iree_uk_mmt4d_tile_s8s8s32_8x8x1_arm_64, 8)
// This kernel is an adaptation of the kernel
// `qd8-f32-qc4w-gemm-1x16-minmax-neon-mlal-lane.c` in
@@ -398,8 +422,12 @@
}
}
-IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0_1_2_4(
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
iree_uk_mmt4d_tile_s8s4s32_1x16x2_to_4x16x2_arm_64,
- iree_uk_mmt4d_tile_s8s4s32_1x16x2_arm_64,
- iree_uk_mmt4d_tile_s8s4s32_2x16x2_arm_64,
- iree_uk_mmt4d_tile_s8s4s32_4x16x2_arm_64)
+ iree_uk_mmt4d_tile_s8s4s32_1x16x2_arm_64, 1)
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
+ iree_uk_mmt4d_tile_s8s4s32_1x16x2_to_4x16x2_arm_64,
+ iree_uk_mmt4d_tile_s8s4s32_2x16x2_arm_64, 2)
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
+ iree_uk_mmt4d_tile_s8s4s32_1x16x2_to_4x16x2_arm_64,
+ iree_uk_mmt4d_tile_s8s4s32_4x16x2_arm_64, 4)
diff --git a/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_bf16.c b/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_bf16.c
index e66e4ec..91d87b3 100644
--- a/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_bf16.c
+++ b/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_bf16.c
@@ -140,16 +140,28 @@
out_tile, lhs_panel, rhs_panel, params, IREE_UK_TYPE_BFLOAT_16, M0);
}
-IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0_1_2_4_8(
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
iree_uk_mmt4d_tile_bf16bf16f32_1x8x4_to_8x8x4_arm_64_bf16,
- iree_uk_mmt4d_tile_bf16bf16f32_1x8x4_arm_64_bf16,
- iree_uk_mmt4d_tile_bf16bf16f32_2x8x4_arm_64_bf16,
- iree_uk_mmt4d_tile_bf16bf16f32_4x8x4_arm_64_bf16,
- iree_uk_mmt4d_tile_bf16bf16f32_8x8x4_arm_64_bf16)
+ iree_uk_mmt4d_tile_bf16bf16f32_1x8x4_arm_64_bf16, 1)
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
+ iree_uk_mmt4d_tile_bf16bf16f32_1x8x4_to_8x8x4_arm_64_bf16,
+ iree_uk_mmt4d_tile_bf16bf16f32_2x8x4_arm_64_bf16, 2)
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
+ iree_uk_mmt4d_tile_bf16bf16f32_1x8x4_to_8x8x4_arm_64_bf16,
+ iree_uk_mmt4d_tile_bf16bf16f32_4x8x4_arm_64_bf16, 4)
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
+ iree_uk_mmt4d_tile_bf16bf16f32_1x8x4_to_8x8x4_arm_64_bf16,
+ iree_uk_mmt4d_tile_bf16bf16f32_8x8x4_arm_64_bf16, 8)
-IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0_1_2_4_8(
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
iree_uk_mmt4d_tile_bf16bf16bf16_1x8x4_to_8x8x4_arm_64_bf16,
- iree_uk_mmt4d_tile_bf16bf16bf16_1x8x4_arm_64_bf16,
- iree_uk_mmt4d_tile_bf16bf16bf16_2x8x4_arm_64_bf16,
- iree_uk_mmt4d_tile_bf16bf16bf16_4x8x4_arm_64_bf16,
- iree_uk_mmt4d_tile_bf16bf16bf16_8x8x4_arm_64_bf16)
+ iree_uk_mmt4d_tile_bf16bf16bf16_1x8x4_arm_64_bf16, 1)
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
+ iree_uk_mmt4d_tile_bf16bf16bf16_1x8x4_to_8x8x4_arm_64_bf16,
+ iree_uk_mmt4d_tile_bf16bf16bf16_2x8x4_arm_64_bf16, 2)
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
+ iree_uk_mmt4d_tile_bf16bf16bf16_1x8x4_to_8x8x4_arm_64_bf16,
+ iree_uk_mmt4d_tile_bf16bf16bf16_4x8x4_arm_64_bf16, 4)
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
+ iree_uk_mmt4d_tile_bf16bf16bf16_1x8x4_to_8x8x4_arm_64_bf16,
+ iree_uk_mmt4d_tile_bf16bf16bf16_8x8x4_arm_64_bf16, 8)
diff --git a/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_dotprod.c b/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_dotprod.c
index 1686669..2430cda 100644
--- a/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_dotprod.c
+++ b/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_dotprod.c
@@ -67,12 +67,18 @@
}
}
-IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0_1_2_4_8(
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
iree_uk_mmt4d_tile_s8s8s32_1x8x4_to_8x8x4_arm_64_dotprod,
- iree_uk_mmt4d_tile_s8s8s32_1x8x4_arm_64_dotprod,
- iree_uk_mmt4d_tile_s8s8s32_2x8x4_arm_64_dotprod,
- iree_uk_mmt4d_tile_s8s8s32_4x8x4_arm_64_dotprod,
- iree_uk_mmt4d_tile_s8s8s32_8x8x4_arm_64_dotprod)
+ iree_uk_mmt4d_tile_s8s8s32_1x8x4_arm_64_dotprod, 1)
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
+ iree_uk_mmt4d_tile_s8s8s32_1x8x4_to_8x8x4_arm_64_dotprod,
+ iree_uk_mmt4d_tile_s8s8s32_2x8x4_arm_64_dotprod, 2)
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
+ iree_uk_mmt4d_tile_s8s8s32_1x8x4_to_8x8x4_arm_64_dotprod,
+ iree_uk_mmt4d_tile_s8s8s32_4x8x4_arm_64_dotprod, 4)
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
+ iree_uk_mmt4d_tile_s8s8s32_1x8x4_to_8x8x4_arm_64_dotprod,
+ iree_uk_mmt4d_tile_s8s8s32_8x8x4_arm_64_dotprod, 8)
static inline void iree_uk_mmt4d_tile_s8s4s32_1x8x8_to_8x8x8_arm_64_dotprod(
void* IREE_UK_RESTRICT out_tile, const void* IREE_UK_RESTRICT lhs_panel,
@@ -200,9 +206,15 @@
}
}
-IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0_1_2_4_8(
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
iree_uk_mmt4d_tile_s8s4s32_1x8x8_to_8x8x8_arm_64_dotprod,
- iree_uk_mmt4d_tile_s8s4s32_1x8x8_arm_64_dotprod,
- iree_uk_mmt4d_tile_s8s4s32_2x8x8_arm_64_dotprod,
- iree_uk_mmt4d_tile_s8s4s32_4x8x8_arm_64_dotprod,
- iree_uk_mmt4d_tile_s8s4s32_8x8x8_arm_64_dotprod)
+ iree_uk_mmt4d_tile_s8s4s32_1x8x8_arm_64_dotprod, 1)
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
+ iree_uk_mmt4d_tile_s8s4s32_1x8x8_to_8x8x8_arm_64_dotprod,
+ iree_uk_mmt4d_tile_s8s4s32_2x8x8_arm_64_dotprod, 2)
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
+ iree_uk_mmt4d_tile_s8s4s32_1x8x8_to_8x8x8_arm_64_dotprod,
+ iree_uk_mmt4d_tile_s8s4s32_4x8x8_arm_64_dotprod, 4)
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
+ iree_uk_mmt4d_tile_s8s4s32_1x8x8_to_8x8x8_arm_64_dotprod,
+ iree_uk_mmt4d_tile_s8s4s32_8x8x8_arm_64_dotprod, 8)
diff --git a/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_fp16fml.c b/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_fp16fml.c
index f14052b..7fa6c5e 100644
--- a/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_fp16fml.c
+++ b/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_fp16fml.c
@@ -90,9 +90,15 @@
}
}
-IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0_1_2_4_8(
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
iree_uk_mmt4d_tile_f16f16f32_1x8x1_to_8x8x1_arm_64_fp16fml,
- iree_uk_mmt4d_tile_f16f16f32_1x8x1_arm_64_fp16fml,
- iree_uk_mmt4d_tile_f16f16f32_2x8x1_arm_64_fp16fml,
- iree_uk_mmt4d_tile_f16f16f32_4x8x1_arm_64_fp16fml,
- iree_uk_mmt4d_tile_f16f16f32_8x8x1_arm_64_fp16fml)
+ iree_uk_mmt4d_tile_f16f16f32_1x8x1_arm_64_fp16fml, 1)
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
+ iree_uk_mmt4d_tile_f16f16f32_1x8x1_to_8x8x1_arm_64_fp16fml,
+ iree_uk_mmt4d_tile_f16f16f32_2x8x1_arm_64_fp16fml, 2)
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
+ iree_uk_mmt4d_tile_f16f16f32_1x8x1_to_8x8x1_arm_64_fp16fml,
+ iree_uk_mmt4d_tile_f16f16f32_4x8x1_arm_64_fp16fml, 4)
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
+ iree_uk_mmt4d_tile_f16f16f32_1x8x1_to_8x8x1_arm_64_fp16fml,
+ iree_uk_mmt4d_tile_f16f16f32_8x8x1_arm_64_fp16fml, 8)
diff --git a/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_fullfp16.c b/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_fullfp16.c
index dc3fc20..cface65 100644
--- a/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_fullfp16.c
+++ b/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_fullfp16.c
@@ -47,9 +47,15 @@
}
}
-IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0_1_2_4_8(
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
iree_uk_mmt4d_tile_f16f16f16_1x8x1_to_8x8x1_arm_64_fullfp16,
- iree_uk_mmt4d_tile_f16f16f16_1x8x1_arm_64_fullfp16,
- iree_uk_mmt4d_tile_f16f16f16_2x8x1_arm_64_fullfp16,
- iree_uk_mmt4d_tile_f16f16f16_4x8x1_arm_64_fullfp16,
- iree_uk_mmt4d_tile_f16f16f16_8x8x1_arm_64_fullfp16)
+ iree_uk_mmt4d_tile_f16f16f16_1x8x1_arm_64_fullfp16, 1)
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
+ iree_uk_mmt4d_tile_f16f16f16_1x8x1_to_8x8x1_arm_64_fullfp16,
+ iree_uk_mmt4d_tile_f16f16f16_2x8x1_arm_64_fullfp16, 2)
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
+ iree_uk_mmt4d_tile_f16f16f16_1x8x1_to_8x8x1_arm_64_fullfp16,
+ iree_uk_mmt4d_tile_f16f16f16_4x8x1_arm_64_fullfp16, 4)
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
+ iree_uk_mmt4d_tile_f16f16f16_1x8x1_to_8x8x1_arm_64_fullfp16,
+ iree_uk_mmt4d_tile_f16f16f16_8x8x1_arm_64_fullfp16, 8)
diff --git a/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_i8mm.c b/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_i8mm.c
index 7cd4735..a62e80c 100644
--- a/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_i8mm.c
+++ b/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_i8mm.c
@@ -95,9 +95,15 @@
}
}
-IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0_1_2_4_8(
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
iree_uk_mmt4d_tile_s8s8s32_1x8x8_to_8x8x8_arm_64_i8mm,
- iree_uk_mmt4d_tile_s8s8s32_1x8x8_arm_64_i8mm,
- iree_uk_mmt4d_tile_s8s8s32_2x8x8_arm_64_i8mm,
- iree_uk_mmt4d_tile_s8s8s32_4x8x8_arm_64_i8mm,
- iree_uk_mmt4d_tile_s8s8s32_8x8x8_arm_64_i8mm)
+ iree_uk_mmt4d_tile_s8s8s32_1x8x8_arm_64_i8mm, 1)
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
+ iree_uk_mmt4d_tile_s8s8s32_1x8x8_to_8x8x8_arm_64_i8mm,
+ iree_uk_mmt4d_tile_s8s8s32_2x8x8_arm_64_i8mm, 2)
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
+ iree_uk_mmt4d_tile_s8s8s32_1x8x8_to_8x8x8_arm_64_i8mm,
+ iree_uk_mmt4d_tile_s8s8s32_4x8x8_arm_64_i8mm, 4)
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
+ iree_uk_mmt4d_tile_s8s8s32_1x8x8_to_8x8x8_arm_64_i8mm,
+ iree_uk_mmt4d_tile_s8s8s32_8x8x8_arm_64_i8mm, 8)
diff --git a/runtime/src/iree/builtins/ukernel/arch/x86_64/mmt4d_x86_64_avx2_fma.c b/runtime/src/iree/builtins/ukernel/arch/x86_64/mmt4d_x86_64_avx2_fma.c
index b6424d2..3665304 100644
--- a/runtime/src/iree/builtins/ukernel/arch/x86_64/mmt4d_x86_64_avx2_fma.c
+++ b/runtime/src/iree/builtins/ukernel/arch/x86_64/mmt4d_x86_64_avx2_fma.c
@@ -39,12 +39,18 @@
}
}
-IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0_1_2_4_8(
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
iree_uk_mmt4d_tile_f32f32f32_1x8x1_to_8x8x1_x86_64_avx2_fma,
- iree_uk_mmt4d_tile_f32f32f32_1x8x1_x86_64_avx2_fma,
- iree_uk_mmt4d_tile_f32f32f32_2x8x1_x86_64_avx2_fma,
- iree_uk_mmt4d_tile_f32f32f32_4x8x1_x86_64_avx2_fma,
- iree_uk_mmt4d_tile_f32f32f32_8x8x1_x86_64_avx2_fma)
+ iree_uk_mmt4d_tile_f32f32f32_1x8x1_x86_64_avx2_fma, 1)
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
+ iree_uk_mmt4d_tile_f32f32f32_1x8x1_to_8x8x1_x86_64_avx2_fma,
+ iree_uk_mmt4d_tile_f32f32f32_2x8x1_x86_64_avx2_fma, 2)
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
+ iree_uk_mmt4d_tile_f32f32f32_1x8x1_to_8x8x1_x86_64_avx2_fma,
+ iree_uk_mmt4d_tile_f32f32f32_4x8x1_x86_64_avx2_fma, 4)
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
+ iree_uk_mmt4d_tile_f32f32f32_1x8x1_to_8x8x1_x86_64_avx2_fma,
+ iree_uk_mmt4d_tile_f32f32f32_8x8x1_x86_64_avx2_fma, 8)
// Shared implementation for f16f16f16 and f16f16f32.
// In the f16f16f16 case, intermediate roundings are skipped. This function
@@ -119,19 +125,31 @@
out_tile, lhs_panel, rhs_panel, params, IREE_UK_TYPE_FLOAT_16, M0);
}
-IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0_1_2_4_8(
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
iree_uk_mmt4d_tile_f16f16f32_1x8x1_to_8x8x1_x86_64_avx2_fma,
- iree_uk_mmt4d_tile_f16f16f32_1x8x1_x86_64_avx2_fma,
- iree_uk_mmt4d_tile_f16f16f32_2x8x1_x86_64_avx2_fma,
- iree_uk_mmt4d_tile_f16f16f32_4x8x1_x86_64_avx2_fma,
- iree_uk_mmt4d_tile_f16f16f32_8x8x1_x86_64_avx2_fma)
+ iree_uk_mmt4d_tile_f16f16f32_1x8x1_x86_64_avx2_fma, 1)
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
+ iree_uk_mmt4d_tile_f16f16f32_1x8x1_to_8x8x1_x86_64_avx2_fma,
+ iree_uk_mmt4d_tile_f16f16f32_2x8x1_x86_64_avx2_fma, 2)
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
+ iree_uk_mmt4d_tile_f16f16f32_1x8x1_to_8x8x1_x86_64_avx2_fma,
+ iree_uk_mmt4d_tile_f16f16f32_4x8x1_x86_64_avx2_fma, 4)
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
+ iree_uk_mmt4d_tile_f16f16f32_1x8x1_to_8x8x1_x86_64_avx2_fma,
+ iree_uk_mmt4d_tile_f16f16f32_8x8x1_x86_64_avx2_fma, 8)
-IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0_1_2_4_8(
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
iree_uk_mmt4d_tile_f16f16f16_1x8x1_to_8x8x1_x86_64_avx2_fma,
- iree_uk_mmt4d_tile_f16f16f16_1x8x1_x86_64_avx2_fma,
- iree_uk_mmt4d_tile_f16f16f16_2x8x1_x86_64_avx2_fma,
- iree_uk_mmt4d_tile_f16f16f16_4x8x1_x86_64_avx2_fma,
- iree_uk_mmt4d_tile_f16f16f16_8x8x1_x86_64_avx2_fma)
+ iree_uk_mmt4d_tile_f16f16f16_1x8x1_x86_64_avx2_fma, 1)
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
+ iree_uk_mmt4d_tile_f16f16f16_1x8x1_to_8x8x1_x86_64_avx2_fma,
+ iree_uk_mmt4d_tile_f16f16f16_2x8x1_x86_64_avx2_fma, 2)
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
+ iree_uk_mmt4d_tile_f16f16f16_1x8x1_to_8x8x1_x86_64_avx2_fma,
+ iree_uk_mmt4d_tile_f16f16f16_4x8x1_x86_64_avx2_fma, 4)
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
+ iree_uk_mmt4d_tile_f16f16f16_1x8x1_to_8x8x1_x86_64_avx2_fma,
+ iree_uk_mmt4d_tile_f16f16f16_8x8x1_x86_64_avx2_fma, 8)
IREE_UK_ATTRIBUTE_ALWAYS_INLINE static inline void
iree_uk_mmt4d_tile_s8s8s32_1x8x2_to_8x8x2_x86_64_avx2_fma(
@@ -222,12 +240,18 @@
}
}
-IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0_1_2_4_8(
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
iree_uk_mmt4d_tile_s8s8s32_1x8x2_to_8x8x2_x86_64_avx2_fma,
- iree_uk_mmt4d_tile_s8s8s32_1x8x2_x86_64_avx2_fma,
- iree_uk_mmt4d_tile_s8s8s32_2x8x2_x86_64_avx2_fma,
- iree_uk_mmt4d_tile_s8s8s32_4x8x2_x86_64_avx2_fma,
- iree_uk_mmt4d_tile_s8s8s32_8x8x2_x86_64_avx2_fma)
+ iree_uk_mmt4d_tile_s8s8s32_1x8x2_x86_64_avx2_fma, 1)
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
+ iree_uk_mmt4d_tile_s8s8s32_1x8x2_to_8x8x2_x86_64_avx2_fma,
+ iree_uk_mmt4d_tile_s8s8s32_2x8x2_x86_64_avx2_fma, 2)
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
+ iree_uk_mmt4d_tile_s8s8s32_1x8x2_to_8x8x2_x86_64_avx2_fma,
+ iree_uk_mmt4d_tile_s8s8s32_4x8x2_x86_64_avx2_fma, 4)
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
+ iree_uk_mmt4d_tile_s8s8s32_1x8x2_to_8x8x2_x86_64_avx2_fma,
+ iree_uk_mmt4d_tile_s8s8s32_8x8x2_x86_64_avx2_fma, 8)
IREE_UK_ATTRIBUTE_ALWAYS_INLINE static inline void
iree_uk_mmt4d_tile_s16s16s32_1x8x2_to_8x8x2_x86_64_avx2_fma(
@@ -316,9 +340,15 @@
}
}
-IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0_1_2_4_8(
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
iree_uk_mmt4d_tile_s16s16s32_1x8x2_to_8x8x2_x86_64_avx2_fma,
- iree_uk_mmt4d_tile_s16s16s32_1x8x2_x86_64_avx2_fma,
- iree_uk_mmt4d_tile_s16s16s32_2x8x2_x86_64_avx2_fma,
- iree_uk_mmt4d_tile_s16s16s32_4x8x2_x86_64_avx2_fma,
- iree_uk_mmt4d_tile_s16s16s32_8x8x2_x86_64_avx2_fma)
+ iree_uk_mmt4d_tile_s16s16s32_1x8x2_x86_64_avx2_fma, 1)
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
+ iree_uk_mmt4d_tile_s16s16s32_1x8x2_to_8x8x2_x86_64_avx2_fma,
+ iree_uk_mmt4d_tile_s16s16s32_2x8x2_x86_64_avx2_fma, 2)
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
+ iree_uk_mmt4d_tile_s16s16s32_1x8x2_to_8x8x2_x86_64_avx2_fma,
+ iree_uk_mmt4d_tile_s16s16s32_4x8x2_x86_64_avx2_fma, 4)
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
+ iree_uk_mmt4d_tile_s16s16s32_1x8x2_to_8x8x2_x86_64_avx2_fma,
+ iree_uk_mmt4d_tile_s16s16s32_8x8x2_x86_64_avx2_fma, 8)
diff --git a/runtime/src/iree/builtins/ukernel/arch/x86_64/mmt4d_x86_64_avx512_base.c b/runtime/src/iree/builtins/ukernel/arch/x86_64/mmt4d_x86_64_avx512_base.c
index d9291d6..ed4dcff 100644
--- a/runtime/src/iree/builtins/ukernel/arch/x86_64/mmt4d_x86_64_avx512_base.c
+++ b/runtime/src/iree/builtins/ukernel/arch/x86_64/mmt4d_x86_64_avx512_base.c
@@ -52,13 +52,21 @@
}
}
-IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0_1_2_4_8_16(
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
iree_uk_mmt4d_tile_f32f32f32_1x16x1_to_16x16x1_x86_64_avx512_base,
- iree_uk_mmt4d_tile_f32f32f32_1x16x1_x86_64_avx512_base,
- iree_uk_mmt4d_tile_f32f32f32_2x16x1_x86_64_avx512_base,
- iree_uk_mmt4d_tile_f32f32f32_4x16x1_x86_64_avx512_base,
- iree_uk_mmt4d_tile_f32f32f32_8x16x1_x86_64_avx512_base,
- iree_uk_mmt4d_tile_f32f32f32_16x16x1_x86_64_avx512_base)
+ iree_uk_mmt4d_tile_f32f32f32_1x16x1_x86_64_avx512_base, 1)
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
+ iree_uk_mmt4d_tile_f32f32f32_1x16x1_to_16x16x1_x86_64_avx512_base,
+ iree_uk_mmt4d_tile_f32f32f32_2x16x1_x86_64_avx512_base, 2)
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
+ iree_uk_mmt4d_tile_f32f32f32_1x16x1_to_16x16x1_x86_64_avx512_base,
+ iree_uk_mmt4d_tile_f32f32f32_4x16x1_x86_64_avx512_base, 4)
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
+ iree_uk_mmt4d_tile_f32f32f32_1x16x1_to_16x16x1_x86_64_avx512_base,
+ iree_uk_mmt4d_tile_f32f32f32_8x16x1_x86_64_avx512_base, 8)
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
+ iree_uk_mmt4d_tile_f32f32f32_1x16x1_to_16x16x1_x86_64_avx512_base,
+ iree_uk_mmt4d_tile_f32f32f32_16x16x1_x86_64_avx512_base, 16)
// Shared implementation for f16f16f16 and f16f16f32.
// In the f16f16f16 case, intermediate roundings are skipped. This function
@@ -101,34 +109,10 @@
__m512 rhs = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)rhs_ptr));
_mm_prefetch((const char*)(rhs_ptr + 128), _MM_HINT_T0);
rhs_ptr += 16;
- // Unrolling needed to avoid 20% perf regression on Clang 15 on AMD Zen4.
-#define IREE_UK_F16F16F32_FMA_STEP(i) \
- acc[i] = _mm512_fmadd_ps(_mm512_cvtph_ps(_mm256_set1_epi16(lhs_ptr[i])), \
- rhs, acc[i])
- do {
- IREE_UK_F16F16F32_FMA_STEP(0);
- if (M0 == 1) continue;
- IREE_UK_F16F16F32_FMA_STEP(1);
- if (M0 == 2) continue;
- IREE_UK_F16F16F32_FMA_STEP(2);
- IREE_UK_F16F16F32_FMA_STEP(3);
- if (M0 == 4) continue;
- IREE_UK_F16F16F32_FMA_STEP(4);
- IREE_UK_F16F16F32_FMA_STEP(5);
- IREE_UK_F16F16F32_FMA_STEP(6);
- IREE_UK_F16F16F32_FMA_STEP(7);
- if (M0 == 8) continue;
- IREE_UK_F16F16F32_FMA_STEP(8);
- IREE_UK_F16F16F32_FMA_STEP(9);
- IREE_UK_F16F16F32_FMA_STEP(10);
- IREE_UK_F16F16F32_FMA_STEP(11);
- IREE_UK_F16F16F32_FMA_STEP(12);
- IREE_UK_F16F16F32_FMA_STEP(13);
- IREE_UK_F16F16F32_FMA_STEP(14);
- IREE_UK_F16F16F32_FMA_STEP(15);
- } while (false);
-#undef IREE_UK_F16F16F32_FMA_STEP
-
+ IREE_UK_UNROLL for (int i = 0; i < M0; ++i) {
+ acc[i] = _mm512_fmadd_ps(_mm512_cvtph_ps(_mm256_set1_epi16(lhs_ptr[i])),
+ rhs, acc[i]);
+ }
_mm_prefetch((const char*)(lhs_ptr + 128), _MM_HINT_T0);
lhs_ptr += M0;
}
@@ -166,21 +150,37 @@
out_tile, lhs_panel, rhs_panel, params, IREE_UK_TYPE_FLOAT_16, M0);
}
-IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0_1_2_4_8_16(
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
iree_uk_mmt4d_tile_f16f16f32_1x16x1_to_16x16x1_x86_64_avx512_base,
- iree_uk_mmt4d_tile_f16f16f32_1x16x1_x86_64_avx512_base,
- iree_uk_mmt4d_tile_f16f16f32_2x16x1_x86_64_avx512_base,
- iree_uk_mmt4d_tile_f16f16f32_4x16x1_x86_64_avx512_base,
- iree_uk_mmt4d_tile_f16f16f32_8x16x1_x86_64_avx512_base,
- iree_uk_mmt4d_tile_f16f16f32_16x16x1_x86_64_avx512_base)
+ iree_uk_mmt4d_tile_f16f16f32_1x16x1_x86_64_avx512_base, 1)
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
+ iree_uk_mmt4d_tile_f16f16f32_1x16x1_to_16x16x1_x86_64_avx512_base,
+ iree_uk_mmt4d_tile_f16f16f32_2x16x1_x86_64_avx512_base, 2)
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
+ iree_uk_mmt4d_tile_f16f16f32_1x16x1_to_16x16x1_x86_64_avx512_base,
+ iree_uk_mmt4d_tile_f16f16f32_4x16x1_x86_64_avx512_base, 4)
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
+ iree_uk_mmt4d_tile_f16f16f32_1x16x1_to_16x16x1_x86_64_avx512_base,
+ iree_uk_mmt4d_tile_f16f16f32_8x16x1_x86_64_avx512_base, 8)
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
+ iree_uk_mmt4d_tile_f16f16f32_1x16x1_to_16x16x1_x86_64_avx512_base,
+ iree_uk_mmt4d_tile_f16f16f32_16x16x1_x86_64_avx512_base, 16)
-IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0_1_2_4_8_16(
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
iree_uk_mmt4d_tile_f16f16f16_1x16x1_to_16x16x1_x86_64_avx512_base,
- iree_uk_mmt4d_tile_f16f16f16_1x16x1_x86_64_avx512_base,
- iree_uk_mmt4d_tile_f16f16f16_2x16x1_x86_64_avx512_base,
- iree_uk_mmt4d_tile_f16f16f16_4x16x1_x86_64_avx512_base,
- iree_uk_mmt4d_tile_f16f16f16_8x16x1_x86_64_avx512_base,
- iree_uk_mmt4d_tile_f16f16f16_16x16x1_x86_64_avx512_base)
+ iree_uk_mmt4d_tile_f16f16f16_1x16x1_x86_64_avx512_base, 1)
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
+ iree_uk_mmt4d_tile_f16f16f16_1x16x1_to_16x16x1_x86_64_avx512_base,
+ iree_uk_mmt4d_tile_f16f16f16_2x16x1_x86_64_avx512_base, 2)
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
+ iree_uk_mmt4d_tile_f16f16f16_1x16x1_to_16x16x1_x86_64_avx512_base,
+ iree_uk_mmt4d_tile_f16f16f16_4x16x1_x86_64_avx512_base, 4)
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
+ iree_uk_mmt4d_tile_f16f16f16_1x16x1_to_16x16x1_x86_64_avx512_base,
+ iree_uk_mmt4d_tile_f16f16f16_8x16x1_x86_64_avx512_base, 8)
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
+ iree_uk_mmt4d_tile_f16f16f16_1x16x1_to_16x16x1_x86_64_avx512_base,
+ iree_uk_mmt4d_tile_f16f16f16_16x16x1_x86_64_avx512_base, 16)
IREE_UK_ATTRIBUTE_ALWAYS_INLINE
static inline void
@@ -304,13 +304,21 @@
}
}
-IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0_1_2_4_8_16(
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
iree_uk_mmt4d_tile_s8s8s32_1x16x2_to_16x16x2_x86_64_avx512_base,
- iree_uk_mmt4d_tile_s8s8s32_1x16x2_x86_64_avx512_base,
- iree_uk_mmt4d_tile_s8s8s32_2x16x2_x86_64_avx512_base,
- iree_uk_mmt4d_tile_s8s8s32_4x16x2_x86_64_avx512_base,
- iree_uk_mmt4d_tile_s8s8s32_8x16x2_x86_64_avx512_base,
- iree_uk_mmt4d_tile_s8s8s32_16x16x2_x86_64_avx512_base)
+ iree_uk_mmt4d_tile_s8s8s32_1x16x2_x86_64_avx512_base, 1)
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
+ iree_uk_mmt4d_tile_s8s8s32_1x16x2_to_16x16x2_x86_64_avx512_base,
+ iree_uk_mmt4d_tile_s8s8s32_2x16x2_x86_64_avx512_base, 2)
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
+ iree_uk_mmt4d_tile_s8s8s32_1x16x2_to_16x16x2_x86_64_avx512_base,
+ iree_uk_mmt4d_tile_s8s8s32_4x16x2_x86_64_avx512_base, 4)
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
+ iree_uk_mmt4d_tile_s8s8s32_1x16x2_to_16x16x2_x86_64_avx512_base,
+ iree_uk_mmt4d_tile_s8s8s32_8x16x2_x86_64_avx512_base, 8)
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
+ iree_uk_mmt4d_tile_s8s8s32_1x16x2_to_16x16x2_x86_64_avx512_base,
+ iree_uk_mmt4d_tile_s8s8s32_16x16x2_x86_64_avx512_base, 16)
IREE_UK_ATTRIBUTE_ALWAYS_INLINE
static inline void
@@ -425,10 +433,18 @@
}
}
-IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0_1_2_4_8_16(
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
iree_uk_mmt4d_tile_s16s16s32_1x16x2_to_16x16x2_x86_64_avx512_base,
- iree_uk_mmt4d_tile_s16s16s32_1x16x2_x86_64_avx512_base,
- iree_uk_mmt4d_tile_s16s16s32_2x16x2_x86_64_avx512_base,
- iree_uk_mmt4d_tile_s16s16s32_4x16x2_x86_64_avx512_base,
- iree_uk_mmt4d_tile_s16s16s32_8x16x2_x86_64_avx512_base,
- iree_uk_mmt4d_tile_s16s16s32_16x16x2_x86_64_avx512_base)
+ iree_uk_mmt4d_tile_s16s16s32_1x16x2_x86_64_avx512_base, 1)
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
+ iree_uk_mmt4d_tile_s16s16s32_1x16x2_to_16x16x2_x86_64_avx512_base,
+ iree_uk_mmt4d_tile_s16s16s32_2x16x2_x86_64_avx512_base, 2)
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
+ iree_uk_mmt4d_tile_s16s16s32_1x16x2_to_16x16x2_x86_64_avx512_base,
+ iree_uk_mmt4d_tile_s16s16s32_4x16x2_x86_64_avx512_base, 4)
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
+ iree_uk_mmt4d_tile_s16s16s32_1x16x2_to_16x16x2_x86_64_avx512_base,
+ iree_uk_mmt4d_tile_s16s16s32_8x16x2_x86_64_avx512_base, 8)
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
+ iree_uk_mmt4d_tile_s16s16s32_1x16x2_to_16x16x2_x86_64_avx512_base,
+ iree_uk_mmt4d_tile_s16s16s32_16x16x2_x86_64_avx512_base, 16)
diff --git a/runtime/src/iree/builtins/ukernel/arch/x86_64/mmt4d_x86_64_avx512_bf16.c b/runtime/src/iree/builtins/ukernel/arch/x86_64/mmt4d_x86_64_avx512_bf16.c
index 1d3d200..97eae89 100644
--- a/runtime/src/iree/builtins/ukernel/arch/x86_64/mmt4d_x86_64_avx512_bf16.c
+++ b/runtime/src/iree/builtins/ukernel/arch/x86_64/mmt4d_x86_64_avx512_bf16.c
@@ -114,18 +114,34 @@
out_tile, lhs_panel, rhs_panel, params, IREE_UK_TYPE_BFLOAT_16, M0);
}
-IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0_1_2_4_8_16(
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
iree_uk_mmt4d_tile_bf16bf16f32_1x16x2_to_16x16x2_x86_64_avx512_bf16,
- iree_uk_mmt4d_tile_bf16bf16f32_1x16x2_x86_64_avx512_bf16,
- iree_uk_mmt4d_tile_bf16bf16f32_2x16x2_x86_64_avx512_bf16,
- iree_uk_mmt4d_tile_bf16bf16f32_4x16x2_x86_64_avx512_bf16,
- iree_uk_mmt4d_tile_bf16bf16f32_8x16x2_x86_64_avx512_bf16,
- iree_uk_mmt4d_tile_bf16bf16f32_16x16x2_x86_64_avx512_bf16)
+ iree_uk_mmt4d_tile_bf16bf16f32_1x16x2_x86_64_avx512_bf16, 1)
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
+ iree_uk_mmt4d_tile_bf16bf16f32_1x16x2_to_16x16x2_x86_64_avx512_bf16,
+ iree_uk_mmt4d_tile_bf16bf16f32_2x16x2_x86_64_avx512_bf16, 2)
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
+ iree_uk_mmt4d_tile_bf16bf16f32_1x16x2_to_16x16x2_x86_64_avx512_bf16,
+ iree_uk_mmt4d_tile_bf16bf16f32_4x16x2_x86_64_avx512_bf16, 4)
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
+ iree_uk_mmt4d_tile_bf16bf16f32_1x16x2_to_16x16x2_x86_64_avx512_bf16,
+ iree_uk_mmt4d_tile_bf16bf16f32_8x16x2_x86_64_avx512_bf16, 8)
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
+ iree_uk_mmt4d_tile_bf16bf16f32_1x16x2_to_16x16x2_x86_64_avx512_bf16,
+ iree_uk_mmt4d_tile_bf16bf16f32_16x16x2_x86_64_avx512_bf16, 16)
-IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0_1_2_4_8_16(
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
iree_uk_mmt4d_tile_bf16bf16bf16_1x16x2_to_16x16x2_x86_64_avx512_bf16,
- iree_uk_mmt4d_tile_bf16bf16bf16_1x16x2_x86_64_avx512_bf16,
- iree_uk_mmt4d_tile_bf16bf16bf16_2x16x2_x86_64_avx512_bf16,
- iree_uk_mmt4d_tile_bf16bf16bf16_4x16x2_x86_64_avx512_bf16,
- iree_uk_mmt4d_tile_bf16bf16bf16_8x16x2_x86_64_avx512_bf16,
- iree_uk_mmt4d_tile_bf16bf16bf16_16x16x2_x86_64_avx512_bf16)
+ iree_uk_mmt4d_tile_bf16bf16bf16_1x16x2_x86_64_avx512_bf16, 1)
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
+ iree_uk_mmt4d_tile_bf16bf16bf16_1x16x2_to_16x16x2_x86_64_avx512_bf16,
+ iree_uk_mmt4d_tile_bf16bf16bf16_2x16x2_x86_64_avx512_bf16, 2)
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
+ iree_uk_mmt4d_tile_bf16bf16bf16_1x16x2_to_16x16x2_x86_64_avx512_bf16,
+ iree_uk_mmt4d_tile_bf16bf16bf16_4x16x2_x86_64_avx512_bf16, 4)
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
+ iree_uk_mmt4d_tile_bf16bf16bf16_1x16x2_to_16x16x2_x86_64_avx512_bf16,
+ iree_uk_mmt4d_tile_bf16bf16bf16_8x16x2_x86_64_avx512_bf16, 8)
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
+ iree_uk_mmt4d_tile_bf16bf16bf16_1x16x2_to_16x16x2_x86_64_avx512_bf16,
+ iree_uk_mmt4d_tile_bf16bf16bf16_16x16x2_x86_64_avx512_bf16, 16)
diff --git a/runtime/src/iree/builtins/ukernel/arch/x86_64/mmt4d_x86_64_avx512_vnni.c b/runtime/src/iree/builtins/ukernel/arch/x86_64/mmt4d_x86_64_avx512_vnni.c
index 331d2eb..2828e6b 100644
--- a/runtime/src/iree/builtins/ukernel/arch/x86_64/mmt4d_x86_64_avx512_vnni.c
+++ b/runtime/src/iree/builtins/ukernel/arch/x86_64/mmt4d_x86_64_avx512_vnni.c
@@ -128,13 +128,21 @@
}
}
-IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0_1_2_4_8_16(
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
iree_uk_mmt4d_tile_s8s8s32_1x16x2_to_16x16x2_x86_64_avx512_vnni,
- iree_uk_mmt4d_tile_s8s8s32_1x16x2_x86_64_avx512_vnni,
- iree_uk_mmt4d_tile_s8s8s32_2x16x2_x86_64_avx512_vnni,
- iree_uk_mmt4d_tile_s8s8s32_4x16x2_x86_64_avx512_vnni,
- iree_uk_mmt4d_tile_s8s8s32_8x16x2_x86_64_avx512_vnni,
- iree_uk_mmt4d_tile_s8s8s32_16x16x2_x86_64_avx512_vnni)
+ iree_uk_mmt4d_tile_s8s8s32_1x16x2_x86_64_avx512_vnni, 1)
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
+ iree_uk_mmt4d_tile_s8s8s32_1x16x2_to_16x16x2_x86_64_avx512_vnni,
+ iree_uk_mmt4d_tile_s8s8s32_2x16x2_x86_64_avx512_vnni, 2)
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
+ iree_uk_mmt4d_tile_s8s8s32_1x16x2_to_16x16x2_x86_64_avx512_vnni,
+ iree_uk_mmt4d_tile_s8s8s32_4x16x2_x86_64_avx512_vnni, 4)
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
+ iree_uk_mmt4d_tile_s8s8s32_1x16x2_to_16x16x2_x86_64_avx512_vnni,
+ iree_uk_mmt4d_tile_s8s8s32_8x16x2_x86_64_avx512_vnni, 8)
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
+ iree_uk_mmt4d_tile_s8s8s32_1x16x2_to_16x16x2_x86_64_avx512_vnni,
+ iree_uk_mmt4d_tile_s8s8s32_16x16x2_x86_64_avx512_vnni, 16)
static inline void
iree_uk_mmt4d_tile_s16s16s32_1x16x2_to_16x16x2_x86_64_avx512_vnni(
@@ -247,13 +255,21 @@
}
}
-IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0_1_2_4_8_16(
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
iree_uk_mmt4d_tile_s16s16s32_1x16x2_to_16x16x2_x86_64_avx512_vnni,
- iree_uk_mmt4d_tile_s16s16s32_1x16x2_x86_64_avx512_vnni,
- iree_uk_mmt4d_tile_s16s16s32_2x16x2_x86_64_avx512_vnni,
- iree_uk_mmt4d_tile_s16s16s32_4x16x2_x86_64_avx512_vnni,
- iree_uk_mmt4d_tile_s16s16s32_8x16x2_x86_64_avx512_vnni,
- iree_uk_mmt4d_tile_s16s16s32_16x16x2_x86_64_avx512_vnni)
+ iree_uk_mmt4d_tile_s16s16s32_1x16x2_x86_64_avx512_vnni, 1)
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
+ iree_uk_mmt4d_tile_s16s16s32_1x16x2_to_16x16x2_x86_64_avx512_vnni,
+ iree_uk_mmt4d_tile_s16s16s32_2x16x2_x86_64_avx512_vnni, 2)
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
+ iree_uk_mmt4d_tile_s16s16s32_1x16x2_to_16x16x2_x86_64_avx512_vnni,
+ iree_uk_mmt4d_tile_s16s16s32_4x16x2_x86_64_avx512_vnni, 4)
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
+ iree_uk_mmt4d_tile_s16s16s32_1x16x2_to_16x16x2_x86_64_avx512_vnni,
+ iree_uk_mmt4d_tile_s16s16s32_8x16x2_x86_64_avx512_vnni, 8)
+IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
+ iree_uk_mmt4d_tile_s16s16s32_1x16x2_to_16x16x2_x86_64_avx512_vnni,
+ iree_uk_mmt4d_tile_s16s16s32_16x16x2_x86_64_avx512_vnni, 16)
// The idea of this kernel is to split the LHS s16 values into high and low
// 8-bit components to be able to use _mm512_dpbusd_epi32.
diff --git a/runtime/src/iree/builtins/ukernel/mmt4d_internal.h b/runtime/src/iree/builtins/ukernel/mmt4d_internal.h
index 8841728..f76d1c0 100644
--- a/runtime/src/iree/builtins/ukernel/mmt4d_internal.h
+++ b/runtime/src/iree/builtins/ukernel/mmt4d_internal.h
@@ -135,21 +135,6 @@
GENERIC_FUNC(out_tile, lhs_panel, rhs_panel, params, M0); \
}
-#define IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0_1_2_4(G, F1, F2, F4) \
- IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(G, F1, 1) \
- IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(G, F2, 2) \
- IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(G, F4, 4)
-
-#define IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0_1_2_4_8(G, F1, F2, F4, F8) \
- IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(G, F1, 1) \
- IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(G, F2, 2) \
- IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(G, F4, 4) \
- IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(G, F8, 8)
-
-#define IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0_1_2_4_8_16(G, F1, F2, F4, F8, F16) \
- IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0_1_2_4_8(G, F1, F2, F4, F8) \
- IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(G, F16, 16)
-
// Architecture-specific implementation, or generic fallback returning null.
iree_uk_mmt4d_tile_func_t iree_uk_mmt4d_select_tile_func_arch(
const iree_uk_mmt4d_params_t* params);