[rocdl] Allow upcast accumulator to use matrix core (#16527)

This commit add a pass to upcast vector.contract accumulator to target
matrix core instructions, and changed kernel configuration to enable it.
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.cpp
index b09a817..7cecbc9 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.cpp
@@ -9,6 +9,7 @@
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/MathExtras.h"
 #include "llvm/Support/raw_ostream.h"
+#include "mlir/IR/BuiltinTypes.h"
 
 #define DEBUG_TYPE "iree-codegen-gpu-heuristics"
 
@@ -19,12 +20,20 @@
 std::optional<GPUMMASchedule>
 deduceMMASchedule(const GPUMatmulShapeType &problem,
                   ArrayRef<GPUMatmulShapeType> intrinsics,
-                  const GPUMMAHeuristicSeeds &seeds) {
+                  const GPUMMAHeuristicSeeds &seeds, bool canUpcastAcc) {
   for (auto [index, intrinsic] : llvm::enumerate(intrinsics)) {
-    if (problem.aType != intrinsic.aType || problem.bType != intrinsic.bType ||
-        problem.cType != intrinsic.cType) {
+    if (problem.aType != intrinsic.aType || problem.bType != intrinsic.bType) {
       continue; // Cannot use this intrinsic for mismatched types
     }
+    if (problem.cType != intrinsic.cType) {
+      auto isFpCase =
+          isa<FloatType>(problem.cType) && isa<FloatType>(intrinsic.cType);
+      auto isUpcast = problem.cType.getIntOrFloatBitWidth() <
+                      intrinsic.cType.getIntOrFloatBitWidth();
+      if (!(canUpcastAcc && isFpCase && isUpcast)) {
+        continue; // Cannot use this intrinsic if not upcasting
+      }
+    }
 
     if (problem.mSize % intrinsic.mSize != 0 ||
         problem.nSize % intrinsic.nSize != 0 ||
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.h b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.h
index 6bf4ff8..f70a428 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.h
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.h
@@ -49,6 +49,6 @@
 std::optional<GPUMMASchedule>
 deduceMMASchedule(const GPUMatmulShapeType &problem,
                   ArrayRef<GPUMatmulShapeType> intrinsics,
-                  const GPUMMAHeuristicSeeds &seeds);
+                  const GPUMMAHeuristicSeeds &seeds, bool canUpcastAcc = false);
 
 } // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel
index 447a592..f739255 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel
@@ -94,6 +94,7 @@
         "ExtractAddressComputationGPUPass.cpp",
         "KernelConfig.cpp",
         "LLVMGPUCastAddressSpaceFunction.cpp",
+        "LLVMGPUCastTypeToFitMMA.cpp",
         "LLVMGPULowerExecutableTarget.cpp",
         "LLVMGPUPackSharedMemoryAlloc.cpp",
         "LLVMGPUSelectLoweringStrategy.cpp",
@@ -133,6 +134,7 @@
         "//compiler/src/iree/compiler/Codegen/TransformStrategies/GPU",
         "//compiler/src/iree/compiler/Codegen/Transforms",
         "//compiler/src/iree/compiler/Codegen/Utils",
+        "//compiler/src/iree/compiler/Codegen/Utils:VectorOpUtils",
         "//compiler/src/iree/compiler/Dialect/Flow/IR",
         "//compiler/src/iree/compiler/Dialect/HAL/IR",
         "//compiler/src/iree/compiler/Dialect/LinalgExt/IR",
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt
index 517b554..3a6fe83 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt
@@ -79,6 +79,7 @@
     "ExtractAddressComputationGPUPass.cpp"
     "KernelConfig.cpp"
     "LLVMGPUCastAddressSpaceFunction.cpp"
+    "LLVMGPUCastTypeToFitMMA.cpp"
     "LLVMGPULowerExecutableTarget.cpp"
     "LLVMGPUPackSharedMemoryAlloc.cpp"
     "LLVMGPUSelectLoweringStrategy.cpp"
@@ -174,6 +175,7 @@
     iree::compiler::Codegen::TransformStrategies::GPU
     iree::compiler::Codegen::Transforms
     iree::compiler::Codegen::Utils
+    iree::compiler::Codegen::Utils::VectorOpUtils
     iree::compiler::Dialect::Flow::IR
     iree::compiler::Dialect::HAL::IR
     iree::compiler::Dialect::LinalgExt::IR
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
index 04328b8..bb1ace1 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
@@ -347,9 +347,15 @@
                              /*bestMNTileCountPerSubgroup=*/8,
                              /*bestKTileCountPerSubgroup=*/2};
 
+  // First try to find a schedule with an exactly matching intrinsic.
   std::optional<GPUMMASchedule> schedule =
       deduceMMASchedule(problem, intrinsics, seeds);
   if (!schedule) {
+    // Then try again by allowing upcasting accumulator.
+    schedule =
+        deduceMMASchedule(problem, intrinsics, seeds, /*canUpcastAcc=*/true);
+  }
+  if (!schedule) {
     return failure();
   }
 
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUCastTypeToFitMMA.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUCastTypeToFitMMA.cpp
new file mode 100644
index 0000000..8ab4ef4
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUCastTypeToFitMMA.cpp
@@ -0,0 +1,128 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h"
+#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUInterfaces.h"
+#include "iree/compiler/Codegen/LLVMGPU/PassDetail.h"
+#include "iree/compiler/Codegen/LLVMGPU/Passes.h"
+#include "iree/compiler/Codegen/Utils/VectorOpUtils.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir::iree_compiler {
+
+namespace {
+
+struct UpcastContractOutput : OpRewritePattern<vector::ContractionOp> {
+  UpcastContractOutput(MLIRContext *context, IREE::GPU::MmaAttr intrinsic,
+                       PatternBenefit benefit = 1)
+      : OpRewritePattern(context, benefit), intrinsic(intrinsic) {}
+
+  LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
+                                PatternRewriter &rewriter) const override {
+    VectorContractOpInfo opInfo(contractOp);
+    if (opInfo.getOpKind() == VectorContractOpInfo::OpKind::UNKNOWN) {
+      return rewriter.notifyMatchFailure(contractOp, "unhandled contract kind");
+    }
+
+    auto srcCType = dyn_cast<VectorType>(contractOp.getAccType());
+    if (!srcCType) {
+      return rewriter.notifyMatchFailure(contractOp, "unhandled scalar case");
+    }
+    auto srcAType = contractOp.getLhsType();
+    auto srcBType = contractOp.getRhsType();
+
+    auto [dstAElemType, dstBElemType, dstCElemType] =
+        intrinsic.getABCElementTypes();
+    auto [dstM, dstN, dstK] = intrinsic.getMNKShape();
+
+    auto srcCElemFType = dyn_cast<FloatType>(srcCType.getElementType());
+    auto dstCElemFType = dyn_cast<FloatType>(dstCElemType);
+    if (!srcCElemFType || !dstCElemFType ||
+        srcCElemFType.getWidth() >= dstCElemFType.getWidth()) {
+      return rewriter.notifyMatchFailure(
+          contractOp, "unhandled non-floating point or non-upcasting case");
+    }
+
+    if (srcAType.getElementType() != dstAElemType ||
+        srcBType.getElementType() != dstBElemType) {
+      return rewriter.notifyMatchFailure(contractOp, "a/b type mismatch");
+    }
+
+    auto [srcCMIndex, srcCNIndex] = *opInfo.getResultMNIndex();
+    auto [srcAKIndex, srcBKIndex] = *opInfo.getOperandKIndex();
+    int64_t srcM = srcCType.getShape()[srcCMIndex];
+    int64_t srcN = srcCType.getShape()[srcCNIndex];
+    int64_t srcK = srcAType.getShape()[srcAKIndex];
+
+    if (srcM % dstM != 0 || srcN % dstN != 0 || srcK % dstK != 0) {
+      return rewriter.notifyMatchFailure(contractOp, "shape cannot divide");
+    }
+
+    Location loc = contractOp.getLoc();
+    auto dstCType = srcCType.clone(dstCElemFType);
+    auto extOp =
+        rewriter.create<arith::ExtFOp>(loc, dstCType, contractOp.getAcc());
+    auto newContractOp = rewriter.create<vector::ContractionOp>(
+        loc, contractOp.getLhs(), contractOp.getRhs(), extOp,
+        contractOp.getIndexingMaps(), contractOp.getIteratorTypes());
+    rewriter.replaceOpWithNewOp<arith::TruncFOp>(contractOp, srcCType,
+                                                 newContractOp);
+    return success();
+  }
+
+private:
+  IREE::GPU::MmaAttr intrinsic;
+};
+
+struct LLVMGPUCastTypeToFitMMAPass
+    : public LLVMGPUCastTypeToFitMMABase<LLVMGPUCastTypeToFitMMAPass> {
+public:
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<vector::VectorDialect>();
+    registry.insert<arith::ArithDialect>();
+  }
+
+  void runOnOperation() override {
+    auto func = getOperation();
+
+    llvm::StringLiteral scheduleAttrName =
+        IREE::GPU::MMAScheduleAttr::getMnemonic();
+    auto scheduleAttr =
+        func->getAttrOfType<IREE::GPU::MMAScheduleAttr>(scheduleAttrName);
+    if (!scheduleAttr) {
+      DictionaryAttr configDict = getTranslationInfo(func).getConfiguration();
+      scheduleAttr = dyn_cast_or_null<IREE::GPU::MMAScheduleAttr>(
+          configDict.get(scheduleAttrName));
+    }
+    if (!scheduleAttr) {
+      func.emitError() << "missing mma_schedule\n";
+      return signalPassFailure();
+    }
+
+    MLIRContext *context = &getContext();
+    RewritePatternSet patterns(context);
+    patterns.add<UpcastContractOutput>(context, scheduleAttr.getIntrinsic());
+
+    if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns)))) {
+      return signalPassFailure();
+    }
+  }
+};
+} // namespace
+std::unique_ptr<InterfacePass<FunctionOpInterface>>
+createLLVMGPUCastTypeToFitMMAPass() {
+  return std::make_unique<LLVMGPUCastTypeToFitMMAPass>();
+}
+
+} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
index 255115a..549055c 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
@@ -512,6 +512,8 @@
       createHoistStaticallyBoundAllocationsPass());
 
   // Vector SIMD -> Vector SIMT
+  nestedModulePM.addNestedPass<func::FuncOp>(
+      createLLVMGPUCastTypeToFitMMAPass());
   nestedModulePM.addNestedPass<func::FuncOp>(createLLVMGPUVectorDistribute());
   nestedModulePM.addPass(createCanonicalizerPass());
   nestedModulePM.addPass(createCSEPass());
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.h b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.h
index 587dfc5..7f84b6e 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.h
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.h
@@ -79,6 +79,11 @@
 std::unique_ptr<OperationPass<ModuleOp>>
 createLLVMGPUCastAddressSpaceFunction();
 
+/// Perform type extension/truncation over vector.contract types to target GPU
+/// MMA intrinsics.
+std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
+createLLVMGPUCastTypeToFitMMAPass();
+
 std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
 createLLVMGPUDistribute();
 
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.td b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.td
index dc6b452..d0cb05b 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.td
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.td
@@ -49,10 +49,17 @@
 
 def LLVMGPUCastAddressSpaceFunction :
     Pass<"iree-llvmgpu-cast-address-space-function", "ModuleOp"> {
-  let summary = "Pass to cast";
+  let summary = "Cast address space to generic in CallOp and FuncOp";
   let constructor = "mlir::iree_compiler::createLLVMGPUCastAddressSpaceFunction()";
 }
 
+def LLVMGPUCastTypeToFitMMA : InterfacePass<"iree-llvmgpu-cast-type-to-fit-mma",
+                                            "mlir::FunctionOpInterface"> {
+  let summary = "Perform type extension/truncation over vector.contract types "
+                "to target GPU MMA intrinsics";
+  let constructor = "mlir::iree_compiler::createLLVMGPUCastTypeToFitMMAPass()";
+}
+
 def LLVMGPULowerExecutableTarget :
     Pass<"iree-llvmgpu-lower-executable-target", "mlir::iree_compiler::IREE::HAL::ExecutableVariantOp"> {
   let summary = "Perform lowering of executable target using one of the IREE::HAL::DispatchLoweringPassPipeline";
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel
index 7f7aba6..7e576e5 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel
@@ -30,6 +30,7 @@
             "distribute_to_thread.mlir",
             "elementwise_pipeline.mlir",
             "cast_address_space_function.mlir",
+            "cast_type_to_fit_mma.mlir",
             "config_matvec.mlir",
             "extract_address_computation_gpu.mlir",
             "gpu_set_num_workgroups.mlir",
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt
index 6ecee2e..1920881 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt
@@ -19,6 +19,7 @@
     "attention.mlir"
     "attention_mfma.mlir"
     "cast_address_space_function.mlir"
+    "cast_type_to_fit_mma.mlir"
     "config_matvec.mlir"
     "conv_pipeline_test.mlir"
     "convert_to_nvvm.mlir"
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute.mlir
index a8a1665..5abc043 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute.mlir
@@ -11,19 +11,19 @@
     #hal.descriptor_set.binding<1, storage_buffer>
   ]>
 ]>
-hal.executable @matmul_256x256x256 {
+hal.executable @matmul_256x256x256_f16_f32 {
 hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb", {
       target_arch = "gfx940",
       mma_intrinsics = [#iree_gpu.mfma_layout<F16_16x16x16_F32>,
                         #iree_gpu.mfma_layout<F16_32x32x8_F32>]
   }>) {
-  hal.executable.export @matmul_256x256x256 layout(#pipeline_layout) {
+  hal.executable.export @matmul_256x256x256_f16_f32 layout(#pipeline_layout) {
     ^bb0(%arg0: !hal.device, %arg1: index, %arg2 : index):
       %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2
       hal.return %x, %y, %z : index, index, index
     }
   builtin.module {
-    func.func @matmul_256x256x256() {
+    func.func @matmul_256x256x256_f16_f32() {
       %cst = arith.constant 0.000000e+00 : f32
       %c0 = arith.constant 0 : index
       %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<256x256xf16>>
@@ -47,15 +47,70 @@
 //  CHECK-SAME:   mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mfma_layout<F16_16x16x16_F32>,
 //  CHECK-SAME:     subgroup_m_count = 2, subgroup_n_count = 2, subgroup_m_tile_count = 2, subgroup_n_tile_count = 4, subgroup_k_tile_count = 2>
 
-// CHECK-LABEL: hal.executable.export public @matmul_256x256x256
+// CHECK-LABEL: hal.executable.export public @matmul_256x256x256_f16_f32
 //  CHECK-SAME:    subgroup_size = 64
 //  CHECK-SAME:    translation_info = #[[$TRANSLATION]]
 //  CHECK-SAME:    workgroup_size = [128 : index, 2 : index, 1 : index]
 
-//    CHECK-LABEL: func.func @matmul_256x256x256
+//    CHECK-LABEL: func.func @matmul_256x256x256_f16_f32
 //          CHECK:   scf.for {{.*}} = %c0 to %c256 step %c32 iter_args({{.*}}) -> (vector<2x4x1x1x1x4xf32>)
 // Each subgroup handles 2 * 4 tiles, and for each tile we accumulate 2 times
 // along the K dimension. So in total 16 mfma ops.
 // CHECK-COUNT-16:     amdgpu.mfma {{.*}} {blocks = 1 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32} blgp =  none : vector<4xf16>, vector<4xf16>, vector<4xf32>
 //          CHECK:     scf.yield %{{.+}} : vector<2x4x1x1x1x4xf32>
 //  CHECK-COUNT-8:   vector.transfer_write {{.+}} {in_bounds = [true, true]} : vector<4x1xf32>, memref<256x256xf32, #hal.descriptor_type<storage_buffer>>
+
+// -----
+
+#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
+  #hal.descriptor_set.layout<0, bindings = [
+    #hal.descriptor_set.binding<0, storage_buffer>,
+    #hal.descriptor_set.binding<1, storage_buffer>
+  ]>
+]>
+hal.executable @matmul_256x256x256_f16_f16 {
+hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb", {
+      target_arch = "gfx940",
+      mma_intrinsics = [#iree_gpu.mfma_layout<F16_16x16x16_F32>,
+                        #iree_gpu.mfma_layout<F16_32x32x8_F32>]
+  }>) {
+  hal.executable.export @matmul_256x256x256_f16_f16 layout(#pipeline_layout) {
+    ^bb0(%arg0: !hal.device, %arg1: index, %arg2 : index):
+      %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2
+      hal.return %x, %y, %z : index, index, index
+    }
+  builtin.module {
+    func.func @matmul_256x256x256_f16_f16() {
+      %cst = arith.constant 0.000000e+00 : f16
+      %c0 = arith.constant 0 : index
+      %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<256x256xf16>>
+      %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<256x256xf16>>
+      %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<256x256xf16>>
+      %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<256x256xf16>> -> tensor<256x256xf16>
+      %4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<256x256xf16>> -> tensor<256x256xf16>
+      %5 = tensor.empty() : tensor<256x256xf16>
+      %6 = linalg.fill ins(%cst : f16) outs(%5 : tensor<256x256xf16>) -> tensor<256x256xf16>
+      %7 = linalg.matmul ins(%3, %4 : tensor<256x256xf16>, tensor<256x256xf16>) outs(%6 : tensor<256x256xf16>) -> tensor<256x256xf16>
+      flow.dispatch.tensor.store %7, %2, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : tensor<256x256xf16> -> !flow.dispatch.tensor<writeonly:tensor<256x256xf16>>
+      return
+    }
+  }
+}
+}
+
+//       CHECK: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info<LLVMGPUVectorDistribute
+//  CHECK-SAME:   mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mfma_layout<F16_16x16x16_F32>,
+//  CHECK-SAME:     subgroup_m_count = 2, subgroup_n_count = 2, subgroup_m_tile_count = 2, subgroup_n_tile_count = 4, subgroup_k_tile_count = 2>
+
+// CHECK-LABEL: hal.executable.export public @matmul_256x256x256_f16_f16
+//  CHECK-SAME:    subgroup_size = 64
+//  CHECK-SAME:    translation_info = #[[$TRANSLATION]]
+//  CHECK-SAME:    workgroup_size = [128 : index, 2 : index, 1 : index]
+
+//    CHECK-LABEL: func.func @matmul_256x256x256_f16_f16
+//          CHECK:   scf.for {{.*}} = %c0 to %c256 step %c32 iter_args(%[[ARG:.+]] = {{.*}}) -> (vector<2x4x1x1x1x4xf16>)
+//          CHECK:     arith.extf %[[ARG]] : vector<2x4x1x1x1x4xf16> to vector<2x4x1x1x1x4xf32>
+// CHECK-COUNT-16:     amdgpu.mfma {{.*}} {blocks = 1 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32} blgp =  none : vector<4xf16>, vector<4xf16>, vector<4xf32>
+//          CHECK:     %[[TRUNC:.+]] = arith.truncf %157 : vector<2x4x1x1x1x4xf32> to vector<2x4x1x1x1x4xf16>
+//          CHECK:     scf.yield %[[TRUNC]] : vector<2x4x1x1x1x4xf16>
+//  CHECK-COUNT-8:   vector.transfer_write {{.+}} {in_bounds = [true, true]} : vector<4x1xf16>, memref<256x256xf16, #hal.descriptor_type<storage_buffer>>
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/cast_type_to_fit_mma.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/cast_type_to_fit_mma.mlir
new file mode 100644
index 0000000..525985f
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/cast_type_to_fit_mma.mlir
@@ -0,0 +1,84 @@
+// RUN: iree-opt --split-input-file --pass-pipeline='builtin.module(func.func(iree-llvmgpu-cast-type-to-fit-mma))' -mlir-print-local-scope %s | FileCheck %s
+
+func.func @matmul_96x64x16_mm(%lhs: vector<96x16xf16>, %rhs: vector<16x64xf16>, %init: vector<96x64xf16>) -> vector<96x64xf16> attributes {
+    mma_schedule = #iree_gpu.mma_schedule<
+      intrinsic = #iree_gpu.mfma_layout<F16_32x32x8_F32>,
+      subgroup_m_count = 1, subgroup_n_count = 1, subgroup_m_tile_count = 3, subgroup_n_tile_count = 2, subgroup_k_tile_count = 2>,
+    workgroup_size = [64, 1, 1]} {
+    %0 = vector.contract {
+      indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
+      iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+      %lhs, %rhs, %init : vector<96x16xf16>, vector<16x64xf16> into vector<96x64xf16>
+  return %0 : vector<96x64xf16>
+}
+
+// CHECK-LABEL: func.func @matmul_96x64x16_mm
+//  CHECK-SAME: (%[[A:.+]]: vector<96x16xf16>, %[[B:.+]]: vector<16x64xf16>, %[[INIT:.+]]: vector<96x64xf16>)
+//       CHECK:   %[[EXT:.+]] = arith.extf %[[INIT]] : vector<96x64xf16> to vector<96x64xf32>
+//       CHECK:   %[[MM:.+]] = vector.contract
+//  CHECK-SAME:       indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>]
+//  CHECK-SAME        iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>
+//  CHECK-SAME:     %[[A]], %[[B]], %[[EXT]] : vector<96x16xf16>, vector<16x64xf16> into vector<96x64xf32>
+//       CHECK:   %[[TRUNC:.+]] = arith.truncf %[[MM]] : vector<96x64xf32> to vector<96x64xf16>
+//       CHECK:   return %[[TRUNC]] : vector<96x64xf16>
+
+// -----
+
+func.func @matmul_96x64x16_mmt(%lhs: vector<96x16xf16>, %rhs: vector<64x16xf16>, %init: vector<96x64xf16>) -> vector<96x64xf16> attributes {
+    mma_schedule = #iree_gpu.mma_schedule<
+      intrinsic = #iree_gpu.mfma_layout<F16_32x32x8_F32>,
+      subgroup_m_count = 1, subgroup_n_count = 1, subgroup_m_tile_count = 3, subgroup_n_tile_count = 2, subgroup_k_tile_count = 2>,
+    workgroup_size = [64, 1, 1]} {
+    %0 = vector.contract {
+      indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
+      iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+      %lhs, %rhs, %init : vector<96x16xf16>, vector<64x16xf16> into vector<96x64xf16>
+  return %0 : vector<96x64xf16>
+}
+
+// CHECK-LABEL: func.func @matmul_96x64x16_mmt
+//  CHECK-SAME: (%[[A:.+]]: vector<96x16xf16>, %[[B:.+]]: vector<64x16xf16>, %[[INIT:.+]]: vector<96x64xf16>)
+//       CHECK:   arith.extf
+//       CHECK:   vector.contract
+//  CHECK-SAME:     : vector<96x16xf16>, vector<64x16xf16> into vector<96x64xf32>
+//       CHECK:   arith.truncf
+
+// -----
+
+func.func @matmul_96x64x16_mm_cannot_divide(%lhs: vector<95x16xf16>, %rhs: vector<16x64xf16>, %init: vector<95x64xf16>) -> vector<95x64xf16> attributes {
+    mma_schedule = #iree_gpu.mma_schedule<
+      intrinsic = #iree_gpu.mfma_layout<F16_32x32x8_F32>,
+      subgroup_m_count = 1, subgroup_n_count = 1, subgroup_m_tile_count = 3, subgroup_n_tile_count = 2, subgroup_k_tile_count = 2>,
+    workgroup_size = [64, 1, 1]} {
+    %0 = vector.contract {
+      indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
+      iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+      %lhs, %rhs, %init : vector<95x16xf16>, vector<16x64xf16> into vector<95x64xf16>
+  return %0 : vector<95x64xf16>
+}
+
+// CHECK-LABEL: func.func @matmul_96x64x16_mm_cannot_divide
+//   CHECK-NOT:   arith.extf
+//       CHECK:   vector.contract
+//  CHECK-SAME:     %{{.+}}, %{{.+}}, %{{.+}} : vector<95x16xf16>, vector<16x64xf16> into vector<95x64xf16>
+//   CHECK-NOT:   arith.truncf
+
+// -----
+
+func.func @matmul_96x64x16_mm_cannot_downcast(%lhs: vector<96x16xf16>, %rhs: vector<16x64xf16>, %init: vector<96x64xf64>) -> vector<96x64xf64> attributes {
+    mma_schedule = #iree_gpu.mma_schedule<
+      intrinsic = #iree_gpu.mfma_layout<F16_32x32x8_F32>,
+      subgroup_m_count = 1, subgroup_n_count = 1, subgroup_m_tile_count = 3, subgroup_n_tile_count = 2, subgroup_k_tile_count = 2>,
+    workgroup_size = [64, 1, 1]} {
+    %0 = vector.contract {
+      indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
+      iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+      %lhs, %rhs, %init : vector<96x16xf16>, vector<16x64xf16> into vector<96x64xf64>
+  return %0 : vector<96x64xf64>
+}
+
+// CHECK-LABEL: func.func @matmul_96x64x16_mm_cannot_downcast
+//   CHECK-NOT:   arith.extf
+//       CHECK:   vector.contract
+//  CHECK-SAME:     %{{.+}}, %{{.+}}, %{{.+}} : vector<96x16xf16>, vector<16x64xf16> into vector<96x64xf64>
+//   CHECK-NOT:   arith.truncf