[Codegen][GPU] Add pass to expand multi_mma op shapes to intrinsic layout (#18139)
This PR adds a new pass to explicitly materialize the dimensions of
intrinsic layouts for `iree_gpu.multi_mma` ops. This means adding an
expand_shape on each of the inputs to go from the `OpaqueMmaLayout`
shape to the `ConcreteMmaLayout` shape. This makes it easy to extract
the correct data from the tensors when it is time to distribute the
multi_mma op to lanes, since the shape will match the number of offsets
and sizes needed for the slice.
---------
Signed-off-by: Max Dawkins <max.dawkins@gmail.com>
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
index ab0d9e1..7ccfccb 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
@@ -24,6 +24,7 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/BuiltinAttributes.h"
@@ -752,6 +753,84 @@
return success();
}
+LogicalResult MMAAttr::materializeOperandConcreteShape(
+ OpBuilder &builder, IREE::GPU::MMAFragment fragment, Value operand,
+ std::optional<ArrayRef<int64_t>> permutation,
+ SmallVector<ReassociationIndices> &reassociations,
+ RankedTensorType &resultType) const {
+ OpaqueMmaLayout opaqueLayout =
+ getOpaqueMFMALayout(operand.getContext(), getIntrinsic().getValue());
+ // TODO(Max191): The `getConcreteMFMALayout` function creates some
+ // `PerDimLayoutAttr` that are not used by this function. This means that
+ // any pass that uses `materializeOperandConcreteShape` needs to be
+ // dependent on the VectorExt dialect. Ideally, the `getConcreteMFMALayout`
+ // function should be refactored so we can reuse the shape information of
+ // the layout without needing to create any `PerDimLayoutAttr`.
+ ConcreteMmaLayout layout =
+ getConcreteMFMALayout(operand.getContext(), getIntrinsic().getValue());
+ SmallVector<ArrayRef<int64_t>> concreteSizes;
+ SmallVector<int64_t, 2> opaqueSizes;
+ switch (fragment) {
+ case IREE::GPU::MMAFragment::Lhs: {
+ concreteSizes.push_back(layout.aMLayout.getShapes());
+ concreteSizes.push_back(layout.aKLayout.getShapes());
+ opaqueSizes.push_back(opaqueLayout.mSize);
+ opaqueSizes.push_back(opaqueLayout.kSize);
+ break;
+ }
+ case IREE::GPU::MMAFragment::Rhs: {
+ concreteSizes.push_back(layout.bKLayout.getShapes());
+ concreteSizes.push_back(layout.bNLayout.getShapes());
+ opaqueSizes.push_back(opaqueLayout.kSize);
+ opaqueSizes.push_back(opaqueLayout.nSize);
+ break;
+ }
+ case IREE::GPU::MMAFragment::Acc: {
+ concreteSizes.push_back(layout.cMLayout.getShapes());
+ concreteSizes.push_back(layout.cNLayout.getShapes());
+ opaqueSizes.push_back(opaqueLayout.mSize);
+ opaqueSizes.push_back(opaqueLayout.nSize);
+ break;
+ }
+ }
+ if (permutation.has_value()) {
+ if (permutation.value().size() != opaqueSizes.size()) {
+ return failure();
+ }
+ applyPermutationToVector(concreteSizes, permutation.value());
+ applyPermutationToVector(opaqueSizes, permutation.value());
+ }
+
+ // Inner tile must have sizes matching the opaque layout.
+ auto operandType = llvm::cast<RankedTensorType>(operand.getType());
+ ArrayRef<int64_t> operandShape = operandType.getShape();
+ SmallVector<int64_t, 2> innerShape(operandShape.end() - opaqueSizes.size(),
+ operandShape.end());
+ if (!llvm::equal(opaqueSizes, innerShape)) {
+ return failure();
+ }
+
+ // Expand the shape of the inner tile to reflect the MMA thread layout.
+ SmallVector<int64_t, 4> resultShape(operandShape.begin(),
+ operandShape.end() - 2);
+ SmallVector<ReassociationIndices> reInds =
+ llvm::map_to_vector(llvm::seq<int64_t>(resultShape.size()),
+ [](int64_t idx) -> ReassociationIndices {
+ return ReassociationIndices({idx});
+ });
+ int idx = reInds.size();
+ for (ArrayRef<int64_t> sizes : concreteSizes) {
+ resultShape.append(SmallVector<int64_t>(sizes));
+ reInds.push_back(
+ llvm::to_vector(llvm::seq<int64_t>(idx, idx + sizes.size())));
+ idx += sizes.size();
+ }
+
+ reassociations = reInds;
+ resultType = operandType.clone(resultShape);
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// MMA Schedule Attributes
//===----------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td
index e421f4e..6809bc6 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td
@@ -131,6 +131,7 @@
"getSubgroupSize",
"buildMmaOperation",
"populateOperandOffsetsSizesStrides",
+ "materializeOperandConcreteShape",
]>
]> {
let cppNamespace = "::mlir::iree_compiler::IREE::GPU";
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUInterfaces.td b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUInterfaces.td
index d706154..88c345d 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUInterfaces.td
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUInterfaces.td
@@ -127,6 +127,26 @@
return failure();
}]
>,
+ InterfaceMethod<
+ /*desc=*/[{
+ Constructs the offsets/sizes/strides for extracting the per-thread
+ slice of the given operand fragment.
+ }],
+ /*retTy=*/"::mlir::LogicalResult",
+ /*methodName=*/"materializeOperandConcreteShape",
+ /*args=*/(ins
+ "::mlir::OpBuilder&":$builder,
+ "::mlir::iree_compiler::IREE::GPU::MMAFragment":$fragment,
+ "::mlir::Value":$operand,
+ "std::optional<::llvm::ArrayRef<int64_t>>":$permutation,
+ "::llvm::SmallVector<::mlir::SmallVector<int64_t, 2>>&":$reassociations,
+ "::mlir::RankedTensorType&":$result_type
+ ),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return failure();
+ }]
+ >,
];
}
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/BUILD.bazel
index 00af941..84a0303 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/BUILD.bazel
@@ -51,6 +51,7 @@
iree_compiler_cc_library(
name = "GPUTransforms",
srcs = [
+ "ConcretizeMmaShapes.cpp",
"DistributeMmaToLanes.cpp",
"FuseAndHoistParallelLoops.cpp",
"LowerIREEGPUOps.cpp",
@@ -69,6 +70,7 @@
":PassesIncGen",
"//compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR:IREECodegenDialect",
"//compiler/src/iree/compiler/Codegen/Dialect/GPU/IR:IREEGPUDialect",
+ "//compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR:IREEVectorExtDialect",
"//compiler/src/iree/compiler/Codegen/Transforms",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:AffineDialect",
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/CMakeLists.txt
index 7d49b1d..e563b04 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/CMakeLists.txt
@@ -45,6 +45,7 @@
"Passes.h.inc"
"Transforms.h"
SRCS
+ "ConcretizeMmaShapes.cpp"
"DistributeMmaToLanes.cpp"
"FuseAndHoistParallelLoops.cpp"
"LowerIREEGPUOps.cpp"
@@ -80,6 +81,7 @@
MLIRVectorUtils
iree::compiler::Codegen::Dialect::Codegen::IR::IREECodegenDialect
iree::compiler::Codegen::Dialect::GPU::IR::IREEGPUDialect
+ iree::compiler::Codegen::Dialect::VectorExt::IR::IREEVectorExtDialect
iree::compiler::Codegen::Transforms
PUBLIC
)
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/ConcretizeMmaShapes.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/ConcretizeMmaShapes.cpp
new file mode 100644
index 0000000..9910840
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/ConcretizeMmaShapes.cpp
@@ -0,0 +1,154 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
+#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.h"
+#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.h"
+#include "iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.h"
+#include "iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.h"
+#include "iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtDialect.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir::iree_compiler::IREE::GPU {
+
+#define GEN_PASS_DEF_CONCRETIZEMMASHAPESPASS
+#include "iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.h.inc"
+
+namespace {
+struct ConcretizeMmaShapesPass final
+ : impl::ConcretizeMmaShapesPassBase<ConcretizeMmaShapesPass> {
+ using ConcretizeMmaShapesPassBase::ConcretizeMmaShapesPassBase;
+ void runOnOperation() override;
+};
+} // namespace
+
+struct ConcretizeMmaOperandShape final : OpRewritePattern<MultiMmaOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ ConcretizeMmaOperandShape(MLIRContext *context, MMAFragment fragment)
+ : OpRewritePattern<MultiMmaOp>(context), fragment(fragment) {}
+
+ LogicalResult matchAndRewrite(MultiMmaOp mmaOp,
+ PatternRewriter &rewriter) const override {
+ if (!mmaOp.hasTensorSemantics()) {
+ return failure();
+ }
+
+ // Get the right operand and permutation for the `fragment`.
+ Value operand;
+ std::optional<ArrayRef<int64_t>> permutation;
+ switch (fragment) {
+ case MMAFragment::Lhs:
+ operand = mmaOp.getLhs();
+ permutation = mmaOp.getLhsPermutation();
+ break;
+ case MMAFragment::Rhs:
+ operand = mmaOp.getRhs();
+ permutation = mmaOp.getRhsPermutation();
+ break;
+ case MMAFragment::Acc:
+ operand = mmaOp.getAcc();
+ permutation = mmaOp.getAccPermutation();
+ break;
+ }
+
+ // Get the reassociation indices and result type of the expand_shape op.
+ MmaInterfaceAttr kind = mmaOp.getKind();
+ SmallVector<ReassociationIndices> reassociations;
+ RankedTensorType concreteType;
+ if (failed(kind.materializeOperandConcreteShape(rewriter, fragment, operand,
+ permutation, reassociations,
+ concreteType))) {
+ return failure();
+ }
+
+ // Create the expand_shape.
+ Location loc = mmaOp->getLoc();
+ Value concreteOperand = rewriter
+ .create<tensor::ExpandShapeOp>(
+ loc, concreteType, operand, reassociations)
+ .getResult();
+
+ // Expand the permutation for the new inner dimensions of the expanded
+ // multi_mma operand.
+ auto expandPerm =
+ [&](std::optional<ArrayRef<int64_t>> perm, MMAFragment frag,
+ int64_t outerRank) -> std::optional<DenseI64ArrayAttr> {
+ if (!perm.has_value()) {
+ return std::nullopt;
+ }
+ if (frag != fragment) {
+ return rewriter.getDenseI64ArrayAttr(perm.value());
+ }
+ SmallVector<ReassociationIndices> innerReInds(
+ reassociations.begin() + outerRank, reassociations.end());
+ for (auto &reInd : innerReInds) {
+ for (auto &idx : reInd) {
+ idx -= outerRank;
+ }
+ }
+ SmallVector<int64_t> expandedPerm;
+ for (auto reInd : applyPermutation(innerReInds, perm.value())) {
+ expandedPerm.append(reInd);
+ }
+ return rewriter.getDenseI64ArrayAttr(expandedPerm);
+ };
+ std::optional<DenseI64ArrayAttr> lhsPerm = expandPerm(
+ mmaOp.getLhsPermutation(), MMAFragment::Lhs, mmaOp.getLhsOuterRank());
+ std::optional<DenseI64ArrayAttr> rhsPerm = expandPerm(
+ mmaOp.getRhsPermutation(), MMAFragment::Rhs, mmaOp.getRhsOuterRank());
+ std::optional<DenseI64ArrayAttr> accPerm = expandPerm(
+ mmaOp.getAccPermutation(), MMAFragment::Acc, mmaOp.getAccOuterRank());
+
+ // Create the new multi_mma op with the concrete type.
+ auto concreteMmaOp = rewriter.create<MultiMmaOp>(
+ loc,
+ /*lhs=*/fragment == MMAFragment::Lhs ? concreteOperand : mmaOp.getLhs(),
+ /*rhs=*/fragment == MMAFragment::Rhs ? concreteOperand : mmaOp.getRhs(),
+ /*acc=*/fragment == MMAFragment::Acc ? concreteOperand : mmaOp.getAcc(),
+ mmaOp.getIndexingMaps(), mmaOp.getIteratorTypes(), mmaOp.getKind(),
+ lhsPerm, rhsPerm, accPerm);
+
+ if (auto config = getLoweringConfig(mmaOp)) {
+ setLoweringConfig(concreteMmaOp, config);
+ }
+
+ if (fragment != MMAFragment::Acc) {
+ rewriter.replaceOp(mmaOp, concreteMmaOp);
+ return success();
+ }
+
+ // For the Acc operand, the result needs to be collapsed back to the
+ // original type so that types match with consumers.
+ rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
+ mmaOp, mmaOp.getAccType(), concreteMmaOp.getResult(), reassociations);
+
+ return success();
+ }
+
+private:
+ MMAFragment fragment;
+};
+
+void ConcretizeMmaShapesPass::runOnOperation() {
+ MLIRContext *context = &getContext();
+ auto funcOp = getOperation();
+
+ RewritePatternSet patterns(context);
+ if (concretizeInputs) {
+ patterns.insert<ConcretizeMmaOperandShape>(context, MMAFragment::Lhs);
+ patterns.insert<ConcretizeMmaOperandShape>(context, MMAFragment::Rhs);
+ }
+ if (concretizeResult) {
+ patterns.insert<ConcretizeMmaOperandShape>(context, MMAFragment::Acc);
+ }
+ if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
+ return signalPassFailure();
+ }
+}
+
+} // namespace mlir::iree_compiler::IREE::GPU
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.td b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.td
index 9f64563..6cc7f11 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.td
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.td
@@ -20,6 +20,24 @@
];
}
+def ConcretizeMmaShapesPass :
+ InterfacePass<"iree-gpu-concretize-mma-shapes", "mlir::FunctionOpInterface"> {
+ let summary = "Expands the inner dimensions of iree_gpu.multi_mma ops to match the thread layout";
+ let dependentDialects = [
+ "::mlir::tensor::TensorDialect",
+ "::mlir::iree_compiler::IREE::GPU::IREEGPUDialect",
+ "::mlir::iree_compiler::IREE::VectorExt::IREEVectorExtDialect",
+ ];
+ let options = [
+ Option<"concretizeInputs", "concretize-inputs",
+ "bool", /*default=*/"true",
+ "Expand the inner dimensions for the lhs and rhs operands of the multi_mma ops.">,
+ Option<"concretizeResult", "concretize-result",
+ "bool", /*default=*/"true",
+ "Expand the inner dimensions for the acc operand of the multi_mma ops.">,
+ ];
+}
+
def FuseAndHoistParallelLoopsPass :
InterfacePass<"iree-gpu-fuse-and-hoist-parallel-loops", "mlir::FunctionOpInterface"> {
let summary = "Greedily fuses and hoists parallel loops.";
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/BUILD.bazel
index 5e9c0e7..8348d9f 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/BUILD.bazel
@@ -18,6 +18,7 @@
name = "lit",
srcs = enforce_glob(
[
+ "concretize_mma_shapes.mlir",
"distribute_mma_to_lanes.mlir",
"fuse_and_hoist_forall.mlir",
"pack_to_intrinsics.mlir",
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/CMakeLists.txt
index ef55e3d..a71fd9c 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/CMakeLists.txt
@@ -14,6 +14,7 @@
NAME
lit
SRCS
+ "concretize_mma_shapes.mlir"
"distribute_mma_to_lanes.mlir"
"fuse_and_hoist_forall.mlir"
"pack_to_intrinsics.mlir"
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/concretize_mma_shapes.mlir b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/concretize_mma_shapes.mlir
new file mode 100644
index 0000000..990bfea
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/concretize_mma_shapes.mlir
@@ -0,0 +1,110 @@
+// RUN: iree-opt %s --pass-pipeline='builtin.module(func.func(iree-gpu-concretize-mma-shapes{concretize-result=false}, canonicalize, cse))' --split-input-file | FileCheck %s -check-prefixes=CHECK,CHECK-INPUTS
+// RUN: iree-opt %s --pass-pipeline='builtin.module(func.func(iree-gpu-concretize-mma-shapes{concretize-inputs=false}, canonicalize, cse))' --split-input-file | FileCheck %s -check-prefixes=CHECK,CHECK-RESULT
+
+#contraction_accesses = [
+ affine_map<(i, j, k) -> (i, k)>,
+ affine_map<(i, j, k) -> (k, j)>,
+ affine_map<(i, j, k) -> (i, j)>
+]
+#config = #iree_gpu.lowering_config<{workgroup = [64, 64, 0], reduction = [0, 0, 4], thread = [8, 4]}>
+func.func @concretize_multi_mma_F32_16x16x16_F16(%lhs: tensor<2x2x16x16xf16>, %rhs: tensor<2x2x16x16xf16>, %acc: tensor<2x2x16x16xf32>) -> tensor<2x2x16x16xf32> {
+ %0 = iree_gpu.multi_mma %lhs, %rhs, %acc {
+ indexing_maps = #contraction_accesses,
+ iterator_types = [#iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<reduction>],
+ kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>, lowering_config = #config
+ } : tensor<2x2x16x16xf16>, tensor<2x2x16x16xf16> into tensor<2x2x16x16xf32>
+ return %0 : tensor<2x2x16x16xf32>
+}
+
+// CHECK-LABEL: func @concretize_multi_mma_F32_16x16x16_F16
+// CHECK-SAME: %[[LHS:[A-Za-z0-9]+]]: tensor<2x2x16x16xf16>
+// CHECK-SAME: %[[RHS:[A-Za-z0-9]+]]: tensor<2x2x16x16xf16>
+// CHECK-SAME: %[[ACC:[A-Za-z0-9]+]]: tensor<2x2x16x16xf32>
+
+// CHECK-INPUTS-DAG: %[[EXPANDED_LHS:.+]] = tensor.expand_shape %[[LHS]] {{\[}}[0], [1], [2], [3, 4]] output_shape [2, 2, 16, 4, 4] : tensor<2x2x16x16xf16> into tensor<2x2x16x4x4xf16>
+// CHECK-INPUTS-DAG: %[[EXPANDED_RHS:.+]] = tensor.expand_shape %[[RHS]] {{\[}}[0], [1], [2, 3], [4]] output_shape [2, 2, 4, 4, 16] : tensor<2x2x16x16xf16> into tensor<2x2x4x4x16xf16>
+// CHECK-INPUTS: %[[MMA:.+]] = iree_gpu.multi_mma %[[EXPANDED_LHS]], %[[EXPANDED_RHS]], %[[ACC]]
+// CHECK-INPUTS-SAME: lowering_config = #iree_gpu.lowering_config
+// CHECK-INPUTS-SAME: : tensor<2x2x16x4x4xf16>, tensor<2x2x4x4x16xf16> into tensor<2x2x16x16xf32>
+// CHECK-INPUTS: return %[[MMA]]
+
+// CHECK-RESULT-DAG: %[[EXPANDED_ACC:.+]] = tensor.expand_shape %[[ACC]] {{\[}}[0], [1], [2, 3], [4]] output_shape [2, 2, 4, 4, 16] : tensor<2x2x16x16xf32> into tensor<2x2x4x4x16xf32>
+// CHECK-RESULT: %[[MMA:.+]] = iree_gpu.multi_mma %[[LHS]], %[[RHS]], %[[EXPANDED_ACC]]
+// CHECK-RESULT-SAME: lowering_config = #iree_gpu.lowering_config
+// CHECK-RESULT-SAME: : tensor<2x2x16x16xf16>, tensor<2x2x16x16xf16> into tensor<2x2x4x4x16xf32>
+// CHECK-RESULT: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[MMA]] {{\[}}[0], [1], [2, 3], [4]] : tensor<2x2x4x4x16xf32> into tensor<2x2x16x16xf32>
+// CHECK-RESULT: return %[[COLLAPSED]]
+
+// -----
+
+#contraction_accesses = [
+ affine_map<(i, j, k) -> (i, k)>,
+ affine_map<(i, j, k) -> (j, k)>,
+ affine_map<(i, j, k) -> (i, j)>
+]
+#config = #iree_gpu.lowering_config<{workgroup = [64, 64, 0], reduction = [0, 0, 4], thread = [8, 4]}>
+func.func @concretize_multi_mma_I32_16x16x32_I8(%lhs: tensor<2x2x16x32xi8>, %rhs: tensor<2x2x16x32xi8>, %acc: tensor<2x2x16x16xi32>) -> tensor<2x2x16x16xi32> {
+ %0 = iree_gpu.multi_mma %lhs, %rhs, %acc {
+ indexing_maps = #contraction_accesses,
+ iterator_types = [#iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<reduction>],
+ kind = #iree_gpu.mma_layout<MFMA_I32_16x16x32_I8>,
+ rhs_permutation = array<i64: 1, 0>, lowering_config = #config
+ } : tensor<2x2x16x32xi8>, tensor<2x2x16x32xi8> into tensor<2x2x16x16xi32>
+ return %0 : tensor<2x2x16x16xi32>
+}
+
+// CHECK-LABEL: func @concretize_multi_mma_I32_16x16x32_I8
+// CHECK-SAME: %[[LHS:[A-Za-z0-9]+]]: tensor<2x2x16x32xi8>
+// CHECK-SAME: %[[RHS:[A-Za-z0-9]+]]: tensor<2x2x16x32xi8>
+// CHECK-SAME: %[[ACC:[A-Za-z0-9]+]]: tensor<2x2x16x16xi32>
+
+// CHECK-INPUTS-DAG: %[[EXPANDED_LHS:.+]] = tensor.expand_shape %[[LHS]] {{\[}}[0], [1], [2], [3, 4]] output_shape [2, 2, 16, 4, 8] : tensor<2x2x16x32xi8> into tensor<2x2x16x4x8xi8>
+// CHECK-INPUTS-DAG: %[[EXPANDED_RHS:.+]] = tensor.expand_shape %[[RHS]] {{\[}}[0], [1], [2], [3, 4]] output_shape [2, 2, 16, 4, 8] : tensor<2x2x16x32xi8> into tensor<2x2x16x4x8xi8>
+// CHECK-INPUTS: %[[MMA:.+]] = iree_gpu.multi_mma %[[EXPANDED_LHS]], %[[EXPANDED_RHS]], %[[ACC]]
+// CHECK-INPUTS-SAME: lowering_config = #iree_gpu.lowering_config
+// CHECK-INPUTS-SAME: rhs_permutation = array<i64: 1, 2, 0>
+// CHECK-INPUTS-SAME: : tensor<2x2x16x4x8xi8>, tensor<2x2x16x4x8xi8> into tensor<2x2x16x16xi32>
+// CHECK-INPUTS: return %[[MMA]]
+
+// CHECK-RESULT-DAG: %[[EXPANDED_ACC:.+]] = tensor.expand_shape %[[ACC]] {{\[}}[0], [1], [2, 3], [4]] output_shape [2, 2, 4, 4, 16] : tensor<2x2x16x16xi32> into tensor<2x2x4x4x16xi32>
+// CHECK-RESULT: %[[MMA:.+]] = iree_gpu.multi_mma %[[LHS]], %[[RHS]], %[[EXPANDED_ACC]]
+// CHECK-RESULT-SAME: lowering_config = #iree_gpu.lowering_config
+// CHECK-RESULT-SAME: : tensor<2x2x16x32xi8>, tensor<2x2x16x32xi8> into tensor<2x2x4x4x16xi32>
+// CHECK-RESULT: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[MMA]] {{\[}}[0], [1], [2, 3], [4]] : tensor<2x2x4x4x16xi32> into tensor<2x2x16x16xi32>
+// CHECK-RESULT: return %[[COLLAPSED]]
+
+// -----
+
+#contraction_accesses = [
+ affine_map<(i, j, k) -> (i, k)>,
+ affine_map<(i, j, k) -> (k, j)>,
+ affine_map<(i, j, k) -> (i, j)>
+]
+#config = #iree_gpu.lowering_config<{workgroup = [64, 64, 0], reduction = [0, 0, 4], thread = [8, 4]}>
+func.func @concretize_multi_mma_F32_32x32x8_F16(%lhs: tensor<2x2x32x8xf16>, %rhs: tensor<2x2x8x32xf16>, %acc: tensor<2x2x32x32xf32>) -> tensor<2x2x32x32xf32> {
+ %0 = iree_gpu.multi_mma %lhs, %rhs, %acc {
+ indexing_maps = #contraction_accesses,
+ iterator_types = [#iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<reduction>],
+ kind = #iree_gpu.mma_layout<MFMA_F32_32x32x8_F16>, lowering_config = #config
+ } : tensor<2x2x32x8xf16>, tensor<2x2x8x32xf16> into tensor<2x2x32x32xf32>
+ return %0 : tensor<2x2x32x32xf32>
+}
+
+// CHECK-LABEL: func @concretize_multi_mma_F32_32x32x8_F16
+// CHECK-SAME: %[[LHS:[A-Za-z0-9]+]]: tensor<2x2x32x8xf16>
+// CHECK-SAME: %[[RHS:[A-Za-z0-9]+]]: tensor<2x2x8x32xf16>
+// CHECK-SAME: %[[ACC:[A-Za-z0-9]+]]: tensor<2x2x32x32xf32>
+
+// CHECK-INPUTS-DAG: %[[EXPANDED_LHS:.+]] = tensor.expand_shape %[[LHS]] {{\[}}[0], [1], [2], [3, 4]] output_shape [2, 2, 32, 2, 4] : tensor<2x2x32x8xf16> into tensor<2x2x32x2x4xf16>
+// CHECK-INPUTS-DAG: %[[EXPANDED_RHS:.+]] = tensor.expand_shape %[[RHS]] {{\[}}[0], [1], [2, 3], [4]] output_shape [2, 2, 2, 4, 32] : tensor<2x2x8x32xf16> into tensor<2x2x2x4x32xf16>
+// CHECK-INPUTS: %[[MMA:.+]] = iree_gpu.multi_mma %[[EXPANDED_LHS]], %[[EXPANDED_RHS]], %[[ACC]]
+// CHECK-INPUTS-SAME: lowering_config = #iree_gpu.lowering_config
+// CHECK-INPUTS-SAME: : tensor<2x2x32x2x4xf16>, tensor<2x2x2x4x32xf16> into tensor<2x2x32x32xf32>
+// CHECK-INPUTS: return %[[MMA]]
+
+// CHECK-RESULT-DAG: %[[EXPANDED_ACC:.+]] = tensor.expand_shape %[[ACC]] {{\[}}[0], [1], [2, 3, 4], [5]] output_shape [2, 2, 4, 2, 4, 32] : tensor<2x2x32x32xf32> into tensor<2x2x4x2x4x32xf32>
+// CHECK-RESULT: %[[MMA:.+]] = iree_gpu.multi_mma %[[LHS]], %[[RHS]], %[[EXPANDED_ACC]]
+// CHECK-RESULT-SAME: lowering_config = #iree_gpu.lowering_config
+// CHECK-RESULT-SAME: : tensor<2x2x32x8xf16>, tensor<2x2x8x32xf16> into tensor<2x2x4x2x4x32xf32>
+// CHECK-RESULT: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[MMA]] {{\[}}[0], [1], [2, 3, 4], [5]] : tensor<2x2x4x2x4x32xf32> into tensor<2x2x32x32xf32>
+// CHECK-RESULT: return %[[COLLAPSED]]