Mmt4d builtin ukernel test/benchmark (#10389)
diff --git a/runtime/src/iree/builtins/ukernel/tools/BUILD b/runtime/src/iree/builtins/ukernel/tools/BUILD
index ac1940a..21c4c87 100644
--- a/runtime/src/iree/builtins/ukernel/tools/BUILD
+++ b/runtime/src/iree/builtins/ukernel/tools/BUILD
@@ -13,11 +13,24 @@
licenses = ["notice"], # Apache 2.0
)
+cc_library(
+ name = "mmt4d_test_utils",
+ srcs = ["mmt4d_test_utils.cc"],
+ hdrs = ["mmt4d_test_utils.h"],
+ deps = [
+ "//runtime/src/iree/base",
+ "//runtime/src/iree/builtins/ukernel:types",
+ "//runtime/src/iree/schemas:cpu_data",
+ ],
+)
+
cc_binary_benchmark(
name = "mmt4d_benchmark",
srcs = ["mmt4d_benchmark.c"],
deps = [
+ ":mmt4d_test_utils",
"//runtime/src/iree/base",
+ "//runtime/src/iree/base/internal:cpu",
"//runtime/src/iree/base/internal:flags",
"//runtime/src/iree/builtins/ukernel",
"//runtime/src/iree/testing:benchmark",
@@ -28,10 +41,11 @@
name = "mmt4d_test",
srcs = ["mmt4d_test.cc"],
deps = [
+ ":mmt4d_test_utils",
"//runtime/src/iree/base",
+ "//runtime/src/iree/base/internal:cpu",
"//runtime/src/iree/base/internal:flags",
"//runtime/src/iree/builtins/ukernel",
"//runtime/src/iree/testing:gtest",
- "//runtime/src/iree/testing:gtest_main",
],
)
diff --git a/runtime/src/iree/builtins/ukernel/tools/CMakeLists.txt b/runtime/src/iree/builtins/ukernel/tools/CMakeLists.txt
index 3e8f6d4..4b4e455 100644
--- a/runtime/src/iree/builtins/ukernel/tools/CMakeLists.txt
+++ b/runtime/src/iree/builtins/ukernel/tools/CMakeLists.txt
@@ -10,13 +10,29 @@
iree_add_all_subdirs()
+iree_cc_library(
+ NAME
+ mmt4d_test_utils
+ HDRS
+ "mmt4d_test_utils.h"
+ SRCS
+ "mmt4d_test_utils.cc"
+ DEPS
+ iree::base
+ iree::builtins::ukernel::types
+ iree::schemas::cpu_data
+ PUBLIC
+)
+
iree_cc_binary_benchmark(
NAME
mmt4d_benchmark
SRCS
"mmt4d_benchmark.c"
DEPS
+ ::mmt4d_test_utils
iree::base
+ iree::base::internal::cpu
iree::base::internal::flags
iree::builtins::ukernel
iree::testing::benchmark
@@ -29,11 +45,12 @@
SRCS
"mmt4d_test.cc"
DEPS
+ ::mmt4d_test_utils
iree::base
+ iree::base::internal::cpu
iree::base::internal::flags
iree::builtins::ukernel
iree::testing::gtest
- iree::testing::gtest_main
)
### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/runtime/src/iree/builtins/ukernel/tools/mmt4d_benchmark.c b/runtime/src/iree/builtins/ukernel/tools/mmt4d_benchmark.c
index ce54a35..60f4826 100644
--- a/runtime/src/iree/builtins/ukernel/tools/mmt4d_benchmark.c
+++ b/runtime/src/iree/builtins/ukernel/tools/mmt4d_benchmark.c
@@ -4,40 +4,160 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-// THIS IS STILL JUST A PLACEHOLDER - NOT AN ACTUAL TEST YET.
+// clang-format off
+#include <stdint.h> // include before ukernel/common.h to keep standard types
+// clang-format on
-#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
#include "iree/base/api.h"
+#include "iree/base/internal/cpu.h"
#include "iree/base/internal/flags.h"
#include "iree/builtins/ukernel/mmt4d.h"
+#include "iree/builtins/ukernel/tools/mmt4d_test_utils.h"
#include "iree/testing/benchmark.h"
-// Example flag; not really useful:
-IREE_FLAG(int32_t, batch_count, 64, "Ops to run per benchmark iteration.");
+IREE_FLAG(int32_t, batch_count, 1000, "Ops to run per benchmark iteration.");
+IREE_FLAG(int32_t, m_size, 1,
+ "M-dimension of mmt4d ops. The overall number of rows of the "
+ "accumulator is that times the M0 tile size.");
+IREE_FLAG(int32_t, n_size, 1,
+ "N-dimension of mmt4d ops. The overall number of columns of the "
+ "accumulator is that times the N0 tile size.");
+IREE_FLAG(
+ int32_t, k_size, 256,
+ "K-dimension of mmt4d ops. That's the number of iterations of the inner "
+ "loop. The overall accumulation depth is that times the K0 tile size.");
+IREE_FLAG(bool, accumulate, false,
+ "Whether the kernel should accumulate into the existing accumulator "
+ "tile values, or zero the accumulator tile.");
-static iree_status_t iree_mmt4d_example_matmul_f32_benchmark(
+struct iree_mmt4d_benchmark_user_data_t {
+ iree_ukernel_mmt4d_type_t type;
+ int M0;
+ int N0;
+ int K0;
+ uint64_t cpu_data_field_0;
+};
+
+typedef struct iree_mmt4d_benchmark_user_data_t
+ iree_mmt4d_benchmark_user_data_t;
+
+static iree_status_t iree_mmt4d_benchmark(
const iree_benchmark_def_t* benchmark_def,
iree_benchmark_state_t* benchmark_state) {
+ const iree_mmt4d_benchmark_user_data_t* user_data = benchmark_def->user_data;
+ iree_ukernel_mmt4d_params_t params;
+ memset(¶ms, 0, sizeof params);
+ params.type = user_data->type;
+ params.flags = FLAG_accumulate ? IREE_VMVX_MATMUL_FLAG_ACCUMULATE : 0;
+ params.M = FLAG_m_size;
+ params.N = FLAG_n_size;
+ params.K = FLAG_k_size;
+ params.M0 = user_data->M0;
+ params.N0 = user_data->N0;
+ params.K0 = user_data->K0;
+ params.cpu_data_field_0 = user_data->cpu_data_field_0;
+ params.lhs_stride = params.K * params.M0 * params.K0;
+ params.rhs_stride = params.K * params.N0 * params.K0;
+ params.out_stride = params.N * params.M0 * params.N0;
+ iree_ukernel_size_t lhs_buffer_size =
+ iree_ukernel_mmt4d_lhs_buffer_size(¶ms);
+ iree_ukernel_size_t rhs_buffer_size =
+ iree_ukernel_mmt4d_rhs_buffer_size(¶ms);
+ iree_ukernel_size_t out_buffer_size =
+ iree_ukernel_mmt4d_out_buffer_size(¶ms);
+ void* lhs_buffer = malloc(lhs_buffer_size);
+ void* rhs_buffer = malloc(lhs_buffer_size);
+ void* out_buffer = malloc(lhs_buffer_size);
+ iree_mmt4d_scalar_type_t lhs_type = iree_ukernel_mmt4d_lhs_type(¶ms);
+ iree_mmt4d_scalar_type_t rhs_type = iree_ukernel_mmt4d_rhs_type(¶ms);
+ iree_mmt4d_scalar_type_t out_type = iree_ukernel_mmt4d_out_type(¶ms);
+ iree_mmt4d_test_random_engine_t* engine =
+ iree_mmt4d_test_random_engine_create();
+ // It's just about plausible that on some platform, for some number type,
+ // performance might be different on zero buffers vs random buffers. But it
+ // shouldn't matter that we recreate the random engine every time, getting
+ // the same random values again.
+ write_random_buffer(lhs_buffer, lhs_buffer_size, lhs_type, engine);
+ write_random_buffer(rhs_buffer, rhs_buffer_size, rhs_type, engine);
+ write_random_buffer(out_buffer, out_buffer_size, out_type, engine);
+ iree_mmt4d_test_random_engine_destroy(engine);
+ params.lhs_buffer = lhs_buffer;
+ params.rhs_buffer = rhs_buffer;
+ params.out_buffer = out_buffer;
+ int64_t total_iterations = 0;
while (iree_benchmark_keep_running(benchmark_state,
/*batch_count=*/FLAG_batch_count)) {
for (int i = 0; i < FLAG_batch_count; ++i) {
- iree_ukernel_mmt4d_params_t params;
- memset(¶ms, 0, sizeof params);
- params.type = iree_ukernel_mmt4d_type_f32f32f32;
iree_ukernel_mmt4d_status_t status = iree_ukernel_mmt4d(¶ms);
if (status != iree_ukernel_mmt4d_status_ok) {
- fprintf(stderr, "FATAL: iree_ukernel_mmt4d_f32f32f32 failed: %s\n",
+ fprintf(stderr, "FATAL: iree_ukernel_mmt4d failed: %s\n",
iree_ukernel_mmt4d_status_message(status));
abort();
}
}
+ total_iterations += FLAG_batch_count;
}
+ iree_benchmark_set_items_processed(
+ benchmark_state, total_iterations * 2 * params.M * params.N * params.K *
+ params.M0 * params.N0 * params.K0);
+ free(lhs_buffer);
+ free(rhs_buffer);
+ free(out_buffer);
return iree_ok_status();
}
+static void iree_mmt4d_benchmark_register(
+ const iree_mmt4d_benchmark_user_data_t* user_data, const char* name) {
+ // Does this benchmark require an optional CPU feature?
+ if (user_data->cpu_data_field_0) {
+ if ((iree_cpu_data_field(0) & user_data->cpu_data_field_0) !=
+ user_data->cpu_data_field_0) {
+ // The CPU does not meet this benchmark's requirements. The builtin
+ // would fall back on generic code. We don't need more generic benchmark
+ // results.
+ return;
+ }
+ }
+
+ // benchmark_def does not need to be static, it will be cloned.
+ const iree_benchmark_def_t benchmark_def = {
+ .flags = IREE_BENCHMARK_FLAG_USE_REAL_TIME,
+ .time_unit = IREE_BENCHMARK_UNIT_MICROSECOND,
+ .minimum_duration_ns = 0,
+ .iteration_count = 0,
+ .run = iree_mmt4d_benchmark,
+ .user_data = user_data,
+ };
+ iree_benchmark_register(IREE_SV(name), &benchmark_def);
+}
+
+#define IREE_MMT4D_BENCHMARK_REGISTER(_type, _m0, _n0, _k0, _cpu_data_field_0, \
+ _label) \
+ do { \
+ static const iree_mmt4d_benchmark_user_data_t user_data = { \
+ .type = iree_ukernel_mmt4d_type_##_type, \
+ .M0 = _m0, \
+ .N0 = _n0, \
+ .K0 = _k0, \
+ .cpu_data_field_0 = _cpu_data_field_0, \
+ }; \
+ iree_mmt4d_benchmark_register(&user_data, \
+ "iree_ukernel_mmt4d_" #_type "_" #_m0 \
+ "x" #_n0 "x" #_k0 "_" #_label); \
+ } while (0)
+
+#define IREE_MMT4D_BENCHMARK_REGISTER_GENERIC(_type, _m0, _n0, _k0) \
+ IREE_MMT4D_BENCHMARK_REGISTER(_type, _m0, _n0, _k0, 0, GENERIC)
+
+#define IREE_MMT4D_BENCHMARK_REGISTER_ARM_64(_type, _m0, _n0, _k0, \
+ _cpu_feature) \
+ IREE_MMT4D_BENCHMARK_REGISTER( \
+ _type, _m0, _n0, _k0, IREE_CPU_DATA_FIELD_0_AARCH64_HAVE_##_cpu_feature, \
+ arm_64_##_cpu_feature)
+
int main(int argc, char** argv) {
iree_flags_set_usage(
"mmt4d_benchmark",
@@ -46,22 +166,21 @@
iree_flags_parse_checked(IREE_FLAGS_PARSE_MODE_UNDEFINED_OK, &argc, &argv);
iree_benchmark_initialize(&argc, argv);
+ iree_cpu_initialize(iree_allocator_system());
- // TODO: always add _generic variants to have a baseline vs reference?
+ // Generic code paths, not actually used, but interesting to get a sense
+ // of how slow generic code goes vs decent SIMD kernels. Interesting also to
+ // compare generic float vs int arithmetic.
+ IREE_MMT4D_BENCHMARK_REGISTER_GENERIC(f32f32f32, 4, 4, 1);
+ IREE_MMT4D_BENCHMARK_REGISTER_GENERIC(i8i8i32, 4, 4, 1);
- {
- static const iree_benchmark_def_t benchmark_def = {
- .flags = IREE_BENCHMARK_FLAG_MEASURE_PROCESS_CPU_TIME |
- IREE_BENCHMARK_FLAG_USE_REAL_TIME,
- .time_unit = IREE_BENCHMARK_UNIT_NANOSECOND,
- .minimum_duration_ns = 0,
- .iteration_count = 0,
- .run = iree_mmt4d_example_matmul_f32_benchmark,
- .user_data = NULL,
- };
- iree_benchmark_register(IREE_SV("iree_mmt4d_example_matmul_f32"),
- &benchmark_def);
- }
+// ARM_64 benchmarks.
+#if defined(IREE_UKERNEL_ARCH_ARM_64)
+
+ IREE_MMT4D_BENCHMARK_REGISTER_ARM_64(i8i8i32, 8, 8, 4, DOTPROD);
+ IREE_MMT4D_BENCHMARK_REGISTER_ARM_64(i8i8i32, 8, 8, 8, I8MM);
+
+#endif // defined(IREE_UKERNEL_ARCH_ARM_64)
iree_benchmark_run_specified();
return 0;
diff --git a/runtime/src/iree/builtins/ukernel/tools/mmt4d_test.cc b/runtime/src/iree/builtins/ukernel/tools/mmt4d_test.cc
index 206b972..0dab85e 100644
--- a/runtime/src/iree/builtins/ukernel/tools/mmt4d_test.cc
+++ b/runtime/src/iree/builtins/ukernel/tools/mmt4d_test.cc
@@ -4,23 +4,301 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-// THIS IS STILL JUST A PLACEHOLDER - NOT AN ACTUAL TEST YET.
+// Design rationale and code creep warning!
+//
+// Summary:
+//
+// The goal of this test is to provide 100% coverage across all
+// internal kernel variants, which is not convenient to do in e2e tests.
+// Resist the temptation to reimplement here all the niceties of the e2e test.
+// Stick to guaranteeing that if the test succeeds, then the mmt4d builtin,
+// with all its asm code path variants, is correct. In case of failure, the
+// user is expected to be happy to jump into a debugger.
+//
+// Longer story:
+//
+// It is said by an ancient prophecy that all matrix multiplication tests grow
+// to be thousands of lines of code.
+//
+// In fact, we already have one, it's the end-to-end matmul test under
+// iree/tests/e2e/matmul. That one is needed anyway, and needs to be large
+// anyway, being end-to-end and applying to all target backends, including those
+// where device!=host. And so it makes sense for that one to have extra bells
+// and whistles such as fuzzy comparisons, pretty-printing of numerical errors
+// to aid debugging, and yet more special logic to make numerical errors easier
+// to debug.
+//
+// Let's not duplicate all that here! Note also that, tempting as it would
+// be to borrow the matrix-pretty-printing stuff from e2e/matmul, that applies
+// to plain row-major 2D matrices, while here we are dealing with 4D arrays /
+// tiled-layout matrices. Trying to bridge over that difference would bring yet
+// more complexity.
+//
+// Instead, let us keep a sharp focus on why we need this separate micro test.
+// The motivation is not the usual "because micro tests are easier to debug than
+// e2e" but rather because it would be difficult to have 100% code coverage in
+// e2e. There are many variants of mmt4d builtin ukernels for various CPU
+// features and tuned for various CPU models. We have to iterate over all these
+// variants. Trying to do so in e2e tests would require exposing knobs for
+// things that we would otherwise prefer to keep internal in the mmt4d builtin
+// implementation, and would make e2e/matmul tests even more expensive.
-#include <stdint.h>
+// clang-format off
+#include <stdint.h> // include before ukernel/common.h to keep standard types
+// clang-format on
-// Include in expected order with stdint and other system headers first.
-// See the note in mmt4d.h about stdint.h. This won't be an issue in most uses
-// but clang-format really likes to put the mmt4d.h above the system headers
-// due to this _test.cc file naming.
+#include "iree/builtins/ukernel/mmt4d.h"
+
+#include <vector>
#include "iree/base/api.h"
-#include "iree/builtins/ukernel/mmt4d.h"
+#include "iree/base/internal/cpu.h"
+#include "iree/builtins/ukernel/tools/mmt4d_test_utils.h"
#include "iree/testing/gtest.h"
#include "iree/testing/status_matchers.h"
-TEST(MMT4DTest, iree_mmt4d_example_matmul_f32) {
+template <typename lhs_t, typename rhs_t, typename out_t>
+static void iree_mmt4d_reference(const iree_ukernel_mmt4d_params_t& params) {
+ bool accumulate = params.flags & IREE_VMVX_MATMUL_FLAG_ACCUMULATE;
+ iree_ukernel_size_t lhs_tile_size = params.M0 * params.K0;
+ iree_ukernel_size_t rhs_tile_size = params.N0 * params.K0;
+ iree_ukernel_size_t out_tile_size = params.M0 * params.N0;
+ for (iree_ukernel_size_t i = 0; i < params.M; ++i) {
+ for (iree_ukernel_size_t j = 0; j < params.N; ++j) {
+ out_t* out_tile_ptr = ((out_t*)params.out_buffer) +
+ i * params.out_stride + j * out_tile_size;
+ const lhs_t* lhs_panel_ptr =
+ ((const lhs_t*)params.lhs_buffer) + i * params.lhs_stride;
+ const rhs_t* rhs_panel_ptr =
+ ((const rhs_t*)params.rhs_buffer) + j * params.rhs_stride;
+ for (iree_ukernel_size_t i0 = 0; i0 < params.M0; ++i0) {
+ for (iree_ukernel_size_t j0 = 0; j0 < params.N0; ++j0) {
+ const lhs_t* lhs_tile_ptr = lhs_panel_ptr;
+ const rhs_t* rhs_tile_ptr = rhs_panel_ptr;
+ out_t* out_ptr = out_tile_ptr + i0 * params.N0 + j0;
+ out_t acc = accumulate ? *out_ptr : 0.f;
+ for (iree_ukernel_size_t k = 0; k < params.K; ++k) {
+ for (iree_ukernel_size_t k0 = 0; k0 < params.K0; ++k0) {
+ out_t lhs_val = lhs_tile_ptr[i0 * params.K0 + k0];
+ out_t rhs_val = rhs_tile_ptr[j0 * params.K0 + k0];
+ acc += lhs_val * rhs_val;
+ }
+ lhs_tile_ptr += lhs_tile_size;
+ rhs_tile_ptr += rhs_tile_size;
+ }
+ *out_ptr = acc;
+ }
+ }
+ }
+ }
+}
+
+static void iree_mmt4d_reference(const iree_ukernel_mmt4d_params_t& params) {
+ switch (params.type) {
+ case iree_ukernel_mmt4d_type_f32f32f32:
+ iree_mmt4d_reference<float, float, float>(params);
+ break;
+ case iree_ukernel_mmt4d_type_i8i8i32:
+ iree_mmt4d_reference<int8_t, int8_t, int32_t>(params);
+ break;
+ default:
+ assert(false && "unknown type");
+ }
+}
+
+static void test_one_matmul_using_given_lhs_rhs(
+ const iree_ukernel_mmt4d_params_t& shared_params,
+ iree_mmt4d_test_random_engine_t* engine) {
+ assert(!shared_params.out_buffer);
+
+ iree_ukernel_mmt4d_params_t reference_params;
+ memcpy(&reference_params, &shared_params, sizeof shared_params);
+ iree_ukernel_size_t out_buffer_size =
+ iree_ukernel_mmt4d_out_buffer_size(&shared_params);
+ reference_params.out_buffer = malloc(out_buffer_size);
+ iree_mmt4d_scalar_type_t out_type =
+ iree_ukernel_mmt4d_out_type(&shared_params);
+ write_random_buffer(reference_params.out_buffer, out_buffer_size, out_type,
+ engine);
+
+ iree_ukernel_mmt4d_params_t actual_params;
+ memcpy(&actual_params, &shared_params, sizeof shared_params);
+ actual_params.out_buffer = malloc(out_buffer_size);
+ memcpy(actual_params.out_buffer, reference_params.out_buffer,
+ out_buffer_size);
+
+ iree_mmt4d_reference(reference_params);
+ iree_ukernel_mmt4d_status_t status = iree_ukernel_mmt4d(&actual_params);
+ if (status != iree_ukernel_mmt4d_status_ok) {
+ fprintf(stderr, "FATAL: iree_ukernel_mmt4d failed: %s\n",
+ iree_ukernel_mmt4d_status_message(status));
+ abort();
+ }
+
+ // For now we use exact comparisons, even for float, even though the reference
+ // code accumulates in a different order compared to the actual code. This
+ // relies on picking input test matrix elements so that all intermediate
+ // values are exactly representable - i.e. small integer numerators. This
+ // become problematic when we do float16. See the comment at the top of this
+ // file explaining how we refrain from letting this grow into a 1000-line-long
+ // fully-featured test.
+ if (memcmp(actual_params.out_buffer, reference_params.out_buffer,
+ out_buffer_size)) {
+ const auto& p = actual_params;
+ fprintf(stderr, "mmt4d test failure with the following params:\n");
+ fprintf(stderr, " type=%s\n", get_mmt4d_type_str(&p));
+ fprintf(stderr, " flags: accumulate=%d\n",
+ (int)(p.flags & IREE_VMVX_MATMUL_FLAG_ACCUMULATE));
+ fprintf(stderr, " M=%d, N=%d, K=%d\n", (int)p.M, (int)p.N, (int)p.K);
+ fprintf(stderr, " M0=%d, N0=%d, K0=%d\n", (int)p.M0, (int)p.N0, (int)p.K0);
+ fprintf(stderr, " lhs_stride=%zu, rhs_stride=%zu, out_stride=%zu\n",
+ (size_t)p.lhs_stride, (size_t)p.rhs_stride, (size_t)p.out_stride);
+ fprintf(stderr, " cpu features: %s\n", get_cpu_features_str(&p));
+ // Don't even try to pretty-print matrices. See the comment at the top of
+ // this file. Don't try to use GTest primitives to show expected vs actual
+ // since that would require dispatching to type-specific code paths.
+ // Also, at this point it's easy for the user to rerun this test
+ // in a debugger and manually inspect values.
+ //
+ // We want fatal here - that is what the user running this in a debugger
+ // wants us to do, so they can inspect values while they exist in memory.
+ // What's the GTest-sanctioned fatal error? GTEST_FAIL() has a comment that
+ // says that it's fatal, but that's a lie at least here on Android.
+ abort();
+ }
+
+ free(reference_params.out_buffer);
+ free(actual_params.out_buffer);
+}
+
+static void test_one_matmul_creating_lhs_rhs_for_given_shape(
+ const iree_ukernel_mmt4d_params_t& shared_params,
+ iree_mmt4d_test_random_engine_t* engine) {
+ iree_ukernel_mmt4d_params_t params;
+ memcpy(¶ms, &shared_params, sizeof params);
+ assert(!params.lhs_buffer);
+ assert(!params.rhs_buffer);
+ assert(!params.out_buffer);
+ assert(!params.lhs_stride);
+ assert(!params.rhs_stride);
+ assert(!params.out_stride);
+ // Populate strides first - they are read by the get_*_buffer_size helper.
+ // Randomly make strides either tight or not to exercise all cases.
+ params.lhs_stride = params.K * params.M0 * params.K0 +
+ iree_mmt4d_test_random_engine_get_0_or_1(engine);
+ params.rhs_stride = params.K * params.N0 * params.K0 +
+ iree_mmt4d_test_random_engine_get_0_or_1(engine);
+ params.out_stride = params.N * params.M0 * params.N0 +
+ iree_mmt4d_test_random_engine_get_0_or_1(engine);
+ iree_ukernel_size_t lhs_buffer_size =
+ iree_ukernel_mmt4d_lhs_buffer_size(¶ms);
+ iree_ukernel_size_t rhs_buffer_size =
+ iree_ukernel_mmt4d_rhs_buffer_size(¶ms);
+ iree_mmt4d_scalar_type_t lhs_type = iree_ukernel_mmt4d_lhs_type(¶ms);
+ iree_mmt4d_scalar_type_t rhs_type = iree_ukernel_mmt4d_rhs_type(¶ms);
+ void* lhs_buffer = malloc(lhs_buffer_size);
+ void* rhs_buffer = malloc(rhs_buffer_size);
+ write_random_buffer(lhs_buffer, lhs_buffer_size, lhs_type, engine);
+ write_random_buffer(rhs_buffer, rhs_buffer_size, rhs_type, engine);
+ params.lhs_buffer = lhs_buffer;
+ params.rhs_buffer = rhs_buffer;
+ test_one_matmul_using_given_lhs_rhs(params, engine);
+ free(lhs_buffer);
+ free(rhs_buffer);
+}
+
+static void test_matmuls_for_various_MNK_shapes_and_flags(
+ const iree_ukernel_mmt4d_params_t& shared_params,
+ iree_mmt4d_test_random_engine_t* engine) {
+ iree_ukernel_mmt4d_params_t params;
+ memcpy(¶ms, &shared_params, sizeof params);
+ assert(params.M == 0);
+ assert(params.N == 0);
+ assert(params.K == 0);
+ assert(params.flags == 0);
+ struct shape_mnk_t {
+ int m, n, k;
+ };
+ std::vector<shape_mnk_t> shapes{
+ {1, 1, 1}, {1, 1, 2}, {1, 1, 10}, {1, 1, 1000},
+ {2, 1, 1}, {1, 2, 1}, {2, 2, 2}, {5, 7, 13},
+ };
+ for (shape_mnk_t shape : shapes) {
+ params.M = shape.m;
+ params.N = shape.n;
+ params.K = shape.k;
+ for (bool accumulate : {false, true}) {
+ params.flags = accumulate ? IREE_VMVX_MATMUL_FLAG_ACCUMULATE : 0;
+ test_one_matmul_creating_lhs_rhs_for_given_shape(params, engine);
+ }
+ }
+}
+
+// Tests mmt4d with the specific data type and specific M0xN0xK0 tile format.
+// If cpu_data_field_0_bit is nonzero, it must then be a single bit (power of 2)
+// and if the CPU supports the corresponding feature, the mmt4d tests are run a
+// second time with that CPU feature enabled.
+static void mmt4d_test(iree_ukernel_mmt4d_type_t type, int M0, int N0, int K0,
+ uint64_t cpu_data_field_0_bit) {
+ // Letting each test create its own engine makes them independent: a testcase
+ // succeeds or fails the same way if we isolate it or reorder it. The
+ // potential downside of repeating the same pseudorandom sequence is OK
+ // because any pseudorandom sequence should be equally good at coverage, and
+ // different testcases tend to use different tile shapes anyway.
+ iree_mmt4d_test_random_engine_t* engine =
+ iree_mmt4d_test_random_engine_create();
iree_ukernel_mmt4d_params_t params;
memset(¶ms, 0, sizeof params);
- params.type = iree_ukernel_mmt4d_type_f32f32f32;
- EXPECT_EQ(0, iree_ukernel_mmt4d(¶ms));
+ params.type = type;
+ params.M0 = M0;
+ params.N0 = N0;
+ params.K0 = K0;
+ // First try without any optional CPU feature. This matters even when the
+ // feature is supported by the CPU because we want to test the fallback to
+ // architecture-default or generic code.
+ test_matmuls_for_various_MNK_shapes_and_flags(params, engine);
+ // If this is nonzero, we are asked to test again with this CPU feature.
+ if (cpu_data_field_0_bit) {
+ // Check if the CPU supports the feature (otherwise, we crash).
+ params.cpu_data_field_0 = cpu_data_field_0_bit;
+ bool supported = iree_cpu_data_field(0) & params.cpu_data_field_0;
+ if (supported) {
+ // Run with the optional CPU feature.
+ fprintf(stderr, "Device supports CPU feature: %s\n",
+ get_cpu_features_str(¶ms));
+ test_matmuls_for_various_MNK_shapes_and_flags(params, engine);
+ } else {
+ fprintf(stderr, "Skipped: device does not support CPU feature: %s\n",
+ get_cpu_features_str(¶ms));
+ }
+ }
+ iree_mmt4d_test_random_engine_destroy(engine);
+}
+
+#define MMT4D_TEST(type, M0, N0, K0, test_suffix, feature_bit) \
+ TEST(Mmt4dTest, type##_tile_##M0##x##N0##x##K0##_##test_suffix) { \
+ mmt4d_test(iree_ukernel_mmt4d_type_##type, M0, N0, K0, feature_bit); \
+ }
+
+// Generic tests, not matching any particular CPU feature. This is the place to
+// test weird M0, N0, K0 to ensure e.g. that we haven't unwittingly baked in a
+// power-of-two assumption
+MMT4D_TEST(f32f32f32, 3, 5, 7, generic, 0)
+MMT4D_TEST(i8i8i32, 9, 6, 3, generic, 0)
+
+// ARM_64 tests.
+#if defined(IREE_UKERNEL_ARCH_ARM_64)
+
+#define MMT4D_ARM_64_TEST(type, M0, N0, K0, FEATURE) \
+ MMT4D_TEST(type, M0, N0, K0, arm_64_##FEATURE, \
+ IREE_CPU_DATA_FIELD_0_AARCH64_HAVE_##FEATURE)
+
+MMT4D_ARM_64_TEST(i8i8i32, 8, 8, 4, DOTPROD)
+MMT4D_ARM_64_TEST(i8i8i32, 8, 8, 8, I8MM)
+#endif // defined(IREE_UKERNEL_ARCH_ARM_64)
+
+int main(int argc, char** argv) {
+ ::testing::InitGoogleTest(&argc, argv);
+ iree_cpu_initialize(iree_allocator_system());
+ return RUN_ALL_TESTS();
}
diff --git a/runtime/src/iree/builtins/ukernel/tools/mmt4d_test_utils.cc b/runtime/src/iree/builtins/ukernel/tools/mmt4d_test_utils.cc
new file mode 100644
index 0000000..0a9f970
--- /dev/null
+++ b/runtime/src/iree/builtins/ukernel/tools/mmt4d_test_utils.cc
@@ -0,0 +1,162 @@
+// Copyright 2022 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/builtins/ukernel/tools/mmt4d_test_utils.h"
+
+#include <cassert>
+#include <random>
+
+#include "iree/schemas/cpu_data.h"
+
+iree_mmt4d_scalar_type_t iree_ukernel_mmt4d_lhs_type(
+ const iree_ukernel_mmt4d_params_t* params) {
+ switch (params->type) {
+ case iree_ukernel_mmt4d_type_f32f32f32:
+ return iree_mmt4d_scalar_type_f32;
+ case iree_ukernel_mmt4d_type_i8i8i32:
+ return iree_mmt4d_scalar_type_i8;
+ default:
+ assert(false && "unknown type");
+ return iree_mmt4d_scalar_type_unknown;
+ }
+}
+
+iree_mmt4d_scalar_type_t iree_ukernel_mmt4d_rhs_type(
+ const iree_ukernel_mmt4d_params_t* params) {
+ // same for now
+ return iree_ukernel_mmt4d_lhs_type(params);
+}
+
+iree_mmt4d_scalar_type_t iree_ukernel_mmt4d_out_type(
+ const iree_ukernel_mmt4d_params_t* params) {
+ switch (params->type) {
+ case iree_ukernel_mmt4d_type_f32f32f32:
+ return iree_mmt4d_scalar_type_f32;
+ case iree_ukernel_mmt4d_type_i8i8i32:
+ return iree_mmt4d_scalar_type_i32;
+ default:
+ assert(false && "unknown type");
+ return iree_mmt4d_scalar_type_unknown;
+ }
+}
+
+iree_ukernel_size_t iree_ukernel_mmt4d_lhs_buffer_size(
+ const iree_ukernel_mmt4d_params_t* params) {
+ return params->M * params->lhs_stride *
+ iree_ukernel_mmt4d_lhs_elem_size(params->type);
+}
+
+iree_ukernel_size_t iree_ukernel_mmt4d_rhs_buffer_size(
+ const iree_ukernel_mmt4d_params_t* params) {
+ return params->N * params->rhs_stride *
+ iree_ukernel_mmt4d_rhs_elem_size(params->type);
+}
+
+iree_ukernel_size_t iree_ukernel_mmt4d_out_buffer_size(
+ const iree_ukernel_mmt4d_params_t* params) {
+ return params->M * params->out_stride *
+ iree_ukernel_mmt4d_out_elem_size(params->type);
+}
+
+struct iree_mmt4d_test_random_engine_t {
+ std::minstd_rand cpp_random_engine;
+};
+
+iree_mmt4d_test_random_engine_t* iree_mmt4d_test_random_engine_create() {
+ return new iree_mmt4d_test_random_engine_t;
+}
+
+void iree_mmt4d_test_random_engine_destroy(iree_mmt4d_test_random_engine_t* e) {
+ delete e;
+}
+
+static int iree_mmt4d_test_random_engine_get_in_uint16_range(
+ iree_mmt4d_test_random_engine_t* e) {
+ uint32_t v = e->cpp_random_engine();
+ // return the second-least-signicant out of the 4 bytes of state. It avoids
+ // some mild issues with the least-significant and most-significant bytes.
+ return (v >> 8) & 0xffff;
+}
+
+int iree_mmt4d_test_random_engine_get_0_or_1(
+ iree_mmt4d_test_random_engine_t* e) {
+ int v = iree_mmt4d_test_random_engine_get_in_uint16_range(e);
+ return v & 1;
+}
+
+int iree_mmt4d_test_random_engine_get_between_minus16_and_plus15(
+ iree_mmt4d_test_random_engine_t* e) {
+ int v = iree_mmt4d_test_random_engine_get_in_uint16_range(e);
+ return (v % 32) - 16;
+}
+
+template <typename T>
+static void write_random_buffer(T* buffer, iree_ukernel_size_t size_in_bytes,
+ iree_mmt4d_test_random_engine_t* engine) {
+ iree_ukernel_size_t size_in_elems = size_in_bytes / sizeof(T);
+ assert(size_in_elems * sizeof(T) == size_in_bytes && "bad size");
+ for (iree_ukernel_size_t i = 0; i < size_in_elems; ++i) {
+ // Small integers, should work for now for all the types we currently have
+ // and enable exact float arithmetic, allowing to keep tests simpler for
+ // now. Watch out for when we'll do float16!
+ T random_val =
+ iree_mmt4d_test_random_engine_get_between_minus16_and_plus15(engine);
+ buffer[i] = random_val;
+ }
+}
+
+void write_random_buffer(void* buffer, iree_ukernel_size_t size_in_bytes,
+ iree_mmt4d_scalar_type_t type,
+ iree_mmt4d_test_random_engine_t* engine) {
+ switch (type) {
+ case iree_mmt4d_scalar_type_f32:
+ write_random_buffer(static_cast<float*>(buffer), size_in_bytes, engine);
+ return;
+ case iree_mmt4d_scalar_type_i32:
+ write_random_buffer(static_cast<int32_t*>(buffer), size_in_bytes, engine);
+ return;
+ case iree_mmt4d_scalar_type_i8:
+ write_random_buffer(static_cast<int8_t*>(buffer), size_in_bytes, engine);
+ return;
+ default:
+ assert(false && "unknown type");
+ }
+}
+
+const char* get_mmt4d_type_str(const iree_ukernel_mmt4d_params_t* params) {
+ switch (params->type) {
+#define GET_MMT4D_TYPE_STR_CASE(x) \
+ case x: \
+ return #x;
+ GET_MMT4D_TYPE_STR_CASE(iree_ukernel_mmt4d_type_f32f32f32);
+ GET_MMT4D_TYPE_STR_CASE(iree_ukernel_mmt4d_type_i8i8i32);
+ default:
+ assert(false && "unknown type");
+ return "unknown type";
+ }
+}
+
+const char* get_cpu_features_str(const iree_ukernel_mmt4d_params_t* params) {
+ // We set only one feature bit at a time in this test --- not an actual
+ // detected cpu data field. This might have to change in the future if some
+ // code path relies on the combination of two features.
+ // For now, asserting only one bit set, and taking advantage of that to work
+ // with plain string literals.
+ assert(0 == (params->cpu_data_field_0 & (params->cpu_data_field_0 - 1)));
+ if (params->cpu_data_field_0 == 0) {
+ return "(none)";
+ }
+#if defined(IREE_UKERNEL_ARCH_ARM_64)
+ if (params->cpu_data_field_0 & IREE_CPU_DATA_FIELD_0_AARCH64_HAVE_I8MM) {
+ return "i8mm";
+ }
+ if (params->cpu_data_field_0 & IREE_CPU_DATA_FIELD_0_AARCH64_HAVE_DOTPROD) {
+ return "dotprod";
+ }
+#endif // defined(IREE_UKERNEL_ARCH_ARM_64)
+ assert(false && "unknown CPU feature");
+ return "unknown CPU feature";
+}
diff --git a/runtime/src/iree/builtins/ukernel/tools/mmt4d_test_utils.h b/runtime/src/iree/builtins/ukernel/tools/mmt4d_test_utils.h
new file mode 100644
index 0000000..0f45f0e
--- /dev/null
+++ b/runtime/src/iree/builtins/ukernel/tools/mmt4d_test_utils.h
@@ -0,0 +1,63 @@
+// Copyright 2022 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_BUILTINS_UKERNEL_TOOLS_MMT4D_TEST_UTILS_H_
+#define IREE_BUILTINS_UKERNEL_TOOLS_MMT4D_TEST_UTILS_H_
+
+// clang-format off
+#include <stdint.h> // include before ukernel/common.h to keep standard types
+// clang-format on
+
+#include "iree/builtins/ukernel/mmt4d_types.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+enum iree_mmt4d_scalar_type_t {
+ iree_mmt4d_scalar_type_unknown,
+ iree_mmt4d_scalar_type_i8,
+ iree_mmt4d_scalar_type_i32,
+ iree_mmt4d_scalar_type_f32,
+};
+
+typedef enum iree_mmt4d_scalar_type_t iree_mmt4d_scalar_type_t;
+
+iree_mmt4d_scalar_type_t iree_ukernel_mmt4d_lhs_type(
+ const iree_ukernel_mmt4d_params_t* params);
+iree_mmt4d_scalar_type_t iree_ukernel_mmt4d_rhs_type(
+ const iree_ukernel_mmt4d_params_t* params);
+iree_mmt4d_scalar_type_t iree_ukernel_mmt4d_out_type(
+ const iree_ukernel_mmt4d_params_t* params);
+
+iree_ukernel_size_t iree_ukernel_mmt4d_lhs_buffer_size(
+ const iree_ukernel_mmt4d_params_t* params);
+iree_ukernel_size_t iree_ukernel_mmt4d_rhs_buffer_size(
+ const iree_ukernel_mmt4d_params_t* params);
+iree_ukernel_size_t iree_ukernel_mmt4d_out_buffer_size(
+ const iree_ukernel_mmt4d_params_t* params);
+
+struct iree_mmt4d_test_random_engine_t;
+typedef struct iree_mmt4d_test_random_engine_t iree_mmt4d_test_random_engine_t;
+iree_mmt4d_test_random_engine_t* iree_mmt4d_test_random_engine_create();
+void iree_mmt4d_test_random_engine_destroy(iree_mmt4d_test_random_engine_t* e);
+int iree_mmt4d_test_random_engine_get_0_or_1(
+ iree_mmt4d_test_random_engine_t* e);
+int iree_mmt4d_test_random_engine_get_between_minus16_and_plus15(
+ iree_mmt4d_test_random_engine_t* e);
+
+void write_random_buffer(void* buffer, iree_ukernel_size_t size_in_bytes,
+ iree_mmt4d_scalar_type_t type,
+ iree_mmt4d_test_random_engine_t* engine);
+
+const char* get_mmt4d_type_str(const iree_ukernel_mmt4d_params_t* params);
+const char* get_cpu_features_str(const iree_ukernel_mmt4d_params_t* params);
+
+#ifdef __cplusplus
+} // extern "C"
+#endif
+
+#endif // IREE_BUILTINS_UKERNEL_TOOLS_MMT4D_TEST_UTILS_H_