Fix enablement of mmt4d ukernel test cases based on ISA code paths built (#16637)
This fixes `build_test_all_windows` CI and adds CI configuration to
trigger it on all ukernel code changes going forward.
There is an unresolved mystery as to what exactly went wrong on
`build_test_all_windows` CI job, as its MSVC compiler failed the
`check_cxx_compile_flags` for `/arch:AVX2` while the exact same MSVC
version succeeded in the `build_test_runtime_windows` CI job.
But regardless, there was something not well thought out in how I did
the testcase enablement. I had added a global constant bool indicating
whether we had linked architecture-specific code, but that was only
per-architecture, not accounting for the fact that, depending on
`check_cxx_compile_flags`, some sub-architecture code paths could be
individually disabled.
This new PR redoes that: the global constant bools are dropped, and
instead, the problem is tackled differently in `mmt4d_test` vs
`mmt4d_benchmark`:
* In `mmt4d_test`, as we really want to avoid testcases silently testing
nothing, we keep running *without fallback*, but we now condition
testcases on `IREE_UK_BUILD_*` cmake-defined variables.
* In `mmt4d_benchmark`, we just enable the fallback, so in the worst
case if a code path is disabled, the outcome is a poor benchmark result.
`mmt4d_benchmark` doesn't need to catch enablement, that's already done
by `mmt4d_test`.
diff --git a/build_tools/github_actions/configure_ci.py b/build_tools/github_actions/configure_ci.py
index 9959b15..215b679 100755
--- a/build_tools/github_actions/configure_ci.py
+++ b/build_tools/github_actions/configure_ci.py
@@ -138,7 +138,10 @@
# The file paths should be specified using Unix shell-style wildcards.
PRESUBMIT_TOUCH_ONLY_JOBS = [
("build_test_all_macos_arm64", ["runtime/src/iree/hal/drivers/metal/*"]),
- ("build_test_all_windows", ["*win32*", "*windows*", "*msvc*"]),
+ (
+ "build_test_all_windows",
+ ["*win32*", "*windows*", "*msvc*", "runtime/src/iree/builtins/ukernel/*"],
+ ),
]
# Default presets enabled in CI.
diff --git a/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_entry_point.c b/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_entry_point.c
index 198e8b0..61ad8f9 100644
--- a/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_entry_point.c
+++ b/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_entry_point.c
@@ -306,5 +306,3 @@
return 0;
}
}
-
-const bool iree_uk_mmt4d_linked_arch_code = true;
diff --git a/runtime/src/iree/builtins/ukernel/arch/x86_64/mmt4d_x86_64_entry_point.c b/runtime/src/iree/builtins/ukernel/arch/x86_64/mmt4d_x86_64_entry_point.c
index 90d96bc..153b96e 100644
--- a/runtime/src/iree/builtins/ukernel/arch/x86_64/mmt4d_x86_64_entry_point.c
+++ b/runtime/src/iree/builtins/ukernel/arch/x86_64/mmt4d_x86_64_entry_point.c
@@ -415,5 +415,3 @@
return 0;
}
}
-
-const bool iree_uk_mmt4d_linked_arch_code = true;
diff --git a/runtime/src/iree/builtins/ukernel/fallback.c b/runtime/src/iree/builtins/ukernel/fallback.c
index 6eb2ad5..eaee5b2 100644
--- a/runtime/src/iree/builtins/ukernel/fallback.c
+++ b/runtime/src/iree/builtins/ukernel/fallback.c
@@ -14,8 +14,6 @@
return 0;
}
-const bool iree_uk_mmt4d_linked_arch_code = false;
-
iree_uk_pack_tile_func_t iree_uk_pack_select_tile_func_arch(
const iree_uk_pack_params_t* params) {
return 0;
diff --git a/runtime/src/iree/builtins/ukernel/mmt4d_internal.h b/runtime/src/iree/builtins/ukernel/mmt4d_internal.h
index 450e44d..8841728 100644
--- a/runtime/src/iree/builtins/ukernel/mmt4d_internal.h
+++ b/runtime/src/iree/builtins/ukernel/mmt4d_internal.h
@@ -154,9 +154,6 @@
iree_uk_mmt4d_tile_func_t iree_uk_mmt4d_select_tile_func_arch(
const iree_uk_mmt4d_params_t* params);
-// Indicator of architecture-specific implementation.
-extern const bool iree_uk_mmt4d_linked_arch_code;
-
// Generic fallback.
iree_uk_mmt4d_tile_func_t iree_uk_mmt4d_select_tile_func_generic(
const iree_uk_mmt4d_params_t* params);
diff --git a/runtime/src/iree/builtins/ukernel/tools/mmt4d_benchmark.c b/runtime/src/iree/builtins/ukernel/tools/mmt4d_benchmark.c
index 6dcd7e6..bfb932e 100644
--- a/runtime/src/iree/builtins/ukernel/tools/mmt4d_benchmark.c
+++ b/runtime/src/iree/builtins/ukernel/tools/mmt4d_benchmark.c
@@ -9,6 +9,8 @@
#include "iree/base/api.h"
#include "iree/base/internal/flags.h"
#include "iree/builtins/ukernel/api.h"
+#include "iree/builtins/ukernel/exported_bits.h"
+#include "iree/builtins/ukernel/mmt4d.h"
#include "iree/builtins/ukernel/mmt4d_internal.h"
#include "iree/builtins/ukernel/tools/benchmark.h"
#include "iree/builtins/ukernel/tools/util.h"
@@ -95,7 +97,8 @@
snprintf(name, sizeof name, "mmt4d_%s_tile_%dx%dx%d%s", type_str, M0, N0, K0,
code_path_suffix);
iree_uk_mmt4d_params_t params = {
- .flags = flags | IREE_UK_FLAG_MMT4D_SKIP_INTERMEDIATE_ROUNDINGS,
+ .flags = flags | IREE_UK_FLAG_MMT4D_SKIP_INTERMEDIATE_ROUNDINGS |
+ IREE_UK_FLAG_MMT4D_ALLOW_GENERIC_FALLBACK_TILE_FUNCTION,
.M0 = M0,
.N0 = N0,
.K0 = K0};
@@ -106,15 +109,6 @@
static void iree_uk_benchmark_register_mmt4d(iree_uk_uint32_t flags, int M0,
int N0, int K0,
const char* cpu_features) {
- // For non-fallback benchmarks, i.e. benchmarks for architecture-specific
- // cases, stop here if we haven't linked architecture-specific code, which is
- // the case in Bazel builds.
- if (!(flags & IREE_UK_FLAG_MMT4D_ALLOW_GENERIC_FALLBACK_TILE_FUNCTION)) {
- if (!iree_uk_mmt4d_linked_arch_code) {
- return;
- }
- }
-
// Test narrowed, power-of-two values of M0, as mmt4d kernels tend to have
// narrow variants for handling these cases.
for (int narrowM0 = 1; narrowM0 < M0; narrowM0 *= 2) {
diff --git a/runtime/src/iree/builtins/ukernel/tools/mmt4d_test.c b/runtime/src/iree/builtins/ukernel/tools/mmt4d_test.c
index 974afb6..370930a 100644
--- a/runtime/src/iree/builtins/ukernel/tools/mmt4d_test.c
+++ b/runtime/src/iree/builtins/ukernel/tools/mmt4d_test.c
@@ -525,15 +525,6 @@
static void iree_uk_test_mmt4d(iree_uk_uint32_t flags, int M0, int N0, int K0,
const char* cpu_features) {
- // For non-fallback tests, i.e. tests for architecture-specific cases, stop
- // here if we haven't linked architecture-specific code, which is the case in
- // Bazel builds.
- if (!(flags & IREE_UK_FLAG_MMT4D_ALLOW_GENERIC_FALLBACK_TILE_FUNCTION)) {
- if (!iree_uk_mmt4d_linked_arch_code) {
- return;
- }
- }
-
// Test narrowed, power-of-two values of M0, as mmt4d kernels tend to have
// narrow variants for handling these cases.
for (int narrowM0 = 1; narrowM0 < M0; narrowM0 *= 2) {
@@ -578,49 +569,71 @@
2, 9, 3, "");
#if defined(IREE_ARCH_ARM_64)
+
iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_F32F32F32, 8, 8, 1, "");
iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_F16F16F32, 8, 8, 1, "");
- iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_F16F16F32, 8, 8, 1, "fp16fml");
iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_SKIP_INTERMEDIATE_ROUNDINGS |
IREE_UK_FLAG_MMT4D_TYPE_F16F16F16,
8, 8, 1, "");
+ iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S8S8S32, 8, 8, 1, "");
+ iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S8S4S32, 4, 16, 2, "");
+#if defined(IREE_UK_BUILD_ARM_64_FP16FML)
+ iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_F16F16F32, 8, 8, 1, "fp16fml");
+#endif // defined(IREE_UK_BUILD_ARM_64_FP16FML)
+#if defined(IREE_UK_BUILD_ARM_64_FULLFP16)
iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_SKIP_INTERMEDIATE_ROUNDINGS |
IREE_UK_FLAG_MMT4D_TYPE_F16F16F16,
8, 8, 1, "fullfp16");
+#endif // defined(IREE_UK_BUILD_ARM_64_FULLFP16)
+#if defined(IREE_UK_BUILD_ARM_64_BF16)
iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_BF16BF16BF16, 8, 8, 4, "bf16");
iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_BF16BF16F32, 8, 8, 4, "bf16");
- iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S8S8S32, 8, 8, 1, "");
+#endif // defined(IREE_UK_BUILD_ARM_64_BF16)
+#if defined(IREE_UK_BUILD_ARM_64_DOTPROD)
iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S8S8S32, 8, 8, 4, "dotprod");
- iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S8S8S32, 8, 8, 8, "i8mm");
- iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S8S4S32, 4, 16, 2, "");
iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S8S4S32, 8, 8, 8, "dotprod");
+#endif // defined(IREE_UK_BUILD_ARM_64_DOTPROD)
+#if defined(IREE_UK_BUILD_ARM_64_I8MM)
+ iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S8S8S32, 8, 8, 8, "i8mm");
+#endif // defined(IREE_UK_BUILD_ARM_64_I8MM)
+
#elif defined(IREE_ARCH_X86_64)
+
+#if defined(IREE_UK_BUILD_X86_64_AVX2_FMA)
iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_F32F32F32, 8, 8, 1, "avx2_fma");
+ iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_F16F16F32, 8, 8, 1, "avx2_fma");
+ iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_SKIP_INTERMEDIATE_ROUNDINGS |
+ IREE_UK_FLAG_MMT4D_TYPE_F16F16F16,
+ 8, 8, 1, "avx2_fma");
+ iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S8S8S32, 8, 8, 2, "avx2_fma");
+ iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S16S16S32, 8, 8, 2, "avx2_fma");
+#endif // defined(IREE_UK_BUILD_X86_64_AVX2_FMA)
+#if defined(IREE_UK_BUILD_X86_64_AVX512_BASE)
iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_F32F32F32, 16, 16, 1,
"avx512_base");
- iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_F16F16F32, 8, 8, 1, "avx2_fma");
iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_F16F16F32, 16, 16, 1,
"avx512_base");
iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_SKIP_INTERMEDIATE_ROUNDINGS |
IREE_UK_FLAG_MMT4D_TYPE_F16F16F16,
- 8, 8, 1, "avx2_fma");
- iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_SKIP_INTERMEDIATE_ROUNDINGS |
- IREE_UK_FLAG_MMT4D_TYPE_F16F16F16,
16, 16, 1, "avx512_base");
+ iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S8S8S32, 16, 16, 2, "avx512_base");
+ iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S16S16S32, 16, 16, 2,
+ "avx512_base");
+#endif // defined(IREE_UK_BUILD_X86_64_AVX512_BASE)
+#if defined(IREE_UK_BUILD_X86_64_AVX512_BF16)
iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_BF16BF16F32, 16, 16, 2,
"avx512_bf16");
iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_SKIP_INTERMEDIATE_ROUNDINGS |
IREE_UK_FLAG_MMT4D_TYPE_BF16BF16BF16,
16, 16, 2, "avx512_bf16");
- iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S8S8S32, 8, 8, 2, "avx2_fma");
- iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S8S8S32, 16, 16, 2, "avx512_base");
+#endif // defined(IREE_UK_BUILD_X86_64_AVX512_BF16)
+#if defined(IREE_UK_BUILD_X86_64_AVX512_VNNI)
iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S8S8S32, 16, 16, 2, "avx512_vnni");
- iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S16S16S32, 8, 8, 2, "avx2_fma");
- iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S16S16S32, 16, 16, 2,
- "avx512_base");
iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S16S16S32, 16, 16, 2,
"avx512_vnni");
iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S16U4S32, 1, 32, 8, "avx512_vnni");
+#endif // defined(IREE_UK_BUILD_X86_64_AVX512_VNNI)
+
#endif // defined(IREE_ARCH_ARM_64)
return iree_uk_test_exit_status();