e2e matmul test improvements (#13657)
* Test the non-accumulating matmul case (fill-zero accumulator). See
discussion in #13582 with @ThomasRaoux .
* Drop the `MIXED` dynamicity: it was already unused, having been
dropped out of actual test variants iterated over. Drop helpers that
were only used for it, such as `pseudorandom_bool`.
* In case of numerical error, drop the logic re-running matmul tests on
simpler matrices. Although potentially useful to help pinpoint the
numerical issue, it was confusing when debugging (as discussed once with
@manishucsd ), and it made test runner code substantially more complex.
diff --git a/compiler/src/iree/compiler/Codegen/VMVX/LowerLinalgMicrokernels.cpp b/compiler/src/iree/compiler/Codegen/VMVX/LowerLinalgMicrokernels.cpp
index 11585c4..bb005e4 100644
--- a/compiler/src/iree/compiler/Codegen/VMVX/LowerLinalgMicrokernels.cpp
+++ b/compiler/src/iree/compiler/Codegen/VMVX/LowerLinalgMicrokernels.cpp
@@ -884,6 +884,12 @@
}
LogicalResult handle2DTile(OpInfo &info, PatternRewriter &rewriter) const {
+ Type scalarType = info.scalar.getType();
+ if (!scalarType.isIntOrFloat() ||
+ scalarType.getIntOrFloatBitWidth() != 32) {
+ return rewriter.notifyMatchFailure(info.op,
+ "handling only 32-bit scalar types");
+ }
auto loc = info.op.getLoc();
StridedBufferDescriptor &outDesc = info.outAnal.getDesc(rewriter);
Value m = outDesc.sizes[0];
diff --git a/tests/e2e/matmul/generate_e2e_matmul_tests.py b/tests/e2e/matmul/generate_e2e_matmul_tests.py
index c409e8a..e8b6100 100644
--- a/tests/e2e/matmul/generate_e2e_matmul_tests.py
+++ b/tests/e2e/matmul/generate_e2e_matmul_tests.py
@@ -66,11 +66,15 @@
# Describes the shape of a matrix multiplication in the usual convention:
# the LHS is {m}x{k}, the RHS is {k}x{n}, the accumulator/result is {m}x{n}.
+# The extra `accumulate` boolean tells whether the matmul is accumulating into
+# an existing accumulator (C += A * B) or just overwriting the result
+# (C = A * B).
@dataclasses.dataclass
class TestShape:
m: int
k: int
n: int
+ accumulate: bool
# Describes how to construct compilation info for the testcase.
@@ -105,32 +109,36 @@
if shapes_id == ShapesId.SMALL:
return [
# square matrices. Start by the simplest case of 1x1x1.
- TestShape(m=1, k=1, n=1),
+ TestShape(m=1, k=1, n=1, accumulate=True),
+ TestShape(m=1, k=1, n=1, accumulate=False),
# test 9x9x9 because as many kernel M0/K0/N0 dims are equal to 8,
# this will often be the smallest value that exercises something above
# the kernel's size.
- TestShape(m=9, k=9, n=9),
+ TestShape(m=9, k=9, n=9, accumulate=True),
# rectangular matrices.
# >= 2x differences between M/N/K dims may exercise tiling corner cases
# not exercised by nearly-square matrices.
- TestShape(m=6, k=13, n=3),
- TestShape(m=15, k=37, n=7),
- TestShape(m=81, k=19, n=41),
+ TestShape(m=6, k=13, n=3, accumulate=True),
+ TestShape(m=15, k=37, n=7, accumulate=False),
+ TestShape(m=81, k=19, n=41, accumulate=True),
# shapes involving vectors (i.e. most rectangular cases)
# This is particularly relevant because we have dedicated kernels for
# the matrix*vector / vector*matrix case.
- TestShape(m=1, k=10, n=10), # vector*matrix
- TestShape(m=10, k=1, n=10), # outer-product
- TestShape(m=10, k=10, n=1), # matrix*vector
+ TestShape(m=1, k=10, n=10, accumulate=True), # vector*matrix
+ TestShape(m=1, k=10, n=10, accumulate=False), # vector*matrix
+ TestShape(m=10, k=1, n=10, accumulate=True), # outer-product
+ TestShape(m=10, k=10, n=1, accumulate=True), # matrix*vector
+ TestShape(m=10, k=10, n=1, accumulate=False), # matrix*vector
]
if shapes_id == ShapesId.LARGE:
return [
# some random large sizes
- TestShape(m=123, k=456, n=789),
- TestShape(m=654, k=321, n=234),
+ TestShape(m=123, k=456, n=789, accumulate=True),
+ TestShape(m=654, k=321, n=234, accumulate=False),
# shapes involving vectors (i.e. most rectangular cases)
- TestShape(m=1, k=1000, n=1000), # large vector*matrix
- TestShape(m=1000, k=1000, n=1), # large matrix*vector
+ TestShape(m=1, k=1000, n=1000, accumulate=True), # large vector*matrix
+ TestShape(m=1000, k=1000, n=1, accumulate=True), # large matrix*vector
+ TestShape(m=1000, k=1000, n=1, accumulate=False), # large matrix*vector
# Be conservative in adding larger shapes. They can result in
# high latency tests. If you have to, consider splitting them
# out in a way that constrains the latency impact, e.g. by
@@ -138,7 +146,11 @@
# (see get_test_generators).
]
if shapes_id == ShapesId.GPU_LARGE:
- return [TestShape(m=256, k=128, n=512)]
+ return [
+ TestShape(m=256, k=128, n=512, accumulate=True),
+ TestShape(m=256, k=128, n=512, accumulate=False),
+ ]
+
raise ValueError(shapes_id)
@@ -252,16 +264,6 @@
local_pseudorandom_state = 1
-# Returns a pseudorandom boolean
-def pseudorandom_bool():
- global local_pseudorandom_state
- # Same as C++ std::minstd_rand.
- # Using a local pseudorandom generator implementation ensures that it's
- # completely reproducible, across runs and across machines.
- local_pseudorandom_state = (local_pseudorandom_state * 48271) % 2147483647
- return local_pseudorandom_state > 1073741824
-
-
# A shape dimension value, i.e. a size value that could appear in a MLIR type
# such as 'tensor<?x4xf32>'. None means a dynamic size, similar to '?' in MLIR.
@dataclasses.dataclass
@@ -276,8 +278,6 @@
return DimSize(None)
elif dynamicity == Dynamicity.STATIC:
return DimSize(x)
- elif dynamicity == Dynamicity.MIXED:
- return DimSize(x if pseudorandom_bool() else None)
else:
raise ValueError(dynamicity)
@@ -320,28 +320,6 @@
acc_rows=shape_dim(shape.m, dynamicity),
acc_cols=shape_dim(shape.n, dynamicity),
)
- # In the mixed-shapes case, we have just randomly picked each of the above 6
- # values independently, making it likely that we got discrepancies where some
- # of the M, K, N dimensions of the problem appear as a static dim in some
- # matrix and as a dynamic dim in another matrix, e.g. for the M dimension we
- # might have lhs_rows=dynamic, acc_rows=3. We should be testing both such
- # 'wild' mixed-shapes cases, and more 'tame' mixed-shapes cases where each of
- # the M, K, N dimensions is consistently either static or dynamic in both
- # matrices (among lhs, rhs, acc) where it appears. If we don't do anything
- # about it, as there is a 1/2 chance of discrepancy for each of the M, K, N
- # dimensions, 7/8 of the mixed-shapes testcases will be 'wild'. Given our
- # limited number of overall mixed-shapes testcases, this risks not testing
- # the 'tame' case at all.
- #
- # At the moment there is an additional practical reason to care about this
- # here: the matmul-to-mmt4d transformation currently bails on most 'wild'
- # cases. So we care even more to test 'tame' cases so that mmt4d is tested
- # on some mixed-shapes cases.
- if pseudorandom_bool():
- # Ensure that we are generating a 'tame' case.
- shapes.acc_rows = shapes.lhs_rows
- shapes.acc_cols = shapes.rhs_cols
- shapes.rhs_rows = shapes.lhs_cols
return shapes
@@ -351,6 +329,7 @@
lhs_rhs_type: MatrixElemTypeId,
acc_type: MatrixElemTypeId,
shapes: TestInputMatricesShapes,
+ accumulate: bool,
compilation_info: typing.Optional[CompilationInfo] = None):
input_t = lhs_rhs_type.value
acc_t = acc_type.value
@@ -369,7 +348,8 @@
]) + "_" + "_".join([str(a) for a in compilation_info.workgroup_size])
info = f"_for_{compilation_info.dispatch_lowering_pass_pipeline}_{tile_workgroup_key}"
- return f"matmul_{lhs_m}x{lhs_k}x{input_t}_times_{rhs_k}x{rhs_n}x{input_t}_into_{acc_m}x{acc_n}x{acc_t}{info}"
+ matmul_kind = "matmul_accumulate" if accumulate else "matmul"
+ return f"{matmul_kind}_{lhs_m}x{lhs_k}x{input_t}_times_{rhs_k}x{rhs_n}x{input_t}_into_{acc_m}x{acc_n}x{acc_t}{info}"
# Represents a generated test function.
@@ -390,7 +370,7 @@
compilation_info: typing.Optional[CompilationInfo] = None):
shapes = generate_shapes(shape, dynamicity)
func_name = generate_function_name(lhs_rhs_type, acc_type, shapes,
- compilation_info)
+ shape.accumulate, compilation_info)
lhs_m = int_or_question_mark(shapes.lhs_rows)
lhs_k = int_or_question_mark(shapes.lhs_cols)
rhs_k = int_or_question_mark(shapes.rhs_rows)
@@ -422,11 +402,37 @@
func_definition = func_definition + compilation_info_string
generate_function.compilation_index += 1
- func_definition = func_definition + (
- f"func.func @{func_name}(%lhs: {lhs_tensor_type}, %rhs: {rhs_tensor_type}, %acc: {acc_tensor_type}) -> {acc_tensor_type} {{\n"
- f" %result = linalg.matmul {compilation_info_attr}ins(%lhs, %rhs: {lhs_tensor_type}, {rhs_tensor_type}) outs(%acc: {acc_tensor_type}) -> {acc_tensor_type}\n"
- f" return %result: {acc_tensor_type}\n"
- f"}}\n")
+ if shape.accumulate:
+ func_definition = func_definition + (
+ f"func.func @{func_name}(%lhs: {lhs_tensor_type}, %rhs: {rhs_tensor_type}, %acc: {acc_tensor_type}) -> {acc_tensor_type} {{\n"
+ f" %result = linalg.matmul {compilation_info_attr}ins(%lhs, %rhs: {lhs_tensor_type}, {rhs_tensor_type}) outs(%acc: {acc_tensor_type}) -> {acc_tensor_type}\n"
+ f" return %result: {acc_tensor_type}\n"
+ f"}}\n")
+ else:
+ literal_zero_for_acc_type = "0.0" if "f" in acc_type.value else "0"
+ acc_dyn_sizes = []
+ if acc_m == "?":
+ func_definition = func_definition + (
+ f"func.func @{func_name}(%lhs: {lhs_tensor_type}, %rhs: {rhs_tensor_type}) -> {acc_tensor_type} {{\n"
+ f" %c0 = arith.constant 0 : index\n"
+ f" %c1 = arith.constant 1 : index\n"
+ f" %acc_dim0 = tensor.dim %lhs, %c0 : {lhs_tensor_type}\n"
+ f" %acc_dim1 = tensor.dim %rhs, %c1 : {rhs_tensor_type}\n"
+ f" %init_acc = tensor.empty(%acc_dim0, %acc_dim1) : {acc_tensor_type}\n"
+ f" %c0_acc_type = arith.constant {literal_zero_for_acc_type}: {acc_type.value}\n"
+ f" %acc = linalg.fill ins(%c0_acc_type : {acc_type.value}) outs(%init_acc : {acc_tensor_type}) -> {acc_tensor_type}\n"
+ f" %result = linalg.matmul {compilation_info_attr}ins(%lhs, %rhs: {lhs_tensor_type}, {rhs_tensor_type}) outs(%acc: {acc_tensor_type}) -> {acc_tensor_type}\n"
+ f" return %result: {acc_tensor_type}\n"
+ f"}}\n")
+ else:
+ func_definition = func_definition + (
+ f"func.func @{func_name}(%lhs: {lhs_tensor_type}, %rhs: {rhs_tensor_type}) -> {acc_tensor_type} {{\n"
+ f" %init_acc = tensor.empty() : {acc_tensor_type}\n"
+ f" %c0_acc_type = arith.constant {literal_zero_for_acc_type}: {acc_type.value}\n"
+ f" %acc = linalg.fill ins(%c0_acc_type : {acc_type.value}) outs(%init_acc : {acc_tensor_type}) -> {acc_tensor_type}\n"
+ f" %result = linalg.matmul {compilation_info_attr}ins(%lhs, %rhs: {lhs_tensor_type}, {rhs_tensor_type}) outs(%acc: {acc_tensor_type}) -> {acc_tensor_type}\n"
+ f" return %result: {acc_tensor_type}\n"
+ f"}}\n")
return MLIRFunction(
name=func_name,
definition=func_definition,
@@ -474,23 +480,24 @@
# as a dictionary to be passed to yaml.dump.
def generate_trace(func_name: str, lhs_rhs_type: MatrixElemTypeId,
acc_type: MatrixElemTypeId, shape: TestShape):
- lhs_arg = generate_trace_matrix_arg([shape.m, shape.k], lhs_rhs_type,
- MatrixGenerator.RANDOM)
- rhs_arg = generate_trace_matrix_arg([shape.k, shape.n], lhs_rhs_type,
- MatrixGenerator.RANDOM)
- acc_arg = generate_trace_matrix_arg([shape.m, shape.n], acc_type,
- MatrixGenerator.RANDOM)
- result_arg = generate_trace_matrix_arg([shape.m, shape.n], acc_type,
- MatrixGenerator.ZERO)
+ args = [
+ generate_trace_matrix_arg([shape.m, shape.k], lhs_rhs_type,
+ MatrixGenerator.RANDOM),
+ generate_trace_matrix_arg([shape.k, shape.n], lhs_rhs_type,
+ MatrixGenerator.RANDOM),
+ ]
+ if shape.accumulate:
+ args.append(
+ generate_trace_matrix_arg([shape.m, shape.n], acc_type,
+ MatrixGenerator.RANDOM))
+
+ result = generate_trace_matrix_arg([shape.m, shape.n], acc_type,
+ MatrixGenerator.ZERO)
return {
"type": "call",
"function": "module." + func_name,
- "args": [
- lhs_arg,
- rhs_arg,
- acc_arg,
- ],
- "results": [result_arg,],
+ "args": args,
+ "results": [result],
}
diff --git a/tools/iree-e2e-matmul-test.c b/tools/iree-e2e-matmul-test.c
index 5b9bb1f..ed941cb 100644
--- a/tools/iree-e2e-matmul-test.c
+++ b/tools/iree-e2e-matmul-test.c
@@ -431,21 +431,35 @@
iree_hal_dim_t result_dims[2] = {0};
IREE_RETURN_IF_ERROR(get_matrix_shape(lhs, lhs_dims));
IREE_RETURN_IF_ERROR(get_matrix_shape(rhs, rhs_dims));
- IREE_RETURN_IF_ERROR(get_matrix_shape(acc, acc_dims));
IREE_RETURN_IF_ERROR(get_matrix_shape(result, result_dims));
*m_size = lhs_dims[0];
*k_size = lhs_dims[1];
*n_size = rhs_dims[1];
- if (!(lhs_dims[0] == *m_size && lhs_dims[1] == *k_size &&
- rhs_dims[0] == *k_size && rhs_dims[1] == *n_size &&
- acc_dims[0] == *m_size && acc_dims[1] == *n_size &&
- result_dims[0] == *m_size && result_dims[1] == *n_size)) {
- return iree_make_status(
- IREE_STATUS_INVALID_ARGUMENT,
- "mismatched matrix shapes in matmul: %" PRIdim "x%" PRIdim " * %" PRIdim
- "x%" PRIdim " + %" PRIdim "x%" PRIdim " -> %" PRIdim "x%" PRIdim,
- lhs_dims[0], lhs_dims[1], rhs_dims[0], rhs_dims[1], acc_dims[0],
- acc_dims[1], result_dims[0], result_dims[1]);
+ if (acc) {
+ IREE_RETURN_IF_ERROR(get_matrix_shape(acc, acc_dims));
+ if (!(lhs_dims[0] == *m_size && lhs_dims[1] == *k_size &&
+ rhs_dims[0] == *k_size && rhs_dims[1] == *n_size &&
+ acc_dims[0] == *m_size && acc_dims[1] == *n_size &&
+ result_dims[0] == *m_size && result_dims[1] == *n_size)) {
+ return iree_make_status(
+ IREE_STATUS_INVALID_ARGUMENT,
+ "mismatched matrix shapes in matmul: %" PRIdim "x%" PRIdim
+ " * %" PRIdim "x%" PRIdim " + %" PRIdim "x%" PRIdim " -> %" PRIdim
+ "x%" PRIdim,
+ lhs_dims[0], lhs_dims[1], rhs_dims[0], rhs_dims[1], acc_dims[0],
+ acc_dims[1], result_dims[0], result_dims[1]);
+ }
+ } else {
+ if (!(lhs_dims[0] == *m_size && lhs_dims[1] == *k_size &&
+ rhs_dims[0] == *k_size && rhs_dims[1] == *n_size &&
+ result_dims[0] == *m_size && result_dims[1] == *n_size)) {
+ return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+ "mismatched matrix shapes in matmul: %" PRIdim
+ "x%" PRIdim " * %" PRIdim "x%" PRIdim
+ " -> %" PRIdim "x%" PRIdim,
+ lhs_dims[0], lhs_dims[1], rhs_dims[0],
+ rhs_dims[1], result_dims[0], result_dims[1]);
+ }
}
return iree_ok_status();
}
@@ -457,7 +471,7 @@
iree_hal_element_type_t acc_type, LHSTYPE* lhs_data, RHSTYPE* rhs_data, \
ACCTYPE* acc_data, RESTYPE* result_data, iree_hal_dim_t m, \
iree_hal_dim_t n) { \
- ACCTYPE acc = acc_data[n + m * n_size]; \
+ ACCTYPE acc = acc_data ? acc_data[n + m * n_size] : 0; \
for (iree_hal_dim_t k = 0; k < k_size; ++k) { \
LHSTYPE lhs_value = lhs_data[k + m * k_size]; \
RHSTYPE rhs_value = rhs_data[n + k * n_size]; \
@@ -484,7 +498,7 @@
iree_hal_element_type_t acc_type, uint16_t* lhs_data, uint16_t* rhs_data,
uint16_t* acc_data, uint16_t* result_data, iree_hal_dim_t m,
iree_hal_dim_t n) {
- float acc = iree_math_f16_to_f32(acc_data[n + m * n_size]);
+ float acc = acc_data ? iree_math_f16_to_f32(acc_data[n + m * n_size]) : 0;
for (iree_hal_dim_t k = 0; k < k_size; ++k) {
acc = iree_math_round_to_nearest_f16(
iree_math_round_to_nearest_f16(
@@ -536,8 +550,9 @@
iree_hal_buffer_view_t* acc = NULL;
IREE_RETURN_IF_ERROR(get_item_as_buffer_view(input_list, 0, &lhs));
IREE_RETURN_IF_ERROR(get_item_as_buffer_view(input_list, 1, &rhs));
- IREE_RETURN_IF_ERROR(get_item_as_buffer_view(input_list, 2, &acc));
-
+ if (iree_vm_list_size(input_list) == 3) {
+ IREE_RETURN_IF_ERROR(get_item_as_buffer_view(input_list, 2, &acc));
+ }
iree_hal_dim_t m_size, k_size, n_size;
IREE_RETURN_IF_ERROR(
get_matmul_sizes(lhs, rhs, acc, result, &m_size, &k_size, &n_size));
@@ -549,24 +564,29 @@
lhs, IREE_HAL_MEMORY_ACCESS_READ, &lhs_mapping));
IREE_RETURN_IF_ERROR(map_host_local_row_major_data(
rhs, IREE_HAL_MEMORY_ACCESS_READ, &rhs_mapping));
- IREE_RETURN_IF_ERROR(map_host_local_row_major_data(
- acc, IREE_HAL_MEMORY_ACCESS_READ, &acc_mapping));
+ if (acc) {
+ IREE_RETURN_IF_ERROR(map_host_local_row_major_data(
+ acc, IREE_HAL_MEMORY_ACCESS_READ, &acc_mapping));
+ }
IREE_RETURN_IF_ERROR(map_host_local_row_major_data(
result, IREE_HAL_MEMORY_ACCESS_WRITE, &result_mapping));
iree_hal_element_type_t lhs_type = iree_hal_buffer_view_element_type(lhs);
iree_hal_element_type_t rhs_type = iree_hal_buffer_view_element_type(rhs);
- iree_hal_element_type_t acc_type = iree_hal_buffer_view_element_type(acc);
+ iree_hal_element_type_t acc_type = iree_hal_buffer_view_element_type(result);
for (iree_hal_dim_t m = 0; m < m_size; ++m) {
for (iree_hal_dim_t n = 0; n < n_size; ++n) {
- reference_matmul_element(
- m_size, k_size, n_size, lhs_type, rhs_type, acc_type,
- lhs_mapping.contents.data, rhs_mapping.contents.data,
- acc_mapping.contents.data, result_mapping.contents.data, m, n);
+ reference_matmul_element(m_size, k_size, n_size, lhs_type, rhs_type,
+ acc_type, lhs_mapping.contents.data,
+ rhs_mapping.contents.data,
+ acc ? acc_mapping.contents.data : NULL,
+ result_mapping.contents.data, m, n);
}
}
IREE_RETURN_IF_ERROR(iree_hal_buffer_unmap_range(&lhs_mapping));
IREE_RETURN_IF_ERROR(iree_hal_buffer_unmap_range(&rhs_mapping));
- IREE_RETURN_IF_ERROR(iree_hal_buffer_unmap_range(&acc_mapping));
+ if (acc) {
+ IREE_RETURN_IF_ERROR(iree_hal_buffer_unmap_range(&acc_mapping));
+ }
IREE_RETURN_IF_ERROR(iree_hal_buffer_unmap_range(&result_mapping));
return iree_ok_status();
}
@@ -813,9 +833,11 @@
print_matrix(file, "right-hand side", PRECISION_LOW, k_start, k_end, n_start,
n_end, rhs, NULL, NULL);
fprintf(file, "\n");
- print_matrix(file, "input accumulator", PRECISION_LOW, m_start, m_end,
- n_start, n_end, acc, NULL, NULL);
- fprintf(file, "\n");
+ if (acc) {
+ print_matrix(file, "input accumulator", PRECISION_LOW, m_start, m_end,
+ n_start, n_end, acc, NULL, NULL);
+ fprintf(file, "\n");
+ }
print_matrix(file, "expected result", PRECISION_LOW, m_start, m_end, n_start,
n_end, expected_result, actual_result, emoji(true));
fprintf(file, "\n");
@@ -874,7 +896,9 @@
iree_hal_buffer_view_t* acc = NULL;
IREE_RETURN_IF_ERROR(get_item_as_buffer_view(input_list, 0, &lhs));
IREE_RETURN_IF_ERROR(get_item_as_buffer_view(input_list, 1, &rhs));
- IREE_RETURN_IF_ERROR(get_item_as_buffer_view(input_list, 2, &acc));
+ if (iree_vm_list_size(input_list) == 3) {
+ IREE_RETURN_IF_ERROR(get_item_as_buffer_view(input_list, 2, &acc));
+ }
iree_hal_dim_t m_size, k_size, n_size;
IREE_RETURN_IF_ERROR(get_matmul_sizes(lhs, rhs, acc, actual_result, &m_size,
@@ -895,24 +919,10 @@
* helpers for it.
*
* |replay_event_call_matmul| calls |do_matmul_and_check_results| to actually
- * perform a matmul. In normal cases, each |replay_event_call_matmul| performs
- * one call to |do_matmul_and_check_results|, but when that generates an error,
- * it will make additional calls to |do_matmul_and_check_results| to evaluate
- * variants of the failed testcase to generate a more helpful log.
- *
- * The |matrix_mask_t| stuff is only used to generate these variants of failed
- * testcases.
+ * perform a matmul.
*
*****************************************************************************/
-// Enumerates ways that we may mask matrices in list of matrix inputs to matmul
-// testcases.
-typedef enum {
- MATRIX_MASK_NONE, // no-op: leave the existing matrix unchanged.
- MATRIX_MASK_ZERO, // overwrite the matrix with zeros.
- MATRIX_MASK_IDENTITY, // overwrite with (general rectangular) identity matrix
-} matrix_mask_t;
-
static iree_status_t make_identity_matrix_callback(
iree_hal_buffer_mapping_t* mapping, void* user_data) {
iree_hal_buffer_view_t* src = (iree_hal_buffer_view_t*)user_data;
@@ -936,53 +946,10 @@
return iree_ok_status();
}
-// Allocates device-local |dst| and initializes it as an identity-matrix shaped
-// like |src|.
-static iree_status_t make_device_identity_matrix_like(
+// Deep-copies device-local list of buffer_views |src| into |dst|.
+static iree_status_t copy_device_buffer_views_to_device(
iree_hal_device_t* device, iree_hal_allocator_t* hal_allocator,
- iree_hal_buffer_view_t* src, iree_hal_buffer_view_t** dst) {
- return iree_hal_buffer_view_generate_buffer(
- hal_allocator, iree_hal_buffer_view_shape_rank(src),
- iree_hal_buffer_view_shape_dims(src),
- iree_hal_buffer_view_element_type(src),
- iree_hal_buffer_view_encoding_type(src),
- (iree_hal_buffer_params_t){
- .type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL,
- .usage = IREE_HAL_BUFFER_USAGE_DEFAULT,
- },
- make_identity_matrix_callback, src, dst);
-}
-
-// Allocates device-local |dst| shaped like |src|, and:
-// - If |mask| is MATRIX_MASK_NONE, copies device-local |src| into |dst|.
-// - If |mask| is MATRIX_MASK_ZERO, leaves |dst| zero-filled.
-// - If |mask| is MATRIX_MASK_IDENTITY, makes |dst| an identity-matrix
-static iree_status_t mask_and_copy_device_buffer_view_to_device(
- iree_hal_device_t* device, iree_hal_allocator_t* hal_allocator,
- iree_hal_buffer_view_t* src, matrix_mask_t mask,
- iree_hal_buffer_view_t** dst) {
- if (mask == MATRIX_MASK_NONE) {
- IREE_RETURN_IF_ERROR(
- copy_device_buffer_view_to_device(device, hal_allocator, src, dst));
- } else if (mask == MATRIX_MASK_ZERO) {
- IREE_RETURN_IF_ERROR(allocate_device_buffer_view_like(
- hal_allocator, src, iree_const_byte_span_empty(), dst));
- } else if (mask == MATRIX_MASK_IDENTITY) {
- IREE_RETURN_IF_ERROR(
- make_device_identity_matrix_like(device, hal_allocator, src, dst));
- } else {
- iree_status_abort(iree_make_status(IREE_STATUS_INTERNAL, "bad mask enum"));
- }
- return iree_ok_status();
-}
-
-// Deep-copies device-local list of buffer_views |src| into |dst|, applying
-// mask[i] to the i-th list element as in
-// |mask_and_copy_device_buffer_view_to_device|.
-// Requirement: |mask| must point to an array of the same length as |src|.
-static iree_status_t mask_and_copy_device_buffer_views_to_device(
- iree_hal_device_t* device, iree_hal_allocator_t* hal_allocator,
- iree_vm_list_t* src_list, matrix_mask_t* mask, iree_vm_list_t** dst_list) {
+ iree_vm_list_t* src_list, iree_vm_list_t** dst_list) {
iree_vm_type_def_t elem_type = iree_vm_list_element_type(src_list);
iree_host_size_t size = iree_vm_list_size(src_list);
iree_allocator_t allocator = iree_hal_allocator_host_allocator(hal_allocator);
@@ -993,8 +960,8 @@
iree_hal_buffer_view_t* src = NULL;
IREE_RETURN_IF_ERROR(get_item_as_buffer_view(src_list, i, &src));
iree_hal_buffer_view_t* dst = NULL;
- IREE_RETURN_IF_ERROR(mask_and_copy_device_buffer_view_to_device(
- device, hal_allocator, src, mask[i], &dst));
+ IREE_RETURN_IF_ERROR(
+ copy_device_buffer_view_to_device(device, hal_allocator, src, &dst));
iree_vm_ref_t dst_ref = {0};
IREE_RETURN_IF_ERROR(
iree_vm_ref_wrap_assign(dst, iree_hal_buffer_view_type(), &dst_ref));
@@ -1004,16 +971,13 @@
}
// Performs one matmul test, on the device-local input matrices given in
-// |original_device_inputs|, applying the masks given in |mask| as in
-// |mask_and_copy_device_buffer_view_to_device|.
-// Both |input_list| and |mask| should have length 3. The 3 input matrices are
-// LHS, RHS, Accumulator, in that order.
+// |original_device_inputs|.
//
// The contents of |original_device_inputs| are preserved, even if the
// |function| would overwrite input-output arguments (e.g. the accumulator).
static iree_status_t do_matmul_and_check_results(
FILE* file, iree_trace_replay_t* replay, iree_vm_function_t function,
- matrix_mask_t* mask, iree_vm_list_t* original_device_inputs) {
+ iree_vm_list_t* original_device_inputs) {
iree_hal_allocator_t* device_allocator =
iree_hal_device_allocator(replay->device);
@@ -1023,8 +987,8 @@
// linalg.matmul. We need to preserve the original test inputs to perform
// reruns on variants in the failure case (see |replay_event_call_matmul|).
iree_vm_list_t* device_inputs = NULL;
- IREE_CHECK_OK(mask_and_copy_device_buffer_views_to_device(
- replay->device, device_allocator, original_device_inputs, mask,
+ IREE_CHECK_OK(copy_device_buffer_views_to_device(
+ replay->device, device_allocator, original_device_inputs,
&device_inputs));
// Perform a deep copy of the device-local inputs into host-local buffers.
@@ -1074,19 +1038,6 @@
return status;
}
-const char* matrix_form(matrix_mask_t mask) {
- switch (mask) {
- case MATRIX_MASK_NONE:
- return "GENERAL";
- case MATRIX_MASK_ZERO:
- return "ZERO";
- case MATRIX_MASK_IDENTITY:
- return "IDENTITY";
- }
- assert(false);
- return NULL;
-}
-
// Prints to |file| a message about the matmul shape. Useful as testcases
// otherwise only print the function name, and in the dynamic-shapes cases, that
// doesn't tell the actual shape.
@@ -1106,8 +1057,6 @@
}
// Special handler for function calls in a e2e matmul test trace.
-// Assumes that all calls are to functions that take 3 inputs (lhs, rhs, acc)
-// and return the result of a matmul (lhs*rhs+acc).
static iree_status_t replay_event_call_matmul(iree_trace_replay_t* replay,
yaml_document_t* document,
yaml_node_t* event_node) {
@@ -1126,69 +1075,8 @@
IREE_CHECK_OK(print_matmul_shape(stderr, device_inputs));
- // Perform the matmul test. So far we are using pseudorandom matrices (as
- // specified in the YAML trace and interpreted above in
- // |iree_trace_replay_event_call_prepare|). So this is a test on general
- // random matrices: great for test coverage (if this succeeds, any variant on
- // more special matrices would also succeed) but bad for debugging (if this
- // fails, having to debug that would involve staring at arrays of random
- // numbers). So for now we pass NULL as the |file| param, keeping errors
- // silent for now.
- matrix_mask_t none_masks[3] = {MATRIX_MASK_NONE, MATRIX_MASK_NONE,
- MATRIX_MASK_NONE};
- iree_status_t status = do_matmul_and_check_results(NULL, replay, function,
- none_masks, device_inputs);
- if (!iree_status_is_ok(status)) {
- // The matmul test failed. So whatever we do now is only for the sake of
- // generating the most undertandable possible error log. We are going to
- // retry the matmul but on more special, easy-to-understand matrices,
- // gradually increasing generality, and we will abort and log details on
- // the first error that we encounter.
- iree_string_builder_t sb;
- iree_string_builder_initialize(replay->host_allocator, &sb);
- matrix_mask_t all_debug_masks[6][3] = {
- // Try Zero * Zero + Zero. Expected result: Zero.
- {MATRIX_MASK_ZERO, MATRIX_MASK_ZERO, MATRIX_MASK_ZERO},
- // Try Identity * Identity + Zero. Expected result: Identity.
- {MATRIX_MASK_IDENTITY, MATRIX_MASK_IDENTITY, MATRIX_MASK_ZERO},
- // Try RandomLHS * Identity + Zero. Expected result: RandomLHS.
- {MATRIX_MASK_NONE, MATRIX_MASK_IDENTITY, MATRIX_MASK_ZERO},
- // Try Identity * RandomRHS + Zero. Expected result: RandomRHS.
- {MATRIX_MASK_IDENTITY, MATRIX_MASK_NONE, MATRIX_MASK_ZERO},
- // Try Identity * Identity + RandomAccum.
- // Expected result: Identity + RandomAccum.
- {MATRIX_MASK_IDENTITY, MATRIX_MASK_IDENTITY, MATRIX_MASK_NONE},
- // Finally run the general case again. If none of the above special
- // cases
- // failed, then that at least must fail, since we already ran that and
- // it
- // had failed.
- {MATRIX_MASK_NONE, MATRIX_MASK_NONE, MATRIX_MASK_NONE}};
- bool reproduced_failure = false;
- for (int i = 0; i < IREE_ARRAYSIZE(all_debug_masks); ++i) {
- matrix_mask_t* masks = all_debug_masks[i];
- iree_status_code_t rerun_status =
- iree_status_consume_code(do_matmul_and_check_results(
- stderr, replay, function, masks, device_inputs));
- bool good = iree_status_is_ok(rerun_status);
- reproduced_failure |= !good;
- iree_string_builder_append_format(
- &sb, "%s LHS:%-10s * RHS:%-10s + ACCUMULATOR:%-10s\n", emoji(good),
- matrix_form(masks[0]), matrix_form(masks[1]), matrix_form(masks[2]));
- if (!good) break;
- }
- if (!reproduced_failure) {
- iree_status_abort(iree_make_status(
- IREE_STATUS_INTERNAL,
- "Internal error: a matmul test failed, but subsequent reruns for "
- "logging purposes were not able to reproduce the failure."));
- }
- fprintf(stderr,
- "Summary of reruns, pinpointing how general matrices need to be to "
- "reproduce this failure:\n%s\n",
- iree_string_builder_buffer(&sb));
- iree_string_builder_deinitialize(&sb);
- }
+ iree_status_t status =
+ do_matmul_and_check_results(stderr, replay, function, device_inputs);
// Clean up.
iree_vm_list_release(device_inputs);