[LLVMGPU] Support CastTypeToFitMMA on TransformDialect script. (#17884)
Previously CastTypeToFitMMA relies on the `mma_schedule` attribute on
the function's translationInfo to obtain information about
`iree.amdgpu.mma`(intrnisic selected).
While this is fine for C++ pipeline, the IR generated from
TransformDialect script do not have such information. Instead IR
generated in TD script typically annotate the
`iree.amdgpu.mma`(intrnisic selected) directly on the
vector.contractOps.
This is a crucial part of enabling performant the latest attention
compilation pipeline (with online attn + transpose fusion) which is
based on TD scripts.
---------
Co-authored-by: Kunwar Grover <groverkss@gmail.com>
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUCastTypeToFitMMA.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUCastTypeToFitMMA.cpp
index 84725a0..621430b 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUCastTypeToFitMMA.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUCastTypeToFitMMA.cpp
@@ -23,11 +23,8 @@
namespace {
-struct UpcastContractOutput : OpRewritePattern<vector::ContractionOp> {
- UpcastContractOutput(MLIRContext *context,
- IREE::GPU::MmaInterfaceAttr intrinsic,
- PatternBenefit benefit = 1)
- : OpRewritePattern(context, benefit), intrinsic(intrinsic) {}
+struct UpcastContractOutput final : OpRewritePattern<vector::ContractionOp> {
+ using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
PatternRewriter &rewriter) const override {
@@ -40,6 +37,12 @@
auto srcAType = contractOp.getLhsType();
auto srcBType = contractOp.getRhsType();
+ auto intrinsic = contractOp->getAttrOfType<IREE::GPU::MmaInterfaceAttr>(
+ "iree.amdgpu.mma");
+ if (!intrinsic) {
+ return rewriter.notifyMatchFailure(
+ contractOp, "could not find iree.amdgpu.mma attribute on contract");
+ }
auto [dstAElemType, dstBElemType, dstCElemType] =
intrinsic.getABCElementTypes();
@@ -67,9 +70,6 @@
newContractOp);
return success();
}
-
-private:
- IREE::GPU::MmaInterfaceAttr intrinsic;
};
struct LLVMGPUCastTypeToFitMMAPass
@@ -89,17 +89,24 @@
func->getAttrOfType<IREE::GPU::MMAScheduleAttr>(scheduleAttrName);
if (!scheduleAttr) {
DictionaryAttr configDict = getTranslationInfo(func).getConfiguration();
- scheduleAttr = dyn_cast_or_null<IREE::GPU::MMAScheduleAttr>(
- configDict.get(scheduleAttrName));
+ if (configDict) {
+ scheduleAttr = dyn_cast_or_null<IREE::GPU::MMAScheduleAttr>(
+ configDict.get(scheduleAttrName));
+ }
}
- if (!scheduleAttr) {
- func.emitError() << "missing mma_schedule\n";
- return signalPassFailure();
+
+ // Import mma type from dispatch schedule attribute if present.
+ if (scheduleAttr) {
+ func.walk([&](vector::ContractionOp contract) {
+ if (!contract->hasAttr("iree.amdgpu.mma")) {
+ contract->setAttr("iree.amdgpu.mma", scheduleAttr.getIntrinsic());
+ }
+ });
}
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);
- patterns.add<UpcastContractOutput>(context, scheduleAttr.getIntrinsic());
+ patterns.add<UpcastContractOutput>(context);
if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns)))) {
return signalPassFailure();
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/cast_type_to_fit_mma.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/cast_type_to_fit_mma.mlir
index 1f07bc7..21eb880 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/cast_type_to_fit_mma.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/cast_type_to_fit_mma.mlir
@@ -86,3 +86,31 @@
// CHECK-SAME: %[[A]], %[[B]], %[[EXT]] : vector<48x32xf16>, vector<32x32xf16> into vector<48x32xf32>
// CHECK: %[[TRUNC:.+]] = arith.truncf %[[MM]] : vector<48x32xf32> to vector<48x32xf16>
// CHECK: return %[[TRUNC]] : vector<48x32xf16>
+
+// -----
+
+// This tests cast_type_to_fit_mma works on IR structure coming out of transform_dialect.
+
+// IR generated in transform_dialect is different from the one in C++ pipeline.
+// it will not have mma_schedule on function attributes, but instead it will have
+// "iree.amdgpu.mma" attribute directly on vector.contract.
+
+func.func @transform_dialect_mfma_matmul_96x64x16(%lhs: vector<96x16xf16>, %rhs: vector<16x64xf16>, %init: vector<96x64xf16>) -> vector<96x64xf16> attributes {translation_info = #iree_codegen.translation_info<None workgroup_size = [64, 1, 1] subgroup_size = 64>} {
+ %0 = vector.contract {
+ indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+ %lhs, %rhs, %init
+ {iree.amdgpu.mma = #iree_gpu.mma_layout<MFMA_F16_16x16x16_F32>}
+ : vector<96x16xf16>, vector<16x64xf16> into vector<96x64xf16>
+ return %0 : vector<96x64xf16>
+}
+
+// CHECK-LABEL: func.func @transform_dialect_mfma_matmul_96x64x16
+// CHECK-SAME: (%[[A:.+]]: vector<96x16xf16>, %[[B:.+]]: vector<16x64xf16>, %[[INIT:.+]]: vector<96x64xf16>)
+// CHECK: %[[EXT:.+]] = arith.extf %[[INIT]] : vector<96x64xf16> to vector<96x64xf32>
+// CHECK: %[[MM:.+]] = vector.contract
+// CHECK-SAME: indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>]
+// CHECK-SAME iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>
+// CHECK-SAME: %[[A]], %[[B]], %[[EXT]] : vector<96x16xf16>, vector<16x64xf16> into vector<96x64xf32>
+// CHECK: %[[TRUNC:.+]] = arith.truncf %[[MM]] : vector<96x64xf32> to vector<96x64xf16>
+// CHECK: return %[[TRUNC]] : vector<96x64xf16>