GPU data tiling: query the target's list of MMA intrinsics. Add FP8 test. (#18862)
The current code had its own list of MFMA intrinsics that we can use,
then checked that against the target. Flipping this around, we can
simply query the list from the target.
The only subtlety is that the target may support multiple intrinsics for
a given combination of element types, in which case we have to choose
one.
This PR also changes `std::optional<Attr>` to just `Attr` since a
default-constructed `Attr` is null-ish, there is no need for a second
null-value.
The heuristic added in this PR is designed to match the existing choices
so that the tests don't need to change; these existing choices are also
what maximizes some microbenchmark performance, but we have known that
they may be counterproductive in real scenarios where the bottleneck is
power.
The test gains a FP8 testcase, and some renaming to simplify function
names (which had become a lie in some testcases).
---------
Signed-off-by: Benoit Jacob <jacob.benoit.1@gmail.com>
Co-authored-by: Quinn Dawkins <quinn.dawkins@gmail.com>
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUMaterializeEncoding.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUMaterializeEncoding.cpp
index de80cc9..778cd08 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUMaterializeEncoding.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUMaterializeEncoding.cpp
@@ -41,66 +41,61 @@
#define GEN_PASS_DEF_GPUMATERIALIZEHOSTENCODINGPASS
#include "iree/compiler/Codegen/Common/GPU/Passes.h.inc"
-static bool hasIntrinsic(IREE::GPU::TargetAttr target,
- IREE::GPU::MMAIntrinsic intrinsic) {
- for (IREE::GPU::MMAAttr mma : target.getWgp().getMma()) {
- if (mma.getIntrinsic().getValue() == intrinsic) {
- return true;
+static IREE::GPU::MMAAttr chooseIntrinsicMMAAttr(TypeRange eTypes,
+ IREE::GPU::TargetWgpAttr wgp) {
+ IREE::GPU::MMAAttr candidateMma;
+ for (IREE::GPU::MMAAttr mma : wgp.getMma()) {
+ // Filter out intrinsics that don't match the element types of this matmul.
+ auto [et0, et1, et2] = mma.getABCElementTypes();
+ if (et0 != eTypes[0] || et1 != eTypes[1] || et2 != eTypes[2]) {
+ continue;
}
+ // If multiple intrinsics are available for the given element types, we have
+ // to make a choice. On CDNA3, there may be an intrinsic with larger M/N and
+ // smaller K, which would optimize power, and an intrinsic with larger K,
+ // which would optimize performance when power is not the bottleneck.
+ // Currently we just choose the intrinsic maximizing K, but that can be
+ // revisited later.
+ if (candidateMma && candidateMma.getKSize() > mma.getKSize()) {
+ continue;
+ }
+ candidateMma = mma;
}
- return false;
+ return candidateMma;
}
-static std::optional<IREE::GPU::DataTiledMMAAttr>
+static IREE::GPU::DataTiledMMAAttr
chooseDataTiledMMAAttr(TypeRange eTypes, IREE::GPU::TargetAttr target,
IREE::Encoding::EncodingAttr encoding) {
using namespace IREE::GPU;
if (!target) {
- return std::nullopt;
+ return {};
}
MLIRContext *ctx = target.getContext();
+ IREE::GPU::TargetWgpAttr wgp = target.getWgp();
+ if (!wgp.getMaxLoadInstructionBits() || !wgp.getVgprSpaceBits() ||
+ !wgp.getSimdsPerWgp()) {
+ // Missing workgroup parameters: data tiling not supported on this target.
+ return {};
+ }
//
// Step 1: select a MMAIntrinsic.
//
- const MMAIntrinsic candidateIntrinsics[] = {
- MMAIntrinsic::MFMA_F32_16x16x4_F32,
- MMAIntrinsic::MFMA_F32_16x16x16_F16,
- MMAIntrinsic::MFMA_I32_16x16x32_I8,
- };
- std::optional<MMAIntrinsic> intrinsic;
- for (MMAIntrinsic candidateIntrinsic : candidateIntrinsics) {
- if (!hasIntrinsic(target, candidateIntrinsic)) {
- continue;
- }
- auto [et0, et1, et2] =
- MMAAttr::get(ctx, candidateIntrinsic).getABCElementTypes();
- if (et0 != eTypes[0] || et1 != eTypes[1] || et2 != eTypes[2]) {
- continue;
- }
- intrinsic = candidateIntrinsic;
- break;
- }
- if (!intrinsic) {
- return std::nullopt;
+ MMAAttr intrinsicMma = chooseIntrinsicMMAAttr(eTypes, wgp);
+ if (!intrinsicMma) {
+ return {};
}
//
// Step 2: Select the unrolling factors for the generic case where there is no
// narrow dimension.
//
- IREE::GPU::TargetWgpAttr wgp = target.getWgp();
- if (!wgp.getMaxLoadInstructionBits() || !wgp.getVgprSpaceBits() ||
- !wgp.getSimdsPerWgp()) {
- // Missing workgroup parameters: data tiling not supported on this target.
- return std::nullopt;
- }
auto sizeInBits = [](VectorType type) -> int {
return type.getElementTypeBitWidth() * type.getNumElements();
};
- MMAAttr intrinsicMma = MMAAttr::get(ctx, *intrinsic);
auto [intrinsicA, intrinsicB, intrinsicC] = intrinsicMma.getABCVectorTypes();
// The unrollK factor serves to allow loads from the A and B matrices to use
// the target ISA's vector loads. For instance, if the ISA has 128-bit loads
@@ -111,7 +106,7 @@
if (*wgp.getMaxLoadInstructionBits() % intrinsicLoadBits != 0) {
// Never seen that case: the ISA does not have a suitable load instruction
// to feed that intrinsic?!
- return std::nullopt;
+ return {};
}
const int unrollK = *wgp.getMaxLoadInstructionBits() / intrinsicLoadBits;
@@ -161,7 +156,7 @@
// and `totalUnrollN`, under the constraints:
// 1. totalUnrollM * totalUnrollN <= x * x
// * Reason: by construction of x, any larger area would exceed the
- // wgp.getVgprSpaceBits() budget)
+ // wgp.getVgprSpaceBits() budget.
// 2. totalUnrollM and totalUnrollN are powers of 2.
// * Reason: that is a self-imposed constraint for now to avoid prematurely
// entering excessing fine-tuning of unrolling factors. Also, since below
@@ -243,7 +238,7 @@
} else {
gpuTargetAttr = getCLGPUTarget(tensorType.getContext());
}
- std::optional<IREE::GPU::DataTiledMMAAttr> mma = chooseDataTiledMMAAttr(
+ IREE::GPU::DataTiledMMAAttr mma = chooseDataTiledMMAAttr(
encoding.getElementTypesArray(), gpuTargetAttr, encoding);
if (!mma) {
return failure();
@@ -253,11 +248,11 @@
// based on its operand index in the matmul.
auto rank = tensorType.getRank();
TileMxNxK innerTile;
- std::tie(innerTile.M, innerTile.N, innerTile.K) = mma->getMNKShape();
+ std::tie(innerTile.M, innerTile.N, innerTile.K) = mma.getMNKShape();
auto encodingInfo = getEncodingInfoForMatmul(encoding, rank, innerTile);
auto fragment =
static_cast<IREE::GPU::MMAFragment>(encoding.getOperandIndex().getInt());
- encodingInfo.swizzle = getSwizzle(*mma, fragment);
+ encodingInfo.swizzle = getSwizzle(mma, fragment);
return encodingInfo;
}
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_materialize_encoding.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_materialize_encoding.mlir
index 1c3a058..e20fc3b 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_materialize_encoding.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_materialize_encoding.mlir
@@ -323,7 +323,7 @@
#hal.pipeline.binding<storage_buffer>,
#hal.pipeline.binding<storage_buffer>
]>
-func.func @matmul_lowering_unroll8x8x4_MFMA_F32_16x16x4_F32() {
+func.func @matmul_lowering_MFMA_F32_16x16x4_F32() {
%c0 = arith.constant 0 : index
%M = hal.interface.constant.load layout(#pipeline_layout_3) ordinal(0) : index
%N = hal.interface.constant.load layout(#pipeline_layout_3) ordinal(1) : index
@@ -356,7 +356,7 @@
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
-// CHECK: func.func @matmul_lowering_unroll8x8x4_MFMA_F32_16x16x4_F32
+// CHECK: func.func @matmul_lowering_MFMA_F32_16x16x4_F32
// CHECK-DAG: %[[LHS_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(0)
// CHECK-DAG: %[[RHS_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(1)
// CHECK-DAG: %[[ACC_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(2)
@@ -382,7 +382,7 @@
#hal.pipeline.binding<storage_buffer>,
#hal.pipeline.binding<storage_buffer>
]>
-func.func @batch_matmul_lowering_unroll8x8x4_MFMA_F32_16x16x4_F32() {
+func.func @batch_matmul_lowering_MFMA_F32_16x16x4_F32() {
%c0 = arith.constant 0 : index
%B = hal.interface.constant.load layout(#pipeline_layout_4) ordinal(0) : index
%M = hal.interface.constant.load layout(#pipeline_layout_4) ordinal(1) : index
@@ -416,7 +416,7 @@
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
-// CHECK: func.func @batch_matmul_lowering_unroll8x8x4_MFMA_F32_16x16x4_F32
+// CHECK: func.func @batch_matmul_lowering_MFMA_F32_16x16x4_F32
// CHECK-DAG: %[[LHS_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(0)
// CHECK-DAG: %[[RHS_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(1)
// CHECK-DAG: %[[ACC_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(2)
@@ -579,7 +579,7 @@
#hal.pipeline.binding<storage_buffer>
]>
-func.func @matmul_lowering_unroll8x8x2_MFMA_I32_16x16x32_I8() {
+func.func @matmul_lowering_MFMA_I32_16x16x32_I8() {
%c0 = arith.constant 0 : index
%M = hal.interface.constant.load layout(#pipeline_layout_3) ordinal(0) : index
%N = hal.interface.constant.load layout(#pipeline_layout_3) ordinal(1) : index
@@ -612,7 +612,7 @@
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
-// CHECK: func.func @matmul_lowering_unroll8x8x2_MFMA_I32_16x16x32_I8
+// CHECK: func.func @matmul_lowering_MFMA_I32_16x16x32_I8
// CHECK-DAG: %[[LHS_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(0)
// CHECK-DAG: %[[RHS_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(1)
// CHECK-DAG: %[[ACC_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(2)
@@ -666,7 +666,7 @@
#hal.pipeline.binding<storage_buffer>
]>
-func.func @matmul_lowering_unroll8x8x2_MFMA_I32_16x16x32_I8_custom_max_load_instruction_bits_64() attributes {hal.executable.target = #target_gfx942_except_max_load_instruction_bits_64} {
+func.func @matmul_lowering_MFMA_I32_16x16x32_I8_custom_max_load_instruction_bits_64() attributes {hal.executable.target = #target_gfx942_except_max_load_instruction_bits_64} {
%c0 = arith.constant 0 : index
%M = hal.interface.constant.load layout(#pipeline_layout_3) ordinal(0) : index
%N = hal.interface.constant.load layout(#pipeline_layout_3) ordinal(1) : index
@@ -697,7 +697,7 @@
return
}
-// CHECK: func.func @matmul_lowering_unroll8x8x2_MFMA_I32_16x16x32_I8_custom_max_load_instruction_bits_64
+// CHECK: func.func @matmul_lowering_MFMA_I32_16x16x32_I8_custom_max_load_instruction_bits_64
// CHECK: iree_gpu.multi_mma %[[LHS]], %[[RHS]], %[[ACC]]
// CHECK-SAME: iterator_types = [#iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<reduction>]
// CHECK-SAME: kind = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_I32_16x16x32_I8, unroll_m = 8, unroll_n = 2, unroll_n_to_subgroups = 4>
@@ -739,7 +739,7 @@
#hal.pipeline.binding<storage_buffer>
]>
-func.func @matmul_lowering_unroll8x8x2_MFMA_I32_16x16x32_I8_custom_max_load_instruction_bits_64() attributes {hal.executable.target = #target_gfx942_except_max_load_instruction_bits_256} {
+func.func @matmul_lowering_MFMA_I32_16x16x32_I8_custom_max_load_instruction_bits_64() attributes {hal.executable.target = #target_gfx942_except_max_load_instruction_bits_256} {
%c0 = arith.constant 0 : index
%M = hal.interface.constant.load layout(#pipeline_layout_3) ordinal(0) : index
%N = hal.interface.constant.load layout(#pipeline_layout_3) ordinal(1) : index
@@ -770,7 +770,7 @@
return
}
-// CHECK: func.func @matmul_lowering_unroll8x8x2_MFMA_I32_16x16x32_I8_custom_max_load_instruction_bits_64
+// CHECK: func.func @matmul_lowering_MFMA_I32_16x16x32_I8_custom_max_load_instruction_bits_64
// CHECK: iree_gpu.multi_mma %[[LHS]], %[[RHS]], %[[ACC]]
// CHECK-SAME: iterator_types = [#iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<reduction>]
// CHECK-SAME: kind = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_I32_16x16x32_I8, unroll_m = 8, unroll_n = 2, unroll_n_to_subgroups = 4, unroll_k = 4>
@@ -812,7 +812,7 @@
#hal.pipeline.binding<storage_buffer>
]>
-func.func @matmul_lowering_unroll8x8x2_MFMA_I32_16x16x32_I8_custom_simds_per_wgp_1() attributes {hal.executable.target = #target_gfx942_except_simds_per_wgp_1} {
+func.func @matmul_lowering_MFMA_I32_16x16x32_I8_custom_simds_per_wgp_1() attributes {hal.executable.target = #target_gfx942_except_simds_per_wgp_1} {
%c0 = arith.constant 0 : index
%M = hal.interface.constant.load layout(#pipeline_layout_3) ordinal(0) : index
%N = hal.interface.constant.load layout(#pipeline_layout_3) ordinal(1) : index
@@ -843,7 +843,7 @@
return
}
-// CHECK: func.func @matmul_lowering_unroll8x8x2_MFMA_I32_16x16x32_I8_custom_simds_per_wgp_1
+// CHECK: func.func @matmul_lowering_MFMA_I32_16x16x32_I8_custom_simds_per_wgp_1
// CHECK: iree_gpu.multi_mma %[[LHS]], %[[RHS]], %[[ACC]]
// CHECK-SAME: iterator_types = [#iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<reduction>]
// CHECK-SAME: kind = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_I32_16x16x32_I8, unroll_m = 8, unroll_n = 8, unroll_k = 2>
@@ -885,7 +885,7 @@
#hal.pipeline.binding<storage_buffer>
]>
-func.func @matmul_lowering_unroll8x8x2_MFMA_I32_16x16x32_I8_custom_vgpr_space_bits_8192() attributes {hal.executable.target = #target_gfx942_except_vgpr_space_bits_8192} {
+func.func @matmul_lowering_MFMA_I32_16x16x32_I8_custom_vgpr_space_bits_8192() attributes {hal.executable.target = #target_gfx942_except_vgpr_space_bits_8192} {
%c0 = arith.constant 0 : index
%M = hal.interface.constant.load layout(#pipeline_layout_3) ordinal(0) : index
%N = hal.interface.constant.load layout(#pipeline_layout_3) ordinal(1) : index
@@ -916,7 +916,7 @@
return
}
-// CHECK: func.func @matmul_lowering_unroll8x8x2_MFMA_I32_16x16x32_I8_custom_vgpr_space_bits_8192
+// CHECK: func.func @matmul_lowering_MFMA_I32_16x16x32_I8_custom_vgpr_space_bits_8192
// CHECK: iree_gpu.multi_mma %[[LHS]], %[[RHS]], %[[ACC]]
// CHECK-SAME: iterator_types = [#iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<reduction>]
// CHECK-SAME: kind = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_I32_16x16x32_I8, unroll_m = 4, unroll_n = 2, unroll_n_to_subgroups = 4, unroll_k = 2>
@@ -958,7 +958,7 @@
#hal.pipeline.binding<storage_buffer>
]>
-func.func @matmul_lowering_unroll8x8x2_MFMA_I32_16x16x32_I8_custom_vgpr_space_bits_4096() attributes {hal.executable.target = #target_gfx942_except_vgpr_space_bits_4096} {
+func.func @matmul_lowering_MFMA_I32_16x16x32_I8_custom_vgpr_space_bits_4096() attributes {hal.executable.target = #target_gfx942_except_vgpr_space_bits_4096} {
%c0 = arith.constant 0 : index
%M = hal.interface.constant.load layout(#pipeline_layout_3) ordinal(0) : index
%N = hal.interface.constant.load layout(#pipeline_layout_3) ordinal(1) : index
@@ -989,7 +989,7 @@
return
}
-// CHECK: func.func @matmul_lowering_unroll8x8x2_MFMA_I32_16x16x32_I8_custom_vgpr_space_bits_4096
+// CHECK: func.func @matmul_lowering_MFMA_I32_16x16x32_I8_custom_vgpr_space_bits_4096
// CHECK: iree_gpu.multi_mma %[[LHS]], %[[RHS]], %[[ACC]]
// CHECK-SAME: iterator_types = [#iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<reduction>]
// CHECK-SAME: kind = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_I32_16x16x32_I8, unroll_m = 4, unroll_n_to_subgroups = 4, unroll_k = 2>
@@ -1031,7 +1031,7 @@
#hal.pipeline.binding<storage_buffer>
]>
-func.func @matmul_lowering_unroll8x8x2_MFMA_I32_16x16x32_I8_custom_vgpr_space_bits_32768() attributes {hal.executable.target = #target_gfx942_except_vgpr_space_bits_32768} {
+func.func @matmul_lowering_MFMA_I32_16x16x32_I8_custom_vgpr_space_bits_32768() attributes {hal.executable.target = #target_gfx942_except_vgpr_space_bits_32768} {
%c0 = arith.constant 0 : index
%M = hal.interface.constant.load layout(#pipeline_layout_3) ordinal(0) : index
%N = hal.interface.constant.load layout(#pipeline_layout_3) ordinal(1) : index
@@ -1062,9 +1062,71 @@
return
}
-// CHECK: func.func @matmul_lowering_unroll8x8x2_MFMA_I32_16x16x32_I8_custom_vgpr_space_bits_32768
+// CHECK: func.func @matmul_lowering_MFMA_I32_16x16x32_I8_custom_vgpr_space_bits_32768
// CHECK: iree_gpu.multi_mma %[[LHS]], %[[RHS]], %[[ACC]]
// CHECK-SAME: iterator_types = [#iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<reduction>]
// CHECK-SAME: kind = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_I32_16x16x32_I8, unroll_m = 8, unroll_n = 4, unroll_n_to_subgroups = 4, unroll_k = 2>
// -----
+
+//---------------------------------------------------------------------------
+// 4. Additional element types, testing only the multi_mma, not set_encoding.
+//---------------------------------------------------------------------------
+
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+#encoding_lhs = #iree_encoding.encoding<operand_index = 0, op_type = matmul, element_types = [f8E4M3FNUZ, f8E4M3FNUZ, f32], user_indexing_maps = [#map, #map1, #map2]>
+#encoding_rhs = #iree_encoding.encoding<operand_index = 1, op_type = matmul, element_types = [f8E4M3FNUZ, f8E4M3FNUZ, f32], user_indexing_maps = [#map, #map1, #map2]>
+#encoding_result = #iree_encoding.encoding<operand_index = 2, op_type = matmul, element_types = [f8E4M3FNUZ, f8E4M3FNUZ, f32], user_indexing_maps = [#map, #map1, #map2]>
+#pipeline_layout_4 = #hal.pipeline.layout<constants = 4, bindings = [
+ #hal.pipeline.binding<storage_buffer>,
+ #hal.pipeline.binding<storage_buffer>,
+ #hal.pipeline.binding<storage_buffer>
+]>
+func.func @batch_matmul_lowering_MFMA_F32_16x16x32_F8E4M3FNUZ() {
+ %c0 = arith.constant 0 : index
+ %B = hal.interface.constant.load layout(#pipeline_layout_4) ordinal(0) : index
+ %M = hal.interface.constant.load layout(#pipeline_layout_4) ordinal(1) : index
+ %N = hal.interface.constant.load layout(#pipeline_layout_4) ordinal(2) : index
+ %K = hal.interface.constant.load layout(#pipeline_layout_4) ordinal(3) : index
+ %0 = hal.interface.binding.subspan layout(#pipeline_layout_4) binding(0) alignment(64) offset(%c0)
+ : !flow.dispatch.tensor<readonly:tensor<?x?x?xf8E4M3FNUZ, #encoding_lhs>>{%B, %M, %K}
+ %1 = hal.interface.binding.subspan layout(#pipeline_layout_4) binding(1) alignment(64) offset(%c0)
+ : !flow.dispatch.tensor<readonly:tensor<?x?x?xf8E4M3FNUZ, #encoding_rhs>>{%B, %K, %N}
+ %2 = hal.interface.binding.subspan layout(#pipeline_layout_4) binding(2) alignment(64) offset(%c0)
+ : !flow.dispatch.tensor<readwrite:tensor<?x?x?xf32, #encoding_result>>{%B, %M, %N}
+ %3 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0], sizes = [%B, %M, %K], strides = [1, 1, 1]
+ : !flow.dispatch.tensor<readonly:tensor<?x?x?xf8E4M3FNUZ, #encoding_lhs>>{%B, %M, %K}
+ -> tensor<?x?x?xf8E4M3FNUZ, #encoding_lhs>
+ %4 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0], sizes = [%B, %K, %N], strides = [1, 1, 1]
+ : !flow.dispatch.tensor<readonly:tensor<?x?x?xf8E4M3FNUZ, #encoding_rhs>>{%B, %K, %N}
+ -> tensor<?x?x?xf8E4M3FNUZ, #encoding_rhs>
+ %5 = flow.dispatch.tensor.load %2, offsets = [0, 0, 0], sizes = [%B, %M, %N], strides = [1, 1, 1]
+ : !flow.dispatch.tensor<readwrite:tensor<?x?x?xf32, #encoding_result>>{%B, %M, %N}
+ -> tensor<?x?x?xf32, #encoding_result>
+ %6 = linalg.batch_matmul
+ ins(%3, %4 : tensor<?x?x?xf8E4M3FNUZ, #encoding_lhs>,
+ tensor<?x?x?xf8E4M3FNUZ, #encoding_rhs>)
+ outs(%5 : tensor<?x?x?xf32, #encoding_result>)
+ -> tensor<?x?x?xf32, #encoding_result>
+ flow.dispatch.tensor.store %6, %2, offsets = [0, 0, 0], sizes = [%B, %M, %N], strides = [1, 1, 1]
+ : tensor<?x?x?xf32, #encoding_result>
+ -> !flow.dispatch.tensor<readwrite:tensor<?x?x?xf32, #encoding_result>>{%B, %M, %N}
+ return
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+// CHECK: func.func @batch_matmul_lowering_MFMA_F32_16x16x32_F8E4M3FNUZ
+// CHECK-DAG: %[[LHS_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(0)
+// CHECK-DAG: %[[RHS_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(1)
+// CHECK-DAG: %[[ACC_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(2)
+// CHECK-DAG: %[[LHS:.+]] = flow.dispatch.tensor.load %[[LHS_BINDING]]{{.+}} -> tensor<?x?x?x8x4x16x2x8xf8E4M3FNUZ>
+// CHECK-DAG: %[[RHS:.+]] = flow.dispatch.tensor.load %[[RHS_BINDING]]{{.+}} -> tensor<?x?x?x4x2x4x16x2x8xf8E4M3FNUZ>
+// CHECK-DAG: %[[ACC:.+]] = flow.dispatch.tensor.load %[[ACC_BINDING]]{{.+}} -> tensor<?x?x?x8x4x2x4x16x4xf32>
+// CHECK: %[[MMA:.+]] = iree_gpu.multi_mma %[[LHS]], %[[RHS]], %[[ACC]]
+// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]],
+// CHECK-SAME: iterator_types = [#iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<reduction>]
+// CHECK-SAME: kind = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_F32_16x16x32_F8E4M3FNUZ, unroll_m = 8, unroll_n = 2, unroll_n_to_subgroups = 4, unroll_k = 2>
+// CHECK: flow.dispatch.tensor.store %[[MMA]], %[[ACC_BINDING]]