Improvements to e2e matmul tests (#15243)

* Test f16 on CPU, not just on GPU.
* Test bf16
* Test many SIMD variants that we now have dedicated ukernel code paths
for. These code paths were so far tested only in isolated unit tests,
not in integration.
* Rationalize/reorganize the collection of e2e matmul tests:
  - Drop no longer relevant cases.
- Rename: `mmt4d` -> `dt` (for data-tiling), `direct` -> `nondt`, etc.
The `nondt` tests explicitly disable data-tiling, so they will continue
testing that through the merging of #15215.
  - Group by target backend.
diff --git a/build_tools/cmake/iree_check_test.cmake b/build_tools/cmake/iree_check_test.cmake
index b7d6a24..73fc57a 100644
--- a/build_tools/cmake/iree_check_test.cmake
+++ b/build_tools/cmake/iree_check_test.cmake
@@ -8,7 +8,8 @@
 
 function(iree_is_bytecode_module_test_excluded_by_labels _DST_IS_EXCLUDED_VAR _SRC_LABELS)
   string(TOLOWER "${CMAKE_BUILD_TYPE}" _LOWERCASE_BUILD_TYPE)
-  if(((IREE_ARCH MARCHES "^riscv_") AND ("noriscv" IN_LIST _SRC_LABELS)) OR
+  if(((IREE_ARCH MATCHES "^riscv_") AND ("noriscv" IN_LIST _SRC_LABELS)) OR
+     (EMSCRIPTEN AND ("nowasm" IN_LIST _SRC_LABELS)) OR
      (IREE_ENABLE_ASAN AND ("noasan" IN_LIST _SRC_LABELS)) OR
      (IREE_ENABLE_TSAN AND ("notsan" IN_LIST _SRC_LABELS)) OR
      (CMAKE_CROSSCOMPILING AND "hostonly" IN_LIST _RULE_LABELS) OR
diff --git a/runtime/src/iree/tooling/trace_replay.c b/runtime/src/iree/tooling/trace_replay.c
index c4b9f31..6f2c849 100644
--- a/runtime/src/iree/tooling/trace_replay.c
+++ b/runtime/src/iree/tooling/trace_replay.c
@@ -349,6 +349,9 @@
     case IREE_HAL_ELEMENT_TYPE_FLOAT_16:
       *(uint16_t*)dst = iree_math_f32_to_f16((float)value);
       break;
+    case IREE_HAL_ELEMENT_TYPE_BFLOAT_16:
+      *(uint16_t*)dst = iree_math_f32_to_bf16((float)value);
+      break;
     IREE_TRACE_REPLAY_WRITE_ELEMENT_CASE(FLOAT_32, float)
     IREE_TRACE_REPLAY_WRITE_ELEMENT_CASE(FLOAT_64, double)
     // clang-format on
@@ -402,6 +405,7 @@
     case IREE_HAL_ELEMENT_TYPE_INT_16:
     case IREE_HAL_ELEMENT_TYPE_SINT_16:
     case IREE_HAL_ELEMENT_TYPE_FLOAT_16:
+    case IREE_HAL_ELEMENT_TYPE_BFLOAT_16:
       *min = -4;
       *max = +4;
       break;
diff --git a/tests/e2e/matmul/BUILD.bazel b/tests/e2e/matmul/BUILD.bazel
index 6c76c35..d3ca5b1 100644
--- a/tests/e2e/matmul/BUILD.bazel
+++ b/tests/e2e/matmul/BUILD.bazel
@@ -18,156 +18,50 @@
     srcs = ["generate_e2e_matmul_tests.py"],
 )
 
-[iree_generated_trace_runner_test(
-    name = "e2e_matmul_direct_%s_small" % lhs_rhs_type,
-    generator = ":generate_e2e_matmul_tests",
-    generator_args = [
-        "--lhs_rhs_type=%s" % lhs_rhs_type,
-        "--shapes=small",
-    ],
-    target_backends_and_drivers = [
-        ("llvm-cpu", "local-task"),
-    ],
-    trace_runner = "//tools:iree-e2e-matmul-test",
-) for lhs_rhs_type in [
-    "i8",
-    "f32",
-]]
+###########################################################################
+##
+## LLVMCPU backend
+##
+###########################################################################
 
-# Test asm
+# LLVMCPU, non-data-tiling, no microkernels
 [iree_generated_trace_runner_test(
-    name = "e2e_matmul_mmt4d_%s_small" % lhs_rhs_type,
+    name = "e2e_matmul_nondt_%s_%s_small" % (lhs_rhs_type, acc_type),
     compiler_flags = [
-        "--iree-opt-data-tiling",
+        "--iree-opt-data-tiling=false",
     ],
     generator = ":generate_e2e_matmul_tests",
     generator_args = [
         "--lhs_rhs_type=%s" % lhs_rhs_type,
+        "--acc_type=%s" % acc_type,
         "--shapes=small",
     ],
-    target_backends_and_drivers = [
-        ("llvm-cpu", "local-task"),
-    ],
-    target_cpu_features_variants = ["default"] +
-                                   ([
-                                       "arm_64:dotprod:+dotprod",
-                                       "arm_64:i8mm:+i8mm",
-                                   ] if lhs_rhs_type == "i8" else []),
-    trace_runner = "//tools:iree-e2e-matmul-test",
-) for lhs_rhs_type in [
-    "i8",
-    "f32",
-]]
-
-[iree_generated_trace_runner_test(
-    name = "e2e_matmul_mmt4d_%s_large" % lhs_rhs_type,
-    compiler_flags = [
-        "--iree-opt-data-tiling",
-    ],
-    generator = ":generate_e2e_matmul_tests",
-    generator_args = [
-        "--lhs_rhs_type=%s" % lhs_rhs_type,
-        "--shapes=large",
-    ],
     tags = [
-        # "--shapes=large" can cause timeouts on riscv emulator and sanitizers.
+        # f16/bf16 trigger internal LLVM assertion errors on riscv and wasm.
         "noriscv",
-        "noasan",
-        "notsan",
-    ],
+        "nowasm",
+    ] if (lhs_rhs_type == "f16" or lhs_rhs_type == "bf16") else [],
     target_backends_and_drivers = [
         ("llvm-cpu", "local-task"),
     ],
-    target_cpu_features_variants = ["default"] +
-                                   ([
-                                       "arm_64:dotprod:+dotprod",
-                                       "arm_64:i8mm:+i8mm",
-                                   ] if lhs_rhs_type == "i8" else []),
     trace_runner = "//tools:iree-e2e-matmul-test",
-) for lhs_rhs_type in [
-    "i8",
-    "f32",
+) for (lhs_rhs_type, acc_type) in [
+    ("i8", "i32"),
+    ("f32", "f32"),
+    ("f16", "f16"),
+    ("f16", "f32"),
+    ("bf16", "bf16"),
+    ("bf16", "f32"),
 ]]
 
-# Test intrinsics. No need to run vmvx again, since it isn't affected by this
-# codegen flag. No need to run "large" sizes, since this only differs from other
-# tests in ways that are orthogonal to problem sizes.
-[iree_generated_trace_runner_test(
-    name = "e2e_matmul_mmt4d_%s_intrinsics_%s" % (lhs_rhs_type, size),
-    compiler_flags = [
-        "--iree-codegen-mmt4d-use-intrinsics",
-        "--iree-opt-data-tiling",
-    ],
-    generator = ":generate_e2e_matmul_tests",
-    generator_args = [
-        "--lhs_rhs_type=%s" % lhs_rhs_type,
-        "--shapes=%s" % size,
-    ],
-    target_backends_and_drivers = [
-        ("llvm-cpu", "local-task"),
-    ],
-    target_cpu_features_variants = ["default"] +
-                                   ([
-                                       "arm_64:dotprod:+dotprod",
-                                       "arm_64:i8mm:+i8mm",
-                                   ] if lhs_rhs_type == "i8" else []),
-    trace_runner = "//tools:iree-e2e-matmul-test",
-) for lhs_rhs_type in [
-    "i8",
-    "f32",
-] for size in [
-    "small",
-]]
-
-# Test VMVX+ukernel, direct (not mmt4d)
-[iree_generated_trace_runner_test(
-    name = "e2e_matmul_direct_%s_small_ukernel" % lhs_rhs_type,
-    compiler_flags = [
-        "--iree-vmvx-enable-microkernels",
-    ],
-    generator = ":generate_e2e_matmul_tests",
-    generator_args = [
-        "--lhs_rhs_type=%s" % lhs_rhs_type,
-        "--shapes=small",
-    ],
-    target_backends_and_drivers = [
-        ("vmvx", "local-task"),
-    ],
-    trace_runner = "//tools:iree-e2e-matmul-test",
-) for lhs_rhs_type in [
-    "i8",
-    "f32",
-]]
-
-# Test VMVX+ukernel, mmt4d, with target CPU features variants relevant to each
-# lhs_rhs_type.
-[iree_generated_trace_runner_test(
-    name = "e2e_matmul_mmt4d_%s_small_vmvx_ukernel" % lhs_rhs_type,
-    compiler_flags = [
-        "--iree-vmvx-enable-microkernels",
-        "--iree-opt-data-tiling",
-    ],
-    generator = ":generate_e2e_matmul_tests",
-    generator_args = [
-        "--lhs_rhs_type=%s" % lhs_rhs_type,
-        "--shapes=small",
-    ],
-    target_backends_and_drivers = [
-        ("vmvx", "local-task"),
-    ],
-    trace_runner = "//tools:iree-e2e-matmul-test",
-) for lhs_rhs_type in [
-    "i8",
-    "f32",
-]]
-
-X86_64_AVX2_FMA = [
+X86_64_AVX2 = [
     "+avx",
     "+avx2",
     "+fma",
+    "+f16c",
 ]
 
-X86_64_AVX512_BASE = X86_64_AVX2_FMA + [
+X86_64_AVX512 = X86_64_AVX2 + [
     "+avx512f",
     "+avx512vl",
     "+avx512cd",
@@ -175,49 +69,132 @@
     "+avx512dq",
 ]
 
-X86_64_AVX512_VNNI = X86_64_AVX512_BASE + [
+X86_64_AVX512_VNNI = X86_64_AVX512 + [
     "+avx512vnni",
 ]
 
-# Test mmt4d with --iree-llvmcpu-enable-microkernels.
+X86_64_AVX512_BF16 = X86_64_AVX512 + [
+    "+avx512bf16",
+]
+
+# LLVMCPU, data-tiling + microkernels.
+# TODO(#15241, #15215): also test data-tiling alone without microkernels. This currently
+# fails (#15241), which needs to be resolved to unblock data-tiling-by-default (#15215).
 [iree_generated_trace_runner_test(
-    name = "e2e_matmul_mmt4d_%s_%s_ukernel" % (lhs_rhs_type, size),
+    name = "e2e_matmul_dt_uk_%s_%s_%s" % (lhs_rhs_type, acc_type, size),
     compiler_flags = [
+        "--iree-opt-data-tiling",
         "--iree-llvmcpu-enable-microkernels",
+    ],
+    generator = ":generate_e2e_matmul_tests",
+    generator_args = [
+        "--lhs_rhs_type=%s" % lhs_rhs_type,
+        "--acc_type=%s" % acc_type,
+        "--shapes=%s" % size,
+    ],
+    tags = ([
+        # "--shapes=large" can cause timeouts on sanitizers.
+        "noasan",
+        "notsan",
+    ] if size == "large" else []) + ([
+        # "--shapes=large" can cause timeouts on RISC-V emulator.
+        # f16/bf16 trigger internal LLVM assertion errors on riscv and wasm.
+        "noriscv",
+        "nowasm",
+    ] if (lhs_rhs_type == "f16" or lhs_rhs_type == "bf16") else []),
+    target_backends_and_drivers = [
+        ("llvm-cpu", "local-task"),
+    ],
+    target_cpu_features_variants = ["default"] +
+                                   ([
+                                       "arm_64:dotprod:+dotprod",
+                                       "arm_64:i8mm:+i8mm",
+                                       "x86_64:avx512vnni:" + ",".join(X86_64_AVX512_VNNI),
+                                   ] if lhs_rhs_type == "i8" and acc_type == "i32" else [
+                                       "x86_64:avx2:" + ",".join(X86_64_AVX2),
+                                       "x86_64:avx512:" + ",".join(X86_64_AVX512),
+                                   ] if lhs_rhs_type == "f32" and acc_type == "f32" else [
+                                       "x86_64:avx2:" + ",".join(X86_64_AVX2),
+                                       "x86_64:avx512:" + ",".join(X86_64_AVX512),
+                                       "arm_64:fp16:+fp16",
+                                   ] if lhs_rhs_type == "f16" and acc_type == "f16" else [
+                                       "x86_64:avx2:" + ",".join(X86_64_AVX2),
+                                       "x86_64:avx512:" + ",".join(X86_64_AVX512),
+                                       "arm_64:fp16:+fp16fml",
+                                   ] if lhs_rhs_type == "f16" and acc_type == "f32" else [
+                                       "x86_64:avx2:" + ",".join(X86_64_AVX2),
+                                       "x86_64:avx512:" + ",".join(X86_64_AVX512),
+                                       "x86_64:avx512bf16:" + ",".join(X86_64_AVX512_BF16),
+                                       "arm_64:fp16:+bf16",
+                                   ] if lhs_rhs_type == "bf16" and acc_type == "bf16" else [
+                                       "x86_64:avx2:" + ",".join(X86_64_AVX2),
+                                       "x86_64:avx512:" + ",".join(X86_64_AVX512),
+                                       "x86_64:avx512bf16:" + ",".join(X86_64_AVX512_BF16),
+                                       "arm_64:fp16:+bf16",
+                                   ] if lhs_rhs_type == "bf16" and acc_type == "f32" else []),
+    trace_runner = "//tools:iree-e2e-matmul-test",
+) for (lhs_rhs_type, acc_type) in [
+    ("i8", "i32"),
+    ("f32", "f32"),
+    ("f16", "f16"),
+    ("f16", "f32"),
+    ("bf16", "bf16"),
+    ("bf16", "f32"),
+] for size in [
+    "small",
+    "large",
+]]
+
+# Some e2e testing for --iree-codegen-enable-vector-padding=false.
+iree_generated_trace_runner_test(
+    name = "e2e_matmul_nondt_f32_small_no_padding",
+    compiler_flags = [
+        "--iree-codegen-enable-vector-padding=false",
+    ],
+    generator = ":generate_e2e_matmul_tests",
+    generator_args = [
+        "--lhs_rhs_type=f32",
+        "--shapes=small",
+    ],
+    target_backends_and_drivers = [
+        ("llvm-cpu", "local-task"),
+    ],
+    trace_runner = "//tools:iree-e2e-matmul-test",
+)
+
+###########################################################################
+##
+## VMVX backend
+##
+###########################################################################
+
+# VMVX, data-tiling + microkernels.
+[iree_generated_trace_runner_test(
+    name = "e2e_matmul_dt_uk_%s_small" % lhs_rhs_type,
+    compiler_flags = [
+        "--iree-vmvx-enable-microkernels",
         "--iree-opt-data-tiling",
     ],
     generator = ":generate_e2e_matmul_tests",
     generator_args = [
         "--lhs_rhs_type=%s" % lhs_rhs_type,
-        "--shapes=%s" % size,
+        "--shapes=small",
     ],
-    tags = [
-        # "--shapes=large" can cause timeouts on riscv emulator and sanitizers.
-        "noriscv",
-        "noasan",
-        "notsan",
-    ] if size == "large" else [],
     target_backends_and_drivers = [
-        ("llvm-cpu", "local-task"),
+        ("vmvx", "local-task"),
     ],
-    target_cpu_features_variants = [
-        "default",
-        "x86_64:avx2_fma:" + ",".join(X86_64_AVX2_FMA),
-        "x86_64:avx512_base:" + ",".join(X86_64_AVX512_BASE),
-    ] + ([
-        "x86_64:avx512_vnni:" + ",".join(X86_64_AVX512_VNNI),
-        "arm_64:dotprod:+dotprod",
-        "arm_64:i8mm:+i8mm",
-    ] if lhs_rhs_type == "i8" else []),
     trace_runner = "//tools:iree-e2e-matmul-test",
 ) for lhs_rhs_type in [
     "i8",
     "f32",
-] for size in [
-    "small",
-    "large",
 ]]
 
+###########################################################################
+##
+## CUDA backend
+##
+###########################################################################
+
 iree_generated_trace_runner_test(
     name = "e2e_matmul_direct_f32_gpu_large_LLVMGPUMatmulSimt",
     generator = ":generate_e2e_matmul_tests",
@@ -415,13 +392,18 @@
     ],
     target_backends_and_drivers = [
         ("cuda", "cuda"),
-        ("llvm-cpu", "local-task"),
     ],
     trace_runner = "//tools:iree-e2e-matmul-test",
 ) for lhs_rhs_type in [
     "f32",
 ]]
 
+###########################################################################
+##
+## Vulkan backend
+##
+###########################################################################
+
 [iree_generated_trace_runner_test(
     name = "e2e_matmul_direct_{0}_gpu_large_valhall".format(lhs_rhs_type),
     compiler_flags = [
@@ -474,22 +456,6 @@
 ]]
 
 iree_generated_trace_runner_test(
-    name = "e2e_matmul_direct_f32_small_no_padding",
-    compiler_flags = [
-        "--iree-codegen-enable-vector-padding=false",
-    ],
-    generator = ":generate_e2e_matmul_tests",
-    generator_args = [
-        "--lhs_rhs_type=f32",
-        "--shapes=small",
-    ],
-    target_backends_and_drivers = [
-        ("llvm-cpu", "local-task"),
-    ],
-    trace_runner = "//tools:iree-e2e-matmul-test",
-)
-
-iree_generated_trace_runner_test(
     name = "e2e_matmul_direct_f16_gpu_large_rdna3",
     compiler_flags = [
         "--iree-vulkan-target-triple=rdna3-unknown-linux",
diff --git a/tests/e2e/matmul/CMakeLists.txt b/tests/e2e/matmul/CMakeLists.txt
index 0fd02ec..913aa3d 100644
--- a/tests/e2e/matmul/CMakeLists.txt
+++ b/tests/e2e/matmul/CMakeLists.txt
@@ -12,11 +12,12 @@
 
 iree_generated_trace_runner_test(
   NAME
-    e2e_matmul_direct_i8_small
+    e2e_matmul_nondt_i8_i32_small
   GENERATOR
     "generate_e2e_matmul_tests.py"
   GENERATOR_ARGS
     "--lhs_rhs_type=i8"
+    "--acc_type=i32"
     "--shapes=small"
   TRACE_RUNNER
     iree-e2e-matmul-test
@@ -24,15 +25,20 @@
     "llvm-cpu"
   DRIVERS
     "local-task"
+  COMPILER_FLAGS
+    "--iree-opt-data-tiling=false"
+  LABELS
+
 )
 
 iree_generated_trace_runner_test(
   NAME
-    e2e_matmul_direct_f32_small
+    e2e_matmul_nondt_f32_f32_small
   GENERATOR
     "generate_e2e_matmul_tests.py"
   GENERATOR_ARGS
     "--lhs_rhs_type=f32"
+    "--acc_type=f32"
     "--shapes=small"
   TRACE_RUNNER
     iree-e2e-matmul-test
@@ -40,15 +46,108 @@
     "llvm-cpu"
   DRIVERS
     "local-task"
+  COMPILER_FLAGS
+    "--iree-opt-data-tiling=false"
+  LABELS
+
 )
 
 iree_generated_trace_runner_test(
   NAME
-    e2e_matmul_mmt4d_i8_small
+    e2e_matmul_nondt_f16_f16_small
+  GENERATOR
+    "generate_e2e_matmul_tests.py"
+  GENERATOR_ARGS
+    "--lhs_rhs_type=f16"
+    "--acc_type=f16"
+    "--shapes=small"
+  TRACE_RUNNER
+    iree-e2e-matmul-test
+  TARGET_BACKENDS
+    "llvm-cpu"
+  DRIVERS
+    "local-task"
+  COMPILER_FLAGS
+    "--iree-opt-data-tiling=false"
+  LABELS
+    "noriscv"
+    "nowasm"
+)
+
+iree_generated_trace_runner_test(
+  NAME
+    e2e_matmul_nondt_f16_f32_small
+  GENERATOR
+    "generate_e2e_matmul_tests.py"
+  GENERATOR_ARGS
+    "--lhs_rhs_type=f16"
+    "--acc_type=f32"
+    "--shapes=small"
+  TRACE_RUNNER
+    iree-e2e-matmul-test
+  TARGET_BACKENDS
+    "llvm-cpu"
+  DRIVERS
+    "local-task"
+  COMPILER_FLAGS
+    "--iree-opt-data-tiling=false"
+  LABELS
+    "noriscv"
+    "nowasm"
+)
+
+iree_generated_trace_runner_test(
+  NAME
+    e2e_matmul_nondt_bf16_bf16_small
+  GENERATOR
+    "generate_e2e_matmul_tests.py"
+  GENERATOR_ARGS
+    "--lhs_rhs_type=bf16"
+    "--acc_type=bf16"
+    "--shapes=small"
+  TRACE_RUNNER
+    iree-e2e-matmul-test
+  TARGET_BACKENDS
+    "llvm-cpu"
+  DRIVERS
+    "local-task"
+  COMPILER_FLAGS
+    "--iree-opt-data-tiling=false"
+  LABELS
+    "noriscv"
+    "nowasm"
+)
+
+iree_generated_trace_runner_test(
+  NAME
+    e2e_matmul_nondt_bf16_f32_small
+  GENERATOR
+    "generate_e2e_matmul_tests.py"
+  GENERATOR_ARGS
+    "--lhs_rhs_type=bf16"
+    "--acc_type=f32"
+    "--shapes=small"
+  TRACE_RUNNER
+    iree-e2e-matmul-test
+  TARGET_BACKENDS
+    "llvm-cpu"
+  DRIVERS
+    "local-task"
+  COMPILER_FLAGS
+    "--iree-opt-data-tiling=false"
+  LABELS
+    "noriscv"
+    "nowasm"
+)
+
+iree_generated_trace_runner_test(
+  NAME
+    e2e_matmul_dt_uk_i8_i32_small
   GENERATOR
     "generate_e2e_matmul_tests.py"
   GENERATOR_ARGS
     "--lhs_rhs_type=i8"
+    "--acc_type=i32"
     "--shapes=small"
   TRACE_RUNNER
     iree-e2e-matmul-test
@@ -58,39 +157,24 @@
     "local-task"
   COMPILER_FLAGS
     "--iree-opt-data-tiling"
+    "--iree-llvmcpu-enable-microkernels"
+  LABELS
+
   TARGET_CPU_FEATURES_VARIANTS
     "default"
     "arm_64:dotprod:+dotprod"
     "arm_64:i8mm:+i8mm"
+    "x86_64:avx512vnni:+avx,+avx2,+fma,+f16c,+avx512f,+avx512vl,+avx512cd,+avx512bw,+avx512dq,+avx512vnni"
 )
 
 iree_generated_trace_runner_test(
   NAME
-    e2e_matmul_mmt4d_f32_small
-  GENERATOR
-    "generate_e2e_matmul_tests.py"
-  GENERATOR_ARGS
-    "--lhs_rhs_type=f32"
-    "--shapes=small"
-  TRACE_RUNNER
-    iree-e2e-matmul-test
-  TARGET_BACKENDS
-    "llvm-cpu"
-  DRIVERS
-    "local-task"
-  COMPILER_FLAGS
-    "--iree-opt-data-tiling"
-  TARGET_CPU_FEATURES_VARIANTS
-    "default"
-)
-
-iree_generated_trace_runner_test(
-  NAME
-    e2e_matmul_mmt4d_i8_large
+    e2e_matmul_dt_uk_i8_i32_large
   GENERATOR
     "generate_e2e_matmul_tests.py"
   GENERATOR_ARGS
     "--lhs_rhs_type=i8"
+    "--acc_type=i32"
     "--shapes=large"
   TRACE_RUNNER
     iree-e2e-matmul-test
@@ -100,23 +184,51 @@
     "local-task"
   COMPILER_FLAGS
     "--iree-opt-data-tiling"
+    "--iree-llvmcpu-enable-microkernels"
   LABELS
-    "noriscv"
     "noasan"
     "notsan"
   TARGET_CPU_FEATURES_VARIANTS
     "default"
     "arm_64:dotprod:+dotprod"
     "arm_64:i8mm:+i8mm"
+    "x86_64:avx512vnni:+avx,+avx2,+fma,+f16c,+avx512f,+avx512vl,+avx512cd,+avx512bw,+avx512dq,+avx512vnni"
 )
 
 iree_generated_trace_runner_test(
   NAME
-    e2e_matmul_mmt4d_f32_large
+    e2e_matmul_dt_uk_f32_f32_small
   GENERATOR
     "generate_e2e_matmul_tests.py"
   GENERATOR_ARGS
     "--lhs_rhs_type=f32"
+    "--acc_type=f32"
+    "--shapes=small"
+  TRACE_RUNNER
+    iree-e2e-matmul-test
+  TARGET_BACKENDS
+    "llvm-cpu"
+  DRIVERS
+    "local-task"
+  COMPILER_FLAGS
+    "--iree-opt-data-tiling"
+    "--iree-llvmcpu-enable-microkernels"
+  LABELS
+
+  TARGET_CPU_FEATURES_VARIANTS
+    "default"
+    "x86_64:avx2:+avx,+avx2,+fma,+f16c"
+    "x86_64:avx512:+avx,+avx2,+fma,+f16c,+avx512f,+avx512vl,+avx512cd,+avx512bw,+avx512dq"
+)
+
+iree_generated_trace_runner_test(
+  NAME
+    e2e_matmul_dt_uk_f32_f32_large
+  GENERATOR
+    "generate_e2e_matmul_tests.py"
+  GENERATOR_ARGS
+    "--lhs_rhs_type=f32"
+    "--acc_type=f32"
     "--shapes=large"
   TRACE_RUNNER
     iree-e2e-matmul-test
@@ -126,21 +238,24 @@
     "local-task"
   COMPILER_FLAGS
     "--iree-opt-data-tiling"
+    "--iree-llvmcpu-enable-microkernels"
   LABELS
-    "noriscv"
     "noasan"
     "notsan"
   TARGET_CPU_FEATURES_VARIANTS
     "default"
+    "x86_64:avx2:+avx,+avx2,+fma,+f16c"
+    "x86_64:avx512:+avx,+avx2,+fma,+f16c,+avx512f,+avx512vl,+avx512cd,+avx512bw,+avx512dq"
 )
 
 iree_generated_trace_runner_test(
   NAME
-    e2e_matmul_mmt4d_i8_intrinsics_small
+    e2e_matmul_dt_uk_f16_f16_small
   GENERATOR
     "generate_e2e_matmul_tests.py"
   GENERATOR_ARGS
-    "--lhs_rhs_type=i8"
+    "--lhs_rhs_type=f16"
+    "--acc_type=f16"
     "--shapes=small"
   TRACE_RUNNER
     iree-e2e-matmul-test
@@ -149,17 +264,229 @@
   DRIVERS
     "local-task"
   COMPILER_FLAGS
-    "--iree-codegen-mmt4d-use-intrinsics"
     "--iree-opt-data-tiling"
+    "--iree-llvmcpu-enable-microkernels"
+  LABELS
+    "noriscv"
+    "nowasm"
   TARGET_CPU_FEATURES_VARIANTS
     "default"
-    "arm_64:dotprod:+dotprod"
-    "arm_64:i8mm:+i8mm"
+    "x86_64:avx2:+avx,+avx2,+fma,+f16c"
+    "x86_64:avx512:+avx,+avx2,+fma,+f16c,+avx512f,+avx512vl,+avx512cd,+avx512bw,+avx512dq"
+    "arm_64:fp16:+fp16"
 )
 
 iree_generated_trace_runner_test(
   NAME
-    e2e_matmul_mmt4d_f32_intrinsics_small
+    e2e_matmul_dt_uk_f16_f16_large
+  GENERATOR
+    "generate_e2e_matmul_tests.py"
+  GENERATOR_ARGS
+    "--lhs_rhs_type=f16"
+    "--acc_type=f16"
+    "--shapes=large"
+  TRACE_RUNNER
+    iree-e2e-matmul-test
+  TARGET_BACKENDS
+    "llvm-cpu"
+  DRIVERS
+    "local-task"
+  COMPILER_FLAGS
+    "--iree-opt-data-tiling"
+    "--iree-llvmcpu-enable-microkernels"
+  LABELS
+    "noasan"
+    "notsan"
+    "noriscv"
+    "nowasm"
+  TARGET_CPU_FEATURES_VARIANTS
+    "default"
+    "x86_64:avx2:+avx,+avx2,+fma,+f16c"
+    "x86_64:avx512:+avx,+avx2,+fma,+f16c,+avx512f,+avx512vl,+avx512cd,+avx512bw,+avx512dq"
+    "arm_64:fp16:+fp16"
+)
+
+iree_generated_trace_runner_test(
+  NAME
+    e2e_matmul_dt_uk_f16_f32_small
+  GENERATOR
+    "generate_e2e_matmul_tests.py"
+  GENERATOR_ARGS
+    "--lhs_rhs_type=f16"
+    "--acc_type=f32"
+    "--shapes=small"
+  TRACE_RUNNER
+    iree-e2e-matmul-test
+  TARGET_BACKENDS
+    "llvm-cpu"
+  DRIVERS
+    "local-task"
+  COMPILER_FLAGS
+    "--iree-opt-data-tiling"
+    "--iree-llvmcpu-enable-microkernels"
+  LABELS
+    "noriscv"
+    "nowasm"
+  TARGET_CPU_FEATURES_VARIANTS
+    "default"
+    "x86_64:avx2:+avx,+avx2,+fma,+f16c"
+    "x86_64:avx512:+avx,+avx2,+fma,+f16c,+avx512f,+avx512vl,+avx512cd,+avx512bw,+avx512dq"
+    "arm_64:fp16:+fp16fml"
+)
+
+iree_generated_trace_runner_test(
+  NAME
+    e2e_matmul_dt_uk_f16_f32_large
+  GENERATOR
+    "generate_e2e_matmul_tests.py"
+  GENERATOR_ARGS
+    "--lhs_rhs_type=f16"
+    "--acc_type=f32"
+    "--shapes=large"
+  TRACE_RUNNER
+    iree-e2e-matmul-test
+  TARGET_BACKENDS
+    "llvm-cpu"
+  DRIVERS
+    "local-task"
+  COMPILER_FLAGS
+    "--iree-opt-data-tiling"
+    "--iree-llvmcpu-enable-microkernels"
+  LABELS
+    "noasan"
+    "notsan"
+    "noriscv"
+    "nowasm"
+  TARGET_CPU_FEATURES_VARIANTS
+    "default"
+    "x86_64:avx2:+avx,+avx2,+fma,+f16c"
+    "x86_64:avx512:+avx,+avx2,+fma,+f16c,+avx512f,+avx512vl,+avx512cd,+avx512bw,+avx512dq"
+    "arm_64:fp16:+fp16fml"
+)
+
+iree_generated_trace_runner_test(
+  NAME
+    e2e_matmul_dt_uk_bf16_bf16_small
+  GENERATOR
+    "generate_e2e_matmul_tests.py"
+  GENERATOR_ARGS
+    "--lhs_rhs_type=bf16"
+    "--acc_type=bf16"
+    "--shapes=small"
+  TRACE_RUNNER
+    iree-e2e-matmul-test
+  TARGET_BACKENDS
+    "llvm-cpu"
+  DRIVERS
+    "local-task"
+  COMPILER_FLAGS
+    "--iree-opt-data-tiling"
+    "--iree-llvmcpu-enable-microkernels"
+  LABELS
+    "noriscv"
+    "nowasm"
+  TARGET_CPU_FEATURES_VARIANTS
+    "default"
+    "x86_64:avx2:+avx,+avx2,+fma,+f16c"
+    "x86_64:avx512:+avx,+avx2,+fma,+f16c,+avx512f,+avx512vl,+avx512cd,+avx512bw,+avx512dq"
+    "x86_64:avx512bf16:+avx,+avx2,+fma,+f16c,+avx512f,+avx512vl,+avx512cd,+avx512bw,+avx512dq,+avx512bf16"
+    "arm_64:fp16:+bf16"
+)
+
+iree_generated_trace_runner_test(
+  NAME
+    e2e_matmul_dt_uk_bf16_bf16_large
+  GENERATOR
+    "generate_e2e_matmul_tests.py"
+  GENERATOR_ARGS
+    "--lhs_rhs_type=bf16"
+    "--acc_type=bf16"
+    "--shapes=large"
+  TRACE_RUNNER
+    iree-e2e-matmul-test
+  TARGET_BACKENDS
+    "llvm-cpu"
+  DRIVERS
+    "local-task"
+  COMPILER_FLAGS
+    "--iree-opt-data-tiling"
+    "--iree-llvmcpu-enable-microkernels"
+  LABELS
+    "noasan"
+    "notsan"
+    "noriscv"
+    "nowasm"
+  TARGET_CPU_FEATURES_VARIANTS
+    "default"
+    "x86_64:avx2:+avx,+avx2,+fma,+f16c"
+    "x86_64:avx512:+avx,+avx2,+fma,+f16c,+avx512f,+avx512vl,+avx512cd,+avx512bw,+avx512dq"
+    "x86_64:avx512bf16:+avx,+avx2,+fma,+f16c,+avx512f,+avx512vl,+avx512cd,+avx512bw,+avx512dq,+avx512bf16"
+    "arm_64:fp16:+bf16"
+)
+
+iree_generated_trace_runner_test(
+  NAME
+    e2e_matmul_dt_uk_bf16_f32_small
+  GENERATOR
+    "generate_e2e_matmul_tests.py"
+  GENERATOR_ARGS
+    "--lhs_rhs_type=bf16"
+    "--acc_type=f32"
+    "--shapes=small"
+  TRACE_RUNNER
+    iree-e2e-matmul-test
+  TARGET_BACKENDS
+    "llvm-cpu"
+  DRIVERS
+    "local-task"
+  COMPILER_FLAGS
+    "--iree-opt-data-tiling"
+    "--iree-llvmcpu-enable-microkernels"
+  LABELS
+    "noriscv"
+    "nowasm"
+  TARGET_CPU_FEATURES_VARIANTS
+    "default"
+    "x86_64:avx2:+avx,+avx2,+fma,+f16c"
+    "x86_64:avx512:+avx,+avx2,+fma,+f16c,+avx512f,+avx512vl,+avx512cd,+avx512bw,+avx512dq"
+    "x86_64:avx512bf16:+avx,+avx2,+fma,+f16c,+avx512f,+avx512vl,+avx512cd,+avx512bw,+avx512dq,+avx512bf16"
+    "arm_64:fp16:+bf16"
+)
+
+iree_generated_trace_runner_test(
+  NAME
+    e2e_matmul_dt_uk_bf16_f32_large
+  GENERATOR
+    "generate_e2e_matmul_tests.py"
+  GENERATOR_ARGS
+    "--lhs_rhs_type=bf16"
+    "--acc_type=f32"
+    "--shapes=large"
+  TRACE_RUNNER
+    iree-e2e-matmul-test
+  TARGET_BACKENDS
+    "llvm-cpu"
+  DRIVERS
+    "local-task"
+  COMPILER_FLAGS
+    "--iree-opt-data-tiling"
+    "--iree-llvmcpu-enable-microkernels"
+  LABELS
+    "noasan"
+    "notsan"
+    "noriscv"
+    "nowasm"
+  TARGET_CPU_FEATURES_VARIANTS
+    "default"
+    "x86_64:avx2:+avx,+avx2,+fma,+f16c"
+    "x86_64:avx512:+avx,+avx2,+fma,+f16c,+avx512f,+avx512vl,+avx512cd,+avx512bw,+avx512dq"
+    "x86_64:avx512bf16:+avx,+avx2,+fma,+f16c,+avx512f,+avx512vl,+avx512cd,+avx512bw,+avx512dq,+avx512bf16"
+    "arm_64:fp16:+bf16"
+)
+
+iree_generated_trace_runner_test(
+  NAME
+    e2e_matmul_nondt_f32_small_no_padding
   GENERATOR
     "generate_e2e_matmul_tests.py"
   GENERATOR_ARGS
@@ -172,51 +499,12 @@
   DRIVERS
     "local-task"
   COMPILER_FLAGS
-    "--iree-codegen-mmt4d-use-intrinsics"
-    "--iree-opt-data-tiling"
-  TARGET_CPU_FEATURES_VARIANTS
-    "default"
+    "--iree-codegen-enable-vector-padding=false"
 )
 
 iree_generated_trace_runner_test(
   NAME
-    e2e_matmul_direct_i8_small_ukernel
-  GENERATOR
-    "generate_e2e_matmul_tests.py"
-  GENERATOR_ARGS
-    "--lhs_rhs_type=i8"
-    "--shapes=small"
-  TRACE_RUNNER
-    iree-e2e-matmul-test
-  TARGET_BACKENDS
-    "vmvx"
-  DRIVERS
-    "local-task"
-  COMPILER_FLAGS
-    "--iree-vmvx-enable-microkernels"
-)
-
-iree_generated_trace_runner_test(
-  NAME
-    e2e_matmul_direct_f32_small_ukernel
-  GENERATOR
-    "generate_e2e_matmul_tests.py"
-  GENERATOR_ARGS
-    "--lhs_rhs_type=f32"
-    "--shapes=small"
-  TRACE_RUNNER
-    iree-e2e-matmul-test
-  TARGET_BACKENDS
-    "vmvx"
-  DRIVERS
-    "local-task"
-  COMPILER_FLAGS
-    "--iree-vmvx-enable-microkernels"
-)
-
-iree_generated_trace_runner_test(
-  NAME
-    e2e_matmul_mmt4d_i8_small_vmvx_ukernel
+    e2e_matmul_dt_uk_i8_small
   GENERATOR
     "generate_e2e_matmul_tests.py"
   GENERATOR_ARGS
@@ -235,7 +523,7 @@
 
 iree_generated_trace_runner_test(
   NAME
-    e2e_matmul_mmt4d_f32_small_vmvx_ukernel
+    e2e_matmul_dt_uk_f32_small
   GENERATOR
     "generate_e2e_matmul_tests.py"
   GENERATOR_ARGS
@@ -254,116 +542,6 @@
 
 iree_generated_trace_runner_test(
   NAME
-    e2e_matmul_mmt4d_i8_small_ukernel
-  GENERATOR
-    "generate_e2e_matmul_tests.py"
-  GENERATOR_ARGS
-    "--lhs_rhs_type=i8"
-    "--shapes=small"
-  TRACE_RUNNER
-    iree-e2e-matmul-test
-  TARGET_BACKENDS
-    "llvm-cpu"
-  DRIVERS
-    "local-task"
-  COMPILER_FLAGS
-    "--iree-llvmcpu-enable-microkernels"
-    "--iree-opt-data-tiling"
-  LABELS
-
-  TARGET_CPU_FEATURES_VARIANTS
-    "default"
-    "x86_64:avx2_fma:+avx,+avx2,+fma"
-    "x86_64:avx512_base:+avx,+avx2,+fma,+avx512f,+avx512vl,+avx512cd,+avx512bw,+avx512dq"
-    "x86_64:avx512_vnni:+avx,+avx2,+fma,+avx512f,+avx512vl,+avx512cd,+avx512bw,+avx512dq,+avx512vnni"
-    "arm_64:dotprod:+dotprod"
-    "arm_64:i8mm:+i8mm"
-)
-
-iree_generated_trace_runner_test(
-  NAME
-    e2e_matmul_mmt4d_i8_large_ukernel
-  GENERATOR
-    "generate_e2e_matmul_tests.py"
-  GENERATOR_ARGS
-    "--lhs_rhs_type=i8"
-    "--shapes=large"
-  TRACE_RUNNER
-    iree-e2e-matmul-test
-  TARGET_BACKENDS
-    "llvm-cpu"
-  DRIVERS
-    "local-task"
-  COMPILER_FLAGS
-    "--iree-llvmcpu-enable-microkernels"
-    "--iree-opt-data-tiling"
-  LABELS
-    "noriscv"
-    "noasan"
-    "notsan"
-  TARGET_CPU_FEATURES_VARIANTS
-    "default"
-    "x86_64:avx2_fma:+avx,+avx2,+fma"
-    "x86_64:avx512_base:+avx,+avx2,+fma,+avx512f,+avx512vl,+avx512cd,+avx512bw,+avx512dq"
-    "x86_64:avx512_vnni:+avx,+avx2,+fma,+avx512f,+avx512vl,+avx512cd,+avx512bw,+avx512dq,+avx512vnni"
-    "arm_64:dotprod:+dotprod"
-    "arm_64:i8mm:+i8mm"
-)
-
-iree_generated_trace_runner_test(
-  NAME
-    e2e_matmul_mmt4d_f32_small_ukernel
-  GENERATOR
-    "generate_e2e_matmul_tests.py"
-  GENERATOR_ARGS
-    "--lhs_rhs_type=f32"
-    "--shapes=small"
-  TRACE_RUNNER
-    iree-e2e-matmul-test
-  TARGET_BACKENDS
-    "llvm-cpu"
-  DRIVERS
-    "local-task"
-  COMPILER_FLAGS
-    "--iree-llvmcpu-enable-microkernels"
-    "--iree-opt-data-tiling"
-  LABELS
-
-  TARGET_CPU_FEATURES_VARIANTS
-    "default"
-    "x86_64:avx2_fma:+avx,+avx2,+fma"
-    "x86_64:avx512_base:+avx,+avx2,+fma,+avx512f,+avx512vl,+avx512cd,+avx512bw,+avx512dq"
-)
-
-iree_generated_trace_runner_test(
-  NAME
-    e2e_matmul_mmt4d_f32_large_ukernel
-  GENERATOR
-    "generate_e2e_matmul_tests.py"
-  GENERATOR_ARGS
-    "--lhs_rhs_type=f32"
-    "--shapes=large"
-  TRACE_RUNNER
-    iree-e2e-matmul-test
-  TARGET_BACKENDS
-    "llvm-cpu"
-  DRIVERS
-    "local-task"
-  COMPILER_FLAGS
-    "--iree-llvmcpu-enable-microkernels"
-    "--iree-opt-data-tiling"
-  LABELS
-    "noriscv"
-    "noasan"
-    "notsan"
-  TARGET_CPU_FEATURES_VARIANTS
-    "default"
-    "x86_64:avx2_fma:+avx,+avx2,+fma"
-    "x86_64:avx512_base:+avx,+avx2,+fma,+avx512f,+avx512vl,+avx512cd,+avx512bw,+avx512dq"
-)
-
-iree_generated_trace_runner_test(
-  NAME
     e2e_matmul_direct_f32_gpu_large_LLVMGPUMatmulSimt
   GENERATOR
     "generate_e2e_matmul_tests.py"
@@ -545,10 +723,8 @@
     iree-e2e-matmul-test
   TARGET_BACKENDS
     "cuda"
-    "llvm-cpu"
   DRIVERS
     "cuda"
-    "local-task"
   COMPILER_FLAGS
     "--iree-flow-split-matmul-reduction=4"
   LABELS
@@ -694,24 +870,6 @@
 
 iree_generated_trace_runner_test(
   NAME
-    e2e_matmul_direct_f32_small_no_padding
-  GENERATOR
-    "generate_e2e_matmul_tests.py"
-  GENERATOR_ARGS
-    "--lhs_rhs_type=f32"
-    "--shapes=small"
-  TRACE_RUNNER
-    iree-e2e-matmul-test
-  TARGET_BACKENDS
-    "llvm-cpu"
-  DRIVERS
-    "local-task"
-  COMPILER_FLAGS
-    "--iree-codegen-enable-vector-padding=false"
-)
-
-iree_generated_trace_runner_test(
-  NAME
     e2e_matmul_direct_f16_gpu_large_rdna3
   GENERATOR
     "generate_e2e_matmul_tests.py"
diff --git a/tests/e2e/matmul/generate_e2e_matmul_tests.py b/tests/e2e/matmul/generate_e2e_matmul_tests.py
index bf0df7f..631fd6a 100644
--- a/tests/e2e/matmul/generate_e2e_matmul_tests.py
+++ b/tests/e2e/matmul/generate_e2e_matmul_tests.py
@@ -22,10 +22,12 @@
 # as this also includes accumulator-specific types like i32.
 @enum.unique
 class MatrixElemTypeId(enum.Enum):
+    NONE = ""
     I8 = "i8"
     I32 = "i32"
     F32 = "f32"
     F16 = "f16"
+    BF16 = "bf16"
 
 
 # Enumerates of the collections of shapes that we can generate tests for.
@@ -613,11 +615,19 @@
     parser.add_argument(
         "--lhs_rhs_type",
         type=str,
-        choices=["i8", "f32", "f16"],
+        choices=["i8", "f32", "f16", "bf16"],
         help="Numeric type of input matrices",
         required=True,
     )
     parser.add_argument(
+        "--acc_type",
+        type=str,
+        choices=["i32", "f32", "f16", "bf16"],
+        help="Numeric type of input matrices",
+        default="",
+        required=False,
+    )
+    parser.add_argument(
         "--shapes",
         type=str,
         choices=[s.value for s in ShapesId],
@@ -632,7 +642,6 @@
         default="",
         required=False,
     )
-
     parser.add_argument(
         "--module_path",
         type=str,
@@ -704,16 +713,18 @@
 # type, so we do that. That is temporary: eventually there will be cases
 # where the same input types are used with different accumulator types, e.g.
 # f16 inputs with both f16 and f32 accumulator.
-def infer_acc_type(lhs_rhs_type: MatrixElemTypeId):
+def infer_acc_type(lhs_rhs_type: MatrixElemTypeId, acc_type: MatrixElemTypeId):
+    if acc_type != MatrixElemTypeId.NONE:
+        return acc_type
     if lhs_rhs_type == MatrixElemTypeId.I8:
         return MatrixElemTypeId.I32
-    else:
-        return lhs_rhs_type
+    return lhs_rhs_type
 
 
 def main(args):
     lhs_rhs_type = MatrixElemTypeId(args.lhs_rhs_type)
-    acc_type = infer_acc_type(lhs_rhs_type)
+    acc_type = MatrixElemTypeId(args.acc_type)
+    acc_type = infer_acc_type(lhs_rhs_type, acc_type)
     shapes_id = ShapesId(args.shapes)
     compilation_info_id = CompilationInfoId(args.compilation_info)
     (function_definitions, traces) = generate(
diff --git a/tools/iree-e2e-matmul-test.c b/tools/iree-e2e-matmul-test.c
index 758ae35..1c0171b 100644
--- a/tools/iree-e2e-matmul-test.c
+++ b/tools/iree-e2e-matmul-test.c
@@ -59,6 +59,8 @@
   IREE_E2E_TEST_VALUE_TYPE_F32 = 6,
   // double.
   IREE_E2E_TEST_VALUE_TYPE_F64 = 7,
+  // bfloat16
+  IREE_E2E_TEST_VALUE_TYPE_BF16 = 8,
 } iree_e2e_test_value_type_t;
 
 // Maximum size, in bytes, of any value type we can represent.
@@ -74,6 +76,7 @@
     int64_t i64;
     float f32;
     uint16_t f16_u16;
+    uint16_t bf16_u16;
     double f64;
     uint8_t value_storage[IREE_E2E_TEST_VALUE_STORAGE_SIZE];  // max size of all
                                                               // value types
@@ -137,6 +140,14 @@
   return result;
 }
 
+static inline iree_e2e_test_value_t iree_e2e_test_value_make_bf16(
+    uint16_t value) {
+  iree_e2e_test_value_t result;
+  result.type = IREE_E2E_TEST_VALUE_TYPE_BF16;
+  result.bf16_u16 = value;
+  return result;
+}
+
 static inline iree_e2e_test_value_t iree_e2e_test_value_make_f32(float value) {
   iree_e2e_test_value_t result;
   result.type = IREE_E2E_TEST_VALUE_TYPE_F32;
@@ -155,6 +166,12 @@
   return value->f16_u16;
 }
 
+// TODO(#5542): check the value type before accessing the union.
+static inline uint16_t iree_e2e_test_value_get_bf16(
+    iree_e2e_test_value_t* value) {
+  return value->bf16_u16;
+}
+
 static inline iree_e2e_test_value_t iree_e2e_test_value_make_f64(double value) {
   iree_e2e_test_value_t result;
   result.type = IREE_E2E_TEST_VALUE_TYPE_F64;
@@ -343,10 +360,13 @@
   switch (element_type) {
     WRITE_INT_ELEMENT_CASE(INT_8, int8_t)
     WRITE_INT_ELEMENT_CASE(INT_32, int32_t)
+    WRITE_INT_ELEMENT_CASE(FLOAT_32, float)
     case IREE_HAL_ELEMENT_TYPE_FLOAT_16:
       *(uint16_t*)dst = iree_math_f32_to_f16((float)value);
       break;
-      WRITE_INT_ELEMENT_CASE(FLOAT_32, float)
+    case IREE_HAL_ELEMENT_TYPE_BFLOAT_16:
+      *(uint16_t*)dst = iree_math_f32_to_bf16((float)value);
+      break;
     default:
       IREE_ASSERT(false, "unhandled element type");
       break;
@@ -382,6 +402,9 @@
   } else if (result_type == IREE_HAL_ELEMENT_TYPE_FLOAT_16) {
     ((uint16_t*)data)[index] = iree_math_f32_to_f16((float)value);
     return;
+  } else if (result_type == IREE_HAL_ELEMENT_TYPE_BFLOAT_16) {
+    ((uint16_t*)data)[index] = iree_math_f32_to_bf16((float)value);
+    return;
   } else if (result_type == IREE_HAL_ELEMENT_TYPE_FLOAT_32) {
     ((float*)data)[index] = value;
     return;
@@ -405,6 +428,8 @@
     return iree_e2e_test_value_make_i32(((int32_t*)data)[index]);
   } else if (result_type == IREE_HAL_ELEMENT_TYPE_FLOAT_16) {
     return iree_e2e_test_value_make_f16(((uint16_t*)data)[index]);
+  } else if (result_type == IREE_HAL_ELEMENT_TYPE_BFLOAT_16) {
+    return iree_e2e_test_value_make_bf16(((uint16_t*)data)[index]);
   } else if (result_type == IREE_HAL_ELEMENT_TYPE_FLOAT_32) {
     return iree_e2e_test_value_make_f32(((float*)data)[index]);
   }
@@ -483,9 +508,9 @@
   static void reference_matmul_##LHSTYPE##_##RHSTYPE##_##RESTYPE##_##ACCTYPE(  \
       iree_hal_dim_t m_size, iree_hal_dim_t k_size, iree_hal_dim_t n_size,     \
       iree_hal_element_type_t lhs_type, iree_hal_element_type_t rhs_type,      \
-      iree_hal_element_type_t acc_type, LHSTYPE* lhs_data, RHSTYPE* rhs_data,  \
-      ACCTYPE* acc_data, RESTYPE* result_data, iree_hal_dim_t m,               \
-      iree_hal_dim_t n) {                                                      \
+      iree_hal_element_type_t acc_type, const LHSTYPE* lhs_data,               \
+      const RHSTYPE* rhs_data, const ACCTYPE* acc_data, RESTYPE* result_data,  \
+      iree_hal_dim_t m, iree_hal_dim_t n) {                                    \
     ACCTYPE acc = acc_data ? acc_data[n + m * n_size] : 0;                     \
     for (iree_hal_dim_t k = 0; k < k_size; ++k) {                              \
       LHSTYPE lhs_value = lhs_data[k + m * k_size];                            \
@@ -505,15 +530,15 @@
 // [i32 <= i8 * i8 + i32]
 IREE_TRACE_REPLAY_REFERENCE_MATMUL(int8_t, int8_t, int32_t, int32_t)
 
-// Reference mamtul for the half_t input, half_t accumlation, and half_t result.
+// Reference mamtul for the f16 input, f16 accumlation, and f16 result.
 // [f16 <= f16 * f16 + f16]
 static void reference_matmul_f16_f16_f16_f16(
     iree_hal_dim_t m_size, iree_hal_dim_t k_size, iree_hal_dim_t n_size,
     iree_hal_element_type_t lhs_type, iree_hal_element_type_t rhs_type,
-    iree_hal_element_type_t acc_type, uint16_t* lhs_data, uint16_t* rhs_data,
-    uint16_t* acc_data, uint16_t* result_data, iree_hal_dim_t m,
-    iree_hal_dim_t n) {
-  float acc = acc_data ? iree_math_f16_to_f32(acc_data[n + m * n_size]) : 0;
+    iree_hal_element_type_t acc_type, const uint16_t* lhs_data,
+    const uint16_t* rhs_data, const uint16_t* acc_data, uint16_t* result_data,
+    iree_hal_dim_t m, iree_hal_dim_t n) {
+  float acc = acc_data ? iree_math_f16_to_f32(acc_data[n + m * n_size]) : 0.f;
   for (iree_hal_dim_t k = 0; k < k_size; ++k) {
     acc = iree_math_round_to_nearest_f16(
         iree_math_round_to_nearest_f16(
@@ -524,6 +549,57 @@
   result_data[n + m * n_size] = iree_math_f32_to_f16(acc);
 }
 
+// Reference mamtul for the f16 input, f32 accumlation, and f32 result.
+// [f32 <= f16 * f16 + f32]
+static void reference_matmul_f16_f16_f32_f32(
+    iree_hal_dim_t m_size, iree_hal_dim_t k_size, iree_hal_dim_t n_size,
+    iree_hal_element_type_t lhs_type, iree_hal_element_type_t rhs_type,
+    iree_hal_element_type_t acc_type, const uint16_t* lhs_data,
+    const uint16_t* rhs_data, const float* acc_data, float* result_data,
+    iree_hal_dim_t m, iree_hal_dim_t n) {
+  float acc = acc_data ? acc_data[n + m * n_size] : 0.f;
+  for (iree_hal_dim_t k = 0; k < k_size; ++k) {
+    acc += iree_math_f16_to_f32(lhs_data[k + m * k_size]) *
+           iree_math_f16_to_f32(rhs_data[n + k * n_size]);
+  }
+  result_data[n + m * n_size] = acc;
+}
+
+// Reference mamtul for the bf16 input, bf16 accumlation, and bf16 result.
+// [bf16 <= bf16 * bf16 + bf16]
+static void reference_matmul_bf16_bf16_bf16_bf16(
+    iree_hal_dim_t m_size, iree_hal_dim_t k_size, iree_hal_dim_t n_size,
+    iree_hal_element_type_t lhs_type, iree_hal_element_type_t rhs_type,
+    iree_hal_element_type_t acc_type, const uint16_t* lhs_data,
+    const uint16_t* rhs_data, const uint16_t* acc_data, uint16_t* result_data,
+    iree_hal_dim_t m, iree_hal_dim_t n) {
+  float acc = acc_data ? iree_math_bf16_to_f32(acc_data[n + m * n_size]) : 0.f;
+  for (iree_hal_dim_t k = 0; k < k_size; ++k) {
+    acc = iree_math_round_to_nearest_bf16(
+        iree_math_round_to_nearest_bf16(
+            (iree_math_bf16_to_f32(lhs_data[k + m * k_size]) *
+             iree_math_bf16_to_f32(rhs_data[n + k * n_size]))) +
+        acc);
+  }
+  result_data[n + m * n_size] = iree_math_f32_to_bf16(acc);
+}
+
+// Reference mamtul for the bf16 input, f32 accumlation, and f32 result.
+// [f32 <= bf16 * bf16 + f32]
+static void reference_matmul_bf16_bf16_f32_f32(
+    iree_hal_dim_t m_size, iree_hal_dim_t k_size, iree_hal_dim_t n_size,
+    iree_hal_element_type_t lhs_type, iree_hal_element_type_t rhs_type,
+    iree_hal_element_type_t acc_type, const uint16_t* lhs_data,
+    const uint16_t* rhs_data, const float* acc_data, float* result_data,
+    iree_hal_dim_t m, iree_hal_dim_t n) {
+  float acc = acc_data ? acc_data[n + m * n_size] : 0.f;
+  for (iree_hal_dim_t k = 0; k < k_size; ++k) {
+    acc += iree_math_bf16_to_f32(lhs_data[k + m * k_size]) *
+           iree_math_bf16_to_f32(rhs_data[n + k * n_size]);
+  }
+  result_data[n + m * n_size] = acc;
+}
+
 // Helper for reference_matmul.
 // Computes one element in the result matrix.
 static void reference_matmul_element(
@@ -535,21 +611,44 @@
       rhs_type == IREE_HAL_ELEMENT_TYPE_FLOAT_32 &&
       acc_type == IREE_HAL_ELEMENT_TYPE_FLOAT_32) {
     reference_matmul_float_float_float_float(
-        m_size, k_size, n_size, lhs_type, rhs_type, acc_type, (float*)lhs_data,
-        (float*)rhs_data, (float*)acc_data, (float*)result_data, m, n);
+        m_size, k_size, n_size, lhs_type, rhs_type, acc_type,
+        (const float*)lhs_data, (const float*)rhs_data, (const float*)acc_data,
+        (float*)result_data, m, n);
   } else if (iree_hal_element_type_is_integer(lhs_type, 8) &&
              iree_hal_element_type_is_integer(rhs_type, 8) &&
              iree_hal_element_type_is_integer(acc_type, 32)) {
     reference_matmul_int8_t_int8_t_int32_t_int32_t(
-        m_size, k_size, n_size, lhs_type, rhs_type, acc_type, (int8_t*)lhs_data,
-        (int8_t*)rhs_data, (int32_t*)acc_data, (int32_t*)result_data, m, n);
+        m_size, k_size, n_size, lhs_type, rhs_type, acc_type,
+        (const int8_t*)lhs_data, (const int8_t*)rhs_data,
+        (const int32_t*)acc_data, (int32_t*)result_data, m, n);
   } else if (lhs_type == IREE_HAL_ELEMENT_TYPE_FLOAT_16 &&
              rhs_type == IREE_HAL_ELEMENT_TYPE_FLOAT_16 &&
              acc_type == IREE_HAL_ELEMENT_TYPE_FLOAT_16) {
-    reference_matmul_f16_f16_f16_f16(m_size, k_size, n_size, lhs_type, rhs_type,
-                                     acc_type, (uint16_t*)lhs_data,
-                                     (uint16_t*)rhs_data, (uint16_t*)acc_data,
-                                     (uint16_t*)result_data, m, n);
+    reference_matmul_f16_f16_f16_f16(
+        m_size, k_size, n_size, lhs_type, rhs_type, acc_type,
+        (const uint16_t*)lhs_data, (const uint16_t*)rhs_data,
+        (const uint16_t*)acc_data, (uint16_t*)result_data, m, n);
+  } else if (lhs_type == IREE_HAL_ELEMENT_TYPE_FLOAT_16 &&
+             rhs_type == IREE_HAL_ELEMENT_TYPE_FLOAT_16 &&
+             acc_type == IREE_HAL_ELEMENT_TYPE_FLOAT_32) {
+    reference_matmul_f16_f16_f32_f32(
+        m_size, k_size, n_size, lhs_type, rhs_type, acc_type,
+        (const uint16_t*)lhs_data, (const uint16_t*)rhs_data,
+        (const float*)acc_data, (float*)result_data, m, n);
+  } else if (lhs_type == IREE_HAL_ELEMENT_TYPE_BFLOAT_16 &&
+             rhs_type == IREE_HAL_ELEMENT_TYPE_BFLOAT_16 &&
+             acc_type == IREE_HAL_ELEMENT_TYPE_BFLOAT_16) {
+    reference_matmul_bf16_bf16_bf16_bf16(
+        m_size, k_size, n_size, lhs_type, rhs_type, acc_type,
+        (const uint16_t*)lhs_data, (const uint16_t*)rhs_data,
+        (const uint16_t*)acc_data, (uint16_t*)result_data, m, n);
+  } else if (lhs_type == IREE_HAL_ELEMENT_TYPE_BFLOAT_16 &&
+             rhs_type == IREE_HAL_ELEMENT_TYPE_BFLOAT_16 &&
+             acc_type == IREE_HAL_ELEMENT_TYPE_FLOAT_32) {
+    reference_matmul_bf16_bf16_f32_f32(
+        m_size, k_size, n_size, lhs_type, rhs_type, acc_type,
+        (const uint16_t*)lhs_data, (const uint16_t*)rhs_data,
+        (const float*)acc_data, (float*)result_data, m, n);
   } else {
     iree_status_abort(
         iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
@@ -682,6 +781,12 @@
       return fabsf(iree_math_f16_to_f32(actual.f16_u16) -
                    iree_math_f16_to_f32(expected.f16_u16)) <
              FLAG_acceptable_fp_delta;
+    case IREE_E2E_TEST_VALUE_TYPE_BF16:
+      if (actual.bf16_u16 == expected.bf16_u16) return true;
+      if (FLAG_require_exact_results) return false;
+      return fabsf(iree_math_bf16_to_f32(actual.f16_u16) -
+                   iree_math_bf16_to_f32(expected.f16_u16)) <
+             FLAG_acceptable_fp_delta;
     case IREE_E2E_TEST_VALUE_TYPE_F32:
       if (actual.f32 == expected.f32) return true;
       if (FLAG_require_exact_results) return false;
@@ -843,12 +948,7 @@
   // We have a lot more freedom to pick k_start, k_end, since these parameters
   // only affect which regions of the input lhs and rhs matrices are printed.
   // If we were only testing random lhs and rhs, we would just pick
-  // k_start = 0 and any reasonable k_end value. Since we are often using
-  // identity matrices for lhs and rhs, and we expect the majority of
-  // test failures to occur with such identity matrices, we try to pick
-  // k_start and k_end so that nontrivial regions of identity matrices will be
-  // printed. That means that we try to have [k_start, k_end) intervals
-  // overlap [m_start, m_end) and [n_start, n_end).
+  // k_start = 0 and any reasonable k_end value.
   int k_start = iree_max(0, iree_min(m_start, n_start));
   int k_end = iree_min(k_size, iree_max(m_end, n_end));
   // [k_start, k_end) could be arbitrarily long at this point. Constrain it a
@@ -978,29 +1078,6 @@
  *
  *****************************************************************************/
 
-static iree_status_t make_identity_matrix_callback(
-    iree_hal_buffer_mapping_t* mapping, void* user_data) {
-  iree_hal_buffer_view_t* src = (iree_hal_buffer_view_t*)user_data;
-  iree_hal_element_type_t elem_type = iree_hal_buffer_view_element_type(src);
-  iree_host_size_t elem_byte_count =
-      iree_hal_element_dense_byte_count(elem_type);
-  iree_hal_dim_t dims[2] = {0};
-  IREE_RETURN_IF_ERROR(get_matrix_shape(src, dims));
-  int rows = dims[0];
-  int cols = dims[1];
-  // Write 1 to matrix elements on the main diagonal.
-  int diagonal_size = iree_min(rows, cols);
-  memset(mapping->contents.data, 0, mapping->contents.data_length);
-  intptr_t diagonal_elem_addr = (intptr_t)mapping->contents.data;
-  for (int i = 0; i < diagonal_size; ++i) {
-    write_int_element(elem_type, 1, (void*)diagonal_elem_addr);
-    // Due to the row-major storage, the diagonal entries are every
-    // (cols + 1)-th buffer elements.
-    diagonal_elem_addr += elem_byte_count * (cols + 1);
-  }
-  return iree_ok_status();
-}
-
 // Deep-copies device-local list of buffer_views |src| into |dst|.
 static iree_status_t copy_device_buffer_views_to_device(
     iree_hal_device_t* device, iree_hal_allocator_t* hal_allocator,