e2e matmul test improvements (#13657)

* Test the non-accumulating matmul case (fill-zero accumulator). See
discussion in #13582 with @ThomasRaoux .
* Drop the `MIXED` dynamicity: it was already unused, having been
dropped out of actual test variants iterated over. Drop helpers that
were only used for it, such as `pseudorandom_bool`.
* In case of numerical error, drop the logic re-running matmul tests on
simpler matrices. Although potentially useful to help pinpoint the
numerical issue, it was confusing when debugging (as discussed once with
@manishucsd ), and it made test runner code substantially more complex.
diff --git a/compiler/src/iree/compiler/Codegen/VMVX/LowerLinalgMicrokernels.cpp b/compiler/src/iree/compiler/Codegen/VMVX/LowerLinalgMicrokernels.cpp
index 11585c4..bb005e4 100644
--- a/compiler/src/iree/compiler/Codegen/VMVX/LowerLinalgMicrokernels.cpp
+++ b/compiler/src/iree/compiler/Codegen/VMVX/LowerLinalgMicrokernels.cpp
@@ -884,6 +884,12 @@
   }
 
   LogicalResult handle2DTile(OpInfo &info, PatternRewriter &rewriter) const {
+    Type scalarType = info.scalar.getType();
+    if (!scalarType.isIntOrFloat() ||
+        scalarType.getIntOrFloatBitWidth() != 32) {
+      return rewriter.notifyMatchFailure(info.op,
+                                         "handling only 32-bit scalar types");
+    }
     auto loc = info.op.getLoc();
     StridedBufferDescriptor &outDesc = info.outAnal.getDesc(rewriter);
     Value m = outDesc.sizes[0];
diff --git a/tests/e2e/matmul/generate_e2e_matmul_tests.py b/tests/e2e/matmul/generate_e2e_matmul_tests.py
index c409e8a..e8b6100 100644
--- a/tests/e2e/matmul/generate_e2e_matmul_tests.py
+++ b/tests/e2e/matmul/generate_e2e_matmul_tests.py
@@ -66,11 +66,15 @@
 
 # Describes the shape of a matrix multiplication in the usual convention:
 # the LHS is {m}x{k}, the RHS is {k}x{n}, the accumulator/result is {m}x{n}.
+# The extra `accumulate` boolean tells whether the matmul is accumulating into
+# an existing accumulator (C += A * B) or just overwriting the result
+# (C = A * B).
 @dataclasses.dataclass
 class TestShape:
   m: int
   k: int
   n: int
+  accumulate: bool
 
 
 # Describes how to construct compilation info for the testcase.
@@ -105,32 +109,36 @@
   if shapes_id == ShapesId.SMALL:
     return [
         # square matrices. Start by the simplest case of 1x1x1.
-        TestShape(m=1, k=1, n=1),
+        TestShape(m=1, k=1, n=1, accumulate=True),
+        TestShape(m=1, k=1, n=1, accumulate=False),
         # test 9x9x9 because as many kernel M0/K0/N0 dims are equal to 8,
         # this will often be the smallest value that exercises something above
         # the kernel's size.
-        TestShape(m=9, k=9, n=9),
+        TestShape(m=9, k=9, n=9, accumulate=True),
         # rectangular matrices.
         # >= 2x differences between M/N/K dims may exercise tiling corner cases
         # not exercised by nearly-square matrices.
-        TestShape(m=6, k=13, n=3),
-        TestShape(m=15, k=37, n=7),
-        TestShape(m=81, k=19, n=41),
+        TestShape(m=6, k=13, n=3, accumulate=True),
+        TestShape(m=15, k=37, n=7, accumulate=False),
+        TestShape(m=81, k=19, n=41, accumulate=True),
         # shapes involving vectors (i.e. most rectangular cases)
         # This is particularly relevant because we have dedicated kernels for
         # the matrix*vector / vector*matrix case.
-        TestShape(m=1, k=10, n=10),  # vector*matrix
-        TestShape(m=10, k=1, n=10),  # outer-product
-        TestShape(m=10, k=10, n=1),  # matrix*vector
+        TestShape(m=1, k=10, n=10, accumulate=True),  # vector*matrix
+        TestShape(m=1, k=10, n=10, accumulate=False),  # vector*matrix
+        TestShape(m=10, k=1, n=10, accumulate=True),  # outer-product
+        TestShape(m=10, k=10, n=1, accumulate=True),  # matrix*vector
+        TestShape(m=10, k=10, n=1, accumulate=False),  # matrix*vector
     ]
   if shapes_id == ShapesId.LARGE:
     return [
         # some random large sizes
-        TestShape(m=123, k=456, n=789),
-        TestShape(m=654, k=321, n=234),
+        TestShape(m=123, k=456, n=789, accumulate=True),
+        TestShape(m=654, k=321, n=234, accumulate=False),
         # shapes involving vectors (i.e. most rectangular cases)
-        TestShape(m=1, k=1000, n=1000),  # large vector*matrix
-        TestShape(m=1000, k=1000, n=1),  # large matrix*vector
+        TestShape(m=1, k=1000, n=1000, accumulate=True),  # large vector*matrix
+        TestShape(m=1000, k=1000, n=1, accumulate=True),  # large matrix*vector
+        TestShape(m=1000, k=1000, n=1, accumulate=False),  # large matrix*vector
         # Be conservative in adding larger shapes. They can result in
         # high latency tests. If you have to, consider splitting them
         # out in a way that constrains the latency impact, e.g. by
@@ -138,7 +146,11 @@
         # (see get_test_generators).
     ]
   if shapes_id == ShapesId.GPU_LARGE:
-    return [TestShape(m=256, k=128, n=512)]
+    return [
+        TestShape(m=256, k=128, n=512, accumulate=True),
+        TestShape(m=256, k=128, n=512, accumulate=False),
+    ]
+
   raise ValueError(shapes_id)
 
 
@@ -252,16 +264,6 @@
 local_pseudorandom_state = 1
 
 
-# Returns a pseudorandom boolean
-def pseudorandom_bool():
-  global local_pseudorandom_state
-  # Same as C++ std::minstd_rand.
-  # Using a local pseudorandom generator implementation ensures that it's
-  # completely reproducible, across runs and across machines.
-  local_pseudorandom_state = (local_pseudorandom_state * 48271) % 2147483647
-  return local_pseudorandom_state > 1073741824
-
-
 # A shape dimension value, i.e. a size value that could appear in a MLIR type
 # such as 'tensor<?x4xf32>'. None means a dynamic size, similar to '?' in MLIR.
 @dataclasses.dataclass
@@ -276,8 +278,6 @@
     return DimSize(None)
   elif dynamicity == Dynamicity.STATIC:
     return DimSize(x)
-  elif dynamicity == Dynamicity.MIXED:
-    return DimSize(x if pseudorandom_bool() else None)
   else:
     raise ValueError(dynamicity)
 
@@ -320,28 +320,6 @@
       acc_rows=shape_dim(shape.m, dynamicity),
       acc_cols=shape_dim(shape.n, dynamicity),
   )
-  # In the mixed-shapes case, we have just randomly picked each of the above 6
-  # values independently, making it likely that we got discrepancies where some
-  # of the M, K, N dimensions of the problem appear as a static dim in some
-  # matrix and as a dynamic dim in another matrix, e.g. for the M dimension we
-  # might have lhs_rows=dynamic, acc_rows=3. We should be testing both such
-  # 'wild' mixed-shapes cases, and more 'tame' mixed-shapes cases where each of
-  # the M, K, N dimensions is consistently either static or dynamic in both
-  # matrices (among lhs, rhs, acc) where it appears. If we don't do anything
-  # about it, as there is a 1/2 chance of discrepancy for each of the M, K, N
-  # dimensions, 7/8 of the mixed-shapes testcases will be 'wild'. Given our
-  # limited number of overall mixed-shapes testcases, this risks not testing
-  # the 'tame' case at all.
-  #
-  # At the moment there is an additional practical reason to care about this
-  # here: the matmul-to-mmt4d transformation currently bails on most 'wild'
-  # cases. So we care even more to test 'tame' cases so that mmt4d is tested
-  # on some mixed-shapes cases.
-  if pseudorandom_bool():
-    # Ensure that we are generating a 'tame' case.
-    shapes.acc_rows = shapes.lhs_rows
-    shapes.acc_cols = shapes.rhs_cols
-    shapes.rhs_rows = shapes.lhs_cols
   return shapes
 
 
@@ -351,6 +329,7 @@
     lhs_rhs_type: MatrixElemTypeId,
     acc_type: MatrixElemTypeId,
     shapes: TestInputMatricesShapes,
+    accumulate: bool,
     compilation_info: typing.Optional[CompilationInfo] = None):
   input_t = lhs_rhs_type.value
   acc_t = acc_type.value
@@ -369,7 +348,8 @@
     ]) + "_" + "_".join([str(a) for a in compilation_info.workgroup_size])
     info = f"_for_{compilation_info.dispatch_lowering_pass_pipeline}_{tile_workgroup_key}"
 
-  return f"matmul_{lhs_m}x{lhs_k}x{input_t}_times_{rhs_k}x{rhs_n}x{input_t}_into_{acc_m}x{acc_n}x{acc_t}{info}"
+  matmul_kind = "matmul_accumulate" if accumulate else "matmul"
+  return f"{matmul_kind}_{lhs_m}x{lhs_k}x{input_t}_times_{rhs_k}x{rhs_n}x{input_t}_into_{acc_m}x{acc_n}x{acc_t}{info}"
 
 
 # Represents a generated test function.
@@ -390,7 +370,7 @@
     compilation_info: typing.Optional[CompilationInfo] = None):
   shapes = generate_shapes(shape, dynamicity)
   func_name = generate_function_name(lhs_rhs_type, acc_type, shapes,
-                                     compilation_info)
+                                     shape.accumulate, compilation_info)
   lhs_m = int_or_question_mark(shapes.lhs_rows)
   lhs_k = int_or_question_mark(shapes.lhs_cols)
   rhs_k = int_or_question_mark(shapes.rhs_rows)
@@ -422,11 +402,37 @@
     func_definition = func_definition + compilation_info_string
     generate_function.compilation_index += 1
 
-  func_definition = func_definition + (
-      f"func.func @{func_name}(%lhs: {lhs_tensor_type}, %rhs: {rhs_tensor_type}, %acc: {acc_tensor_type}) -> {acc_tensor_type} {{\n"
-      f"  %result = linalg.matmul {compilation_info_attr}ins(%lhs, %rhs: {lhs_tensor_type}, {rhs_tensor_type}) outs(%acc: {acc_tensor_type}) -> {acc_tensor_type}\n"
-      f"  return %result: {acc_tensor_type}\n"
-      f"}}\n")
+  if shape.accumulate:
+    func_definition = func_definition + (
+        f"func.func @{func_name}(%lhs: {lhs_tensor_type}, %rhs: {rhs_tensor_type}, %acc: {acc_tensor_type}) -> {acc_tensor_type} {{\n"
+        f"  %result = linalg.matmul {compilation_info_attr}ins(%lhs, %rhs: {lhs_tensor_type}, {rhs_tensor_type}) outs(%acc: {acc_tensor_type}) -> {acc_tensor_type}\n"
+        f"  return %result: {acc_tensor_type}\n"
+        f"}}\n")
+  else:
+    literal_zero_for_acc_type = "0.0" if "f" in acc_type.value else "0"
+    acc_dyn_sizes = []
+    if acc_m == "?":
+      func_definition = func_definition + (
+          f"func.func @{func_name}(%lhs: {lhs_tensor_type}, %rhs: {rhs_tensor_type}) -> {acc_tensor_type} {{\n"
+          f"  %c0 = arith.constant 0 : index\n"
+          f"  %c1 = arith.constant 1 : index\n"
+          f"  %acc_dim0 = tensor.dim %lhs, %c0 : {lhs_tensor_type}\n"
+          f"  %acc_dim1 = tensor.dim %rhs, %c1 : {rhs_tensor_type}\n"
+          f"  %init_acc = tensor.empty(%acc_dim0, %acc_dim1) : {acc_tensor_type}\n"
+          f"  %c0_acc_type = arith.constant {literal_zero_for_acc_type}: {acc_type.value}\n"
+          f"  %acc = linalg.fill ins(%c0_acc_type : {acc_type.value}) outs(%init_acc : {acc_tensor_type}) -> {acc_tensor_type}\n"
+          f"  %result = linalg.matmul {compilation_info_attr}ins(%lhs, %rhs: {lhs_tensor_type}, {rhs_tensor_type}) outs(%acc: {acc_tensor_type}) -> {acc_tensor_type}\n"
+          f"  return %result: {acc_tensor_type}\n"
+          f"}}\n")
+    else:
+      func_definition = func_definition + (
+          f"func.func @{func_name}(%lhs: {lhs_tensor_type}, %rhs: {rhs_tensor_type}) -> {acc_tensor_type} {{\n"
+          f"  %init_acc = tensor.empty() : {acc_tensor_type}\n"
+          f"  %c0_acc_type = arith.constant {literal_zero_for_acc_type}: {acc_type.value}\n"
+          f"  %acc = linalg.fill ins(%c0_acc_type : {acc_type.value}) outs(%init_acc : {acc_tensor_type}) -> {acc_tensor_type}\n"
+          f"  %result = linalg.matmul {compilation_info_attr}ins(%lhs, %rhs: {lhs_tensor_type}, {rhs_tensor_type}) outs(%acc: {acc_tensor_type}) -> {acc_tensor_type}\n"
+          f"  return %result: {acc_tensor_type}\n"
+          f"}}\n")
   return MLIRFunction(
       name=func_name,
       definition=func_definition,
@@ -474,23 +480,24 @@
 # as a dictionary to be passed to yaml.dump.
 def generate_trace(func_name: str, lhs_rhs_type: MatrixElemTypeId,
                    acc_type: MatrixElemTypeId, shape: TestShape):
-  lhs_arg = generate_trace_matrix_arg([shape.m, shape.k], lhs_rhs_type,
-                                      MatrixGenerator.RANDOM)
-  rhs_arg = generate_trace_matrix_arg([shape.k, shape.n], lhs_rhs_type,
-                                      MatrixGenerator.RANDOM)
-  acc_arg = generate_trace_matrix_arg([shape.m, shape.n], acc_type,
-                                      MatrixGenerator.RANDOM)
-  result_arg = generate_trace_matrix_arg([shape.m, shape.n], acc_type,
-                                         MatrixGenerator.ZERO)
+  args = [
+      generate_trace_matrix_arg([shape.m, shape.k], lhs_rhs_type,
+                                MatrixGenerator.RANDOM),
+      generate_trace_matrix_arg([shape.k, shape.n], lhs_rhs_type,
+                                MatrixGenerator.RANDOM),
+  ]
+  if shape.accumulate:
+    args.append(
+        generate_trace_matrix_arg([shape.m, shape.n], acc_type,
+                                  MatrixGenerator.RANDOM))
+
+  result = generate_trace_matrix_arg([shape.m, shape.n], acc_type,
+                                     MatrixGenerator.ZERO)
   return {
       "type": "call",
       "function": "module." + func_name,
-      "args": [
-          lhs_arg,
-          rhs_arg,
-          acc_arg,
-      ],
-      "results": [result_arg,],
+      "args": args,
+      "results": [result],
   }
 
 
diff --git a/tools/iree-e2e-matmul-test.c b/tools/iree-e2e-matmul-test.c
index 5b9bb1f..ed941cb 100644
--- a/tools/iree-e2e-matmul-test.c
+++ b/tools/iree-e2e-matmul-test.c
@@ -431,21 +431,35 @@
   iree_hal_dim_t result_dims[2] = {0};
   IREE_RETURN_IF_ERROR(get_matrix_shape(lhs, lhs_dims));
   IREE_RETURN_IF_ERROR(get_matrix_shape(rhs, rhs_dims));
-  IREE_RETURN_IF_ERROR(get_matrix_shape(acc, acc_dims));
   IREE_RETURN_IF_ERROR(get_matrix_shape(result, result_dims));
   *m_size = lhs_dims[0];
   *k_size = lhs_dims[1];
   *n_size = rhs_dims[1];
-  if (!(lhs_dims[0] == *m_size && lhs_dims[1] == *k_size &&
-        rhs_dims[0] == *k_size && rhs_dims[1] == *n_size &&
-        acc_dims[0] == *m_size && acc_dims[1] == *n_size &&
-        result_dims[0] == *m_size && result_dims[1] == *n_size)) {
-    return iree_make_status(
-        IREE_STATUS_INVALID_ARGUMENT,
-        "mismatched matrix shapes in matmul: %" PRIdim "x%" PRIdim " * %" PRIdim
-        "x%" PRIdim " + %" PRIdim "x%" PRIdim " -> %" PRIdim "x%" PRIdim,
-        lhs_dims[0], lhs_dims[1], rhs_dims[0], rhs_dims[1], acc_dims[0],
-        acc_dims[1], result_dims[0], result_dims[1]);
+  if (acc) {
+    IREE_RETURN_IF_ERROR(get_matrix_shape(acc, acc_dims));
+    if (!(lhs_dims[0] == *m_size && lhs_dims[1] == *k_size &&
+          rhs_dims[0] == *k_size && rhs_dims[1] == *n_size &&
+          acc_dims[0] == *m_size && acc_dims[1] == *n_size &&
+          result_dims[0] == *m_size && result_dims[1] == *n_size)) {
+      return iree_make_status(
+          IREE_STATUS_INVALID_ARGUMENT,
+          "mismatched matrix shapes in matmul: %" PRIdim "x%" PRIdim
+          " * %" PRIdim "x%" PRIdim " + %" PRIdim "x%" PRIdim " -> %" PRIdim
+          "x%" PRIdim,
+          lhs_dims[0], lhs_dims[1], rhs_dims[0], rhs_dims[1], acc_dims[0],
+          acc_dims[1], result_dims[0], result_dims[1]);
+    }
+  } else {
+    if (!(lhs_dims[0] == *m_size && lhs_dims[1] == *k_size &&
+          rhs_dims[0] == *k_size && rhs_dims[1] == *n_size &&
+          result_dims[0] == *m_size && result_dims[1] == *n_size)) {
+      return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                              "mismatched matrix shapes in matmul: %" PRIdim
+                              "x%" PRIdim " * %" PRIdim "x%" PRIdim
+                              " -> %" PRIdim "x%" PRIdim,
+                              lhs_dims[0], lhs_dims[1], rhs_dims[0],
+                              rhs_dims[1], result_dims[0], result_dims[1]);
+    }
   }
   return iree_ok_status();
 }
@@ -457,7 +471,7 @@
       iree_hal_element_type_t acc_type, LHSTYPE* lhs_data, RHSTYPE* rhs_data,  \
       ACCTYPE* acc_data, RESTYPE* result_data, iree_hal_dim_t m,               \
       iree_hal_dim_t n) {                                                      \
-    ACCTYPE acc = acc_data[n + m * n_size];                                    \
+    ACCTYPE acc = acc_data ? acc_data[n + m * n_size] : 0;                     \
     for (iree_hal_dim_t k = 0; k < k_size; ++k) {                              \
       LHSTYPE lhs_value = lhs_data[k + m * k_size];                            \
       RHSTYPE rhs_value = rhs_data[n + k * n_size];                            \
@@ -484,7 +498,7 @@
     iree_hal_element_type_t acc_type, uint16_t* lhs_data, uint16_t* rhs_data,
     uint16_t* acc_data, uint16_t* result_data, iree_hal_dim_t m,
     iree_hal_dim_t n) {
-  float acc = iree_math_f16_to_f32(acc_data[n + m * n_size]);
+  float acc = acc_data ? iree_math_f16_to_f32(acc_data[n + m * n_size]) : 0;
   for (iree_hal_dim_t k = 0; k < k_size; ++k) {
     acc = iree_math_round_to_nearest_f16(
         iree_math_round_to_nearest_f16(
@@ -536,8 +550,9 @@
   iree_hal_buffer_view_t* acc = NULL;
   IREE_RETURN_IF_ERROR(get_item_as_buffer_view(input_list, 0, &lhs));
   IREE_RETURN_IF_ERROR(get_item_as_buffer_view(input_list, 1, &rhs));
-  IREE_RETURN_IF_ERROR(get_item_as_buffer_view(input_list, 2, &acc));
-
+  if (iree_vm_list_size(input_list) == 3) {
+    IREE_RETURN_IF_ERROR(get_item_as_buffer_view(input_list, 2, &acc));
+  }
   iree_hal_dim_t m_size, k_size, n_size;
   IREE_RETURN_IF_ERROR(
       get_matmul_sizes(lhs, rhs, acc, result, &m_size, &k_size, &n_size));
@@ -549,24 +564,29 @@
       lhs, IREE_HAL_MEMORY_ACCESS_READ, &lhs_mapping));
   IREE_RETURN_IF_ERROR(map_host_local_row_major_data(
       rhs, IREE_HAL_MEMORY_ACCESS_READ, &rhs_mapping));
-  IREE_RETURN_IF_ERROR(map_host_local_row_major_data(
-      acc, IREE_HAL_MEMORY_ACCESS_READ, &acc_mapping));
+  if (acc) {
+    IREE_RETURN_IF_ERROR(map_host_local_row_major_data(
+        acc, IREE_HAL_MEMORY_ACCESS_READ, &acc_mapping));
+  }
   IREE_RETURN_IF_ERROR(map_host_local_row_major_data(
       result, IREE_HAL_MEMORY_ACCESS_WRITE, &result_mapping));
   iree_hal_element_type_t lhs_type = iree_hal_buffer_view_element_type(lhs);
   iree_hal_element_type_t rhs_type = iree_hal_buffer_view_element_type(rhs);
-  iree_hal_element_type_t acc_type = iree_hal_buffer_view_element_type(acc);
+  iree_hal_element_type_t acc_type = iree_hal_buffer_view_element_type(result);
   for (iree_hal_dim_t m = 0; m < m_size; ++m) {
     for (iree_hal_dim_t n = 0; n < n_size; ++n) {
-      reference_matmul_element(
-          m_size, k_size, n_size, lhs_type, rhs_type, acc_type,
-          lhs_mapping.contents.data, rhs_mapping.contents.data,
-          acc_mapping.contents.data, result_mapping.contents.data, m, n);
+      reference_matmul_element(m_size, k_size, n_size, lhs_type, rhs_type,
+                               acc_type, lhs_mapping.contents.data,
+                               rhs_mapping.contents.data,
+                               acc ? acc_mapping.contents.data : NULL,
+                               result_mapping.contents.data, m, n);
     }
   }
   IREE_RETURN_IF_ERROR(iree_hal_buffer_unmap_range(&lhs_mapping));
   IREE_RETURN_IF_ERROR(iree_hal_buffer_unmap_range(&rhs_mapping));
-  IREE_RETURN_IF_ERROR(iree_hal_buffer_unmap_range(&acc_mapping));
+  if (acc) {
+    IREE_RETURN_IF_ERROR(iree_hal_buffer_unmap_range(&acc_mapping));
+  }
   IREE_RETURN_IF_ERROR(iree_hal_buffer_unmap_range(&result_mapping));
   return iree_ok_status();
 }
@@ -813,9 +833,11 @@
   print_matrix(file, "right-hand side", PRECISION_LOW, k_start, k_end, n_start,
                n_end, rhs, NULL, NULL);
   fprintf(file, "\n");
-  print_matrix(file, "input accumulator", PRECISION_LOW, m_start, m_end,
-               n_start, n_end, acc, NULL, NULL);
-  fprintf(file, "\n");
+  if (acc) {
+    print_matrix(file, "input accumulator", PRECISION_LOW, m_start, m_end,
+                 n_start, n_end, acc, NULL, NULL);
+    fprintf(file, "\n");
+  }
   print_matrix(file, "expected result", PRECISION_LOW, m_start, m_end, n_start,
                n_end, expected_result, actual_result, emoji(true));
   fprintf(file, "\n");
@@ -874,7 +896,9 @@
   iree_hal_buffer_view_t* acc = NULL;
   IREE_RETURN_IF_ERROR(get_item_as_buffer_view(input_list, 0, &lhs));
   IREE_RETURN_IF_ERROR(get_item_as_buffer_view(input_list, 1, &rhs));
-  IREE_RETURN_IF_ERROR(get_item_as_buffer_view(input_list, 2, &acc));
+  if (iree_vm_list_size(input_list) == 3) {
+    IREE_RETURN_IF_ERROR(get_item_as_buffer_view(input_list, 2, &acc));
+  }
 
   iree_hal_dim_t m_size, k_size, n_size;
   IREE_RETURN_IF_ERROR(get_matmul_sizes(lhs, rhs, acc, actual_result, &m_size,
@@ -895,24 +919,10 @@
  * helpers for it.
  *
  * |replay_event_call_matmul| calls |do_matmul_and_check_results| to actually
- * perform a matmul. In normal cases, each |replay_event_call_matmul| performs
- * one call to |do_matmul_and_check_results|, but when that generates an error,
- * it will make additional calls to |do_matmul_and_check_results| to evaluate
- * variants of the failed testcase to generate a more helpful log.
- *
- * The |matrix_mask_t| stuff is only used to generate these variants of failed
- * testcases.
+ * perform a matmul.
  *
  *****************************************************************************/
 
-// Enumerates ways that we may mask matrices in list of matrix inputs to matmul
-// testcases.
-typedef enum {
-  MATRIX_MASK_NONE,      // no-op: leave the existing matrix unchanged.
-  MATRIX_MASK_ZERO,      // overwrite the matrix with zeros.
-  MATRIX_MASK_IDENTITY,  // overwrite with (general rectangular) identity matrix
-} matrix_mask_t;
-
 static iree_status_t make_identity_matrix_callback(
     iree_hal_buffer_mapping_t* mapping, void* user_data) {
   iree_hal_buffer_view_t* src = (iree_hal_buffer_view_t*)user_data;
@@ -936,53 +946,10 @@
   return iree_ok_status();
 }
 
-// Allocates device-local |dst| and initializes it as an identity-matrix shaped
-// like |src|.
-static iree_status_t make_device_identity_matrix_like(
+// Deep-copies device-local list of buffer_views |src| into |dst|.
+static iree_status_t copy_device_buffer_views_to_device(
     iree_hal_device_t* device, iree_hal_allocator_t* hal_allocator,
-    iree_hal_buffer_view_t* src, iree_hal_buffer_view_t** dst) {
-  return iree_hal_buffer_view_generate_buffer(
-      hal_allocator, iree_hal_buffer_view_shape_rank(src),
-      iree_hal_buffer_view_shape_dims(src),
-      iree_hal_buffer_view_element_type(src),
-      iree_hal_buffer_view_encoding_type(src),
-      (iree_hal_buffer_params_t){
-          .type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL,
-          .usage = IREE_HAL_BUFFER_USAGE_DEFAULT,
-      },
-      make_identity_matrix_callback, src, dst);
-}
-
-// Allocates device-local |dst| shaped like |src|, and:
-// - If |mask| is MATRIX_MASK_NONE, copies device-local |src| into |dst|.
-// - If |mask| is MATRIX_MASK_ZERO, leaves |dst| zero-filled.
-// - If |mask| is MATRIX_MASK_IDENTITY, makes |dst| an identity-matrix
-static iree_status_t mask_and_copy_device_buffer_view_to_device(
-    iree_hal_device_t* device, iree_hal_allocator_t* hal_allocator,
-    iree_hal_buffer_view_t* src, matrix_mask_t mask,
-    iree_hal_buffer_view_t** dst) {
-  if (mask == MATRIX_MASK_NONE) {
-    IREE_RETURN_IF_ERROR(
-        copy_device_buffer_view_to_device(device, hal_allocator, src, dst));
-  } else if (mask == MATRIX_MASK_ZERO) {
-    IREE_RETURN_IF_ERROR(allocate_device_buffer_view_like(
-        hal_allocator, src, iree_const_byte_span_empty(), dst));
-  } else if (mask == MATRIX_MASK_IDENTITY) {
-    IREE_RETURN_IF_ERROR(
-        make_device_identity_matrix_like(device, hal_allocator, src, dst));
-  } else {
-    iree_status_abort(iree_make_status(IREE_STATUS_INTERNAL, "bad mask enum"));
-  }
-  return iree_ok_status();
-}
-
-// Deep-copies device-local list of buffer_views |src| into |dst|, applying
-// mask[i] to the i-th list element as in
-// |mask_and_copy_device_buffer_view_to_device|.
-// Requirement: |mask| must point to an array of the same length as |src|.
-static iree_status_t mask_and_copy_device_buffer_views_to_device(
-    iree_hal_device_t* device, iree_hal_allocator_t* hal_allocator,
-    iree_vm_list_t* src_list, matrix_mask_t* mask, iree_vm_list_t** dst_list) {
+    iree_vm_list_t* src_list, iree_vm_list_t** dst_list) {
   iree_vm_type_def_t elem_type = iree_vm_list_element_type(src_list);
   iree_host_size_t size = iree_vm_list_size(src_list);
   iree_allocator_t allocator = iree_hal_allocator_host_allocator(hal_allocator);
@@ -993,8 +960,8 @@
     iree_hal_buffer_view_t* src = NULL;
     IREE_RETURN_IF_ERROR(get_item_as_buffer_view(src_list, i, &src));
     iree_hal_buffer_view_t* dst = NULL;
-    IREE_RETURN_IF_ERROR(mask_and_copy_device_buffer_view_to_device(
-        device, hal_allocator, src, mask[i], &dst));
+    IREE_RETURN_IF_ERROR(
+        copy_device_buffer_view_to_device(device, hal_allocator, src, &dst));
     iree_vm_ref_t dst_ref = {0};
     IREE_RETURN_IF_ERROR(
         iree_vm_ref_wrap_assign(dst, iree_hal_buffer_view_type(), &dst_ref));
@@ -1004,16 +971,13 @@
 }
 
 // Performs one matmul test, on the device-local input matrices given in
-// |original_device_inputs|, applying the masks given in |mask| as in
-// |mask_and_copy_device_buffer_view_to_device|.
-// Both |input_list| and |mask| should have length 3. The 3 input matrices are
-// LHS, RHS, Accumulator, in that order.
+// |original_device_inputs|.
 //
 // The contents of |original_device_inputs| are preserved, even if the
 // |function| would overwrite input-output arguments (e.g. the accumulator).
 static iree_status_t do_matmul_and_check_results(
     FILE* file, iree_trace_replay_t* replay, iree_vm_function_t function,
-    matrix_mask_t* mask, iree_vm_list_t* original_device_inputs) {
+    iree_vm_list_t* original_device_inputs) {
   iree_hal_allocator_t* device_allocator =
       iree_hal_device_allocator(replay->device);
 
@@ -1023,8 +987,8 @@
   // linalg.matmul. We need to preserve the original test inputs to perform
   // reruns on variants in the failure case (see |replay_event_call_matmul|).
   iree_vm_list_t* device_inputs = NULL;
-  IREE_CHECK_OK(mask_and_copy_device_buffer_views_to_device(
-      replay->device, device_allocator, original_device_inputs, mask,
+  IREE_CHECK_OK(copy_device_buffer_views_to_device(
+      replay->device, device_allocator, original_device_inputs,
       &device_inputs));
 
   // Perform a deep copy of the device-local inputs into host-local buffers.
@@ -1074,19 +1038,6 @@
   return status;
 }
 
-const char* matrix_form(matrix_mask_t mask) {
-  switch (mask) {
-    case MATRIX_MASK_NONE:
-      return "GENERAL";
-    case MATRIX_MASK_ZERO:
-      return "ZERO";
-    case MATRIX_MASK_IDENTITY:
-      return "IDENTITY";
-  }
-  assert(false);
-  return NULL;
-}
-
 // Prints to |file| a message about the matmul shape. Useful as testcases
 // otherwise only print the function name, and in the dynamic-shapes cases, that
 // doesn't tell the actual shape.
@@ -1106,8 +1057,6 @@
 }
 
 // Special handler for function calls in a e2e matmul test trace.
-// Assumes that all calls are to functions that take 3 inputs (lhs, rhs, acc)
-// and return the result of a matmul (lhs*rhs+acc).
 static iree_status_t replay_event_call_matmul(iree_trace_replay_t* replay,
                                               yaml_document_t* document,
                                               yaml_node_t* event_node) {
@@ -1126,69 +1075,8 @@
 
   IREE_CHECK_OK(print_matmul_shape(stderr, device_inputs));
 
-  // Perform the matmul test. So far we are using pseudorandom matrices (as
-  // specified in the YAML trace and interpreted above in
-  // |iree_trace_replay_event_call_prepare|). So this is a test on general
-  // random matrices: great for test coverage (if this succeeds, any variant on
-  // more special matrices would also succeed) but bad for debugging (if this
-  // fails, having to debug that would involve staring at arrays of random
-  // numbers). So for now we pass NULL as the |file| param, keeping errors
-  // silent for now.
-  matrix_mask_t none_masks[3] = {MATRIX_MASK_NONE, MATRIX_MASK_NONE,
-                                 MATRIX_MASK_NONE};
-  iree_status_t status = do_matmul_and_check_results(NULL, replay, function,
-                                                     none_masks, device_inputs);
-  if (!iree_status_is_ok(status)) {
-    // The matmul test failed. So whatever we do now is only for the sake of
-    // generating the most undertandable possible error log. We are going to
-    // retry the matmul but on more special, easy-to-understand matrices,
-    // gradually increasing generality, and we will abort and log details on
-    // the first error that we encounter.
-    iree_string_builder_t sb;
-    iree_string_builder_initialize(replay->host_allocator, &sb);
-    matrix_mask_t all_debug_masks[6][3] = {
-        // Try Zero * Zero + Zero. Expected result: Zero.
-        {MATRIX_MASK_ZERO, MATRIX_MASK_ZERO, MATRIX_MASK_ZERO},
-        // Try Identity * Identity + Zero. Expected result: Identity.
-        {MATRIX_MASK_IDENTITY, MATRIX_MASK_IDENTITY, MATRIX_MASK_ZERO},
-        // Try RandomLHS * Identity + Zero. Expected result: RandomLHS.
-        {MATRIX_MASK_NONE, MATRIX_MASK_IDENTITY, MATRIX_MASK_ZERO},
-        // Try Identity * RandomRHS + Zero. Expected result: RandomRHS.
-        {MATRIX_MASK_IDENTITY, MATRIX_MASK_NONE, MATRIX_MASK_ZERO},
-        // Try Identity * Identity + RandomAccum.
-        // Expected result: Identity + RandomAccum.
-        {MATRIX_MASK_IDENTITY, MATRIX_MASK_IDENTITY, MATRIX_MASK_NONE},
-        // Finally run the general case again. If none of the above special
-        // cases
-        // failed, then that at least must fail, since we already ran that and
-        // it
-        // had failed.
-        {MATRIX_MASK_NONE, MATRIX_MASK_NONE, MATRIX_MASK_NONE}};
-    bool reproduced_failure = false;
-    for (int i = 0; i < IREE_ARRAYSIZE(all_debug_masks); ++i) {
-      matrix_mask_t* masks = all_debug_masks[i];
-      iree_status_code_t rerun_status =
-          iree_status_consume_code(do_matmul_and_check_results(
-              stderr, replay, function, masks, device_inputs));
-      bool good = iree_status_is_ok(rerun_status);
-      reproduced_failure |= !good;
-      iree_string_builder_append_format(
-          &sb, "%s LHS:%-10s * RHS:%-10s + ACCUMULATOR:%-10s\n", emoji(good),
-          matrix_form(masks[0]), matrix_form(masks[1]), matrix_form(masks[2]));
-      if (!good) break;
-    }
-    if (!reproduced_failure) {
-      iree_status_abort(iree_make_status(
-          IREE_STATUS_INTERNAL,
-          "Internal error: a matmul test failed, but subsequent reruns for "
-          "logging purposes were not able to reproduce the failure."));
-    }
-    fprintf(stderr,
-            "Summary of reruns, pinpointing how general matrices need to be to "
-            "reproduce this failure:\n%s\n",
-            iree_string_builder_buffer(&sb));
-    iree_string_builder_deinitialize(&sb);
-  }
+  iree_status_t status =
+      do_matmul_and_check_results(stderr, replay, function, device_inputs);
 
   // Clean up.
   iree_vm_list_release(device_inputs);