[LLVMGPU] Support CastTypeToFitMMA on TransformDialect script.  (#17884)

Previously CastTypeToFitMMA relies on the `mma_schedule` attribute on
the function's translationInfo to obtain information about
`iree.amdgpu.mma`(intrnisic selected).

While this is fine for C++ pipeline, the IR generated from
TransformDialect script do not have such information. Instead IR
generated in TD script typically annotate the
`iree.amdgpu.mma`(intrnisic selected) directly on the
vector.contractOps.

This is a crucial part of enabling performant the latest attention
compilation pipeline (with online attn + transpose fusion) which is
based on TD scripts.

---------

Co-authored-by: Kunwar Grover <groverkss@gmail.com>
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUCastTypeToFitMMA.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUCastTypeToFitMMA.cpp
index 84725a0..621430b 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUCastTypeToFitMMA.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUCastTypeToFitMMA.cpp
@@ -23,11 +23,8 @@
 
 namespace {
 
-struct UpcastContractOutput : OpRewritePattern<vector::ContractionOp> {
-  UpcastContractOutput(MLIRContext *context,
-                       IREE::GPU::MmaInterfaceAttr intrinsic,
-                       PatternBenefit benefit = 1)
-      : OpRewritePattern(context, benefit), intrinsic(intrinsic) {}
+struct UpcastContractOutput final : OpRewritePattern<vector::ContractionOp> {
+  using OpRewritePattern::OpRewritePattern;
 
   LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
                                 PatternRewriter &rewriter) const override {
@@ -40,6 +37,12 @@
     auto srcAType = contractOp.getLhsType();
     auto srcBType = contractOp.getRhsType();
 
+    auto intrinsic = contractOp->getAttrOfType<IREE::GPU::MmaInterfaceAttr>(
+        "iree.amdgpu.mma");
+    if (!intrinsic) {
+      return rewriter.notifyMatchFailure(
+          contractOp, "could not find iree.amdgpu.mma attribute on contract");
+    }
     auto [dstAElemType, dstBElemType, dstCElemType] =
         intrinsic.getABCElementTypes();
 
@@ -67,9 +70,6 @@
                                                  newContractOp);
     return success();
   }
-
-private:
-  IREE::GPU::MmaInterfaceAttr intrinsic;
 };
 
 struct LLVMGPUCastTypeToFitMMAPass
@@ -89,17 +89,24 @@
         func->getAttrOfType<IREE::GPU::MMAScheduleAttr>(scheduleAttrName);
     if (!scheduleAttr) {
       DictionaryAttr configDict = getTranslationInfo(func).getConfiguration();
-      scheduleAttr = dyn_cast_or_null<IREE::GPU::MMAScheduleAttr>(
-          configDict.get(scheduleAttrName));
+      if (configDict) {
+        scheduleAttr = dyn_cast_or_null<IREE::GPU::MMAScheduleAttr>(
+            configDict.get(scheduleAttrName));
+      }
     }
-    if (!scheduleAttr) {
-      func.emitError() << "missing mma_schedule\n";
-      return signalPassFailure();
+
+    // Import mma type from dispatch schedule attribute if present.
+    if (scheduleAttr) {
+      func.walk([&](vector::ContractionOp contract) {
+        if (!contract->hasAttr("iree.amdgpu.mma")) {
+          contract->setAttr("iree.amdgpu.mma", scheduleAttr.getIntrinsic());
+        }
+      });
     }
 
     MLIRContext *context = &getContext();
     RewritePatternSet patterns(context);
-    patterns.add<UpcastContractOutput>(context, scheduleAttr.getIntrinsic());
+    patterns.add<UpcastContractOutput>(context);
 
     if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns)))) {
       return signalPassFailure();
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/cast_type_to_fit_mma.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/cast_type_to_fit_mma.mlir
index 1f07bc7..21eb880 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/cast_type_to_fit_mma.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/cast_type_to_fit_mma.mlir
@@ -86,3 +86,31 @@
 //  CHECK-SAME:     %[[A]], %[[B]], %[[EXT]] : vector<48x32xf16>, vector<32x32xf16> into vector<48x32xf32>
 //       CHECK:   %[[TRUNC:.+]] = arith.truncf %[[MM]] : vector<48x32xf32> to vector<48x32xf16>
 //       CHECK:   return %[[TRUNC]] : vector<48x32xf16>
+
+// -----
+
+// This tests cast_type_to_fit_mma works on IR structure coming out of transform_dialect.
+
+// IR generated in transform_dialect is different from the one in C++ pipeline.
+// it will not have mma_schedule on function attributes, but instead it will have
+// "iree.amdgpu.mma" attribute directly on vector.contract.
+
+func.func @transform_dialect_mfma_matmul_96x64x16(%lhs: vector<96x16xf16>, %rhs: vector<16x64xf16>, %init: vector<96x64xf16>) -> vector<96x64xf16> attributes {translation_info = #iree_codegen.translation_info<None workgroup_size = [64, 1, 1] subgroup_size = 64>} {
+    %0 = vector.contract {
+      indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
+      iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+      %lhs, %rhs, %init
+      {iree.amdgpu.mma = #iree_gpu.mma_layout<MFMA_F16_16x16x16_F32>}
+      : vector<96x16xf16>, vector<16x64xf16> into vector<96x64xf16>
+  return %0 : vector<96x64xf16>
+}
+
+// CHECK-LABEL: func.func @transform_dialect_mfma_matmul_96x64x16
+//  CHECK-SAME: (%[[A:.+]]: vector<96x16xf16>, %[[B:.+]]: vector<16x64xf16>, %[[INIT:.+]]: vector<96x64xf16>)
+//       CHECK:   %[[EXT:.+]] = arith.extf %[[INIT]] : vector<96x64xf16> to vector<96x64xf32>
+//       CHECK:   %[[MM:.+]] = vector.contract
+//  CHECK-SAME:       indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>]
+//  CHECK-SAME        iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>
+//  CHECK-SAME:     %[[A]], %[[B]], %[[EXT]] : vector<96x16xf16>, vector<16x64xf16> into vector<96x64xf32>
+//       CHECK:   %[[TRUNC:.+]] = arith.truncf %[[MM]] : vector<96x64xf32> to vector<96x64xf16>
+//       CHECK:   return %[[TRUNC]] : vector<96x64xf16>