[matmul] Add transpose B matrix coverage for CDNA3 (#16558)

This commit adds transpose B matrix coverage in the matmul test suite.
This is to enable adding such tests for CDNA3 mfma CodeGen pipeline.

ci-extra: test_gpu
diff --git a/tests/e2e/matmul/BUILD.bazel b/tests/e2e/matmul/BUILD.bazel
index 99ffe99..587d373 100644
--- a/tests/e2e/matmul/BUILD.bazel
+++ b/tests/e2e/matmul/BUILD.bazel
@@ -430,9 +430,8 @@
 ###########################################################################
 
 # Testing CDNA3 + matrix core path.
-# v_mfma_f32_16x16x16_f16
 iree_generated_e2e_matmul_test(
-    name = "e2e_matmul_rocm_f16_large_cdna3_matrixcore",
+    name = "e2e_matmul_rocm_f16_large_cdna3_mfma",
     compiler_flags = [
         "--iree-rocm-target-chip=gfx942",
     ],
@@ -456,6 +455,32 @@
     test_runner = "//tools:iree-e2e-matmul-test",
 )
 
+iree_generated_e2e_matmul_test(
+    name = "e2e_matmul_rocm_f16_large_cdna3_mfma_tb",
+    compiler_flags = [
+        "--iree-rocm-target-chip=gfx942",
+    ],
+    generator = ":generate_e2e_matmul_tests",
+    generator_args = [
+        "--lhs_rhs_type=f16",
+        "--acc_type=f32",
+        "--transpose_rhs",
+        "--shapes=gpu_large_aligned",
+        "--compilation_info=LLVMGPUVectorDistribute",
+    ],
+    tags = [
+        "noasan",
+        "nomsan",
+        "notsan",
+        "noubsan",
+        "requires-gpu-cdna3",
+    ],
+    target_backends_and_drivers = [
+        ("rocm", "rocm"),
+    ],
+    test_runner = "//tools:iree-e2e-matmul-test",
+)
+
 ###########################################################################
 ##
 ## Vulkan backend
diff --git a/tests/e2e/matmul/CMakeLists.txt b/tests/e2e/matmul/CMakeLists.txt
index 9939e35..6b9fdf9 100644
--- a/tests/e2e/matmul/CMakeLists.txt
+++ b/tests/e2e/matmul/CMakeLists.txt
@@ -978,7 +978,7 @@
 
 iree_generated_e2e_matmul_test(
   NAME
-    e2e_matmul_rocm_f16_large_cdna3_matrixcore
+    e2e_matmul_rocm_f16_large_cdna3_mfma
   GENERATOR
     "generate_e2e_matmul_tests.py"
   GENERATOR_ARGS
@@ -1004,6 +1004,33 @@
 
 iree_generated_e2e_matmul_test(
   NAME
+    e2e_matmul_rocm_f16_large_cdna3_mfma_tb
+  GENERATOR
+    "generate_e2e_matmul_tests.py"
+  GENERATOR_ARGS
+    "--lhs_rhs_type=f16"
+    "--acc_type=f32"
+    "--transpose_rhs"
+    "--shapes=gpu_large_aligned"
+    "--compilation_info=LLVMGPUVectorDistribute"
+  TEST_RUNNER
+    iree-e2e-matmul-test
+  TARGET_BACKENDS
+    "rocm"
+  DRIVERS
+    "rocm"
+  COMPILER_FLAGS
+    "--iree-rocm-target-chip=gfx942"
+  LABELS
+    "noasan"
+    "nomsan"
+    "notsan"
+    "noubsan"
+    "requires-gpu-cdna3"
+)
+
+iree_generated_e2e_matmul_test(
+  NAME
     e2e_matmul_vulkan_i8_large_valhall
   GENERATOR
     "generate_e2e_matmul_tests.py"
diff --git a/tests/e2e/matmul/generate_e2e_matmul_tests.py b/tests/e2e/matmul/generate_e2e_matmul_tests.py
index e295a3d..1b443a8 100644
--- a/tests/e2e/matmul/generate_e2e_matmul_tests.py
+++ b/tests/e2e/matmul/generate_e2e_matmul_tests.py
@@ -420,14 +420,24 @@
 # Helper for generate_function. Generates TestInputMatricesShapes, i.e.
 # converts from the runtime shape dimensions in TestShape and given dynamicity to
 # the set of shapes to be used in a test function's input tensors.
-def generate_shapes(shape: TestShape, dynamicity: Dynamicity):
+def generate_shapes(shape: TestShape, transpose_rhs: bool, dynamicity: Dynamicity):
+    lhs_rows = shape_dim(shape.m, dynamicity)
+    lhs_cols = shape_dim(shape.k, dynamicity)
+    acc_rows = shape_dim(shape.m, dynamicity)
+    acc_cols = shape_dim(shape.n, dynamicity)
+    if transpose_rhs:
+        rhs_rows = shape_dim(shape.n, dynamicity)
+        rhs_cols = shape_dim(shape.k, dynamicity)
+    else:
+        rhs_rows = shape_dim(shape.k, dynamicity)
+        rhs_cols = shape_dim(shape.n, dynamicity)
     shapes = TestInputMatricesShapes(
-        lhs_rows=shape_dim(shape.m, dynamicity),
-        lhs_cols=shape_dim(shape.k, dynamicity),
-        rhs_rows=shape_dim(shape.k, dynamicity),
-        rhs_cols=shape_dim(shape.n, dynamicity),
-        acc_rows=shape_dim(shape.m, dynamicity),
-        acc_cols=shape_dim(shape.n, dynamicity),
+        lhs_rows=lhs_rows,
+        lhs_cols=lhs_cols,
+        rhs_rows=rhs_rows,
+        rhs_cols=rhs_cols,
+        acc_rows=acc_rows,
+        acc_cols=acc_cols,
     )
     return shapes
 
@@ -443,12 +453,12 @@
 ):
     input_t = lhs_rhs_type.value
     acc_t = acc_type.value
-    lhs_m = int_or_DYN(shapes.lhs_rows)
-    lhs_k = int_or_DYN(shapes.lhs_cols)
-    rhs_k = int_or_DYN(shapes.rhs_rows)
-    rhs_n = int_or_DYN(shapes.rhs_cols)
-    acc_m = int_or_DYN(shapes.acc_rows)
-    acc_n = int_or_DYN(shapes.acc_cols)
+    lhs_r = int_or_DYN(shapes.lhs_rows)
+    lhs_c = int_or_DYN(shapes.lhs_cols)
+    rhs_r = int_or_DYN(shapes.rhs_rows)
+    rhs_c = int_or_DYN(shapes.rhs_cols)
+    acc_r = int_or_DYN(shapes.acc_rows)
+    acc_c = int_or_DYN(shapes.acc_cols)
 
     info = ""
     if compilation_info:
@@ -462,8 +472,8 @@
 
     matmul_kind = "matmul_accumulate" if accumulate else "matmul"
     return (
-        f"{matmul_kind}_{lhs_m}x{lhs_k}x{input_t}_times_"
-        + f"{rhs_k}x{rhs_n}x{input_t}_into_{acc_m}x{acc_n}x{acc_t}{info}"
+        f"{matmul_kind}_{lhs_r}x{lhs_c}x{input_t}_times_"
+        + f"{rhs_r}x{rhs_c}x{input_t}_into_{acc_r}x{acc_c}x{acc_t}{info}"
     )
 
 
@@ -477,28 +487,34 @@
 
 
 # Generates a test function in the generated MLIR code.
-# The generated function will take the same arguments as linalg.matmul and
-# will just call linalg.matmul with them, returning its result.
+# The generated function will take the same arguments as linalg.matmul variants
+# and will just call linalg.matmul variants with them, returning its result.
 def generate_function(
     lhs_rhs_type: MatrixElemTypeId,
     acc_type: MatrixElemTypeId,
     shape: TestShape,
+    transpose_rhs: bool,
     dynamicity: Dynamicity,
     compilation_info: typing.Optional[CompilationInfo] = None,
 ):
-    shapes = generate_shapes(shape, dynamicity)
+    shapes = generate_shapes(shape, transpose_rhs, dynamicity)
     func_name = generate_function_name(
         lhs_rhs_type, acc_type, shapes, shape.accumulate, compilation_info
     )
-    lhs_m = int_or_question_mark(shapes.lhs_rows)
-    lhs_k = int_or_question_mark(shapes.lhs_cols)
-    rhs_k = int_or_question_mark(shapes.rhs_rows)
-    rhs_n = int_or_question_mark(shapes.rhs_cols)
-    acc_m = int_or_question_mark(shapes.acc_rows)
-    acc_n = int_or_question_mark(shapes.acc_cols)
-    lhs_tensor_type = f"tensor<{lhs_m}x{lhs_k}x{lhs_rhs_type.value}>"
-    rhs_tensor_type = f"tensor<{rhs_k}x{rhs_n}x{lhs_rhs_type.value}>"
-    acc_tensor_type = f"tensor<{acc_m}x{acc_n}x{acc_type.value}>"
+    lhs_r = int_or_question_mark(shapes.lhs_rows)
+    lhs_c = int_or_question_mark(shapes.lhs_cols)
+    rhs_r = int_or_question_mark(shapes.rhs_rows)
+    rhs_c = int_or_question_mark(shapes.rhs_cols)
+    acc_r = int_or_question_mark(shapes.acc_rows)
+    acc_c = int_or_question_mark(shapes.acc_cols)
+    lhs_tensor_type = f"tensor<{lhs_r}x{lhs_c}x{lhs_rhs_type.value}>"
+    rhs_tensor_type = f"tensor<{rhs_r}x{rhs_c}x{lhs_rhs_type.value}>"
+    acc_tensor_type = f"tensor<{acc_r}x{acc_c}x{acc_type.value}>"
+
+    if transpose_rhs:
+        op_name = "linalg.matmul_transpose_b"
+    else:
+        op_name = "linalg.matmul"
 
     # Compilation info is optional; prints empty string by default.
     func_definition = ""
@@ -537,13 +553,13 @@
         import_declaration = f"func.func private @module.{func_name}(%lhs: !hal.buffer_view, %rhs: !hal.buffer_view, %acc: !hal.buffer_view) -> !hal.buffer_view"
         func_definition = func_definition + (
             f"func.func @{func_name}(%lhs: {lhs_tensor_type}, %rhs: {rhs_tensor_type}, %acc: {acc_tensor_type}) -> {acc_tensor_type} {{\n"
-            f"  %result = linalg.matmul {compilation_info_attr}ins(%lhs, %rhs: {lhs_tensor_type}, {rhs_tensor_type}) outs(%acc: {acc_tensor_type}) -> {acc_tensor_type}\n"
+            f"  %result = {op_name} {compilation_info_attr}ins(%lhs, %rhs: {lhs_tensor_type}, {rhs_tensor_type}) outs(%acc: {acc_tensor_type}) -> {acc_tensor_type}\n"
             f"  return %result: {acc_tensor_type}\n"
             f"}}\n"
         )
     else:
         literal_zero_for_acc_type = "0.0" if "f" in acc_type.value else "0"
-        if acc_m == "?":
+        if acc_r == "?":
             signature = f"({lhs_tensor_type}, {rhs_tensor_type}) -> {acc_tensor_type}"
             import_declaration = f"func.func private @module.{func_name}(%lhs: !hal.buffer_view, %rhs: !hal.buffer_view) -> !hal.buffer_view"
             func_definition = func_definition + (
@@ -555,7 +571,7 @@
                 f"  %init_acc = tensor.empty(%acc_dim0, %acc_dim1) : {acc_tensor_type}\n"
                 f"  %c0_acc_type = arith.constant {literal_zero_for_acc_type}: {acc_type.value}\n"
                 f"  %acc = linalg.fill ins(%c0_acc_type : {acc_type.value}) outs(%init_acc : {acc_tensor_type}) -> {acc_tensor_type}\n"
-                f"  %result = linalg.matmul {compilation_info_attr}ins(%lhs, %rhs: {lhs_tensor_type}, {rhs_tensor_type}) outs(%acc: {acc_tensor_type}) -> {acc_tensor_type}\n"
+                f"  %result = {op_name} {compilation_info_attr}ins(%lhs, %rhs: {lhs_tensor_type}, {rhs_tensor_type}) outs(%acc: {acc_tensor_type}) -> {acc_tensor_type}\n"
                 f"  return %result: {acc_tensor_type}\n"
                 f"}}\n"
             )
@@ -567,7 +583,7 @@
                 f"  %init_acc = tensor.empty() : {acc_tensor_type}\n"
                 f"  %c0_acc_type = arith.constant {literal_zero_for_acc_type}: {acc_type.value}\n"
                 f"  %acc = linalg.fill ins(%c0_acc_type : {acc_type.value}) outs(%init_acc : {acc_tensor_type}) -> {acc_tensor_type}\n"
-                f"  %result = linalg.matmul {compilation_info_attr}ins(%lhs, %rhs: {lhs_tensor_type}, {rhs_tensor_type}) outs(%acc: {acc_tensor_type}) -> {acc_tensor_type}\n"
+                f"  %result = {op_name} {compilation_info_attr}ins(%lhs, %rhs: {lhs_tensor_type}, {rhs_tensor_type}) outs(%acc: {acc_tensor_type}) -> {acc_tensor_type}\n"
                 f"  return %result: {acc_tensor_type}\n"
                 f"}}\n"
             )
@@ -635,6 +651,7 @@
     lhs_rhs_type: MatrixElemTypeId,
     acc_type: MatrixElemTypeId,
     shape: TestShape,
+    transpose_rhs: bool,
 ):
     global call_id
     func_name = f"{function.name}_{shape.m}_{shape.k}_{shape.n}"
@@ -652,8 +669,16 @@
         "  %device = hal.devices.get %device_index : !hal.device\n"
     )
 
-    op = op + generate_random_matrix("lhs", [shape.m, shape.k], lhs_rhs_type)
-    op = op + generate_random_matrix("rhs", [shape.k, shape.n], lhs_rhs_type)
+    lhs_shape = [shape.m, shape.k]
+    if transpose_rhs:
+        rhs_shape = [shape.n, shape.k]
+        transpose_rhs = 1
+    else:
+        rhs_shape = [shape.k, shape.n]
+        transpose_rhs = 0
+
+    op = op + generate_random_matrix("lhs", lhs_shape, lhs_rhs_type)
+    op = op + generate_random_matrix("rhs", rhs_shape, lhs_rhs_type)
     if shape.accumulate:
         op = op + generate_random_matrix("acc", [shape.m, shape.n], acc_type)
         # TODO(#16168): there's a bug with in-place input->output aliasing and
@@ -674,7 +699,8 @@
         f"  %m = arith.constant {shape.m} : i64\n"
         f"  %k = arith.constant {shape.k} : i64\n"
         f"  %n = arith.constant {shape.n} : i64\n"
-        f"  call @matmul_test.check_matmul_results(%device, %m, %k, %n, %lhs, %rhs, %acc, %result) : (!hal.device, i64, i64, i64, !hal.buffer_view, !hal.buffer_view, !hal.buffer_view, !hal.buffer_view) -> ()\n"
+        f"  %transpose_rhs = arith.constant {transpose_rhs} : i32\n"
+        f"  call @matmul_test.check_matmul_results(%device, %m, %k, %n, %transpose_rhs, %lhs, %rhs, %acc, %result) : (!hal.device, i64, i64, i64, i32, !hal.buffer_view, !hal.buffer_view, !hal.buffer_view, !hal.buffer_view) -> ()\n"
     )
 
     op = op + "  return\n"
@@ -688,6 +714,7 @@
     lhs_rhs_type: MatrixElemTypeId,
     acc_type: MatrixElemTypeId,
     shapes_id: ShapesId,
+    transpose_rhs: bool,
     compilation_info_id: CompilationInfoId,
 ):
     functions = {}
@@ -699,7 +726,12 @@
         for shape in get_test_shapes(shapes_id):
             for dynamicity in get_dynamicities(shapes_id):
                 function = generate_function(
-                    lhs_rhs_type, acc_type, shape, dynamicity, compilation_info
+                    lhs_rhs_type,
+                    acc_type,
+                    shape,
+                    transpose_rhs,
+                    dynamicity,
+                    compilation_info,
                 )
                 # Different testcases may differ only by runtime parameters but
                 # share the same code. For example, dynamic-shapes testcases
@@ -708,7 +740,11 @@
                 # to calls, but unconditionally to function_definitions.
                 if function.name not in functions:
                     functions[function.name] = function
-                calls.append(generate_call(function, lhs_rhs_type, acc_type, shape))
+                calls.append(
+                    generate_call(
+                        function, lhs_rhs_type, acc_type, shape, transpose_rhs
+                    )
+                )
 
     return (functions, calls)
 
@@ -750,6 +786,13 @@
         required=True,
     )
     parser.add_argument(
+        "--transpose_rhs",
+        action="store_true",
+        help="Whether to transpose RHS",
+        default=False,
+        required=False,
+    )
+    parser.add_argument(
         "--compilation_info",
         type=str,
         choices=[i.value for i in CompilationInfoId],
@@ -790,7 +833,7 @@
     # Declare the custom module that generates arguments.
     module_definition = module_definition + (
         "func.func private @matmul_test.generate_random_matrix(%device: !hal.device, %dim0: i64, %dim1: i64, %element_type: i32, %seed: i32) -> !hal.buffer_view\n"
-        "func.func private @matmul_test.check_matmul_results(%device: !hal.device, %m: i64, %k: i64, %n: i64, %lhs: !hal.buffer_view, %rhs: !hal.buffer_view, %acc: !hal.buffer_view, %actual_result: !hal.buffer_view)\n"
+        "func.func private @matmul_test.check_matmul_results(%device: !hal.device, %m: i64, %k: i64, %n: i64, %transpose_rhs: i32, %lhs: !hal.buffer_view, %rhs: !hal.buffer_view, %acc: !hal.buffer_view, %actual_result: !hal.buffer_view)\n"
         "\n"
     )
 
@@ -827,8 +870,9 @@
     acc_type = infer_acc_type(lhs_rhs_type, acc_type)
     shapes_id = ShapesId(args.shapes)
     compilation_info_id = CompilationInfoId(args.compilation_info)
+
     (functions, calls) = generate(
-        lhs_rhs_type, acc_type, shapes_id, compilation_info_id
+        lhs_rhs_type, acc_type, shapes_id, args.transpose_rhs, compilation_info_id
     )
 
     write_code_file(functions, args.output_matmuls_mlir)
diff --git a/tools/iree-e2e-matmul-test.cc b/tools/iree-e2e-matmul-test.cc
index a4d5b11..e165830 100644
--- a/tools/iree-e2e-matmul-test.cc
+++ b/tools/iree-e2e-matmul-test.cc
@@ -199,20 +199,22 @@
   return iree_ok_status();
 }
 
-#define REFERENCE_MATMUL(LHSTYPE, RHSTYPE, RESTYPE, ACCTYPE)                  \
-  static void reference_matmul_##LHSTYPE##_##RHSTYPE##_##RESTYPE##_##ACCTYPE( \
-      iree_hal_dim_t m_size, iree_hal_dim_t k_size, iree_hal_dim_t n_size,    \
-      iree_hal_element_type_t lhs_type, iree_hal_element_type_t rhs_type,     \
-      iree_hal_element_type_t acc_type, const LHSTYPE* lhs_data,              \
-      const RHSTYPE* rhs_data, const ACCTYPE* acc_data, RESTYPE* result_data, \
-      iree_hal_dim_t m, iree_hal_dim_t n) {                                   \
-    ACCTYPE acc = acc_data ? acc_data[n + m * n_size] : 0;                    \
-    for (iree_hal_dim_t k = 0; k < k_size; ++k) {                             \
-      LHSTYPE lhs_value = lhs_data[k + m * k_size];                           \
-      RHSTYPE rhs_value = rhs_data[n + k * n_size];                           \
-      acc += (ACCTYPE)lhs_value * (ACCTYPE)rhs_value;                         \
-    }                                                                         \
-    result_data[n + m * n_size] = acc;                                        \
+#define REFERENCE_MATMUL(LHSTYPE, RHSTYPE, RESTYPE, ACCTYPE)                   \
+  static void reference_matmul_##LHSTYPE##_##RHSTYPE##_##RESTYPE##_##ACCTYPE(  \
+      iree_hal_dim_t m_size, iree_hal_dim_t k_size, iree_hal_dim_t n_size,     \
+      iree_hal_element_type_t lhs_type, iree_hal_element_type_t rhs_type,      \
+      iree_hal_element_type_t acc_type, bool transpose_rhs,                    \
+      const LHSTYPE* lhs_data, const RHSTYPE* rhs_data,                        \
+      const ACCTYPE* acc_data, RESTYPE* result_data, iree_hal_dim_t m,         \
+      iree_hal_dim_t n) {                                                      \
+    ACCTYPE acc = acc_data ? acc_data[n + m * n_size] : 0;                     \
+    for (iree_hal_dim_t k = 0; k < k_size; ++k) {                              \
+      LHSTYPE lhs_value = lhs_data[k + m * k_size];                            \
+      RHSTYPE rhs_value =                                                      \
+          transpose_rhs ? rhs_data[k + n * k_size] : rhs_data[n + k * n_size]; \
+      acc += (ACCTYPE)lhs_value * (ACCTYPE)rhs_value;                          \
+    }                                                                          \
+    result_data[n + m * n_size] = acc;                                         \
   }
 
 // Reference mamtul instantiations from macro REFERENCE_MATMUL
@@ -235,13 +237,15 @@
 static void reference_matmul_f16_f16_f16_f16(
     iree_hal_dim_t m_size, iree_hal_dim_t k_size, iree_hal_dim_t n_size,
     iree_hal_element_type_t lhs_type, iree_hal_element_type_t rhs_type,
-    iree_hal_element_type_t acc_type, const uint16_t* lhs_data,
-    const uint16_t* rhs_data, const uint16_t* acc_data, uint16_t* result_data,
-    iree_hal_dim_t m, iree_hal_dim_t n) {
+    iree_hal_element_type_t acc_type, bool transpose_rhs,
+    const uint16_t* lhs_data, const uint16_t* rhs_data,
+    const uint16_t* acc_data, uint16_t* result_data, iree_hal_dim_t m,
+    iree_hal_dim_t n) {
   float acc = acc_data ? iree_math_f16_to_f32(acc_data[n + m * n_size]) : 0.f;
   for (iree_hal_dim_t k = 0; k < k_size; ++k) {
+    int64_t rhs_index = transpose_rhs ? k + n * k_size : n + k * n_size;
     acc += iree_math_f16_to_f32(lhs_data[k + m * k_size]) *
-           iree_math_f16_to_f32(rhs_data[n + k * n_size]);
+           iree_math_f16_to_f32(rhs_data[rhs_index]);
   }
   result_data[n + m * n_size] = iree_math_f32_to_f16(acc);
 }
@@ -251,13 +255,14 @@
 static void reference_matmul_f16_f16_f32_f32(
     iree_hal_dim_t m_size, iree_hal_dim_t k_size, iree_hal_dim_t n_size,
     iree_hal_element_type_t lhs_type, iree_hal_element_type_t rhs_type,
-    iree_hal_element_type_t acc_type, const uint16_t* lhs_data,
-    const uint16_t* rhs_data, const float* acc_data, float* result_data,
-    iree_hal_dim_t m, iree_hal_dim_t n) {
+    iree_hal_element_type_t acc_type, bool transpose_rhs,
+    const uint16_t* lhs_data, const uint16_t* rhs_data, const float* acc_data,
+    float* result_data, iree_hal_dim_t m, iree_hal_dim_t n) {
   float acc = acc_data ? acc_data[n + m * n_size] : 0.f;
   for (iree_hal_dim_t k = 0; k < k_size; ++k) {
+    int64_t rhs_index = transpose_rhs ? k + n * k_size : n + k * n_size;
     acc += iree_math_f16_to_f32(lhs_data[k + m * k_size]) *
-           iree_math_f16_to_f32(rhs_data[n + k * n_size]);
+           iree_math_f16_to_f32(rhs_data[rhs_index]);
   }
   result_data[n + m * n_size] = acc;
 }
@@ -267,13 +272,15 @@
 static void reference_matmul_bf16_bf16_bf16_bf16(
     iree_hal_dim_t m_size, iree_hal_dim_t k_size, iree_hal_dim_t n_size,
     iree_hal_element_type_t lhs_type, iree_hal_element_type_t rhs_type,
-    iree_hal_element_type_t acc_type, const uint16_t* lhs_data,
-    const uint16_t* rhs_data, const uint16_t* acc_data, uint16_t* result_data,
-    iree_hal_dim_t m, iree_hal_dim_t n) {
+    iree_hal_element_type_t acc_type, bool transpose_rhs,
+    const uint16_t* lhs_data, const uint16_t* rhs_data,
+    const uint16_t* acc_data, uint16_t* result_data, iree_hal_dim_t m,
+    iree_hal_dim_t n) {
   float acc = acc_data ? iree_math_bf16_to_f32(acc_data[n + m * n_size]) : 0.f;
   for (iree_hal_dim_t k = 0; k < k_size; ++k) {
+    int64_t rhs_index = transpose_rhs ? k + n * k_size : n + k * n_size;
     acc += iree_math_bf16_to_f32(lhs_data[k + m * k_size]) *
-           iree_math_bf16_to_f32(rhs_data[n + k * n_size]);
+           iree_math_bf16_to_f32(rhs_data[rhs_index]);
   }
   result_data[n + m * n_size] = iree_math_f32_to_bf16(acc);
 }
@@ -283,13 +290,14 @@
 static void reference_matmul_bf16_bf16_f32_f32(
     iree_hal_dim_t m_size, iree_hal_dim_t k_size, iree_hal_dim_t n_size,
     iree_hal_element_type_t lhs_type, iree_hal_element_type_t rhs_type,
-    iree_hal_element_type_t acc_type, const uint16_t* lhs_data,
-    const uint16_t* rhs_data, const float* acc_data, float* result_data,
-    iree_hal_dim_t m, iree_hal_dim_t n) {
+    iree_hal_element_type_t acc_type, bool transpose_rhs,
+    const uint16_t* lhs_data, const uint16_t* rhs_data, const float* acc_data,
+    float* result_data, iree_hal_dim_t m, iree_hal_dim_t n) {
   float acc = acc_data ? acc_data[n + m * n_size] : 0.f;
   for (iree_hal_dim_t k = 0; k < k_size; ++k) {
+    int64_t rhs_index = transpose_rhs ? k + n * k_size : n + k * n_size;
     acc += iree_math_bf16_to_f32(lhs_data[k + m * k_size]) *
-           iree_math_bf16_to_f32(rhs_data[n + k * n_size]);
+           iree_math_bf16_to_f32(rhs_data[rhs_index]);
   }
   result_data[n + m * n_size] = acc;
 }
@@ -299,55 +307,56 @@
 static iree_status_t reference_matmul_element(
     iree_hal_dim_t m_size, iree_hal_dim_t k_size, iree_hal_dim_t n_size,
     iree_hal_element_type_t lhs_type, iree_hal_element_type_t rhs_type,
-    iree_hal_element_type_t acc_type, void* lhs_data, void* rhs_data,
-    void* acc_data, void* result_data, iree_hal_dim_t m, iree_hal_dim_t n) {
+    iree_hal_element_type_t acc_type, bool transpose_rhs, void* lhs_data,
+    void* rhs_data, void* acc_data, void* result_data, iree_hal_dim_t m,
+    iree_hal_dim_t n) {
   if (lhs_type == IREE_HAL_ELEMENT_TYPE_FLOAT_32 &&
       rhs_type == IREE_HAL_ELEMENT_TYPE_FLOAT_32 &&
       acc_type == IREE_HAL_ELEMENT_TYPE_FLOAT_32) {
     reference_matmul_float_float_float_float(
-        m_size, k_size, n_size, lhs_type, rhs_type, acc_type,
+        m_size, k_size, n_size, lhs_type, rhs_type, acc_type, transpose_rhs,
         (const float*)lhs_data, (const float*)rhs_data, (const float*)acc_data,
         (float*)result_data, m, n);
   } else if (iree_hal_element_type_is_integer(lhs_type, 8) &&
              iree_hal_element_type_is_integer(rhs_type, 8) &&
              iree_hal_element_type_is_integer(acc_type, 32)) {
     reference_matmul_int8_t_int8_t_int32_t_int32_t(
-        m_size, k_size, n_size, lhs_type, rhs_type, acc_type,
+        m_size, k_size, n_size, lhs_type, rhs_type, acc_type, transpose_rhs,
         (const int8_t*)lhs_data, (const int8_t*)rhs_data,
         (const int32_t*)acc_data, (int32_t*)result_data, m, n);
   } else if (iree_hal_element_type_is_integer(lhs_type, 32) &&
              iree_hal_element_type_is_integer(rhs_type, 32) &&
              iree_hal_element_type_is_integer(acc_type, 32)) {
     reference_matmul_int32_t_int32_t_int32_t_int32_t(
-        m_size, k_size, n_size, lhs_type, rhs_type, acc_type,
+        m_size, k_size, n_size, lhs_type, rhs_type, acc_type, transpose_rhs,
         (const int32_t*)lhs_data, (const int32_t*)rhs_data,
         (const int32_t*)acc_data, (int32_t*)result_data, m, n);
   } else if (lhs_type == IREE_HAL_ELEMENT_TYPE_FLOAT_16 &&
              rhs_type == IREE_HAL_ELEMENT_TYPE_FLOAT_16 &&
              acc_type == IREE_HAL_ELEMENT_TYPE_FLOAT_16) {
     reference_matmul_f16_f16_f16_f16(
-        m_size, k_size, n_size, lhs_type, rhs_type, acc_type,
+        m_size, k_size, n_size, lhs_type, rhs_type, acc_type, transpose_rhs,
         (const uint16_t*)lhs_data, (const uint16_t*)rhs_data,
         (const uint16_t*)acc_data, (uint16_t*)result_data, m, n);
   } else if (lhs_type == IREE_HAL_ELEMENT_TYPE_FLOAT_16 &&
              rhs_type == IREE_HAL_ELEMENT_TYPE_FLOAT_16 &&
              acc_type == IREE_HAL_ELEMENT_TYPE_FLOAT_32) {
     reference_matmul_f16_f16_f32_f32(
-        m_size, k_size, n_size, lhs_type, rhs_type, acc_type,
+        m_size, k_size, n_size, lhs_type, rhs_type, acc_type, transpose_rhs,
         (const uint16_t*)lhs_data, (const uint16_t*)rhs_data,
         (const float*)acc_data, (float*)result_data, m, n);
   } else if (lhs_type == IREE_HAL_ELEMENT_TYPE_BFLOAT_16 &&
              rhs_type == IREE_HAL_ELEMENT_TYPE_BFLOAT_16 &&
              acc_type == IREE_HAL_ELEMENT_TYPE_BFLOAT_16) {
     reference_matmul_bf16_bf16_bf16_bf16(
-        m_size, k_size, n_size, lhs_type, rhs_type, acc_type,
+        m_size, k_size, n_size, lhs_type, rhs_type, acc_type, transpose_rhs,
         (const uint16_t*)lhs_data, (const uint16_t*)rhs_data,
         (const uint16_t*)acc_data, (uint16_t*)result_data, m, n);
   } else if (lhs_type == IREE_HAL_ELEMENT_TYPE_BFLOAT_16 &&
              rhs_type == IREE_HAL_ELEMENT_TYPE_BFLOAT_16 &&
              acc_type == IREE_HAL_ELEMENT_TYPE_FLOAT_32) {
     reference_matmul_bf16_bf16_f32_f32(
-        m_size, k_size, n_size, lhs_type, rhs_type, acc_type,
+        m_size, k_size, n_size, lhs_type, rhs_type, acc_type, transpose_rhs,
         (const uint16_t*)lhs_data, (const uint16_t*)rhs_data,
         (const float*)acc_data, (float*)result_data, m, n);
   } else {
@@ -361,9 +370,10 @@
 static iree_status_t reference_matmul(
     iree_hal_dim_t m_size, iree_hal_dim_t k_size, iree_hal_dim_t n_size,
     iree_hal_element_type_t lhs_type, iree_hal_element_type_t rhs_type,
-    iree_hal_element_type_t acc_type, iree_byte_span_t lhs_contents,
-    iree_byte_span_t rhs_contents, iree_byte_span_t acc_contents,
-    iree_byte_span_t result_contents, int compute_every) {
+    iree_hal_element_type_t acc_type, bool transpose_rhs,
+    iree_byte_span_t lhs_contents, iree_byte_span_t rhs_contents,
+    iree_byte_span_t acc_contents, iree_byte_span_t result_contents,
+    int compute_every) {
   IREE_TRACE_ZONE_BEGIN(z0);
   IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, m_size);
   IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, k_size);
@@ -375,7 +385,7 @@
       if (++count < compute_every) continue;
       count = 0;
       IREE_RETURN_IF_ERROR(reference_matmul_element(
-          m_size, k_size, n_size, lhs_type, rhs_type, acc_type,
+          m_size, k_size, n_size, lhs_type, rhs_type, acc_type, transpose_rhs,
           lhs_contents.data, rhs_contents.data, acc_contents.data,
           result_contents.data, m, n));
     }
@@ -398,6 +408,7 @@
   iree_hal_element_type_t rhs_type;
   iree_hal_element_type_t acc_type;
   iree_hal_element_type_t result_type;
+  bool transpose_rhs;
   iree_byte_span_t lhs_contents;
   iree_byte_span_t rhs_contents;
   iree_byte_span_t acc_contents;
@@ -409,7 +420,7 @@
 
 static iree_status_t matmul_results_initialize(
     iree_hal_device_t* device, iree_hal_dim_t m_size, iree_hal_dim_t k_size,
-    iree_hal_dim_t n_size, iree_hal_buffer_view_t* lhs,
+    iree_hal_dim_t n_size, uint32_t transpose_rhs, iree_hal_buffer_view_t* lhs,
     iree_hal_buffer_view_t* rhs, iree_hal_buffer_view_t* acc,
     iree_hal_buffer_view_t* result, iree_allocator_t host_allocator,
     matmul_results_t* out_results) {
@@ -427,6 +438,8 @@
   out_results->acc_type = iree_hal_buffer_view_element_type(result);
   out_results->result_type = iree_hal_buffer_view_element_type(result);
 
+  out_results->transpose_rhs = transpose_rhs != 0;
+
   iree_hal_buffer_t* lhs_buffer = iree_hal_buffer_view_buffer(lhs);
   iree_hal_buffer_t* rhs_buffer = iree_hal_buffer_view_buffer(rhs);
   iree_hal_buffer_t* acc_buffer = acc ? iree_hal_buffer_view_buffer(acc) : NULL;
@@ -776,11 +789,11 @@
   IREE_TRACE_ZONE_BEGIN(z0);
 
   IREE_RETURN_AND_END_ZONE_IF_ERROR(
-      z0, reference_matmul(results->m, results->k, results->n,
-                           results->lhs_type, results->rhs_type,
-                           results->acc_type, results->lhs_contents,
-                           results->rhs_contents, results->acc_contents,
-                           results->expected_contents, check_every));
+      z0, reference_matmul(
+              results->m, results->k, results->n, results->lhs_type,
+              results->rhs_type, results->acc_type, results->transpose_rhs,
+              results->lhs_contents, results->rhs_contents,
+              results->acc_contents, results->expected_contents, check_every));
 
   int count = 0;
   for (iree_hal_dim_t m = 0; m < results->m; ++m) {
@@ -1009,15 +1022,15 @@
 
   Status CheckMatmulResults(
       const vm::ref<iree_hal_device_t> device, int64_t m, int64_t k, int64_t n,
-      const vm::ref<iree_hal_buffer_view_t> lhs,
+      int32_t transpose_rhs, const vm::ref<iree_hal_buffer_view_t> lhs,
       const vm::ref<iree_hal_buffer_view_t> rhs,
       const vm::ref<iree_hal_buffer_view_t> acc,
       const vm::ref<iree_hal_buffer_view_t> actual_result) {
     matmul_results_t results = {};
     IREE_RETURN_IF_ERROR(matmul_results_initialize(
         device.get(), (iree_hal_dim_t)m, (iree_hal_dim_t)k, (iree_hal_dim_t)n,
-        lhs.get(), rhs.get(), acc.get(), actual_result.get(), host_allocator_,
-        &results));
+        transpose_rhs, lhs.get(), rhs.get(), acc.get(), actual_result.get(),
+        host_allocator_, &results));
     iree_status_t status = check_matmul_results(stderr, &results);
     matmul_results_deinitialize(&results);
     return status;