Mmt4d builtin ukernel test/benchmark (#10389)

diff --git a/runtime/src/iree/builtins/ukernel/tools/BUILD b/runtime/src/iree/builtins/ukernel/tools/BUILD
index ac1940a..21c4c87 100644
--- a/runtime/src/iree/builtins/ukernel/tools/BUILD
+++ b/runtime/src/iree/builtins/ukernel/tools/BUILD
@@ -13,11 +13,24 @@
     licenses = ["notice"],  # Apache 2.0
 )
 
+cc_library(
+    name = "mmt4d_test_utils",
+    srcs = ["mmt4d_test_utils.cc"],
+    hdrs = ["mmt4d_test_utils.h"],
+    deps = [
+        "//runtime/src/iree/base",
+        "//runtime/src/iree/builtins/ukernel:types",
+        "//runtime/src/iree/schemas:cpu_data",
+    ],
+)
+
 cc_binary_benchmark(
     name = "mmt4d_benchmark",
     srcs = ["mmt4d_benchmark.c"],
     deps = [
+        ":mmt4d_test_utils",
         "//runtime/src/iree/base",
+        "//runtime/src/iree/base/internal:cpu",
         "//runtime/src/iree/base/internal:flags",
         "//runtime/src/iree/builtins/ukernel",
         "//runtime/src/iree/testing:benchmark",
@@ -28,10 +41,11 @@
     name = "mmt4d_test",
     srcs = ["mmt4d_test.cc"],
     deps = [
+        ":mmt4d_test_utils",
         "//runtime/src/iree/base",
+        "//runtime/src/iree/base/internal:cpu",
         "//runtime/src/iree/base/internal:flags",
         "//runtime/src/iree/builtins/ukernel",
         "//runtime/src/iree/testing:gtest",
-        "//runtime/src/iree/testing:gtest_main",
     ],
 )
diff --git a/runtime/src/iree/builtins/ukernel/tools/CMakeLists.txt b/runtime/src/iree/builtins/ukernel/tools/CMakeLists.txt
index 3e8f6d4..4b4e455 100644
--- a/runtime/src/iree/builtins/ukernel/tools/CMakeLists.txt
+++ b/runtime/src/iree/builtins/ukernel/tools/CMakeLists.txt
@@ -10,13 +10,29 @@
 
 iree_add_all_subdirs()
 
+iree_cc_library(
+  NAME
+    mmt4d_test_utils
+  HDRS
+    "mmt4d_test_utils.h"
+  SRCS
+    "mmt4d_test_utils.cc"
+  DEPS
+    iree::base
+    iree::builtins::ukernel::types
+    iree::schemas::cpu_data
+  PUBLIC
+)
+
 iree_cc_binary_benchmark(
   NAME
     mmt4d_benchmark
   SRCS
     "mmt4d_benchmark.c"
   DEPS
+    ::mmt4d_test_utils
     iree::base
+    iree::base::internal::cpu
     iree::base::internal::flags
     iree::builtins::ukernel
     iree::testing::benchmark
@@ -29,11 +45,12 @@
   SRCS
     "mmt4d_test.cc"
   DEPS
+    ::mmt4d_test_utils
     iree::base
+    iree::base::internal::cpu
     iree::base::internal::flags
     iree::builtins::ukernel
     iree::testing::gtest
-    iree::testing::gtest_main
 )
 
 ### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/runtime/src/iree/builtins/ukernel/tools/mmt4d_benchmark.c b/runtime/src/iree/builtins/ukernel/tools/mmt4d_benchmark.c
index ce54a35..60f4826 100644
--- a/runtime/src/iree/builtins/ukernel/tools/mmt4d_benchmark.c
+++ b/runtime/src/iree/builtins/ukernel/tools/mmt4d_benchmark.c
@@ -4,40 +4,160 @@
 // See https://llvm.org/LICENSE.txt for license information.
 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
-// THIS IS STILL JUST A PLACEHOLDER - NOT AN ACTUAL TEST YET.
+// clang-format off
+#include <stdint.h>  // include before ukernel/common.h to keep standard types
+// clang-format on
 
-#include <stdint.h>
 #include <stdio.h>
 #include <stdlib.h>
 
 #include "iree/base/api.h"
+#include "iree/base/internal/cpu.h"
 #include "iree/base/internal/flags.h"
 #include "iree/builtins/ukernel/mmt4d.h"
+#include "iree/builtins/ukernel/tools/mmt4d_test_utils.h"
 #include "iree/testing/benchmark.h"
 
-// Example flag; not really useful:
-IREE_FLAG(int32_t, batch_count, 64, "Ops to run per benchmark iteration.");
+IREE_FLAG(int32_t, batch_count, 1000, "Ops to run per benchmark iteration.");
+IREE_FLAG(int32_t, m_size, 1,
+          "M-dimension of mmt4d ops. The overall number of rows of the "
+          "accumulator is that times the M0 tile size.");
+IREE_FLAG(int32_t, n_size, 1,
+          "N-dimension of mmt4d ops. The overall number of columns of the "
+          "accumulator is that times the N0 tile size.");
+IREE_FLAG(
+    int32_t, k_size, 256,
+    "K-dimension of mmt4d ops. That's the number of iterations of the inner "
+    "loop. The overall accumulation depth is that times the K0 tile size.");
+IREE_FLAG(bool, accumulate, false,
+          "Whether the kernel should accumulate into the existing accumulator "
+          "tile values, or zero the accumulator tile.");
 
-static iree_status_t iree_mmt4d_example_matmul_f32_benchmark(
+struct iree_mmt4d_benchmark_user_data_t {
+  iree_ukernel_mmt4d_type_t type;
+  int M0;
+  int N0;
+  int K0;
+  uint64_t cpu_data_field_0;
+};
+
+typedef struct iree_mmt4d_benchmark_user_data_t
+    iree_mmt4d_benchmark_user_data_t;
+
+static iree_status_t iree_mmt4d_benchmark(
     const iree_benchmark_def_t* benchmark_def,
     iree_benchmark_state_t* benchmark_state) {
+  const iree_mmt4d_benchmark_user_data_t* user_data = benchmark_def->user_data;
+  iree_ukernel_mmt4d_params_t params;
+  memset(&params, 0, sizeof params);
+  params.type = user_data->type;
+  params.flags = FLAG_accumulate ? IREE_VMVX_MATMUL_FLAG_ACCUMULATE : 0;
+  params.M = FLAG_m_size;
+  params.N = FLAG_n_size;
+  params.K = FLAG_k_size;
+  params.M0 = user_data->M0;
+  params.N0 = user_data->N0;
+  params.K0 = user_data->K0;
+  params.cpu_data_field_0 = user_data->cpu_data_field_0;
+  params.lhs_stride = params.K * params.M0 * params.K0;
+  params.rhs_stride = params.K * params.N0 * params.K0;
+  params.out_stride = params.N * params.M0 * params.N0;
+  iree_ukernel_size_t lhs_buffer_size =
+      iree_ukernel_mmt4d_lhs_buffer_size(&params);
+  iree_ukernel_size_t rhs_buffer_size =
+      iree_ukernel_mmt4d_rhs_buffer_size(&params);
+  iree_ukernel_size_t out_buffer_size =
+      iree_ukernel_mmt4d_out_buffer_size(&params);
+  void* lhs_buffer = malloc(lhs_buffer_size);
+  void* rhs_buffer = malloc(lhs_buffer_size);
+  void* out_buffer = malloc(lhs_buffer_size);
+  iree_mmt4d_scalar_type_t lhs_type = iree_ukernel_mmt4d_lhs_type(&params);
+  iree_mmt4d_scalar_type_t rhs_type = iree_ukernel_mmt4d_rhs_type(&params);
+  iree_mmt4d_scalar_type_t out_type = iree_ukernel_mmt4d_out_type(&params);
+  iree_mmt4d_test_random_engine_t* engine =
+      iree_mmt4d_test_random_engine_create();
+  // It's just about plausible that on some platform, for some number type,
+  // performance might be different on zero buffers vs random buffers. But it
+  // shouldn't matter that we recreate the random engine every time, getting
+  // the same random values again.
+  write_random_buffer(lhs_buffer, lhs_buffer_size, lhs_type, engine);
+  write_random_buffer(rhs_buffer, rhs_buffer_size, rhs_type, engine);
+  write_random_buffer(out_buffer, out_buffer_size, out_type, engine);
+  iree_mmt4d_test_random_engine_destroy(engine);
+  params.lhs_buffer = lhs_buffer;
+  params.rhs_buffer = rhs_buffer;
+  params.out_buffer = out_buffer;
+  int64_t total_iterations = 0;
   while (iree_benchmark_keep_running(benchmark_state,
                                      /*batch_count=*/FLAG_batch_count)) {
     for (int i = 0; i < FLAG_batch_count; ++i) {
-      iree_ukernel_mmt4d_params_t params;
-      memset(&params, 0, sizeof params);
-      params.type = iree_ukernel_mmt4d_type_f32f32f32;
       iree_ukernel_mmt4d_status_t status = iree_ukernel_mmt4d(&params);
       if (status != iree_ukernel_mmt4d_status_ok) {
-        fprintf(stderr, "FATAL: iree_ukernel_mmt4d_f32f32f32 failed: %s\n",
+        fprintf(stderr, "FATAL: iree_ukernel_mmt4d failed: %s\n",
                 iree_ukernel_mmt4d_status_message(status));
         abort();
       }
     }
+    total_iterations += FLAG_batch_count;
   }
+  iree_benchmark_set_items_processed(
+      benchmark_state, total_iterations * 2 * params.M * params.N * params.K *
+                           params.M0 * params.N0 * params.K0);
+  free(lhs_buffer);
+  free(rhs_buffer);
+  free(out_buffer);
   return iree_ok_status();
 }
 
+static void iree_mmt4d_benchmark_register(
+    const iree_mmt4d_benchmark_user_data_t* user_data, const char* name) {
+  // Does this benchmark require an optional CPU feature?
+  if (user_data->cpu_data_field_0) {
+    if ((iree_cpu_data_field(0) & user_data->cpu_data_field_0) !=
+        user_data->cpu_data_field_0) {
+      // The CPU does not meet this benchmark's requirements. The builtin
+      // would fall back on generic code. We don't need more generic benchmark
+      // results.
+      return;
+    }
+  }
+
+  // benchmark_def does not need to be static, it will be cloned.
+  const iree_benchmark_def_t benchmark_def = {
+      .flags = IREE_BENCHMARK_FLAG_USE_REAL_TIME,
+      .time_unit = IREE_BENCHMARK_UNIT_MICROSECOND,
+      .minimum_duration_ns = 0,
+      .iteration_count = 0,
+      .run = iree_mmt4d_benchmark,
+      .user_data = user_data,
+  };
+  iree_benchmark_register(IREE_SV(name), &benchmark_def);
+}
+
+#define IREE_MMT4D_BENCHMARK_REGISTER(_type, _m0, _n0, _k0, _cpu_data_field_0, \
+                                      _label)                                  \
+  do {                                                                         \
+    static const iree_mmt4d_benchmark_user_data_t user_data = {                \
+        .type = iree_ukernel_mmt4d_type_##_type,                               \
+        .M0 = _m0,                                                             \
+        .N0 = _n0,                                                             \
+        .K0 = _k0,                                                             \
+        .cpu_data_field_0 = _cpu_data_field_0,                                 \
+    };                                                                         \
+    iree_mmt4d_benchmark_register(&user_data,                                  \
+                                  "iree_ukernel_mmt4d_" #_type "_" #_m0        \
+                                  "x" #_n0 "x" #_k0 "_" #_label);              \
+  } while (0)
+
+#define IREE_MMT4D_BENCHMARK_REGISTER_GENERIC(_type, _m0, _n0, _k0) \
+  IREE_MMT4D_BENCHMARK_REGISTER(_type, _m0, _n0, _k0, 0, GENERIC)
+
+#define IREE_MMT4D_BENCHMARK_REGISTER_ARM_64(_type, _m0, _n0, _k0,             \
+                                             _cpu_feature)                     \
+  IREE_MMT4D_BENCHMARK_REGISTER(                                               \
+      _type, _m0, _n0, _k0, IREE_CPU_DATA_FIELD_0_AARCH64_HAVE_##_cpu_feature, \
+      arm_64_##_cpu_feature)
+
 int main(int argc, char** argv) {
   iree_flags_set_usage(
       "mmt4d_benchmark",
@@ -46,22 +166,21 @@
 
   iree_flags_parse_checked(IREE_FLAGS_PARSE_MODE_UNDEFINED_OK, &argc, &argv);
   iree_benchmark_initialize(&argc, argv);
+  iree_cpu_initialize(iree_allocator_system());
 
-  // TODO: always add _generic variants to have a baseline vs reference?
+  // Generic code paths, not actually used, but interesting to get a sense
+  // of how slow generic code goes vs decent SIMD kernels. Interesting also to
+  // compare generic float vs int arithmetic.
+  IREE_MMT4D_BENCHMARK_REGISTER_GENERIC(f32f32f32, 4, 4, 1);
+  IREE_MMT4D_BENCHMARK_REGISTER_GENERIC(i8i8i32, 4, 4, 1);
 
-  {
-    static const iree_benchmark_def_t benchmark_def = {
-        .flags = IREE_BENCHMARK_FLAG_MEASURE_PROCESS_CPU_TIME |
-                 IREE_BENCHMARK_FLAG_USE_REAL_TIME,
-        .time_unit = IREE_BENCHMARK_UNIT_NANOSECOND,
-        .minimum_duration_ns = 0,
-        .iteration_count = 0,
-        .run = iree_mmt4d_example_matmul_f32_benchmark,
-        .user_data = NULL,
-    };
-    iree_benchmark_register(IREE_SV("iree_mmt4d_example_matmul_f32"),
-                            &benchmark_def);
-  }
+// ARM_64 benchmarks.
+#if defined(IREE_UKERNEL_ARCH_ARM_64)
+
+  IREE_MMT4D_BENCHMARK_REGISTER_ARM_64(i8i8i32, 8, 8, 4, DOTPROD);
+  IREE_MMT4D_BENCHMARK_REGISTER_ARM_64(i8i8i32, 8, 8, 8, I8MM);
+
+#endif  // defined(IREE_UKERNEL_ARCH_ARM_64)
 
   iree_benchmark_run_specified();
   return 0;
diff --git a/runtime/src/iree/builtins/ukernel/tools/mmt4d_test.cc b/runtime/src/iree/builtins/ukernel/tools/mmt4d_test.cc
index 206b972..0dab85e 100644
--- a/runtime/src/iree/builtins/ukernel/tools/mmt4d_test.cc
+++ b/runtime/src/iree/builtins/ukernel/tools/mmt4d_test.cc
@@ -4,23 +4,301 @@
 // See https://llvm.org/LICENSE.txt for license information.
 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
-// THIS IS STILL JUST A PLACEHOLDER - NOT AN ACTUAL TEST YET.
+// Design rationale and code creep warning!
+//
+// Summary:
+//
+//   The goal of this test is to provide 100% coverage across all
+//   internal kernel variants, which is not convenient to do in e2e tests.
+//   Resist the temptation to reimplement here all the niceties of the e2e test.
+//   Stick to guaranteeing that if the test succeeds, then the mmt4d builtin,
+//   with all its asm code path variants, is correct. In case of failure, the
+//   user is expected to be happy to jump into a debugger.
+//
+// Longer story:
+//
+// It is said by an ancient prophecy that all matrix multiplication tests grow
+// to be thousands of lines of code.
+//
+// In fact, we already have one, it's the end-to-end matmul test under
+// iree/tests/e2e/matmul. That one is needed anyway, and needs to be large
+// anyway, being end-to-end and applying to all target backends, including those
+// where device!=host. And so it makes sense for that one to have extra bells
+// and whistles such as fuzzy comparisons, pretty-printing of numerical errors
+// to aid debugging, and yet more special logic to make numerical errors easier
+// to debug.
+//
+// Let's not duplicate all that here! Note also that, tempting as it would
+// be to borrow the matrix-pretty-printing stuff from e2e/matmul, that applies
+// to plain row-major 2D matrices, while here we are dealing with 4D arrays /
+// tiled-layout matrices. Trying to bridge over that difference would bring yet
+// more complexity.
+//
+// Instead, let us keep a sharp focus on why we need this separate micro test.
+// The motivation is not the usual "because micro tests are easier to debug than
+// e2e" but rather because it would be difficult to have 100% code coverage in
+// e2e. There are many variants of mmt4d builtin ukernels for various CPU
+// features and tuned for various CPU models. We have to iterate over all these
+// variants. Trying to do so in e2e tests would require exposing knobs for
+// things that we would otherwise prefer to keep internal in the mmt4d builtin
+// implementation, and would make e2e/matmul tests even more expensive.
 
-#include <stdint.h>
+// clang-format off
+#include <stdint.h>  // include before ukernel/common.h to keep standard types
+// clang-format on
 
-// Include in expected order with stdint and other system headers first.
-// See the note in mmt4d.h about stdint.h. This won't be an issue in most uses
-// but clang-format really likes to put the mmt4d.h above the system headers
-// due to this _test.cc file naming.
+#include "iree/builtins/ukernel/mmt4d.h"
+
+#include <vector>
 
 #include "iree/base/api.h"
-#include "iree/builtins/ukernel/mmt4d.h"
+#include "iree/base/internal/cpu.h"
+#include "iree/builtins/ukernel/tools/mmt4d_test_utils.h"
 #include "iree/testing/gtest.h"
 #include "iree/testing/status_matchers.h"
 
-TEST(MMT4DTest, iree_mmt4d_example_matmul_f32) {
+template <typename lhs_t, typename rhs_t, typename out_t>
+static void iree_mmt4d_reference(const iree_ukernel_mmt4d_params_t& params) {
+  bool accumulate = params.flags & IREE_VMVX_MATMUL_FLAG_ACCUMULATE;
+  iree_ukernel_size_t lhs_tile_size = params.M0 * params.K0;
+  iree_ukernel_size_t rhs_tile_size = params.N0 * params.K0;
+  iree_ukernel_size_t out_tile_size = params.M0 * params.N0;
+  for (iree_ukernel_size_t i = 0; i < params.M; ++i) {
+    for (iree_ukernel_size_t j = 0; j < params.N; ++j) {
+      out_t* out_tile_ptr = ((out_t*)params.out_buffer) +
+                            i * params.out_stride + j * out_tile_size;
+      const lhs_t* lhs_panel_ptr =
+          ((const lhs_t*)params.lhs_buffer) + i * params.lhs_stride;
+      const rhs_t* rhs_panel_ptr =
+          ((const rhs_t*)params.rhs_buffer) + j * params.rhs_stride;
+      for (iree_ukernel_size_t i0 = 0; i0 < params.M0; ++i0) {
+        for (iree_ukernel_size_t j0 = 0; j0 < params.N0; ++j0) {
+          const lhs_t* lhs_tile_ptr = lhs_panel_ptr;
+          const rhs_t* rhs_tile_ptr = rhs_panel_ptr;
+          out_t* out_ptr = out_tile_ptr + i0 * params.N0 + j0;
+          out_t acc = accumulate ? *out_ptr : 0.f;
+          for (iree_ukernel_size_t k = 0; k < params.K; ++k) {
+            for (iree_ukernel_size_t k0 = 0; k0 < params.K0; ++k0) {
+              out_t lhs_val = lhs_tile_ptr[i0 * params.K0 + k0];
+              out_t rhs_val = rhs_tile_ptr[j0 * params.K0 + k0];
+              acc += lhs_val * rhs_val;
+            }
+            lhs_tile_ptr += lhs_tile_size;
+            rhs_tile_ptr += rhs_tile_size;
+          }
+          *out_ptr = acc;
+        }
+      }
+    }
+  }
+}
+
+static void iree_mmt4d_reference(const iree_ukernel_mmt4d_params_t& params) {
+  switch (params.type) {
+    case iree_ukernel_mmt4d_type_f32f32f32:
+      iree_mmt4d_reference<float, float, float>(params);
+      break;
+    case iree_ukernel_mmt4d_type_i8i8i32:
+      iree_mmt4d_reference<int8_t, int8_t, int32_t>(params);
+      break;
+    default:
+      assert(false && "unknown type");
+  }
+}
+
+static void test_one_matmul_using_given_lhs_rhs(
+    const iree_ukernel_mmt4d_params_t& shared_params,
+    iree_mmt4d_test_random_engine_t* engine) {
+  assert(!shared_params.out_buffer);
+
+  iree_ukernel_mmt4d_params_t reference_params;
+  memcpy(&reference_params, &shared_params, sizeof shared_params);
+  iree_ukernel_size_t out_buffer_size =
+      iree_ukernel_mmt4d_out_buffer_size(&shared_params);
+  reference_params.out_buffer = malloc(out_buffer_size);
+  iree_mmt4d_scalar_type_t out_type =
+      iree_ukernel_mmt4d_out_type(&shared_params);
+  write_random_buffer(reference_params.out_buffer, out_buffer_size, out_type,
+                      engine);
+
+  iree_ukernel_mmt4d_params_t actual_params;
+  memcpy(&actual_params, &shared_params, sizeof shared_params);
+  actual_params.out_buffer = malloc(out_buffer_size);
+  memcpy(actual_params.out_buffer, reference_params.out_buffer,
+         out_buffer_size);
+
+  iree_mmt4d_reference(reference_params);
+  iree_ukernel_mmt4d_status_t status = iree_ukernel_mmt4d(&actual_params);
+  if (status != iree_ukernel_mmt4d_status_ok) {
+    fprintf(stderr, "FATAL: iree_ukernel_mmt4d failed: %s\n",
+            iree_ukernel_mmt4d_status_message(status));
+    abort();
+  }
+
+  // For now we use exact comparisons, even for float, even though the reference
+  // code accumulates in a different order compared to the actual code. This
+  // relies on picking input test matrix elements so that all intermediate
+  // values are exactly representable - i.e. small integer numerators. This
+  // become problematic when we do float16. See the comment at the top of this
+  // file explaining how we refrain from letting this grow into a 1000-line-long
+  // fully-featured test.
+  if (memcmp(actual_params.out_buffer, reference_params.out_buffer,
+             out_buffer_size)) {
+    const auto& p = actual_params;
+    fprintf(stderr, "mmt4d test failure with the following params:\n");
+    fprintf(stderr, "  type=%s\n", get_mmt4d_type_str(&p));
+    fprintf(stderr, "  flags: accumulate=%d\n",
+            (int)(p.flags & IREE_VMVX_MATMUL_FLAG_ACCUMULATE));
+    fprintf(stderr, "  M=%d, N=%d, K=%d\n", (int)p.M, (int)p.N, (int)p.K);
+    fprintf(stderr, "  M0=%d, N0=%d, K0=%d\n", (int)p.M0, (int)p.N0, (int)p.K0);
+    fprintf(stderr, "  lhs_stride=%zu, rhs_stride=%zu, out_stride=%zu\n",
+            (size_t)p.lhs_stride, (size_t)p.rhs_stride, (size_t)p.out_stride);
+    fprintf(stderr, "  cpu features: %s\n", get_cpu_features_str(&p));
+    // Don't even try to pretty-print matrices. See the comment at the top of
+    // this file. Don't try to use GTest primitives to show expected vs actual
+    // since that would require dispatching to type-specific code paths.
+    // Also, at this point it's easy for the user to rerun this test
+    // in a debugger and manually inspect values.
+    //
+    // We want fatal here - that is what the user running this in a debugger
+    // wants us to do, so they can inspect values while they exist in memory.
+    // What's the GTest-sanctioned fatal error? GTEST_FAIL() has a comment that
+    // says that it's fatal, but that's a lie at least here on Android.
+    abort();
+  }
+
+  free(reference_params.out_buffer);
+  free(actual_params.out_buffer);
+}
+
+static void test_one_matmul_creating_lhs_rhs_for_given_shape(
+    const iree_ukernel_mmt4d_params_t& shared_params,
+    iree_mmt4d_test_random_engine_t* engine) {
+  iree_ukernel_mmt4d_params_t params;
+  memcpy(&params, &shared_params, sizeof params);
+  assert(!params.lhs_buffer);
+  assert(!params.rhs_buffer);
+  assert(!params.out_buffer);
+  assert(!params.lhs_stride);
+  assert(!params.rhs_stride);
+  assert(!params.out_stride);
+  // Populate strides first - they are read by the get_*_buffer_size helper.
+  // Randomly make strides either tight or not to exercise all cases.
+  params.lhs_stride = params.K * params.M0 * params.K0 +
+                      iree_mmt4d_test_random_engine_get_0_or_1(engine);
+  params.rhs_stride = params.K * params.N0 * params.K0 +
+                      iree_mmt4d_test_random_engine_get_0_or_1(engine);
+  params.out_stride = params.N * params.M0 * params.N0 +
+                      iree_mmt4d_test_random_engine_get_0_or_1(engine);
+  iree_ukernel_size_t lhs_buffer_size =
+      iree_ukernel_mmt4d_lhs_buffer_size(&params);
+  iree_ukernel_size_t rhs_buffer_size =
+      iree_ukernel_mmt4d_rhs_buffer_size(&params);
+  iree_mmt4d_scalar_type_t lhs_type = iree_ukernel_mmt4d_lhs_type(&params);
+  iree_mmt4d_scalar_type_t rhs_type = iree_ukernel_mmt4d_rhs_type(&params);
+  void* lhs_buffer = malloc(lhs_buffer_size);
+  void* rhs_buffer = malloc(rhs_buffer_size);
+  write_random_buffer(lhs_buffer, lhs_buffer_size, lhs_type, engine);
+  write_random_buffer(rhs_buffer, rhs_buffer_size, rhs_type, engine);
+  params.lhs_buffer = lhs_buffer;
+  params.rhs_buffer = rhs_buffer;
+  test_one_matmul_using_given_lhs_rhs(params, engine);
+  free(lhs_buffer);
+  free(rhs_buffer);
+}
+
+static void test_matmuls_for_various_MNK_shapes_and_flags(
+    const iree_ukernel_mmt4d_params_t& shared_params,
+    iree_mmt4d_test_random_engine_t* engine) {
+  iree_ukernel_mmt4d_params_t params;
+  memcpy(&params, &shared_params, sizeof params);
+  assert(params.M == 0);
+  assert(params.N == 0);
+  assert(params.K == 0);
+  assert(params.flags == 0);
+  struct shape_mnk_t {
+    int m, n, k;
+  };
+  std::vector<shape_mnk_t> shapes{
+      {1, 1, 1}, {1, 1, 2}, {1, 1, 10}, {1, 1, 1000},
+      {2, 1, 1}, {1, 2, 1}, {2, 2, 2},  {5, 7, 13},
+  };
+  for (shape_mnk_t shape : shapes) {
+    params.M = shape.m;
+    params.N = shape.n;
+    params.K = shape.k;
+    for (bool accumulate : {false, true}) {
+      params.flags = accumulate ? IREE_VMVX_MATMUL_FLAG_ACCUMULATE : 0;
+      test_one_matmul_creating_lhs_rhs_for_given_shape(params, engine);
+    }
+  }
+}
+
+// Tests mmt4d with the specific data type and specific M0xN0xK0 tile format.
+// If cpu_data_field_0_bit is nonzero, it must then be a single bit (power of 2)
+// and if the CPU supports the corresponding feature, the mmt4d tests are run a
+// second time with that CPU feature enabled.
+static void mmt4d_test(iree_ukernel_mmt4d_type_t type, int M0, int N0, int K0,
+                       uint64_t cpu_data_field_0_bit) {
+  // Letting each test create its own engine makes them independent: a testcase
+  // succeeds or fails the same way if we isolate it or reorder it. The
+  // potential downside of repeating the same pseudorandom sequence is OK
+  // because any pseudorandom sequence should be equally good at coverage, and
+  // different testcases tend to use different tile shapes anyway.
+  iree_mmt4d_test_random_engine_t* engine =
+      iree_mmt4d_test_random_engine_create();
   iree_ukernel_mmt4d_params_t params;
   memset(&params, 0, sizeof params);
-  params.type = iree_ukernel_mmt4d_type_f32f32f32;
-  EXPECT_EQ(0, iree_ukernel_mmt4d(&params));
+  params.type = type;
+  params.M0 = M0;
+  params.N0 = N0;
+  params.K0 = K0;
+  // First try without any optional CPU feature. This matters even when the
+  // feature is supported by the CPU because we want to test the fallback to
+  // architecture-default or generic code.
+  test_matmuls_for_various_MNK_shapes_and_flags(params, engine);
+  // If this is nonzero, we are asked to test again with this CPU feature.
+  if (cpu_data_field_0_bit) {
+    // Check if the CPU supports the feature (otherwise, we crash).
+    params.cpu_data_field_0 = cpu_data_field_0_bit;
+    bool supported = iree_cpu_data_field(0) & params.cpu_data_field_0;
+    if (supported) {
+      // Run with the optional CPU feature.
+      fprintf(stderr, "Device supports CPU feature: %s\n",
+              get_cpu_features_str(&params));
+      test_matmuls_for_various_MNK_shapes_and_flags(params, engine);
+    } else {
+      fprintf(stderr, "Skipped: device does not support CPU feature: %s\n",
+              get_cpu_features_str(&params));
+    }
+  }
+  iree_mmt4d_test_random_engine_destroy(engine);
+}
+
+#define MMT4D_TEST(type, M0, N0, K0, test_suffix, feature_bit)           \
+  TEST(Mmt4dTest, type##_tile_##M0##x##N0##x##K0##_##test_suffix) {      \
+    mmt4d_test(iree_ukernel_mmt4d_type_##type, M0, N0, K0, feature_bit); \
+  }
+
+// Generic tests, not matching any particular CPU feature. This is the place to
+// test weird M0, N0, K0 to ensure e.g. that we haven't unwittingly baked in a
+// power-of-two assumption
+MMT4D_TEST(f32f32f32, 3, 5, 7, generic, 0)
+MMT4D_TEST(i8i8i32, 9, 6, 3, generic, 0)
+
+// ARM_64 tests.
+#if defined(IREE_UKERNEL_ARCH_ARM_64)
+
+#define MMT4D_ARM_64_TEST(type, M0, N0, K0, FEATURE) \
+  MMT4D_TEST(type, M0, N0, K0, arm_64_##FEATURE,     \
+             IREE_CPU_DATA_FIELD_0_AARCH64_HAVE_##FEATURE)
+
+MMT4D_ARM_64_TEST(i8i8i32, 8, 8, 4, DOTPROD)
+MMT4D_ARM_64_TEST(i8i8i32, 8, 8, 8, I8MM)
+#endif  // defined(IREE_UKERNEL_ARCH_ARM_64)
+
+int main(int argc, char** argv) {
+  ::testing::InitGoogleTest(&argc, argv);
+  iree_cpu_initialize(iree_allocator_system());
+  return RUN_ALL_TESTS();
 }
diff --git a/runtime/src/iree/builtins/ukernel/tools/mmt4d_test_utils.cc b/runtime/src/iree/builtins/ukernel/tools/mmt4d_test_utils.cc
new file mode 100644
index 0000000..0a9f970
--- /dev/null
+++ b/runtime/src/iree/builtins/ukernel/tools/mmt4d_test_utils.cc
@@ -0,0 +1,162 @@
+// Copyright 2022 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/builtins/ukernel/tools/mmt4d_test_utils.h"
+
+#include <cassert>
+#include <random>
+
+#include "iree/schemas/cpu_data.h"
+
+iree_mmt4d_scalar_type_t iree_ukernel_mmt4d_lhs_type(
+    const iree_ukernel_mmt4d_params_t* params) {
+  switch (params->type) {
+    case iree_ukernel_mmt4d_type_f32f32f32:
+      return iree_mmt4d_scalar_type_f32;
+    case iree_ukernel_mmt4d_type_i8i8i32:
+      return iree_mmt4d_scalar_type_i8;
+    default:
+      assert(false && "unknown type");
+      return iree_mmt4d_scalar_type_unknown;
+  }
+}
+
+iree_mmt4d_scalar_type_t iree_ukernel_mmt4d_rhs_type(
+    const iree_ukernel_mmt4d_params_t* params) {
+  // same for now
+  return iree_ukernel_mmt4d_lhs_type(params);
+}
+
+iree_mmt4d_scalar_type_t iree_ukernel_mmt4d_out_type(
+    const iree_ukernel_mmt4d_params_t* params) {
+  switch (params->type) {
+    case iree_ukernel_mmt4d_type_f32f32f32:
+      return iree_mmt4d_scalar_type_f32;
+    case iree_ukernel_mmt4d_type_i8i8i32:
+      return iree_mmt4d_scalar_type_i32;
+    default:
+      assert(false && "unknown type");
+      return iree_mmt4d_scalar_type_unknown;
+  }
+}
+
+iree_ukernel_size_t iree_ukernel_mmt4d_lhs_buffer_size(
+    const iree_ukernel_mmt4d_params_t* params) {
+  return params->M * params->lhs_stride *
+         iree_ukernel_mmt4d_lhs_elem_size(params->type);
+}
+
+iree_ukernel_size_t iree_ukernel_mmt4d_rhs_buffer_size(
+    const iree_ukernel_mmt4d_params_t* params) {
+  return params->N * params->rhs_stride *
+         iree_ukernel_mmt4d_rhs_elem_size(params->type);
+}
+
+iree_ukernel_size_t iree_ukernel_mmt4d_out_buffer_size(
+    const iree_ukernel_mmt4d_params_t* params) {
+  return params->M * params->out_stride *
+         iree_ukernel_mmt4d_out_elem_size(params->type);
+}
+
+struct iree_mmt4d_test_random_engine_t {
+  std::minstd_rand cpp_random_engine;
+};
+
+iree_mmt4d_test_random_engine_t* iree_mmt4d_test_random_engine_create() {
+  return new iree_mmt4d_test_random_engine_t;
+}
+
+void iree_mmt4d_test_random_engine_destroy(iree_mmt4d_test_random_engine_t* e) {
+  delete e;
+}
+
+static int iree_mmt4d_test_random_engine_get_in_uint16_range(
+    iree_mmt4d_test_random_engine_t* e) {
+  uint32_t v = e->cpp_random_engine();
+  // return the second-least-signicant out of the 4 bytes of state. It avoids
+  // some mild issues with the least-significant and most-significant bytes.
+  return (v >> 8) & 0xffff;
+}
+
+int iree_mmt4d_test_random_engine_get_0_or_1(
+    iree_mmt4d_test_random_engine_t* e) {
+  int v = iree_mmt4d_test_random_engine_get_in_uint16_range(e);
+  return v & 1;
+}
+
+int iree_mmt4d_test_random_engine_get_between_minus16_and_plus15(
+    iree_mmt4d_test_random_engine_t* e) {
+  int v = iree_mmt4d_test_random_engine_get_in_uint16_range(e);
+  return (v % 32) - 16;
+}
+
+template <typename T>
+static void write_random_buffer(T* buffer, iree_ukernel_size_t size_in_bytes,
+                                iree_mmt4d_test_random_engine_t* engine) {
+  iree_ukernel_size_t size_in_elems = size_in_bytes / sizeof(T);
+  assert(size_in_elems * sizeof(T) == size_in_bytes && "bad size");
+  for (iree_ukernel_size_t i = 0; i < size_in_elems; ++i) {
+    // Small integers, should work for now for all the types we currently have
+    // and enable exact float arithmetic, allowing to keep tests simpler for
+    // now. Watch out for when we'll do float16!
+    T random_val =
+        iree_mmt4d_test_random_engine_get_between_minus16_and_plus15(engine);
+    buffer[i] = random_val;
+  }
+}
+
+void write_random_buffer(void* buffer, iree_ukernel_size_t size_in_bytes,
+                         iree_mmt4d_scalar_type_t type,
+                         iree_mmt4d_test_random_engine_t* engine) {
+  switch (type) {
+    case iree_mmt4d_scalar_type_f32:
+      write_random_buffer(static_cast<float*>(buffer), size_in_bytes, engine);
+      return;
+    case iree_mmt4d_scalar_type_i32:
+      write_random_buffer(static_cast<int32_t*>(buffer), size_in_bytes, engine);
+      return;
+    case iree_mmt4d_scalar_type_i8:
+      write_random_buffer(static_cast<int8_t*>(buffer), size_in_bytes, engine);
+      return;
+    default:
+      assert(false && "unknown type");
+  }
+}
+
+const char* get_mmt4d_type_str(const iree_ukernel_mmt4d_params_t* params) {
+  switch (params->type) {
+#define GET_MMT4D_TYPE_STR_CASE(x) \
+  case x:                          \
+    return #x;
+    GET_MMT4D_TYPE_STR_CASE(iree_ukernel_mmt4d_type_f32f32f32);
+    GET_MMT4D_TYPE_STR_CASE(iree_ukernel_mmt4d_type_i8i8i32);
+    default:
+      assert(false && "unknown type");
+      return "unknown type";
+  }
+}
+
+const char* get_cpu_features_str(const iree_ukernel_mmt4d_params_t* params) {
+  // We set only one feature bit at a time in this test --- not an actual
+  // detected cpu data field. This might have to change in the future if some
+  // code path relies on the combination of two features.
+  // For now, asserting only one bit set, and taking advantage of that to work
+  // with plain string literals.
+  assert(0 == (params->cpu_data_field_0 & (params->cpu_data_field_0 - 1)));
+  if (params->cpu_data_field_0 == 0) {
+    return "(none)";
+  }
+#if defined(IREE_UKERNEL_ARCH_ARM_64)
+  if (params->cpu_data_field_0 & IREE_CPU_DATA_FIELD_0_AARCH64_HAVE_I8MM) {
+    return "i8mm";
+  }
+  if (params->cpu_data_field_0 & IREE_CPU_DATA_FIELD_0_AARCH64_HAVE_DOTPROD) {
+    return "dotprod";
+  }
+#endif  // defined(IREE_UKERNEL_ARCH_ARM_64)
+  assert(false && "unknown CPU feature");
+  return "unknown CPU feature";
+}
diff --git a/runtime/src/iree/builtins/ukernel/tools/mmt4d_test_utils.h b/runtime/src/iree/builtins/ukernel/tools/mmt4d_test_utils.h
new file mode 100644
index 0000000..0f45f0e
--- /dev/null
+++ b/runtime/src/iree/builtins/ukernel/tools/mmt4d_test_utils.h
@@ -0,0 +1,63 @@
+// Copyright 2022 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_BUILTINS_UKERNEL_TOOLS_MMT4D_TEST_UTILS_H_
+#define IREE_BUILTINS_UKERNEL_TOOLS_MMT4D_TEST_UTILS_H_
+
+// clang-format off
+#include <stdint.h>  // include before ukernel/common.h to keep standard types
+// clang-format on
+
+#include "iree/builtins/ukernel/mmt4d_types.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+enum iree_mmt4d_scalar_type_t {
+  iree_mmt4d_scalar_type_unknown,
+  iree_mmt4d_scalar_type_i8,
+  iree_mmt4d_scalar_type_i32,
+  iree_mmt4d_scalar_type_f32,
+};
+
+typedef enum iree_mmt4d_scalar_type_t iree_mmt4d_scalar_type_t;
+
+iree_mmt4d_scalar_type_t iree_ukernel_mmt4d_lhs_type(
+    const iree_ukernel_mmt4d_params_t* params);
+iree_mmt4d_scalar_type_t iree_ukernel_mmt4d_rhs_type(
+    const iree_ukernel_mmt4d_params_t* params);
+iree_mmt4d_scalar_type_t iree_ukernel_mmt4d_out_type(
+    const iree_ukernel_mmt4d_params_t* params);
+
+iree_ukernel_size_t iree_ukernel_mmt4d_lhs_buffer_size(
+    const iree_ukernel_mmt4d_params_t* params);
+iree_ukernel_size_t iree_ukernel_mmt4d_rhs_buffer_size(
+    const iree_ukernel_mmt4d_params_t* params);
+iree_ukernel_size_t iree_ukernel_mmt4d_out_buffer_size(
+    const iree_ukernel_mmt4d_params_t* params);
+
+struct iree_mmt4d_test_random_engine_t;
+typedef struct iree_mmt4d_test_random_engine_t iree_mmt4d_test_random_engine_t;
+iree_mmt4d_test_random_engine_t* iree_mmt4d_test_random_engine_create();
+void iree_mmt4d_test_random_engine_destroy(iree_mmt4d_test_random_engine_t* e);
+int iree_mmt4d_test_random_engine_get_0_or_1(
+    iree_mmt4d_test_random_engine_t* e);
+int iree_mmt4d_test_random_engine_get_between_minus16_and_plus15(
+    iree_mmt4d_test_random_engine_t* e);
+
+void write_random_buffer(void* buffer, iree_ukernel_size_t size_in_bytes,
+                         iree_mmt4d_scalar_type_t type,
+                         iree_mmt4d_test_random_engine_t* engine);
+
+const char* get_mmt4d_type_str(const iree_ukernel_mmt4d_params_t* params);
+const char* get_cpu_features_str(const iree_ukernel_mmt4d_params_t* params);
+
+#ifdef __cplusplus
+}  // extern "C"
+#endif
+
+#endif  // IREE_BUILTINS_UKERNEL_TOOLS_MMT4D_TEST_UTILS_H_