[rocdl] Allow upcast accumulator to use matrix core (#16527)
This commit add a pass to upcast vector.contract accumulator to target
matrix core instructions, and changed kernel configuration to enable it.
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.cpp
index b09a817..7cecbc9 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.cpp
@@ -9,6 +9,7 @@
#include "llvm/Support/Debug.h"
#include "llvm/Support/MathExtras.h"
#include "llvm/Support/raw_ostream.h"
+#include "mlir/IR/BuiltinTypes.h"
#define DEBUG_TYPE "iree-codegen-gpu-heuristics"
@@ -19,12 +20,20 @@
std::optional<GPUMMASchedule>
deduceMMASchedule(const GPUMatmulShapeType &problem,
ArrayRef<GPUMatmulShapeType> intrinsics,
- const GPUMMAHeuristicSeeds &seeds) {
+ const GPUMMAHeuristicSeeds &seeds, bool canUpcastAcc) {
for (auto [index, intrinsic] : llvm::enumerate(intrinsics)) {
- if (problem.aType != intrinsic.aType || problem.bType != intrinsic.bType ||
- problem.cType != intrinsic.cType) {
+ if (problem.aType != intrinsic.aType || problem.bType != intrinsic.bType) {
continue; // Cannot use this intrinsic for mismatched types
}
+ if (problem.cType != intrinsic.cType) {
+ auto isFpCase =
+ isa<FloatType>(problem.cType) && isa<FloatType>(intrinsic.cType);
+ auto isUpcast = problem.cType.getIntOrFloatBitWidth() <
+ intrinsic.cType.getIntOrFloatBitWidth();
+ if (!(canUpcastAcc && isFpCase && isUpcast)) {
+ continue; // Cannot use this intrinsic if not upcasting
+ }
+ }
if (problem.mSize % intrinsic.mSize != 0 ||
problem.nSize % intrinsic.nSize != 0 ||
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.h b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.h
index 6bf4ff8..f70a428 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.h
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.h
@@ -49,6 +49,6 @@
std::optional<GPUMMASchedule>
deduceMMASchedule(const GPUMatmulShapeType &problem,
ArrayRef<GPUMatmulShapeType> intrinsics,
- const GPUMMAHeuristicSeeds &seeds);
+ const GPUMMAHeuristicSeeds &seeds, bool canUpcastAcc = false);
} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel
index 447a592..f739255 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel
@@ -94,6 +94,7 @@
"ExtractAddressComputationGPUPass.cpp",
"KernelConfig.cpp",
"LLVMGPUCastAddressSpaceFunction.cpp",
+ "LLVMGPUCastTypeToFitMMA.cpp",
"LLVMGPULowerExecutableTarget.cpp",
"LLVMGPUPackSharedMemoryAlloc.cpp",
"LLVMGPUSelectLoweringStrategy.cpp",
@@ -133,6 +134,7 @@
"//compiler/src/iree/compiler/Codegen/TransformStrategies/GPU",
"//compiler/src/iree/compiler/Codegen/Transforms",
"//compiler/src/iree/compiler/Codegen/Utils",
+ "//compiler/src/iree/compiler/Codegen/Utils:VectorOpUtils",
"//compiler/src/iree/compiler/Dialect/Flow/IR",
"//compiler/src/iree/compiler/Dialect/HAL/IR",
"//compiler/src/iree/compiler/Dialect/LinalgExt/IR",
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt
index 517b554..3a6fe83 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt
@@ -79,6 +79,7 @@
"ExtractAddressComputationGPUPass.cpp"
"KernelConfig.cpp"
"LLVMGPUCastAddressSpaceFunction.cpp"
+ "LLVMGPUCastTypeToFitMMA.cpp"
"LLVMGPULowerExecutableTarget.cpp"
"LLVMGPUPackSharedMemoryAlloc.cpp"
"LLVMGPUSelectLoweringStrategy.cpp"
@@ -174,6 +175,7 @@
iree::compiler::Codegen::TransformStrategies::GPU
iree::compiler::Codegen::Transforms
iree::compiler::Codegen::Utils
+ iree::compiler::Codegen::Utils::VectorOpUtils
iree::compiler::Dialect::Flow::IR
iree::compiler::Dialect::HAL::IR
iree::compiler::Dialect::LinalgExt::IR
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
index 04328b8..bb1ace1 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
@@ -347,9 +347,15 @@
/*bestMNTileCountPerSubgroup=*/8,
/*bestKTileCountPerSubgroup=*/2};
+ // First try to find a schedule with an exactly matching intrinsic.
std::optional<GPUMMASchedule> schedule =
deduceMMASchedule(problem, intrinsics, seeds);
if (!schedule) {
+ // Then try again by allowing upcasting accumulator.
+ schedule =
+ deduceMMASchedule(problem, intrinsics, seeds, /*canUpcastAcc=*/true);
+ }
+ if (!schedule) {
return failure();
}
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUCastTypeToFitMMA.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUCastTypeToFitMMA.cpp
new file mode 100644
index 0000000..8ab4ef4
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUCastTypeToFitMMA.cpp
@@ -0,0 +1,128 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h"
+#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUInterfaces.h"
+#include "iree/compiler/Codegen/LLVMGPU/PassDetail.h"
+#include "iree/compiler/Codegen/LLVMGPU/Passes.h"
+#include "iree/compiler/Codegen/Utils/VectorOpUtils.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir::iree_compiler {
+
+namespace {
+
+struct UpcastContractOutput : OpRewritePattern<vector::ContractionOp> {
+ UpcastContractOutput(MLIRContext *context, IREE::GPU::MmaAttr intrinsic,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern(context, benefit), intrinsic(intrinsic) {}
+
+ LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
+ PatternRewriter &rewriter) const override {
+ VectorContractOpInfo opInfo(contractOp);
+ if (opInfo.getOpKind() == VectorContractOpInfo::OpKind::UNKNOWN) {
+ return rewriter.notifyMatchFailure(contractOp, "unhandled contract kind");
+ }
+
+ auto srcCType = dyn_cast<VectorType>(contractOp.getAccType());
+ if (!srcCType) {
+ return rewriter.notifyMatchFailure(contractOp, "unhandled scalar case");
+ }
+ auto srcAType = contractOp.getLhsType();
+ auto srcBType = contractOp.getRhsType();
+
+ auto [dstAElemType, dstBElemType, dstCElemType] =
+ intrinsic.getABCElementTypes();
+ auto [dstM, dstN, dstK] = intrinsic.getMNKShape();
+
+ auto srcCElemFType = dyn_cast<FloatType>(srcCType.getElementType());
+ auto dstCElemFType = dyn_cast<FloatType>(dstCElemType);
+ if (!srcCElemFType || !dstCElemFType ||
+ srcCElemFType.getWidth() >= dstCElemFType.getWidth()) {
+ return rewriter.notifyMatchFailure(
+ contractOp, "unhandled non-floating point or non-upcasting case");
+ }
+
+ if (srcAType.getElementType() != dstAElemType ||
+ srcBType.getElementType() != dstBElemType) {
+ return rewriter.notifyMatchFailure(contractOp, "a/b type mismatch");
+ }
+
+ auto [srcCMIndex, srcCNIndex] = *opInfo.getResultMNIndex();
+ auto [srcAKIndex, srcBKIndex] = *opInfo.getOperandKIndex();
+ int64_t srcM = srcCType.getShape()[srcCMIndex];
+ int64_t srcN = srcCType.getShape()[srcCNIndex];
+ int64_t srcK = srcAType.getShape()[srcAKIndex];
+
+ if (srcM % dstM != 0 || srcN % dstN != 0 || srcK % dstK != 0) {
+ return rewriter.notifyMatchFailure(contractOp, "shape cannot divide");
+ }
+
+ Location loc = contractOp.getLoc();
+ auto dstCType = srcCType.clone(dstCElemFType);
+ auto extOp =
+ rewriter.create<arith::ExtFOp>(loc, dstCType, contractOp.getAcc());
+ auto newContractOp = rewriter.create<vector::ContractionOp>(
+ loc, contractOp.getLhs(), contractOp.getRhs(), extOp,
+ contractOp.getIndexingMaps(), contractOp.getIteratorTypes());
+ rewriter.replaceOpWithNewOp<arith::TruncFOp>(contractOp, srcCType,
+ newContractOp);
+ return success();
+ }
+
+private:
+ IREE::GPU::MmaAttr intrinsic;
+};
+
+struct LLVMGPUCastTypeToFitMMAPass
+ : public LLVMGPUCastTypeToFitMMABase<LLVMGPUCastTypeToFitMMAPass> {
+public:
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<vector::VectorDialect>();
+ registry.insert<arith::ArithDialect>();
+ }
+
+ void runOnOperation() override {
+ auto func = getOperation();
+
+ llvm::StringLiteral scheduleAttrName =
+ IREE::GPU::MMAScheduleAttr::getMnemonic();
+ auto scheduleAttr =
+ func->getAttrOfType<IREE::GPU::MMAScheduleAttr>(scheduleAttrName);
+ if (!scheduleAttr) {
+ DictionaryAttr configDict = getTranslationInfo(func).getConfiguration();
+ scheduleAttr = dyn_cast_or_null<IREE::GPU::MMAScheduleAttr>(
+ configDict.get(scheduleAttrName));
+ }
+ if (!scheduleAttr) {
+ func.emitError() << "missing mma_schedule\n";
+ return signalPassFailure();
+ }
+
+ MLIRContext *context = &getContext();
+ RewritePatternSet patterns(context);
+ patterns.add<UpcastContractOutput>(context, scheduleAttr.getIntrinsic());
+
+ if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns)))) {
+ return signalPassFailure();
+ }
+ }
+};
+} // namespace
+std::unique_ptr<InterfacePass<FunctionOpInterface>>
+createLLVMGPUCastTypeToFitMMAPass() {
+ return std::make_unique<LLVMGPUCastTypeToFitMMAPass>();
+}
+
+} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
index 255115a..549055c 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
@@ -512,6 +512,8 @@
createHoistStaticallyBoundAllocationsPass());
// Vector SIMD -> Vector SIMT
+ nestedModulePM.addNestedPass<func::FuncOp>(
+ createLLVMGPUCastTypeToFitMMAPass());
nestedModulePM.addNestedPass<func::FuncOp>(createLLVMGPUVectorDistribute());
nestedModulePM.addPass(createCanonicalizerPass());
nestedModulePM.addPass(createCSEPass());
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.h b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.h
index 587dfc5..7f84b6e 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.h
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.h
@@ -79,6 +79,11 @@
std::unique_ptr<OperationPass<ModuleOp>>
createLLVMGPUCastAddressSpaceFunction();
+/// Perform type extension/truncation over vector.contract types to target GPU
+/// MMA intrinsics.
+std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
+createLLVMGPUCastTypeToFitMMAPass();
+
std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
createLLVMGPUDistribute();
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.td b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.td
index dc6b452..d0cb05b 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.td
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.td
@@ -49,10 +49,17 @@
def LLVMGPUCastAddressSpaceFunction :
Pass<"iree-llvmgpu-cast-address-space-function", "ModuleOp"> {
- let summary = "Pass to cast";
+ let summary = "Cast address space to generic in CallOp and FuncOp";
let constructor = "mlir::iree_compiler::createLLVMGPUCastAddressSpaceFunction()";
}
+def LLVMGPUCastTypeToFitMMA : InterfacePass<"iree-llvmgpu-cast-type-to-fit-mma",
+ "mlir::FunctionOpInterface"> {
+ let summary = "Perform type extension/truncation over vector.contract types "
+ "to target GPU MMA intrinsics";
+ let constructor = "mlir::iree_compiler::createLLVMGPUCastTypeToFitMMAPass()";
+}
+
def LLVMGPULowerExecutableTarget :
Pass<"iree-llvmgpu-lower-executable-target", "mlir::iree_compiler::IREE::HAL::ExecutableVariantOp"> {
let summary = "Perform lowering of executable target using one of the IREE::HAL::DispatchLoweringPassPipeline";
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel
index 7f7aba6..7e576e5 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel
@@ -30,6 +30,7 @@
"distribute_to_thread.mlir",
"elementwise_pipeline.mlir",
"cast_address_space_function.mlir",
+ "cast_type_to_fit_mma.mlir",
"config_matvec.mlir",
"extract_address_computation_gpu.mlir",
"gpu_set_num_workgroups.mlir",
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt
index 6ecee2e..1920881 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt
@@ -19,6 +19,7 @@
"attention.mlir"
"attention_mfma.mlir"
"cast_address_space_function.mlir"
+ "cast_type_to_fit_mma.mlir"
"config_matvec.mlir"
"conv_pipeline_test.mlir"
"convert_to_nvvm.mlir"
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute.mlir
index a8a1665..5abc043 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute.mlir
@@ -11,19 +11,19 @@
#hal.descriptor_set.binding<1, storage_buffer>
]>
]>
-hal.executable @matmul_256x256x256 {
+hal.executable @matmul_256x256x256_f16_f32 {
hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb", {
target_arch = "gfx940",
mma_intrinsics = [#iree_gpu.mfma_layout<F16_16x16x16_F32>,
#iree_gpu.mfma_layout<F16_32x32x8_F32>]
}>) {
- hal.executable.export @matmul_256x256x256 layout(#pipeline_layout) {
+ hal.executable.export @matmul_256x256x256_f16_f32 layout(#pipeline_layout) {
^bb0(%arg0: !hal.device, %arg1: index, %arg2 : index):
%x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2
hal.return %x, %y, %z : index, index, index
}
builtin.module {
- func.func @matmul_256x256x256() {
+ func.func @matmul_256x256x256_f16_f32() {
%cst = arith.constant 0.000000e+00 : f32
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<256x256xf16>>
@@ -47,15 +47,70 @@
// CHECK-SAME: mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mfma_layout<F16_16x16x16_F32>,
// CHECK-SAME: subgroup_m_count = 2, subgroup_n_count = 2, subgroup_m_tile_count = 2, subgroup_n_tile_count = 4, subgroup_k_tile_count = 2>
-// CHECK-LABEL: hal.executable.export public @matmul_256x256x256
+// CHECK-LABEL: hal.executable.export public @matmul_256x256x256_f16_f32
// CHECK-SAME: subgroup_size = 64
// CHECK-SAME: translation_info = #[[$TRANSLATION]]
// CHECK-SAME: workgroup_size = [128 : index, 2 : index, 1 : index]
-// CHECK-LABEL: func.func @matmul_256x256x256
+// CHECK-LABEL: func.func @matmul_256x256x256_f16_f32
// CHECK: scf.for {{.*}} = %c0 to %c256 step %c32 iter_args({{.*}}) -> (vector<2x4x1x1x1x4xf32>)
// Each subgroup handles 2 * 4 tiles, and for each tile we accumulate 2 times
// along the K dimension. So in total 16 mfma ops.
// CHECK-COUNT-16: amdgpu.mfma {{.*}} {blocks = 1 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32} blgp = none : vector<4xf16>, vector<4xf16>, vector<4xf32>
// CHECK: scf.yield %{{.+}} : vector<2x4x1x1x1x4xf32>
// CHECK-COUNT-8: vector.transfer_write {{.+}} {in_bounds = [true, true]} : vector<4x1xf32>, memref<256x256xf32, #hal.descriptor_type<storage_buffer>>
+
+// -----
+
+#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
+ #hal.descriptor_set.layout<0, bindings = [
+ #hal.descriptor_set.binding<0, storage_buffer>,
+ #hal.descriptor_set.binding<1, storage_buffer>
+ ]>
+]>
+hal.executable @matmul_256x256x256_f16_f16 {
+hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb", {
+ target_arch = "gfx940",
+ mma_intrinsics = [#iree_gpu.mfma_layout<F16_16x16x16_F32>,
+ #iree_gpu.mfma_layout<F16_32x32x8_F32>]
+ }>) {
+ hal.executable.export @matmul_256x256x256_f16_f16 layout(#pipeline_layout) {
+ ^bb0(%arg0: !hal.device, %arg1: index, %arg2 : index):
+ %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2
+ hal.return %x, %y, %z : index, index, index
+ }
+ builtin.module {
+ func.func @matmul_256x256x256_f16_f16() {
+ %cst = arith.constant 0.000000e+00 : f16
+ %c0 = arith.constant 0 : index
+ %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<256x256xf16>>
+ %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<256x256xf16>>
+ %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<256x256xf16>>
+ %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<256x256xf16>> -> tensor<256x256xf16>
+ %4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<256x256xf16>> -> tensor<256x256xf16>
+ %5 = tensor.empty() : tensor<256x256xf16>
+ %6 = linalg.fill ins(%cst : f16) outs(%5 : tensor<256x256xf16>) -> tensor<256x256xf16>
+ %7 = linalg.matmul ins(%3, %4 : tensor<256x256xf16>, tensor<256x256xf16>) outs(%6 : tensor<256x256xf16>) -> tensor<256x256xf16>
+ flow.dispatch.tensor.store %7, %2, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : tensor<256x256xf16> -> !flow.dispatch.tensor<writeonly:tensor<256x256xf16>>
+ return
+ }
+ }
+}
+}
+
+// CHECK: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info<LLVMGPUVectorDistribute
+// CHECK-SAME: mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mfma_layout<F16_16x16x16_F32>,
+// CHECK-SAME: subgroup_m_count = 2, subgroup_n_count = 2, subgroup_m_tile_count = 2, subgroup_n_tile_count = 4, subgroup_k_tile_count = 2>
+
+// CHECK-LABEL: hal.executable.export public @matmul_256x256x256_f16_f16
+// CHECK-SAME: subgroup_size = 64
+// CHECK-SAME: translation_info = #[[$TRANSLATION]]
+// CHECK-SAME: workgroup_size = [128 : index, 2 : index, 1 : index]
+
+// CHECK-LABEL: func.func @matmul_256x256x256_f16_f16
+// CHECK: scf.for {{.*}} = %c0 to %c256 step %c32 iter_args(%[[ARG:.+]] = {{.*}}) -> (vector<2x4x1x1x1x4xf16>)
+// CHECK: arith.extf %[[ARG]] : vector<2x4x1x1x1x4xf16> to vector<2x4x1x1x1x4xf32>
+// CHECK-COUNT-16: amdgpu.mfma {{.*}} {blocks = 1 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32} blgp = none : vector<4xf16>, vector<4xf16>, vector<4xf32>
+// CHECK: %[[TRUNC:.+]] = arith.truncf %157 : vector<2x4x1x1x1x4xf32> to vector<2x4x1x1x1x4xf16>
+// CHECK: scf.yield %[[TRUNC]] : vector<2x4x1x1x1x4xf16>
+// CHECK-COUNT-8: vector.transfer_write {{.+}} {in_bounds = [true, true]} : vector<4x1xf16>, memref<256x256xf16, #hal.descriptor_type<storage_buffer>>
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/cast_type_to_fit_mma.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/cast_type_to_fit_mma.mlir
new file mode 100644
index 0000000..525985f
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/cast_type_to_fit_mma.mlir
@@ -0,0 +1,84 @@
+// RUN: iree-opt --split-input-file --pass-pipeline='builtin.module(func.func(iree-llvmgpu-cast-type-to-fit-mma))' -mlir-print-local-scope %s | FileCheck %s
+
+func.func @matmul_96x64x16_mm(%lhs: vector<96x16xf16>, %rhs: vector<16x64xf16>, %init: vector<96x64xf16>) -> vector<96x64xf16> attributes {
+ mma_schedule = #iree_gpu.mma_schedule<
+ intrinsic = #iree_gpu.mfma_layout<F16_32x32x8_F32>,
+ subgroup_m_count = 1, subgroup_n_count = 1, subgroup_m_tile_count = 3, subgroup_n_tile_count = 2, subgroup_k_tile_count = 2>,
+ workgroup_size = [64, 1, 1]} {
+ %0 = vector.contract {
+ indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+ %lhs, %rhs, %init : vector<96x16xf16>, vector<16x64xf16> into vector<96x64xf16>
+ return %0 : vector<96x64xf16>
+}
+
+// CHECK-LABEL: func.func @matmul_96x64x16_mm
+// CHECK-SAME: (%[[A:.+]]: vector<96x16xf16>, %[[B:.+]]: vector<16x64xf16>, %[[INIT:.+]]: vector<96x64xf16>)
+// CHECK: %[[EXT:.+]] = arith.extf %[[INIT]] : vector<96x64xf16> to vector<96x64xf32>
+// CHECK: %[[MM:.+]] = vector.contract
+// CHECK-SAME: indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>]
+// CHECK-SAME iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>
+// CHECK-SAME: %[[A]], %[[B]], %[[EXT]] : vector<96x16xf16>, vector<16x64xf16> into vector<96x64xf32>
+// CHECK: %[[TRUNC:.+]] = arith.truncf %[[MM]] : vector<96x64xf32> to vector<96x64xf16>
+// CHECK: return %[[TRUNC]] : vector<96x64xf16>
+
+// -----
+
+func.func @matmul_96x64x16_mmt(%lhs: vector<96x16xf16>, %rhs: vector<64x16xf16>, %init: vector<96x64xf16>) -> vector<96x64xf16> attributes {
+ mma_schedule = #iree_gpu.mma_schedule<
+ intrinsic = #iree_gpu.mfma_layout<F16_32x32x8_F32>,
+ subgroup_m_count = 1, subgroup_n_count = 1, subgroup_m_tile_count = 3, subgroup_n_tile_count = 2, subgroup_k_tile_count = 2>,
+ workgroup_size = [64, 1, 1]} {
+ %0 = vector.contract {
+ indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+ %lhs, %rhs, %init : vector<96x16xf16>, vector<64x16xf16> into vector<96x64xf16>
+ return %0 : vector<96x64xf16>
+}
+
+// CHECK-LABEL: func.func @matmul_96x64x16_mmt
+// CHECK-SAME: (%[[A:.+]]: vector<96x16xf16>, %[[B:.+]]: vector<64x16xf16>, %[[INIT:.+]]: vector<96x64xf16>)
+// CHECK: arith.extf
+// CHECK: vector.contract
+// CHECK-SAME: : vector<96x16xf16>, vector<64x16xf16> into vector<96x64xf32>
+// CHECK: arith.truncf
+
+// -----
+
+func.func @matmul_96x64x16_mm_cannot_divide(%lhs: vector<95x16xf16>, %rhs: vector<16x64xf16>, %init: vector<95x64xf16>) -> vector<95x64xf16> attributes {
+ mma_schedule = #iree_gpu.mma_schedule<
+ intrinsic = #iree_gpu.mfma_layout<F16_32x32x8_F32>,
+ subgroup_m_count = 1, subgroup_n_count = 1, subgroup_m_tile_count = 3, subgroup_n_tile_count = 2, subgroup_k_tile_count = 2>,
+ workgroup_size = [64, 1, 1]} {
+ %0 = vector.contract {
+ indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+ %lhs, %rhs, %init : vector<95x16xf16>, vector<16x64xf16> into vector<95x64xf16>
+ return %0 : vector<95x64xf16>
+}
+
+// CHECK-LABEL: func.func @matmul_96x64x16_mm_cannot_divide
+// CHECK-NOT: arith.extf
+// CHECK: vector.contract
+// CHECK-SAME: %{{.+}}, %{{.+}}, %{{.+}} : vector<95x16xf16>, vector<16x64xf16> into vector<95x64xf16>
+// CHECK-NOT: arith.truncf
+
+// -----
+
+func.func @matmul_96x64x16_mm_cannot_downcast(%lhs: vector<96x16xf16>, %rhs: vector<16x64xf16>, %init: vector<96x64xf64>) -> vector<96x64xf64> attributes {
+ mma_schedule = #iree_gpu.mma_schedule<
+ intrinsic = #iree_gpu.mfma_layout<F16_32x32x8_F32>,
+ subgroup_m_count = 1, subgroup_n_count = 1, subgroup_m_tile_count = 3, subgroup_n_tile_count = 2, subgroup_k_tile_count = 2>,
+ workgroup_size = [64, 1, 1]} {
+ %0 = vector.contract {
+ indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+ %lhs, %rhs, %init : vector<96x16xf16>, vector<16x64xf16> into vector<96x64xf64>
+ return %0 : vector<96x64xf64>
+}
+
+// CHECK-LABEL: func.func @matmul_96x64x16_mm_cannot_downcast
+// CHECK-NOT: arith.extf
+// CHECK: vector.contract
+// CHECK-SAME: %{{.+}}, %{{.+}}, %{{.+}} : vector<96x16xf16>, vector<16x64xf16> into vector<96x64xf64>
+// CHECK-NOT: arith.truncf