[Codegen][GPU] Add pass to expand multi_mma op shapes to intrinsic layout (#18139)

This PR adds a new pass to explicitly materialize the dimensions of
intrinsic layouts for `iree_gpu.multi_mma` ops. This means adding an
expand_shape on each of the inputs to go from the `OpaqueMmaLayout`
shape to the `ConcreteMmaLayout` shape. This makes it easy to extract
the correct data from the tensors when it is time to distribute the
multi_mma op to lanes, since the shape will match the number of offsets
and sizes needed for the slice.

---------

Signed-off-by: Max Dawkins <max.dawkins@gmail.com>
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
index ab0d9e1..7ccfccb 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
@@ -24,6 +24,7 @@
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/IR/BuiltinAttributes.h"
@@ -752,6 +753,84 @@
   return success();
 }
 
+LogicalResult MMAAttr::materializeOperandConcreteShape(
+    OpBuilder &builder, IREE::GPU::MMAFragment fragment, Value operand,
+    std::optional<ArrayRef<int64_t>> permutation,
+    SmallVector<ReassociationIndices> &reassociations,
+    RankedTensorType &resultType) const {
+  OpaqueMmaLayout opaqueLayout =
+      getOpaqueMFMALayout(operand.getContext(), getIntrinsic().getValue());
+  // TODO(Max191): The `getConcreteMFMALayout` function creates some
+  //   `PerDimLayoutAttr` that are not used by this function. This means that
+  //   any pass that uses `materializeOperandConcreteShape` needs to be
+  //   dependent on the VectorExt dialect. Ideally, the `getConcreteMFMALayout`
+  //   function should be refactored so we can reuse the shape information of
+  //   the layout without needing to create any `PerDimLayoutAttr`.
+  ConcreteMmaLayout layout =
+      getConcreteMFMALayout(operand.getContext(), getIntrinsic().getValue());
+  SmallVector<ArrayRef<int64_t>> concreteSizes;
+  SmallVector<int64_t, 2> opaqueSizes;
+  switch (fragment) {
+  case IREE::GPU::MMAFragment::Lhs: {
+    concreteSizes.push_back(layout.aMLayout.getShapes());
+    concreteSizes.push_back(layout.aKLayout.getShapes());
+    opaqueSizes.push_back(opaqueLayout.mSize);
+    opaqueSizes.push_back(opaqueLayout.kSize);
+    break;
+  }
+  case IREE::GPU::MMAFragment::Rhs: {
+    concreteSizes.push_back(layout.bKLayout.getShapes());
+    concreteSizes.push_back(layout.bNLayout.getShapes());
+    opaqueSizes.push_back(opaqueLayout.kSize);
+    opaqueSizes.push_back(opaqueLayout.nSize);
+    break;
+  }
+  case IREE::GPU::MMAFragment::Acc: {
+    concreteSizes.push_back(layout.cMLayout.getShapes());
+    concreteSizes.push_back(layout.cNLayout.getShapes());
+    opaqueSizes.push_back(opaqueLayout.mSize);
+    opaqueSizes.push_back(opaqueLayout.nSize);
+    break;
+  }
+  }
+  if (permutation.has_value()) {
+    if (permutation.value().size() != opaqueSizes.size()) {
+      return failure();
+    }
+    applyPermutationToVector(concreteSizes, permutation.value());
+    applyPermutationToVector(opaqueSizes, permutation.value());
+  }
+
+  // Inner tile must have sizes matching the opaque layout.
+  auto operandType = llvm::cast<RankedTensorType>(operand.getType());
+  ArrayRef<int64_t> operandShape = operandType.getShape();
+  SmallVector<int64_t, 2> innerShape(operandShape.end() - opaqueSizes.size(),
+                                     operandShape.end());
+  if (!llvm::equal(opaqueSizes, innerShape)) {
+    return failure();
+  }
+
+  // Expand the shape of the inner tile to reflect the MMA thread layout.
+  SmallVector<int64_t, 4> resultShape(operandShape.begin(),
+                                      operandShape.end() - 2);
+  SmallVector<ReassociationIndices> reInds =
+      llvm::map_to_vector(llvm::seq<int64_t>(resultShape.size()),
+                          [](int64_t idx) -> ReassociationIndices {
+                            return ReassociationIndices({idx});
+                          });
+  int idx = reInds.size();
+  for (ArrayRef<int64_t> sizes : concreteSizes) {
+    resultShape.append(SmallVector<int64_t>(sizes));
+    reInds.push_back(
+        llvm::to_vector(llvm::seq<int64_t>(idx, idx + sizes.size())));
+    idx += sizes.size();
+  }
+
+  reassociations = reInds;
+  resultType = operandType.clone(resultShape);
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // MMA Schedule Attributes
 //===----------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td
index e421f4e..6809bc6 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td
@@ -131,6 +131,7 @@
     "getSubgroupSize",
     "buildMmaOperation",
     "populateOperandOffsetsSizesStrides",
+    "materializeOperandConcreteShape",
   ]>
 ]> {
   let cppNamespace = "::mlir::iree_compiler::IREE::GPU";
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUInterfaces.td b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUInterfaces.td
index d706154..88c345d 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUInterfaces.td
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUInterfaces.td
@@ -127,6 +127,26 @@
         return failure();
       }]
     >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Constructs the offsets/sizes/strides for extracting the per-thread
+        slice of the given operand fragment.
+      }],
+      /*retTy=*/"::mlir::LogicalResult",
+      /*methodName=*/"materializeOperandConcreteShape",
+      /*args=*/(ins
+        "::mlir::OpBuilder&":$builder,
+        "::mlir::iree_compiler::IREE::GPU::MMAFragment":$fragment,
+        "::mlir::Value":$operand,
+        "std::optional<::llvm::ArrayRef<int64_t>>":$permutation,
+        "::llvm::SmallVector<::mlir::SmallVector<int64_t, 2>>&":$reassociations,
+        "::mlir::RankedTensorType&":$result_type
+      ),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        return failure();
+      }]
+    >,
   ];
 }
 
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/BUILD.bazel
index 00af941..84a0303 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/BUILD.bazel
@@ -51,6 +51,7 @@
 iree_compiler_cc_library(
     name = "GPUTransforms",
     srcs = [
+        "ConcretizeMmaShapes.cpp",
         "DistributeMmaToLanes.cpp",
         "FuseAndHoistParallelLoops.cpp",
         "LowerIREEGPUOps.cpp",
@@ -69,6 +70,7 @@
         ":PassesIncGen",
         "//compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR:IREECodegenDialect",
         "//compiler/src/iree/compiler/Codegen/Dialect/GPU/IR:IREEGPUDialect",
+        "//compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR:IREEVectorExtDialect",
         "//compiler/src/iree/compiler/Codegen/Transforms",
         "@llvm-project//llvm:Support",
         "@llvm-project//mlir:AffineDialect",
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/CMakeLists.txt
index 7d49b1d..e563b04 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/CMakeLists.txt
@@ -45,6 +45,7 @@
     "Passes.h.inc"
     "Transforms.h"
   SRCS
+    "ConcretizeMmaShapes.cpp"
     "DistributeMmaToLanes.cpp"
     "FuseAndHoistParallelLoops.cpp"
     "LowerIREEGPUOps.cpp"
@@ -80,6 +81,7 @@
     MLIRVectorUtils
     iree::compiler::Codegen::Dialect::Codegen::IR::IREECodegenDialect
     iree::compiler::Codegen::Dialect::GPU::IR::IREEGPUDialect
+    iree::compiler::Codegen::Dialect::VectorExt::IR::IREEVectorExtDialect
     iree::compiler::Codegen::Transforms
   PUBLIC
 )
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/ConcretizeMmaShapes.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/ConcretizeMmaShapes.cpp
new file mode 100644
index 0000000..9910840
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/ConcretizeMmaShapes.cpp
@@ -0,0 +1,154 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
+#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.h"
+#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.h"
+#include "iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.h"
+#include "iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.h"
+#include "iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtDialect.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir::iree_compiler::IREE::GPU {
+
+#define GEN_PASS_DEF_CONCRETIZEMMASHAPESPASS
+#include "iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.h.inc"
+
+namespace {
+struct ConcretizeMmaShapesPass final
+    : impl::ConcretizeMmaShapesPassBase<ConcretizeMmaShapesPass> {
+  using ConcretizeMmaShapesPassBase::ConcretizeMmaShapesPassBase;
+  void runOnOperation() override;
+};
+} // namespace
+
+struct ConcretizeMmaOperandShape final : OpRewritePattern<MultiMmaOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  ConcretizeMmaOperandShape(MLIRContext *context, MMAFragment fragment)
+      : OpRewritePattern<MultiMmaOp>(context), fragment(fragment) {}
+
+  LogicalResult matchAndRewrite(MultiMmaOp mmaOp,
+                                PatternRewriter &rewriter) const override {
+    if (!mmaOp.hasTensorSemantics()) {
+      return failure();
+    }
+
+    // Get the right operand and permutation for the `fragment`.
+    Value operand;
+    std::optional<ArrayRef<int64_t>> permutation;
+    switch (fragment) {
+    case MMAFragment::Lhs:
+      operand = mmaOp.getLhs();
+      permutation = mmaOp.getLhsPermutation();
+      break;
+    case MMAFragment::Rhs:
+      operand = mmaOp.getRhs();
+      permutation = mmaOp.getRhsPermutation();
+      break;
+    case MMAFragment::Acc:
+      operand = mmaOp.getAcc();
+      permutation = mmaOp.getAccPermutation();
+      break;
+    }
+
+    // Get the reassociation indices and result type of the expand_shape op.
+    MmaInterfaceAttr kind = mmaOp.getKind();
+    SmallVector<ReassociationIndices> reassociations;
+    RankedTensorType concreteType;
+    if (failed(kind.materializeOperandConcreteShape(rewriter, fragment, operand,
+                                                    permutation, reassociations,
+                                                    concreteType))) {
+      return failure();
+    }
+
+    // Create the expand_shape.
+    Location loc = mmaOp->getLoc();
+    Value concreteOperand = rewriter
+                                .create<tensor::ExpandShapeOp>(
+                                    loc, concreteType, operand, reassociations)
+                                .getResult();
+
+    // Expand the permutation for the new inner dimensions of the expanded
+    // multi_mma operand.
+    auto expandPerm =
+        [&](std::optional<ArrayRef<int64_t>> perm, MMAFragment frag,
+            int64_t outerRank) -> std::optional<DenseI64ArrayAttr> {
+      if (!perm.has_value()) {
+        return std::nullopt;
+      }
+      if (frag != fragment) {
+        return rewriter.getDenseI64ArrayAttr(perm.value());
+      }
+      SmallVector<ReassociationIndices> innerReInds(
+          reassociations.begin() + outerRank, reassociations.end());
+      for (auto &reInd : innerReInds) {
+        for (auto &idx : reInd) {
+          idx -= outerRank;
+        }
+      }
+      SmallVector<int64_t> expandedPerm;
+      for (auto reInd : applyPermutation(innerReInds, perm.value())) {
+        expandedPerm.append(reInd);
+      }
+      return rewriter.getDenseI64ArrayAttr(expandedPerm);
+    };
+    std::optional<DenseI64ArrayAttr> lhsPerm = expandPerm(
+        mmaOp.getLhsPermutation(), MMAFragment::Lhs, mmaOp.getLhsOuterRank());
+    std::optional<DenseI64ArrayAttr> rhsPerm = expandPerm(
+        mmaOp.getRhsPermutation(), MMAFragment::Rhs, mmaOp.getRhsOuterRank());
+    std::optional<DenseI64ArrayAttr> accPerm = expandPerm(
+        mmaOp.getAccPermutation(), MMAFragment::Acc, mmaOp.getAccOuterRank());
+
+    // Create the new multi_mma op with the concrete type.
+    auto concreteMmaOp = rewriter.create<MultiMmaOp>(
+        loc,
+        /*lhs=*/fragment == MMAFragment::Lhs ? concreteOperand : mmaOp.getLhs(),
+        /*rhs=*/fragment == MMAFragment::Rhs ? concreteOperand : mmaOp.getRhs(),
+        /*acc=*/fragment == MMAFragment::Acc ? concreteOperand : mmaOp.getAcc(),
+        mmaOp.getIndexingMaps(), mmaOp.getIteratorTypes(), mmaOp.getKind(),
+        lhsPerm, rhsPerm, accPerm);
+
+    if (auto config = getLoweringConfig(mmaOp)) {
+      setLoweringConfig(concreteMmaOp, config);
+    }
+
+    if (fragment != MMAFragment::Acc) {
+      rewriter.replaceOp(mmaOp, concreteMmaOp);
+      return success();
+    }
+
+    // For the Acc operand, the result needs to be collapsed back to the
+    // original type so that types match with consumers.
+    rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
+        mmaOp, mmaOp.getAccType(), concreteMmaOp.getResult(), reassociations);
+
+    return success();
+  }
+
+private:
+  MMAFragment fragment;
+};
+
+void ConcretizeMmaShapesPass::runOnOperation() {
+  MLIRContext *context = &getContext();
+  auto funcOp = getOperation();
+
+  RewritePatternSet patterns(context);
+  if (concretizeInputs) {
+    patterns.insert<ConcretizeMmaOperandShape>(context, MMAFragment::Lhs);
+    patterns.insert<ConcretizeMmaOperandShape>(context, MMAFragment::Rhs);
+  }
+  if (concretizeResult) {
+    patterns.insert<ConcretizeMmaOperandShape>(context, MMAFragment::Acc);
+  }
+  if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
+    return signalPassFailure();
+  }
+}
+
+} // namespace mlir::iree_compiler::IREE::GPU
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.td b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.td
index 9f64563..6cc7f11 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.td
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.td
@@ -20,6 +20,24 @@
   ];
 }
 
+def ConcretizeMmaShapesPass :
+    InterfacePass<"iree-gpu-concretize-mma-shapes", "mlir::FunctionOpInterface"> {
+  let summary = "Expands the inner dimensions of iree_gpu.multi_mma ops to match the thread layout";
+  let dependentDialects = [
+    "::mlir::tensor::TensorDialect",
+    "::mlir::iree_compiler::IREE::GPU::IREEGPUDialect",
+    "::mlir::iree_compiler::IREE::VectorExt::IREEVectorExtDialect",
+  ];
+  let options = [
+    Option<"concretizeInputs", "concretize-inputs",
+      "bool", /*default=*/"true",
+      "Expand the inner dimensions for the lhs and rhs operands of the multi_mma ops.">,
+    Option<"concretizeResult", "concretize-result",
+      "bool", /*default=*/"true",
+      "Expand the inner dimensions for the acc operand of the multi_mma ops.">,
+  ];
+}
+
 def FuseAndHoistParallelLoopsPass :
     InterfacePass<"iree-gpu-fuse-and-hoist-parallel-loops", "mlir::FunctionOpInterface"> {
   let summary = "Greedily fuses and hoists parallel loops.";
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/BUILD.bazel
index 5e9c0e7..8348d9f 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/BUILD.bazel
@@ -18,6 +18,7 @@
     name = "lit",
     srcs = enforce_glob(
         [
+            "concretize_mma_shapes.mlir",
             "distribute_mma_to_lanes.mlir",
             "fuse_and_hoist_forall.mlir",
             "pack_to_intrinsics.mlir",
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/CMakeLists.txt
index ef55e3d..a71fd9c 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/CMakeLists.txt
@@ -14,6 +14,7 @@
   NAME
     lit
   SRCS
+    "concretize_mma_shapes.mlir"
     "distribute_mma_to_lanes.mlir"
     "fuse_and_hoist_forall.mlir"
     "pack_to_intrinsics.mlir"
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/concretize_mma_shapes.mlir b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/concretize_mma_shapes.mlir
new file mode 100644
index 0000000..990bfea
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/concretize_mma_shapes.mlir
@@ -0,0 +1,110 @@
+// RUN: iree-opt %s --pass-pipeline='builtin.module(func.func(iree-gpu-concretize-mma-shapes{concretize-result=false}, canonicalize, cse))' --split-input-file | FileCheck %s -check-prefixes=CHECK,CHECK-INPUTS
+// RUN: iree-opt %s --pass-pipeline='builtin.module(func.func(iree-gpu-concretize-mma-shapes{concretize-inputs=false}, canonicalize, cse))' --split-input-file | FileCheck %s -check-prefixes=CHECK,CHECK-RESULT
+
+#contraction_accesses = [
+ affine_map<(i, j, k) -> (i, k)>,
+ affine_map<(i, j, k) -> (k, j)>,
+ affine_map<(i, j, k) -> (i, j)>
+]
+#config = #iree_gpu.lowering_config<{workgroup = [64, 64, 0], reduction = [0, 0, 4], thread = [8, 4]}>
+func.func @concretize_multi_mma_F32_16x16x16_F16(%lhs: tensor<2x2x16x16xf16>, %rhs: tensor<2x2x16x16xf16>, %acc: tensor<2x2x16x16xf32>) -> tensor<2x2x16x16xf32> {
+  %0 = iree_gpu.multi_mma %lhs, %rhs, %acc {
+    indexing_maps = #contraction_accesses,
+    iterator_types = [#iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<reduction>],
+    kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>, lowering_config = #config
+  } : tensor<2x2x16x16xf16>, tensor<2x2x16x16xf16> into tensor<2x2x16x16xf32>
+  return %0 : tensor<2x2x16x16xf32>
+}
+
+// CHECK-LABEL:       func @concretize_multi_mma_F32_16x16x16_F16
+// CHECK-SAME:          %[[LHS:[A-Za-z0-9]+]]: tensor<2x2x16x16xf16>
+// CHECK-SAME:          %[[RHS:[A-Za-z0-9]+]]: tensor<2x2x16x16xf16>
+// CHECK-SAME:          %[[ACC:[A-Za-z0-9]+]]: tensor<2x2x16x16xf32>
+
+// CHECK-INPUTS-DAG:    %[[EXPANDED_LHS:.+]] = tensor.expand_shape %[[LHS]] {{\[}}[0], [1], [2], [3, 4]] output_shape [2, 2, 16, 4, 4] : tensor<2x2x16x16xf16> into tensor<2x2x16x4x4xf16>
+// CHECK-INPUTS-DAG:    %[[EXPANDED_RHS:.+]] = tensor.expand_shape %[[RHS]] {{\[}}[0], [1], [2, 3], [4]] output_shape [2, 2, 4, 4, 16] : tensor<2x2x16x16xf16> into tensor<2x2x4x4x16xf16>
+// CHECK-INPUTS:        %[[MMA:.+]] = iree_gpu.multi_mma %[[EXPANDED_LHS]], %[[EXPANDED_RHS]], %[[ACC]]
+// CHECK-INPUTS-SAME:     lowering_config = #iree_gpu.lowering_config
+// CHECK-INPUTS-SAME:     : tensor<2x2x16x4x4xf16>, tensor<2x2x4x4x16xf16> into tensor<2x2x16x16xf32>
+// CHECK-INPUTS:        return %[[MMA]]
+
+// CHECK-RESULT-DAG:    %[[EXPANDED_ACC:.+]] = tensor.expand_shape %[[ACC]] {{\[}}[0], [1], [2, 3], [4]] output_shape [2, 2, 4, 4, 16] : tensor<2x2x16x16xf32> into tensor<2x2x4x4x16xf32>
+// CHECK-RESULT:        %[[MMA:.+]] = iree_gpu.multi_mma %[[LHS]], %[[RHS]], %[[EXPANDED_ACC]]
+// CHECK-RESULT-SAME:     lowering_config = #iree_gpu.lowering_config
+// CHECK-RESULT-SAME:     : tensor<2x2x16x16xf16>, tensor<2x2x16x16xf16> into tensor<2x2x4x4x16xf32>
+// CHECK-RESULT:        %[[COLLAPSED:.+]] = tensor.collapse_shape %[[MMA]] {{\[}}[0], [1], [2, 3], [4]] : tensor<2x2x4x4x16xf32> into tensor<2x2x16x16xf32>
+// CHECK-RESULT:        return %[[COLLAPSED]]
+
+// -----
+
+#contraction_accesses = [
+ affine_map<(i, j, k) -> (i, k)>,
+ affine_map<(i, j, k) -> (j, k)>,
+ affine_map<(i, j, k) -> (i, j)>
+]
+#config = #iree_gpu.lowering_config<{workgroup = [64, 64, 0], reduction = [0, 0, 4], thread = [8, 4]}>
+func.func @concretize_multi_mma_I32_16x16x32_I8(%lhs: tensor<2x2x16x32xi8>, %rhs: tensor<2x2x16x32xi8>, %acc: tensor<2x2x16x16xi32>) -> tensor<2x2x16x16xi32> {
+  %0 = iree_gpu.multi_mma %lhs, %rhs, %acc {
+    indexing_maps = #contraction_accesses,
+    iterator_types = [#iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<reduction>],
+    kind = #iree_gpu.mma_layout<MFMA_I32_16x16x32_I8>,
+    rhs_permutation = array<i64: 1, 0>, lowering_config = #config
+  } : tensor<2x2x16x32xi8>, tensor<2x2x16x32xi8> into tensor<2x2x16x16xi32>
+  return %0 : tensor<2x2x16x16xi32>
+}
+
+// CHECK-LABEL:       func @concretize_multi_mma_I32_16x16x32_I8
+// CHECK-SAME:          %[[LHS:[A-Za-z0-9]+]]: tensor<2x2x16x32xi8>
+// CHECK-SAME:          %[[RHS:[A-Za-z0-9]+]]: tensor<2x2x16x32xi8>
+// CHECK-SAME:          %[[ACC:[A-Za-z0-9]+]]: tensor<2x2x16x16xi32>
+
+// CHECK-INPUTS-DAG:    %[[EXPANDED_LHS:.+]] = tensor.expand_shape %[[LHS]] {{\[}}[0], [1], [2], [3, 4]] output_shape [2, 2, 16, 4, 8] : tensor<2x2x16x32xi8> into tensor<2x2x16x4x8xi8>
+// CHECK-INPUTS-DAG:    %[[EXPANDED_RHS:.+]] = tensor.expand_shape %[[RHS]] {{\[}}[0], [1], [2], [3, 4]] output_shape [2, 2, 16, 4, 8] : tensor<2x2x16x32xi8> into tensor<2x2x16x4x8xi8>
+// CHECK-INPUTS:        %[[MMA:.+]] = iree_gpu.multi_mma %[[EXPANDED_LHS]], %[[EXPANDED_RHS]], %[[ACC]]
+// CHECK-INPUTS-SAME:     lowering_config = #iree_gpu.lowering_config
+// CHECK-INPUTS-SAME:     rhs_permutation = array<i64: 1, 2, 0>
+// CHECK-INPUTS-SAME:     : tensor<2x2x16x4x8xi8>, tensor<2x2x16x4x8xi8> into tensor<2x2x16x16xi32>
+// CHECK-INPUTS:        return %[[MMA]]
+
+// CHECK-RESULT-DAG:    %[[EXPANDED_ACC:.+]] = tensor.expand_shape %[[ACC]] {{\[}}[0], [1], [2, 3], [4]] output_shape [2, 2, 4, 4, 16] : tensor<2x2x16x16xi32> into tensor<2x2x4x4x16xi32>
+// CHECK-RESULT:        %[[MMA:.+]] = iree_gpu.multi_mma %[[LHS]], %[[RHS]], %[[EXPANDED_ACC]]
+// CHECK-RESULT-SAME:     lowering_config = #iree_gpu.lowering_config
+// CHECK-RESULT-SAME:     : tensor<2x2x16x32xi8>, tensor<2x2x16x32xi8> into tensor<2x2x4x4x16xi32>
+// CHECK-RESULT:        %[[COLLAPSED:.+]] = tensor.collapse_shape %[[MMA]] {{\[}}[0], [1], [2, 3], [4]] : tensor<2x2x4x4x16xi32> into tensor<2x2x16x16xi32>
+// CHECK-RESULT:        return %[[COLLAPSED]]
+
+// -----
+
+#contraction_accesses = [
+ affine_map<(i, j, k) -> (i, k)>,
+ affine_map<(i, j, k) -> (k, j)>,
+ affine_map<(i, j, k) -> (i, j)>
+]
+#config = #iree_gpu.lowering_config<{workgroup = [64, 64, 0], reduction = [0, 0, 4], thread = [8, 4]}>
+func.func @concretize_multi_mma_F32_32x32x8_F16(%lhs: tensor<2x2x32x8xf16>, %rhs: tensor<2x2x8x32xf16>, %acc: tensor<2x2x32x32xf32>) -> tensor<2x2x32x32xf32> {
+  %0 = iree_gpu.multi_mma %lhs, %rhs, %acc {
+    indexing_maps = #contraction_accesses,
+    iterator_types = [#iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<reduction>],
+    kind = #iree_gpu.mma_layout<MFMA_F32_32x32x8_F16>, lowering_config = #config
+  } : tensor<2x2x32x8xf16>, tensor<2x2x8x32xf16> into tensor<2x2x32x32xf32>
+  return %0 : tensor<2x2x32x32xf32>
+}
+
+// CHECK-LABEL:       func @concretize_multi_mma_F32_32x32x8_F16
+// CHECK-SAME:          %[[LHS:[A-Za-z0-9]+]]: tensor<2x2x32x8xf16>
+// CHECK-SAME:          %[[RHS:[A-Za-z0-9]+]]: tensor<2x2x8x32xf16>
+// CHECK-SAME:          %[[ACC:[A-Za-z0-9]+]]: tensor<2x2x32x32xf32>
+
+// CHECK-INPUTS-DAG:    %[[EXPANDED_LHS:.+]] = tensor.expand_shape %[[LHS]] {{\[}}[0], [1], [2], [3, 4]] output_shape [2, 2, 32, 2, 4] : tensor<2x2x32x8xf16> into tensor<2x2x32x2x4xf16>
+// CHECK-INPUTS-DAG:    %[[EXPANDED_RHS:.+]] = tensor.expand_shape %[[RHS]] {{\[}}[0], [1], [2, 3], [4]] output_shape [2, 2, 2, 4, 32] : tensor<2x2x8x32xf16> into tensor<2x2x2x4x32xf16>
+// CHECK-INPUTS:        %[[MMA:.+]] = iree_gpu.multi_mma %[[EXPANDED_LHS]], %[[EXPANDED_RHS]], %[[ACC]]
+// CHECK-INPUTS-SAME:     lowering_config = #iree_gpu.lowering_config
+// CHECK-INPUTS-SAME:     : tensor<2x2x32x2x4xf16>, tensor<2x2x2x4x32xf16> into tensor<2x2x32x32xf32>
+// CHECK-INPUTS:        return %[[MMA]]
+
+// CHECK-RESULT-DAG:    %[[EXPANDED_ACC:.+]] = tensor.expand_shape %[[ACC]] {{\[}}[0], [1], [2, 3, 4], [5]] output_shape [2, 2, 4, 2, 4, 32] : tensor<2x2x32x32xf32> into tensor<2x2x4x2x4x32xf32>
+// CHECK-RESULT:        %[[MMA:.+]] = iree_gpu.multi_mma %[[LHS]], %[[RHS]], %[[EXPANDED_ACC]]
+// CHECK-RESULT-SAME:     lowering_config = #iree_gpu.lowering_config
+// CHECK-RESULT-SAME:     : tensor<2x2x32x8xf16>, tensor<2x2x8x32xf16> into tensor<2x2x4x2x4x32xf32>
+// CHECK-RESULT:        %[[COLLAPSED:.+]] = tensor.collapse_shape %[[MMA]] {{\[}}[0], [1], [2, 3, 4], [5]] : tensor<2x2x4x2x4x32xf32> into tensor<2x2x32x32xf32>
+// CHECK-RESULT:        return %[[COLLAPSED]]