[matmul] Add transpose B matrix coverage for CDNA3 (#16558)
This commit adds transpose B matrix coverage in the matmul test suite.
This is to enable adding such tests for CDNA3 mfma CodeGen pipeline.
ci-extra: test_gpu
diff --git a/tests/e2e/matmul/BUILD.bazel b/tests/e2e/matmul/BUILD.bazel
index 99ffe99..587d373 100644
--- a/tests/e2e/matmul/BUILD.bazel
+++ b/tests/e2e/matmul/BUILD.bazel
@@ -430,9 +430,8 @@
###########################################################################
# Testing CDNA3 + matrix core path.
-# v_mfma_f32_16x16x16_f16
iree_generated_e2e_matmul_test(
- name = "e2e_matmul_rocm_f16_large_cdna3_matrixcore",
+ name = "e2e_matmul_rocm_f16_large_cdna3_mfma",
compiler_flags = [
"--iree-rocm-target-chip=gfx942",
],
@@ -456,6 +455,32 @@
test_runner = "//tools:iree-e2e-matmul-test",
)
+iree_generated_e2e_matmul_test(
+ name = "e2e_matmul_rocm_f16_large_cdna3_mfma_tb",
+ compiler_flags = [
+ "--iree-rocm-target-chip=gfx942",
+ ],
+ generator = ":generate_e2e_matmul_tests",
+ generator_args = [
+ "--lhs_rhs_type=f16",
+ "--acc_type=f32",
+ "--transpose_rhs",
+ "--shapes=gpu_large_aligned",
+ "--compilation_info=LLVMGPUVectorDistribute",
+ ],
+ tags = [
+ "noasan",
+ "nomsan",
+ "notsan",
+ "noubsan",
+ "requires-gpu-cdna3",
+ ],
+ target_backends_and_drivers = [
+ ("rocm", "rocm"),
+ ],
+ test_runner = "//tools:iree-e2e-matmul-test",
+)
+
###########################################################################
##
## Vulkan backend
diff --git a/tests/e2e/matmul/CMakeLists.txt b/tests/e2e/matmul/CMakeLists.txt
index 9939e35..6b9fdf9 100644
--- a/tests/e2e/matmul/CMakeLists.txt
+++ b/tests/e2e/matmul/CMakeLists.txt
@@ -978,7 +978,7 @@
iree_generated_e2e_matmul_test(
NAME
- e2e_matmul_rocm_f16_large_cdna3_matrixcore
+ e2e_matmul_rocm_f16_large_cdna3_mfma
GENERATOR
"generate_e2e_matmul_tests.py"
GENERATOR_ARGS
@@ -1004,6 +1004,33 @@
iree_generated_e2e_matmul_test(
NAME
+ e2e_matmul_rocm_f16_large_cdna3_mfma_tb
+ GENERATOR
+ "generate_e2e_matmul_tests.py"
+ GENERATOR_ARGS
+ "--lhs_rhs_type=f16"
+ "--acc_type=f32"
+ "--transpose_rhs"
+ "--shapes=gpu_large_aligned"
+ "--compilation_info=LLVMGPUVectorDistribute"
+ TEST_RUNNER
+ iree-e2e-matmul-test
+ TARGET_BACKENDS
+ "rocm"
+ DRIVERS
+ "rocm"
+ COMPILER_FLAGS
+ "--iree-rocm-target-chip=gfx942"
+ LABELS
+ "noasan"
+ "nomsan"
+ "notsan"
+ "noubsan"
+ "requires-gpu-cdna3"
+)
+
+iree_generated_e2e_matmul_test(
+ NAME
e2e_matmul_vulkan_i8_large_valhall
GENERATOR
"generate_e2e_matmul_tests.py"
diff --git a/tests/e2e/matmul/generate_e2e_matmul_tests.py b/tests/e2e/matmul/generate_e2e_matmul_tests.py
index e295a3d..1b443a8 100644
--- a/tests/e2e/matmul/generate_e2e_matmul_tests.py
+++ b/tests/e2e/matmul/generate_e2e_matmul_tests.py
@@ -420,14 +420,24 @@
# Helper for generate_function. Generates TestInputMatricesShapes, i.e.
# converts from the runtime shape dimensions in TestShape and given dynamicity to
# the set of shapes to be used in a test function's input tensors.
-def generate_shapes(shape: TestShape, dynamicity: Dynamicity):
+def generate_shapes(shape: TestShape, transpose_rhs: bool, dynamicity: Dynamicity):
+ lhs_rows = shape_dim(shape.m, dynamicity)
+ lhs_cols = shape_dim(shape.k, dynamicity)
+ acc_rows = shape_dim(shape.m, dynamicity)
+ acc_cols = shape_dim(shape.n, dynamicity)
+ if transpose_rhs:
+ rhs_rows = shape_dim(shape.n, dynamicity)
+ rhs_cols = shape_dim(shape.k, dynamicity)
+ else:
+ rhs_rows = shape_dim(shape.k, dynamicity)
+ rhs_cols = shape_dim(shape.n, dynamicity)
shapes = TestInputMatricesShapes(
- lhs_rows=shape_dim(shape.m, dynamicity),
- lhs_cols=shape_dim(shape.k, dynamicity),
- rhs_rows=shape_dim(shape.k, dynamicity),
- rhs_cols=shape_dim(shape.n, dynamicity),
- acc_rows=shape_dim(shape.m, dynamicity),
- acc_cols=shape_dim(shape.n, dynamicity),
+ lhs_rows=lhs_rows,
+ lhs_cols=lhs_cols,
+ rhs_rows=rhs_rows,
+ rhs_cols=rhs_cols,
+ acc_rows=acc_rows,
+ acc_cols=acc_cols,
)
return shapes
@@ -443,12 +453,12 @@
):
input_t = lhs_rhs_type.value
acc_t = acc_type.value
- lhs_m = int_or_DYN(shapes.lhs_rows)
- lhs_k = int_or_DYN(shapes.lhs_cols)
- rhs_k = int_or_DYN(shapes.rhs_rows)
- rhs_n = int_or_DYN(shapes.rhs_cols)
- acc_m = int_or_DYN(shapes.acc_rows)
- acc_n = int_or_DYN(shapes.acc_cols)
+ lhs_r = int_or_DYN(shapes.lhs_rows)
+ lhs_c = int_or_DYN(shapes.lhs_cols)
+ rhs_r = int_or_DYN(shapes.rhs_rows)
+ rhs_c = int_or_DYN(shapes.rhs_cols)
+ acc_r = int_or_DYN(shapes.acc_rows)
+ acc_c = int_or_DYN(shapes.acc_cols)
info = ""
if compilation_info:
@@ -462,8 +472,8 @@
matmul_kind = "matmul_accumulate" if accumulate else "matmul"
return (
- f"{matmul_kind}_{lhs_m}x{lhs_k}x{input_t}_times_"
- + f"{rhs_k}x{rhs_n}x{input_t}_into_{acc_m}x{acc_n}x{acc_t}{info}"
+ f"{matmul_kind}_{lhs_r}x{lhs_c}x{input_t}_times_"
+ + f"{rhs_r}x{rhs_c}x{input_t}_into_{acc_r}x{acc_c}x{acc_t}{info}"
)
@@ -477,28 +487,34 @@
# Generates a test function in the generated MLIR code.
-# The generated function will take the same arguments as linalg.matmul and
-# will just call linalg.matmul with them, returning its result.
+# The generated function will take the same arguments as linalg.matmul variants
+# and will just call linalg.matmul variants with them, returning its result.
def generate_function(
lhs_rhs_type: MatrixElemTypeId,
acc_type: MatrixElemTypeId,
shape: TestShape,
+ transpose_rhs: bool,
dynamicity: Dynamicity,
compilation_info: typing.Optional[CompilationInfo] = None,
):
- shapes = generate_shapes(shape, dynamicity)
+ shapes = generate_shapes(shape, transpose_rhs, dynamicity)
func_name = generate_function_name(
lhs_rhs_type, acc_type, shapes, shape.accumulate, compilation_info
)
- lhs_m = int_or_question_mark(shapes.lhs_rows)
- lhs_k = int_or_question_mark(shapes.lhs_cols)
- rhs_k = int_or_question_mark(shapes.rhs_rows)
- rhs_n = int_or_question_mark(shapes.rhs_cols)
- acc_m = int_or_question_mark(shapes.acc_rows)
- acc_n = int_or_question_mark(shapes.acc_cols)
- lhs_tensor_type = f"tensor<{lhs_m}x{lhs_k}x{lhs_rhs_type.value}>"
- rhs_tensor_type = f"tensor<{rhs_k}x{rhs_n}x{lhs_rhs_type.value}>"
- acc_tensor_type = f"tensor<{acc_m}x{acc_n}x{acc_type.value}>"
+ lhs_r = int_or_question_mark(shapes.lhs_rows)
+ lhs_c = int_or_question_mark(shapes.lhs_cols)
+ rhs_r = int_or_question_mark(shapes.rhs_rows)
+ rhs_c = int_or_question_mark(shapes.rhs_cols)
+ acc_r = int_or_question_mark(shapes.acc_rows)
+ acc_c = int_or_question_mark(shapes.acc_cols)
+ lhs_tensor_type = f"tensor<{lhs_r}x{lhs_c}x{lhs_rhs_type.value}>"
+ rhs_tensor_type = f"tensor<{rhs_r}x{rhs_c}x{lhs_rhs_type.value}>"
+ acc_tensor_type = f"tensor<{acc_r}x{acc_c}x{acc_type.value}>"
+
+ if transpose_rhs:
+ op_name = "linalg.matmul_transpose_b"
+ else:
+ op_name = "linalg.matmul"
# Compilation info is optional; prints empty string by default.
func_definition = ""
@@ -537,13 +553,13 @@
import_declaration = f"func.func private @module.{func_name}(%lhs: !hal.buffer_view, %rhs: !hal.buffer_view, %acc: !hal.buffer_view) -> !hal.buffer_view"
func_definition = func_definition + (
f"func.func @{func_name}(%lhs: {lhs_tensor_type}, %rhs: {rhs_tensor_type}, %acc: {acc_tensor_type}) -> {acc_tensor_type} {{\n"
- f" %result = linalg.matmul {compilation_info_attr}ins(%lhs, %rhs: {lhs_tensor_type}, {rhs_tensor_type}) outs(%acc: {acc_tensor_type}) -> {acc_tensor_type}\n"
+ f" %result = {op_name} {compilation_info_attr}ins(%lhs, %rhs: {lhs_tensor_type}, {rhs_tensor_type}) outs(%acc: {acc_tensor_type}) -> {acc_tensor_type}\n"
f" return %result: {acc_tensor_type}\n"
f"}}\n"
)
else:
literal_zero_for_acc_type = "0.0" if "f" in acc_type.value else "0"
- if acc_m == "?":
+ if acc_r == "?":
signature = f"({lhs_tensor_type}, {rhs_tensor_type}) -> {acc_tensor_type}"
import_declaration = f"func.func private @module.{func_name}(%lhs: !hal.buffer_view, %rhs: !hal.buffer_view) -> !hal.buffer_view"
func_definition = func_definition + (
@@ -555,7 +571,7 @@
f" %init_acc = tensor.empty(%acc_dim0, %acc_dim1) : {acc_tensor_type}\n"
f" %c0_acc_type = arith.constant {literal_zero_for_acc_type}: {acc_type.value}\n"
f" %acc = linalg.fill ins(%c0_acc_type : {acc_type.value}) outs(%init_acc : {acc_tensor_type}) -> {acc_tensor_type}\n"
- f" %result = linalg.matmul {compilation_info_attr}ins(%lhs, %rhs: {lhs_tensor_type}, {rhs_tensor_type}) outs(%acc: {acc_tensor_type}) -> {acc_tensor_type}\n"
+ f" %result = {op_name} {compilation_info_attr}ins(%lhs, %rhs: {lhs_tensor_type}, {rhs_tensor_type}) outs(%acc: {acc_tensor_type}) -> {acc_tensor_type}\n"
f" return %result: {acc_tensor_type}\n"
f"}}\n"
)
@@ -567,7 +583,7 @@
f" %init_acc = tensor.empty() : {acc_tensor_type}\n"
f" %c0_acc_type = arith.constant {literal_zero_for_acc_type}: {acc_type.value}\n"
f" %acc = linalg.fill ins(%c0_acc_type : {acc_type.value}) outs(%init_acc : {acc_tensor_type}) -> {acc_tensor_type}\n"
- f" %result = linalg.matmul {compilation_info_attr}ins(%lhs, %rhs: {lhs_tensor_type}, {rhs_tensor_type}) outs(%acc: {acc_tensor_type}) -> {acc_tensor_type}\n"
+ f" %result = {op_name} {compilation_info_attr}ins(%lhs, %rhs: {lhs_tensor_type}, {rhs_tensor_type}) outs(%acc: {acc_tensor_type}) -> {acc_tensor_type}\n"
f" return %result: {acc_tensor_type}\n"
f"}}\n"
)
@@ -635,6 +651,7 @@
lhs_rhs_type: MatrixElemTypeId,
acc_type: MatrixElemTypeId,
shape: TestShape,
+ transpose_rhs: bool,
):
global call_id
func_name = f"{function.name}_{shape.m}_{shape.k}_{shape.n}"
@@ -652,8 +669,16 @@
" %device = hal.devices.get %device_index : !hal.device\n"
)
- op = op + generate_random_matrix("lhs", [shape.m, shape.k], lhs_rhs_type)
- op = op + generate_random_matrix("rhs", [shape.k, shape.n], lhs_rhs_type)
+ lhs_shape = [shape.m, shape.k]
+ if transpose_rhs:
+ rhs_shape = [shape.n, shape.k]
+ transpose_rhs = 1
+ else:
+ rhs_shape = [shape.k, shape.n]
+ transpose_rhs = 0
+
+ op = op + generate_random_matrix("lhs", lhs_shape, lhs_rhs_type)
+ op = op + generate_random_matrix("rhs", rhs_shape, lhs_rhs_type)
if shape.accumulate:
op = op + generate_random_matrix("acc", [shape.m, shape.n], acc_type)
# TODO(#16168): there's a bug with in-place input->output aliasing and
@@ -674,7 +699,8 @@
f" %m = arith.constant {shape.m} : i64\n"
f" %k = arith.constant {shape.k} : i64\n"
f" %n = arith.constant {shape.n} : i64\n"
- f" call @matmul_test.check_matmul_results(%device, %m, %k, %n, %lhs, %rhs, %acc, %result) : (!hal.device, i64, i64, i64, !hal.buffer_view, !hal.buffer_view, !hal.buffer_view, !hal.buffer_view) -> ()\n"
+ f" %transpose_rhs = arith.constant {transpose_rhs} : i32\n"
+ f" call @matmul_test.check_matmul_results(%device, %m, %k, %n, %transpose_rhs, %lhs, %rhs, %acc, %result) : (!hal.device, i64, i64, i64, i32, !hal.buffer_view, !hal.buffer_view, !hal.buffer_view, !hal.buffer_view) -> ()\n"
)
op = op + " return\n"
@@ -688,6 +714,7 @@
lhs_rhs_type: MatrixElemTypeId,
acc_type: MatrixElemTypeId,
shapes_id: ShapesId,
+ transpose_rhs: bool,
compilation_info_id: CompilationInfoId,
):
functions = {}
@@ -699,7 +726,12 @@
for shape in get_test_shapes(shapes_id):
for dynamicity in get_dynamicities(shapes_id):
function = generate_function(
- lhs_rhs_type, acc_type, shape, dynamicity, compilation_info
+ lhs_rhs_type,
+ acc_type,
+ shape,
+ transpose_rhs,
+ dynamicity,
+ compilation_info,
)
# Different testcases may differ only by runtime parameters but
# share the same code. For example, dynamic-shapes testcases
@@ -708,7 +740,11 @@
# to calls, but unconditionally to function_definitions.
if function.name not in functions:
functions[function.name] = function
- calls.append(generate_call(function, lhs_rhs_type, acc_type, shape))
+ calls.append(
+ generate_call(
+ function, lhs_rhs_type, acc_type, shape, transpose_rhs
+ )
+ )
return (functions, calls)
@@ -750,6 +786,13 @@
required=True,
)
parser.add_argument(
+ "--transpose_rhs",
+ action="store_true",
+ help="Whether to transpose RHS",
+ default=False,
+ required=False,
+ )
+ parser.add_argument(
"--compilation_info",
type=str,
choices=[i.value for i in CompilationInfoId],
@@ -790,7 +833,7 @@
# Declare the custom module that generates arguments.
module_definition = module_definition + (
"func.func private @matmul_test.generate_random_matrix(%device: !hal.device, %dim0: i64, %dim1: i64, %element_type: i32, %seed: i32) -> !hal.buffer_view\n"
- "func.func private @matmul_test.check_matmul_results(%device: !hal.device, %m: i64, %k: i64, %n: i64, %lhs: !hal.buffer_view, %rhs: !hal.buffer_view, %acc: !hal.buffer_view, %actual_result: !hal.buffer_view)\n"
+ "func.func private @matmul_test.check_matmul_results(%device: !hal.device, %m: i64, %k: i64, %n: i64, %transpose_rhs: i32, %lhs: !hal.buffer_view, %rhs: !hal.buffer_view, %acc: !hal.buffer_view, %actual_result: !hal.buffer_view)\n"
"\n"
)
@@ -827,8 +870,9 @@
acc_type = infer_acc_type(lhs_rhs_type, acc_type)
shapes_id = ShapesId(args.shapes)
compilation_info_id = CompilationInfoId(args.compilation_info)
+
(functions, calls) = generate(
- lhs_rhs_type, acc_type, shapes_id, compilation_info_id
+ lhs_rhs_type, acc_type, shapes_id, args.transpose_rhs, compilation_info_id
)
write_code_file(functions, args.output_matmuls_mlir)
diff --git a/tools/iree-e2e-matmul-test.cc b/tools/iree-e2e-matmul-test.cc
index a4d5b11..e165830 100644
--- a/tools/iree-e2e-matmul-test.cc
+++ b/tools/iree-e2e-matmul-test.cc
@@ -199,20 +199,22 @@
return iree_ok_status();
}
-#define REFERENCE_MATMUL(LHSTYPE, RHSTYPE, RESTYPE, ACCTYPE) \
- static void reference_matmul_##LHSTYPE##_##RHSTYPE##_##RESTYPE##_##ACCTYPE( \
- iree_hal_dim_t m_size, iree_hal_dim_t k_size, iree_hal_dim_t n_size, \
- iree_hal_element_type_t lhs_type, iree_hal_element_type_t rhs_type, \
- iree_hal_element_type_t acc_type, const LHSTYPE* lhs_data, \
- const RHSTYPE* rhs_data, const ACCTYPE* acc_data, RESTYPE* result_data, \
- iree_hal_dim_t m, iree_hal_dim_t n) { \
- ACCTYPE acc = acc_data ? acc_data[n + m * n_size] : 0; \
- for (iree_hal_dim_t k = 0; k < k_size; ++k) { \
- LHSTYPE lhs_value = lhs_data[k + m * k_size]; \
- RHSTYPE rhs_value = rhs_data[n + k * n_size]; \
- acc += (ACCTYPE)lhs_value * (ACCTYPE)rhs_value; \
- } \
- result_data[n + m * n_size] = acc; \
+#define REFERENCE_MATMUL(LHSTYPE, RHSTYPE, RESTYPE, ACCTYPE) \
+ static void reference_matmul_##LHSTYPE##_##RHSTYPE##_##RESTYPE##_##ACCTYPE( \
+ iree_hal_dim_t m_size, iree_hal_dim_t k_size, iree_hal_dim_t n_size, \
+ iree_hal_element_type_t lhs_type, iree_hal_element_type_t rhs_type, \
+ iree_hal_element_type_t acc_type, bool transpose_rhs, \
+ const LHSTYPE* lhs_data, const RHSTYPE* rhs_data, \
+ const ACCTYPE* acc_data, RESTYPE* result_data, iree_hal_dim_t m, \
+ iree_hal_dim_t n) { \
+ ACCTYPE acc = acc_data ? acc_data[n + m * n_size] : 0; \
+ for (iree_hal_dim_t k = 0; k < k_size; ++k) { \
+ LHSTYPE lhs_value = lhs_data[k + m * k_size]; \
+ RHSTYPE rhs_value = \
+ transpose_rhs ? rhs_data[k + n * k_size] : rhs_data[n + k * n_size]; \
+ acc += (ACCTYPE)lhs_value * (ACCTYPE)rhs_value; \
+ } \
+ result_data[n + m * n_size] = acc; \
}
// Reference mamtul instantiations from macro REFERENCE_MATMUL
@@ -235,13 +237,15 @@
static void reference_matmul_f16_f16_f16_f16(
iree_hal_dim_t m_size, iree_hal_dim_t k_size, iree_hal_dim_t n_size,
iree_hal_element_type_t lhs_type, iree_hal_element_type_t rhs_type,
- iree_hal_element_type_t acc_type, const uint16_t* lhs_data,
- const uint16_t* rhs_data, const uint16_t* acc_data, uint16_t* result_data,
- iree_hal_dim_t m, iree_hal_dim_t n) {
+ iree_hal_element_type_t acc_type, bool transpose_rhs,
+ const uint16_t* lhs_data, const uint16_t* rhs_data,
+ const uint16_t* acc_data, uint16_t* result_data, iree_hal_dim_t m,
+ iree_hal_dim_t n) {
float acc = acc_data ? iree_math_f16_to_f32(acc_data[n + m * n_size]) : 0.f;
for (iree_hal_dim_t k = 0; k < k_size; ++k) {
+ int64_t rhs_index = transpose_rhs ? k + n * k_size : n + k * n_size;
acc += iree_math_f16_to_f32(lhs_data[k + m * k_size]) *
- iree_math_f16_to_f32(rhs_data[n + k * n_size]);
+ iree_math_f16_to_f32(rhs_data[rhs_index]);
}
result_data[n + m * n_size] = iree_math_f32_to_f16(acc);
}
@@ -251,13 +255,14 @@
static void reference_matmul_f16_f16_f32_f32(
iree_hal_dim_t m_size, iree_hal_dim_t k_size, iree_hal_dim_t n_size,
iree_hal_element_type_t lhs_type, iree_hal_element_type_t rhs_type,
- iree_hal_element_type_t acc_type, const uint16_t* lhs_data,
- const uint16_t* rhs_data, const float* acc_data, float* result_data,
- iree_hal_dim_t m, iree_hal_dim_t n) {
+ iree_hal_element_type_t acc_type, bool transpose_rhs,
+ const uint16_t* lhs_data, const uint16_t* rhs_data, const float* acc_data,
+ float* result_data, iree_hal_dim_t m, iree_hal_dim_t n) {
float acc = acc_data ? acc_data[n + m * n_size] : 0.f;
for (iree_hal_dim_t k = 0; k < k_size; ++k) {
+ int64_t rhs_index = transpose_rhs ? k + n * k_size : n + k * n_size;
acc += iree_math_f16_to_f32(lhs_data[k + m * k_size]) *
- iree_math_f16_to_f32(rhs_data[n + k * n_size]);
+ iree_math_f16_to_f32(rhs_data[rhs_index]);
}
result_data[n + m * n_size] = acc;
}
@@ -267,13 +272,15 @@
static void reference_matmul_bf16_bf16_bf16_bf16(
iree_hal_dim_t m_size, iree_hal_dim_t k_size, iree_hal_dim_t n_size,
iree_hal_element_type_t lhs_type, iree_hal_element_type_t rhs_type,
- iree_hal_element_type_t acc_type, const uint16_t* lhs_data,
- const uint16_t* rhs_data, const uint16_t* acc_data, uint16_t* result_data,
- iree_hal_dim_t m, iree_hal_dim_t n) {
+ iree_hal_element_type_t acc_type, bool transpose_rhs,
+ const uint16_t* lhs_data, const uint16_t* rhs_data,
+ const uint16_t* acc_data, uint16_t* result_data, iree_hal_dim_t m,
+ iree_hal_dim_t n) {
float acc = acc_data ? iree_math_bf16_to_f32(acc_data[n + m * n_size]) : 0.f;
for (iree_hal_dim_t k = 0; k < k_size; ++k) {
+ int64_t rhs_index = transpose_rhs ? k + n * k_size : n + k * n_size;
acc += iree_math_bf16_to_f32(lhs_data[k + m * k_size]) *
- iree_math_bf16_to_f32(rhs_data[n + k * n_size]);
+ iree_math_bf16_to_f32(rhs_data[rhs_index]);
}
result_data[n + m * n_size] = iree_math_f32_to_bf16(acc);
}
@@ -283,13 +290,14 @@
static void reference_matmul_bf16_bf16_f32_f32(
iree_hal_dim_t m_size, iree_hal_dim_t k_size, iree_hal_dim_t n_size,
iree_hal_element_type_t lhs_type, iree_hal_element_type_t rhs_type,
- iree_hal_element_type_t acc_type, const uint16_t* lhs_data,
- const uint16_t* rhs_data, const float* acc_data, float* result_data,
- iree_hal_dim_t m, iree_hal_dim_t n) {
+ iree_hal_element_type_t acc_type, bool transpose_rhs,
+ const uint16_t* lhs_data, const uint16_t* rhs_data, const float* acc_data,
+ float* result_data, iree_hal_dim_t m, iree_hal_dim_t n) {
float acc = acc_data ? acc_data[n + m * n_size] : 0.f;
for (iree_hal_dim_t k = 0; k < k_size; ++k) {
+ int64_t rhs_index = transpose_rhs ? k + n * k_size : n + k * n_size;
acc += iree_math_bf16_to_f32(lhs_data[k + m * k_size]) *
- iree_math_bf16_to_f32(rhs_data[n + k * n_size]);
+ iree_math_bf16_to_f32(rhs_data[rhs_index]);
}
result_data[n + m * n_size] = acc;
}
@@ -299,55 +307,56 @@
static iree_status_t reference_matmul_element(
iree_hal_dim_t m_size, iree_hal_dim_t k_size, iree_hal_dim_t n_size,
iree_hal_element_type_t lhs_type, iree_hal_element_type_t rhs_type,
- iree_hal_element_type_t acc_type, void* lhs_data, void* rhs_data,
- void* acc_data, void* result_data, iree_hal_dim_t m, iree_hal_dim_t n) {
+ iree_hal_element_type_t acc_type, bool transpose_rhs, void* lhs_data,
+ void* rhs_data, void* acc_data, void* result_data, iree_hal_dim_t m,
+ iree_hal_dim_t n) {
if (lhs_type == IREE_HAL_ELEMENT_TYPE_FLOAT_32 &&
rhs_type == IREE_HAL_ELEMENT_TYPE_FLOAT_32 &&
acc_type == IREE_HAL_ELEMENT_TYPE_FLOAT_32) {
reference_matmul_float_float_float_float(
- m_size, k_size, n_size, lhs_type, rhs_type, acc_type,
+ m_size, k_size, n_size, lhs_type, rhs_type, acc_type, transpose_rhs,
(const float*)lhs_data, (const float*)rhs_data, (const float*)acc_data,
(float*)result_data, m, n);
} else if (iree_hal_element_type_is_integer(lhs_type, 8) &&
iree_hal_element_type_is_integer(rhs_type, 8) &&
iree_hal_element_type_is_integer(acc_type, 32)) {
reference_matmul_int8_t_int8_t_int32_t_int32_t(
- m_size, k_size, n_size, lhs_type, rhs_type, acc_type,
+ m_size, k_size, n_size, lhs_type, rhs_type, acc_type, transpose_rhs,
(const int8_t*)lhs_data, (const int8_t*)rhs_data,
(const int32_t*)acc_data, (int32_t*)result_data, m, n);
} else if (iree_hal_element_type_is_integer(lhs_type, 32) &&
iree_hal_element_type_is_integer(rhs_type, 32) &&
iree_hal_element_type_is_integer(acc_type, 32)) {
reference_matmul_int32_t_int32_t_int32_t_int32_t(
- m_size, k_size, n_size, lhs_type, rhs_type, acc_type,
+ m_size, k_size, n_size, lhs_type, rhs_type, acc_type, transpose_rhs,
(const int32_t*)lhs_data, (const int32_t*)rhs_data,
(const int32_t*)acc_data, (int32_t*)result_data, m, n);
} else if (lhs_type == IREE_HAL_ELEMENT_TYPE_FLOAT_16 &&
rhs_type == IREE_HAL_ELEMENT_TYPE_FLOAT_16 &&
acc_type == IREE_HAL_ELEMENT_TYPE_FLOAT_16) {
reference_matmul_f16_f16_f16_f16(
- m_size, k_size, n_size, lhs_type, rhs_type, acc_type,
+ m_size, k_size, n_size, lhs_type, rhs_type, acc_type, transpose_rhs,
(const uint16_t*)lhs_data, (const uint16_t*)rhs_data,
(const uint16_t*)acc_data, (uint16_t*)result_data, m, n);
} else if (lhs_type == IREE_HAL_ELEMENT_TYPE_FLOAT_16 &&
rhs_type == IREE_HAL_ELEMENT_TYPE_FLOAT_16 &&
acc_type == IREE_HAL_ELEMENT_TYPE_FLOAT_32) {
reference_matmul_f16_f16_f32_f32(
- m_size, k_size, n_size, lhs_type, rhs_type, acc_type,
+ m_size, k_size, n_size, lhs_type, rhs_type, acc_type, transpose_rhs,
(const uint16_t*)lhs_data, (const uint16_t*)rhs_data,
(const float*)acc_data, (float*)result_data, m, n);
} else if (lhs_type == IREE_HAL_ELEMENT_TYPE_BFLOAT_16 &&
rhs_type == IREE_HAL_ELEMENT_TYPE_BFLOAT_16 &&
acc_type == IREE_HAL_ELEMENT_TYPE_BFLOAT_16) {
reference_matmul_bf16_bf16_bf16_bf16(
- m_size, k_size, n_size, lhs_type, rhs_type, acc_type,
+ m_size, k_size, n_size, lhs_type, rhs_type, acc_type, transpose_rhs,
(const uint16_t*)lhs_data, (const uint16_t*)rhs_data,
(const uint16_t*)acc_data, (uint16_t*)result_data, m, n);
} else if (lhs_type == IREE_HAL_ELEMENT_TYPE_BFLOAT_16 &&
rhs_type == IREE_HAL_ELEMENT_TYPE_BFLOAT_16 &&
acc_type == IREE_HAL_ELEMENT_TYPE_FLOAT_32) {
reference_matmul_bf16_bf16_f32_f32(
- m_size, k_size, n_size, lhs_type, rhs_type, acc_type,
+ m_size, k_size, n_size, lhs_type, rhs_type, acc_type, transpose_rhs,
(const uint16_t*)lhs_data, (const uint16_t*)rhs_data,
(const float*)acc_data, (float*)result_data, m, n);
} else {
@@ -361,9 +370,10 @@
static iree_status_t reference_matmul(
iree_hal_dim_t m_size, iree_hal_dim_t k_size, iree_hal_dim_t n_size,
iree_hal_element_type_t lhs_type, iree_hal_element_type_t rhs_type,
- iree_hal_element_type_t acc_type, iree_byte_span_t lhs_contents,
- iree_byte_span_t rhs_contents, iree_byte_span_t acc_contents,
- iree_byte_span_t result_contents, int compute_every) {
+ iree_hal_element_type_t acc_type, bool transpose_rhs,
+ iree_byte_span_t lhs_contents, iree_byte_span_t rhs_contents,
+ iree_byte_span_t acc_contents, iree_byte_span_t result_contents,
+ int compute_every) {
IREE_TRACE_ZONE_BEGIN(z0);
IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, m_size);
IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, k_size);
@@ -375,7 +385,7 @@
if (++count < compute_every) continue;
count = 0;
IREE_RETURN_IF_ERROR(reference_matmul_element(
- m_size, k_size, n_size, lhs_type, rhs_type, acc_type,
+ m_size, k_size, n_size, lhs_type, rhs_type, acc_type, transpose_rhs,
lhs_contents.data, rhs_contents.data, acc_contents.data,
result_contents.data, m, n));
}
@@ -398,6 +408,7 @@
iree_hal_element_type_t rhs_type;
iree_hal_element_type_t acc_type;
iree_hal_element_type_t result_type;
+ bool transpose_rhs;
iree_byte_span_t lhs_contents;
iree_byte_span_t rhs_contents;
iree_byte_span_t acc_contents;
@@ -409,7 +420,7 @@
static iree_status_t matmul_results_initialize(
iree_hal_device_t* device, iree_hal_dim_t m_size, iree_hal_dim_t k_size,
- iree_hal_dim_t n_size, iree_hal_buffer_view_t* lhs,
+ iree_hal_dim_t n_size, uint32_t transpose_rhs, iree_hal_buffer_view_t* lhs,
iree_hal_buffer_view_t* rhs, iree_hal_buffer_view_t* acc,
iree_hal_buffer_view_t* result, iree_allocator_t host_allocator,
matmul_results_t* out_results) {
@@ -427,6 +438,8 @@
out_results->acc_type = iree_hal_buffer_view_element_type(result);
out_results->result_type = iree_hal_buffer_view_element_type(result);
+ out_results->transpose_rhs = transpose_rhs != 0;
+
iree_hal_buffer_t* lhs_buffer = iree_hal_buffer_view_buffer(lhs);
iree_hal_buffer_t* rhs_buffer = iree_hal_buffer_view_buffer(rhs);
iree_hal_buffer_t* acc_buffer = acc ? iree_hal_buffer_view_buffer(acc) : NULL;
@@ -776,11 +789,11 @@
IREE_TRACE_ZONE_BEGIN(z0);
IREE_RETURN_AND_END_ZONE_IF_ERROR(
- z0, reference_matmul(results->m, results->k, results->n,
- results->lhs_type, results->rhs_type,
- results->acc_type, results->lhs_contents,
- results->rhs_contents, results->acc_contents,
- results->expected_contents, check_every));
+ z0, reference_matmul(
+ results->m, results->k, results->n, results->lhs_type,
+ results->rhs_type, results->acc_type, results->transpose_rhs,
+ results->lhs_contents, results->rhs_contents,
+ results->acc_contents, results->expected_contents, check_every));
int count = 0;
for (iree_hal_dim_t m = 0; m < results->m; ++m) {
@@ -1009,15 +1022,15 @@
Status CheckMatmulResults(
const vm::ref<iree_hal_device_t> device, int64_t m, int64_t k, int64_t n,
- const vm::ref<iree_hal_buffer_view_t> lhs,
+ int32_t transpose_rhs, const vm::ref<iree_hal_buffer_view_t> lhs,
const vm::ref<iree_hal_buffer_view_t> rhs,
const vm::ref<iree_hal_buffer_view_t> acc,
const vm::ref<iree_hal_buffer_view_t> actual_result) {
matmul_results_t results = {};
IREE_RETURN_IF_ERROR(matmul_results_initialize(
device.get(), (iree_hal_dim_t)m, (iree_hal_dim_t)k, (iree_hal_dim_t)n,
- lhs.get(), rhs.get(), acc.get(), actual_result.get(), host_allocator_,
- &results));
+ transpose_rhs, lhs.get(), rhs.get(), acc.get(), actual_result.get(),
+ host_allocator_, &results));
iree_status_t status = check_matmul_results(stderr, &results);
matmul_results_deinitialize(&results);
return status;