Pass the cpu_data pointer to the ukernel, not just the value of field 0. (#10485)
diff --git a/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_select_tile_arm_64.c b/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_select_tile_arm_64.c
index e015f0d..5e3473a 100644
--- a/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_select_tile_arm_64.c
+++ b/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_select_tile_arm_64.c
@@ -35,7 +35,7 @@
iree_ukernel_mmt4d_select_tile_func_arm_64_i8i8i32_8x8x8(
const iree_ukernel_mmt4d_params_t* params) {
#ifdef IREE_UKERNEL_BUILD_ARM_64_I8MM
- if (params->cpu_data_field_0 & IREE_CPU_DATA_FIELD_0_AARCH64_HAVE_I8MM) {
+ if (params->cpu_data[0] & IREE_CPU_DATA_FIELD_0_AARCH64_HAVE_I8MM) {
return iree_ukernel_mmt4d_i8i8i32_tile_8x8x8_arm_64_i8mm;
}
#else
@@ -48,7 +48,7 @@
iree_ukernel_mmt4d_select_tile_func_arm_64_i8i8i32_8x8x4(
const iree_ukernel_mmt4d_params_t* params) {
#ifdef IREE_UKERNEL_BUILD_ARM_64_DOTPROD
- if (params->cpu_data_field_0 & IREE_CPU_DATA_FIELD_0_AARCH64_HAVE_DOTPROD) {
+ if (params->cpu_data[0] & IREE_CPU_DATA_FIELD_0_AARCH64_HAVE_DOTPROD) {
return iree_ukernel_mmt4d_i8i8i32_tile_8x8x4_arm_64_dotprod;
}
#else
diff --git a/runtime/src/iree/builtins/ukernel/mmt4d_types.h b/runtime/src/iree/builtins/ukernel/mmt4d_types.h
index 36e49d5..4ebb779 100644
--- a/runtime/src/iree/builtins/ukernel/mmt4d_types.h
+++ b/runtime/src/iree/builtins/ukernel/mmt4d_types.h
@@ -34,7 +34,7 @@
int32_t M0;
int32_t N0;
int32_t K0;
- uint64_t cpu_data_field_0;
+ const uint64_t* cpu_data;
};
typedef struct iree_ukernel_mmt4d_params_t iree_ukernel_mmt4d_params_t;
diff --git a/runtime/src/iree/builtins/ukernel/tools/mmt4d_benchmark.c b/runtime/src/iree/builtins/ukernel/tools/mmt4d_benchmark.c
index 26d3b40..b9b6d4e 100644
--- a/runtime/src/iree/builtins/ukernel/tools/mmt4d_benchmark.c
+++ b/runtime/src/iree/builtins/ukernel/tools/mmt4d_benchmark.c
@@ -38,7 +38,7 @@
int M0;
int N0;
int K0;
- uint64_t cpu_data_field_0;
+ const uint64_t* cpu_data;
};
typedef struct iree_mmt4d_benchmark_user_data_t
@@ -58,7 +58,7 @@
params.M0 = user_data->M0;
params.N0 = user_data->N0;
params.K0 = user_data->K0;
- params.cpu_data_field_0 = user_data->cpu_data_field_0;
+ params.cpu_data = user_data->cpu_data;
params.lhs_stride = params.K * params.M0 * params.K0;
params.rhs_stride = params.K * params.N0 * params.K0;
params.out_stride = params.N * params.M0 * params.N0;
@@ -112,9 +112,9 @@
static void iree_mmt4d_benchmark_register(
const iree_mmt4d_benchmark_user_data_t* user_data, const char* name) {
// Does this benchmark require an optional CPU feature?
- if (user_data->cpu_data_field_0) {
- if ((iree_cpu_data_field(0) & user_data->cpu_data_field_0) !=
- user_data->cpu_data_field_0) {
+ if (user_data->cpu_data[0]) {
+ if ((iree_cpu_data_field(0) & user_data->cpu_data[0]) !=
+ user_data->cpu_data[0]) {
// The CPU does not meet this benchmark's requirements. The builtin
// would crash.
return;
@@ -136,12 +136,14 @@
#define MMT4D_BENCHMARK_REGISTER(_type, _m0, _n0, _k0, _cpu_data_field_0, \
_label) \
do { \
+ static const uint64_t local_cpu_data[IREE_CPU_DATA_FIELD_COUNT] = { \
+ _cpu_data_field_0}; \
static const iree_mmt4d_benchmark_user_data_t user_data = { \
.type = iree_ukernel_mmt4d_type_##_type, \
.M0 = _m0, \
.N0 = _n0, \
.K0 = _k0, \
- .cpu_data_field_0 = _cpu_data_field_0, \
+ .cpu_data = local_cpu_data, \
}; \
iree_mmt4d_benchmark_register(&user_data, \
"iree_ukernel_mmt4d_" #_type "_" #_m0 \
diff --git a/runtime/src/iree/builtins/ukernel/tools/mmt4d_test.cc b/runtime/src/iree/builtins/ukernel/tools/mmt4d_test.cc
index 1d2bf8c..fb4601b 100644
--- a/runtime/src/iree/builtins/ukernel/tools/mmt4d_test.cc
+++ b/runtime/src/iree/builtins/ukernel/tools/mmt4d_test.cc
@@ -253,6 +253,9 @@
params.M0 = M0;
params.N0 = N0;
params.K0 = K0;
+ const uint64_t local_cpu_data[IREE_CPU_DATA_FIELD_COUNT] = {
+ cpu_data_field_0_bit};
+ params.cpu_data = local_cpu_data;
// First try without any optional CPU feature. This matters even when the
// feature is supported by the CPU because we want to test the fallback to
// architecture-default or generic code.
@@ -260,8 +263,7 @@
// If this is nonzero, we are asked to test again with this CPU feature.
if (cpu_data_field_0_bit) {
// Check if the CPU supports the feature (otherwise, we crash).
- params.cpu_data_field_0 = cpu_data_field_0_bit;
- bool supported = iree_cpu_data_field(0) & params.cpu_data_field_0;
+ bool supported = iree_cpu_data_field(0) & params.cpu_data[0];
if (supported) {
// Run with the optional CPU feature.
fprintf(stderr, "Device supports CPU feature: %s\n",
diff --git a/runtime/src/iree/builtins/ukernel/tools/mmt4d_test_utils.cc b/runtime/src/iree/builtins/ukernel/tools/mmt4d_test_utils.cc
index 916a7f3..878b996 100644
--- a/runtime/src/iree/builtins/ukernel/tools/mmt4d_test_utils.cc
+++ b/runtime/src/iree/builtins/ukernel/tools/mmt4d_test_utils.cc
@@ -145,15 +145,15 @@
// code path relies on the combination of two features.
// For now, asserting only one bit set, and taking advantage of that to work
// with plain string literals.
- assert(0 == (params->cpu_data_field_0 & (params->cpu_data_field_0 - 1)));
- if (params->cpu_data_field_0 == 0) {
+ assert(0 == (params->cpu_data[0] & (params->cpu_data[0] - 1)));
+ if (params->cpu_data[0] == 0) {
return "(none)";
}
#if defined(IREE_UKERNEL_ARCH_ARM_64)
- if (params->cpu_data_field_0 & IREE_CPU_DATA_FIELD_0_AARCH64_HAVE_I8MM) {
+ if (params->cpu_data[0] & IREE_CPU_DATA_FIELD_0_AARCH64_HAVE_I8MM) {
return "i8mm";
}
- if (params->cpu_data_field_0 & IREE_CPU_DATA_FIELD_0_AARCH64_HAVE_DOTPROD) {
+ if (params->cpu_data[0] & IREE_CPU_DATA_FIELD_0_AARCH64_HAVE_DOTPROD) {
return "dotprod";
}
#endif // defined(IREE_UKERNEL_ARCH_ARM_64)
diff --git a/runtime/src/iree/modules/vmvx/module.c b/runtime/src/iree/modules/vmvx/module.c
index 9310d34..da1db6b 100644
--- a/runtime/src/iree/modules/vmvx/module.c
+++ b/runtime/src/iree/modules/vmvx/module.c
@@ -693,7 +693,7 @@
.M0 = M0,
.N0 = N0,
.K0 = K0,
- .cpu_data_field_0 = iree_cpu_data_field(0),
+ .cpu_data = iree_cpu_data_fields(),
};
iree_ukernel_mmt4d_status_t status = iree_ukernel_mmt4d(&ukernel_params);
IREE_TRACE_ZONE_END(z0);