[LLVMGPU] Embed mma_intrinsic in to_layout and infer contraction's intrinsic from it. (#18842)

To enable faster flash attention, we'd like to be able to force
different vector widths => we'd like different contraction to
potentially have different intrinsics. This PR introduces a way to set
intrinsic information for individual contraction, and have it preserved
until vector distribution.

---------

Signed-off-by: Stanley Winata <stanley.winata@amd.com>
Co-authored-by: Kunwar Grover <groverkss@gmail.com>
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtOps.td b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtOps.td
index 496ee0f..4e40cd8 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtOps.td
+++ b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtOps.td
@@ -40,7 +40,10 @@
   let arguments = (ins
     AnyShaped:$input,
     VectorLayoutInterface:$layout,
-    DefaultValuedAttr<UnitAttr, "false">:$shared_memory_conversion
+    DefaultValuedAttr<UnitAttr, "false">:$shared_memory_conversion,
+    // TODO: Solve cmake IREEGPU and VectorExt cyclic dependency to
+    // change mma_Kind type to be of MMAInterfaceAttr.
+    OptionalAttr<AnyAttr>:$mma_kind
   );
   let results = (outs
     AnyShaped:$output
@@ -48,13 +51,20 @@
   let builders = [
     OpBuilder<(ins "Value":$input,
                    "VectorLayoutInterface":$layout,
+                   "Attribute":$mma_kind_attr,
                    CArg<"bool", "false">:$shared_memory_conversion), [{
+      UnitAttr defaultSharedMemoryConversion;
       if (shared_memory_conversion) {
-        build($_builder, $_state, input.getType(), input, layout, UnitAttr::get(input.getContext()));
-      } else{
-        build($_builder, $_state, input.getType(), input, layout);
+        defaultSharedMemoryConversion = UnitAttr::get(input.getContext());
       }
-    }]>
+      build($_builder, $_state, input.getType(), input, layout, defaultSharedMemoryConversion, mma_kind_attr);
+    }]>,
+  OpBuilder<(ins "Value":$input,
+                "VectorLayoutInterface":$layout), [{
+      UnitAttr defaultSharedMemoryConversion;
+      Attribute emptyIntrinsic;
+      build($_builder, $_state, input.getType(), input, layout, defaultSharedMemoryConversion, emptyIntrinsic);
+    }]>,
   ];
   let extraClassDeclaration = [{
     bool hasTensorSemantics() {
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/VectorExtFoldUnitExtentDims.cpp b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/VectorExtFoldUnitExtentDims.cpp
index 64edf00..fe1f425 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/VectorExtFoldUnitExtentDims.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/VectorExtFoldUnitExtentDims.cpp
@@ -61,9 +61,7 @@
     Value rankReducedValue = rankReducingExtract.value();
     auto newToLayoutOp = rewriter.create<IREE::VectorExt::ToLayoutOp>(
         loc, rankReducedValue.getType(), rankReducedValue, newLayout,
-        toLayoutOp.getSharedMemoryConversion());
-    newToLayoutOp->setDiscardableAttrs(
-        toLayoutOp->getDiscardableAttrDictionary());
+        toLayoutOp.getSharedMemoryConversion(), toLayoutOp.getMmaKindAttr());
 
     // Expand to preserve output shape using insert_slice.
     // Here, since the shape comes from the result of a to_layout op, it will
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/VectorizeIREEVectorExtOps.cpp b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/VectorizeIREEVectorExtOps.cpp
index e2c5c0c..7a1696e 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/VectorizeIREEVectorExtOps.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/VectorizeIREEVectorExtOps.cpp
@@ -48,7 +48,7 @@
 
     // Create the toLayout operation but with vector types instead.
     auto newLayoutOp = rewriter.create<IREE::VectorExt::ToLayoutOp>(
-        loc, newInput, toLayoutOp.getLayout(),
+        loc, newInput, toLayoutOp.getLayout(), toLayoutOp.getMmaKindAttr(),
         toLayoutOp.getSharedMemoryConversion());
 
     // Create the write back to a tensor.
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUCastTypeToFitMMA.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUCastTypeToFitMMA.cpp
index 359c6ff..26fb949 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUCastTypeToFitMMA.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUCastTypeToFitMMA.cpp
@@ -6,8 +6,10 @@
 
 #include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h"
 #include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUInterfaces.h"
+#include "iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtOps.h"
 #include "iree/compiler/Codegen/LLVMGPU/Passes.h"
 #include "iree/compiler/Codegen/Utils/VectorOpUtils.h"
+#include "mlir/Analysis/SliceAnalysis.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/IR/BuiltinTypes.h"
@@ -81,6 +83,34 @@
   }
 };
 
+static void inferMmaKind(vector::ContractionOp contract) {
+  SetVector<Operation *> slice;
+  getForwardSlice(contract.getResult(), &slice);
+
+  // Operations in slice are ordered in topological order, so the first
+  // to_layout operation we encounter is setting the layout.
+  IREE::VectorExt::ToLayoutOp toLayout;
+  for (Operation *op : slice) {
+    auto candidate = dyn_cast<IREE::VectorExt::ToLayoutOp>(op);
+    if (candidate) {
+      toLayout = candidate;
+      break;
+    }
+  }
+
+  if (!toLayout) {
+    return;
+  }
+
+  auto intrinsic =
+      dyn_cast_or_null<IREE::GPU::MmaInterfaceAttr>(toLayout.getMmaKindAttr());
+  if (!intrinsic) {
+    return;
+  }
+
+  contract->setAttr("iree.amdgpu.mma", intrinsic);
+}
+
 struct LLVMGPUCastTypeToFitMMAPass final
     : impl::LLVMGPUCastTypeToFitMMAPassBase<LLVMGPUCastTypeToFitMMAPass> {
   void getDependentDialects(DialectRegistry &registry) const override {
@@ -91,26 +121,15 @@
   void runOnOperation() override {
     auto func = getOperation();
 
-    llvm::StringLiteral scheduleAttrName =
-        IREE::GPU::MMAScheduleAttr::getMnemonic();
-    auto scheduleAttr =
-        func->getAttrOfType<IREE::GPU::MMAScheduleAttr>(scheduleAttrName);
-    if (!scheduleAttr) {
-      DictionaryAttr configDict = getTranslationInfo(func).getConfiguration();
-      if (configDict) {
-        scheduleAttr = dyn_cast_or_null<IREE::GPU::MMAScheduleAttr>(
-            configDict.get(scheduleAttrName));
+    // Set MMA type from config embedded in toLayoutOp of contraction.
+    func.walk([&](vector::ContractionOp contract) {
+      inferMmaKind(contract);
+      if (!contract->hasAttr("iree.amdgpu.mma")) {
+        func.emitOpError("Failed to detect valid to_layout consumer of "
+                         "vector.contract to infer MMA kind.");
+        return signalPassFailure();
       }
-    }
-
-    // Import mma type from dispatch schedule attribute if present.
-    if (scheduleAttr) {
-      func.walk([&](vector::ContractionOp contract) {
-        if (!contract->hasAttr("iree.amdgpu.mma")) {
-          contract->setAttr("iree.amdgpu.mma", scheduleAttr.getIntrinsic());
-        }
-      });
-    }
+    });
 
     MLIRContext *context = &getContext();
     RewritePatternSet patterns(context);
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUConfigureTensorLayouts.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUConfigureTensorLayouts.cpp
index 3f84454..22b570b 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUConfigureTensorLayouts.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUConfigureTensorLayouts.cpp
@@ -56,12 +56,12 @@
 
   // Set layouts for lhs, rhs and acc.
   rewriter.setInsertionPoint(contract);
-  auto layoutedLhs =
-      rewriter.create<IREE::VectorExt::ToLayoutOp>(loc, lhs, aLayout);
-  auto layoutedRhs =
-      rewriter.create<IREE::VectorExt::ToLayoutOp>(loc, rhs, bLayout);
-  auto layoutedAcc =
-      rewriter.create<IREE::VectorExt::ToLayoutOp>(loc, acc, cLayout);
+  auto layoutedLhs = rewriter.create<IREE::VectorExt::ToLayoutOp>(
+      loc, lhs, aLayout, schedule.getIntrinsic());
+  auto layoutedRhs = rewriter.create<IREE::VectorExt::ToLayoutOp>(
+      loc, rhs, bLayout, schedule.getIntrinsic());
+  auto layoutedAcc = rewriter.create<IREE::VectorExt::ToLayoutOp>(
+      loc, acc, cLayout, schedule.getIntrinsic());
 
   // Promote matmul lhs and rhs.
   // TODO: We should read this from the lowering_config on the operation.
@@ -82,7 +82,7 @@
   // Set layout for result.
   rewriter.setInsertionPointAfter(contract);
   auto toLayout = rewriter.create<IREE::VectorExt::ToLayoutOp>(
-      loc, contract->getResult(0), cLayout);
+      loc, contract->getResult(0), cLayout, schedule.getIntrinsic());
   rewriter.replaceAllUsesExcept(contract->getResult(0), toLayout.getResult(),
                                 toLayout);
 
@@ -140,11 +140,11 @@
   // Set layouts for lhs, rhs and acc.
   rewriter.setInsertionPoint(conv);
   auto layoutedLhs = rewriter.create<IREE::VectorExt::ToLayoutOp>(
-      loc, lhs.getType(), lhs, aLayout);
+      loc, lhs, aLayout, schedule.getIntrinsic());
   auto layoutedRhs = rewriter.create<IREE::VectorExt::ToLayoutOp>(
-      loc, rhs.getType(), rhs, bLayout);
+      loc, rhs, bLayout, schedule.getIntrinsic());
   auto layoutedAcc = rewriter.create<IREE::VectorExt::ToLayoutOp>(
-      loc, acc.getType(), acc, cLayout);
+      loc, acc, cLayout, schedule.getIntrinsic());
 
   // Promote matmul lhs and rhs.
   // TODO: We should read this from the lowering_config on the operation.
@@ -160,7 +160,7 @@
   // Set layout for result.
   rewriter.setInsertionPointAfter(conv);
   auto toLayout = rewriter.create<IREE::VectorExt::ToLayoutOp>(
-      loc, conv->getResult(0).getType(), conv->getResult(0), cLayout);
+      loc, conv->getResult(0), cLayout, schedule.getIntrinsic());
   rewriter.replaceAllUsesExcept(conv->getResult(0), toLayout.getResult(),
                                 toLayout);
 
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/cast_type_to_fit_mma.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/cast_type_to_fit_mma.mlir
index 4f97eb6..7da3e14 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/cast_type_to_fit_mma.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/cast_type_to_fit_mma.mlir
@@ -9,7 +9,11 @@
       indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
       iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
       %lhs, %rhs, %init : vector<96x16xf16>, vector<16x64xf16> into vector<96x64xf16>
-  return %0 : vector<96x64xf16>
+    %1 = iree_vector_ext.to_layout %0 to layout(#iree_vector_ext.nested_layout<subgroup_tile = [1, 1], batch_tile = [3, 2],
+                                      outer_tile = [4, 1], thread_tile = [2, 32], element_tile = [4, 1],
+                                      subgroup_strides = [0, 0], thread_strides = [32, 1]>)
+                                      {mma_kind = #iree_gpu.mma_layout<MFMA_F32_32x32x8_F16>} : vector<96x64xf16>
+  return %1 : vector<96x64xf16>
 }
 
 // CHECK-LABEL: func.func @mfma_matmul_96x64x16_mm
@@ -21,7 +25,6 @@
 //  CHECK-SAME:     %[[A]], %[[B]], %[[EXT]]
 //  CHECK-SAME:     vector<96x16xf16>, vector<16x64xf16> into vector<96x64xf32>
 //       CHECK:   %[[TRUNC:.+]] = arith.truncf %[[MM]] : vector<96x64xf32> to vector<96x64xf16>
-//       CHECK:   return %[[TRUNC]] : vector<96x64xf16>
 
 // -----
 
@@ -34,7 +37,11 @@
       indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
       iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
       %lhs, %rhs, %init : vector<96x16xf16>, vector<64x16xf16> into vector<96x64xf16>
-  return %0 : vector<96x64xf16>
+    %1 = iree_vector_ext.to_layout %0 to layout(#iree_vector_ext.nested_layout<subgroup_tile = [1, 1], batch_tile = [3, 2],
+                                      outer_tile = [4, 1], thread_tile = [2, 32], element_tile = [4, 1],
+                                      subgroup_strides = [0, 0], thread_strides = [32, 1]>)
+                                      {mma_kind = #iree_gpu.mma_layout<MFMA_F32_32x32x8_F16>} : vector<96x64xf16>
+  return %1 : vector<96x64xf16>
 }
 
 // CHECK-LABEL: func.func @mfma_matmul_96x64x16_mmt
@@ -55,7 +62,11 @@
       indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
       iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
       %lhs, %rhs, %init : vector<96x16xf16>, vector<16x64xf16> into vector<96x64xf64>
-  return %0 : vector<96x64xf64>
+    %1 = iree_vector_ext.to_layout %0 to layout(#iree_vector_ext.nested_layout<subgroup_tile = [1, 1], batch_tile = [3, 2],
+                                      outer_tile = [4, 1], thread_tile = [2, 32], element_tile = [4, 1],
+                                      subgroup_strides = [0, 0], thread_strides = [32, 1]>)
+                                      {mma_kind = #iree_gpu.mma_layout<MFMA_F32_32x32x8_F16>} : vector<96x64xf64>
+  return %1 : vector<96x64xf64>
 }
 
 // CHECK-LABEL: func.func @mfma_matmul_96x64x16_mm_cannot_downcast
@@ -75,7 +86,11 @@
       indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
       iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
       %lhs, %rhs, %init : vector<48x32xf16>, vector<32x32xf16> into vector<48x32xf16>
-  return %0 : vector<48x32xf16>
+    %1 = iree_vector_ext.to_layout %0 to layout(#iree_vector_ext.nested_layout<subgroup_tile = [1, 1], batch_tile = [3, 2],
+                                      outer_tile = [8, 1], thread_tile = [2, 16], element_tile = [1, 1],
+                                      subgroup_strides = [0, 0], thread_strides = [16, 1]>)
+                                      {mma_kind = #iree_gpu.mma_layout<WMMA_F32_16x16x16_F16>} : vector<48x32xf16>
+  return %1 : vector<48x32xf16>
 }
 
 // CHECK-LABEL: func.func @wmma_matmul_48x32x32_mm
@@ -87,7 +102,36 @@
 //  CHECK-SAME:     %[[A]], %[[B]], %[[EXT]]
 //  CHECK-SAME:     vector<48x32xf16>, vector<32x32xf16> into vector<48x32xf32>
 //       CHECK:   %[[TRUNC:.+]] = arith.truncf %[[MM]] : vector<48x32xf32> to vector<48x32xf16>
-//       CHECK:   return %[[TRUNC]] : vector<48x32xf16>
+
+// -----
+
+// This tests cast_type_to_fit_mma works on contract where intrinsic is set by to_layout.
+// "iree.amdgpu.mma" will be generated from the "intrinsic" attribute of to_layout.
+// this also shows that we can overwrite default intrinsics if explicitly set.
+
+func.func @to_layout_config_matmul_96x64x16_mm(%lhs: vector<96x16xf16>, %rhs: vector<16x64xf16>, %init: vector<96x64xf16>) -> vector<96x64xf16> attributes {
+    mma_schedule = #iree_gpu.mma_schedule<
+      intrinsic = #iree_gpu.mma_layout<MFMA_F32_32x32x8_F16>,
+      subgroup_m_count = 1, subgroup_n_count = 1>,
+    workgroup_size = [64, 1, 1]} {
+    %0 = vector.contract {
+      indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
+      iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+      %lhs, %rhs, %init : vector<96x16xf16>, vector<16x64xf16> into vector<96x64xf16>
+    %1 = iree_vector_ext.to_layout %0 to layout(#iree_vector_ext.nested_layout<subgroup_tile = [1, 1], batch_tile = [6, 4],
+                                      outer_tile = [1, 1], thread_tile = [16, 4], element_tile = [1, 4],
+                                      subgroup_strides = [0, 0], thread_strides = [1, 16]>)
+                                      {mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>} : vector<96x64xf16>
+  return %1 : vector<96x64xf16>
+}
+
+// CHECK-LABEL: func.func @to_layout_config_matmul_96x64x16_mm
+//  CHECK-SAME: (%[[A:.+]]: vector<96x16xf16>, %[[B:.+]]: vector<16x64xf16>, %[[INIT:.+]]: vector<96x64xf16>)
+//       CHECK:   arith.extf
+//       CHECK:   vector.contract
+//  CHECK-SAME:     {iree.amdgpu.mma = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>}
+//  CHECK-SAME:     : vector<96x16xf16>, vector<16x64xf16> into vector<96x64xf32>
+//       CHECK:   arith.truncf
 
 // -----
 
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/configure_tensor_layout.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/configure_tensor_layout.mlir
index b9ecc4a..8b761f8 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/configure_tensor_layout.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/configure_tensor_layout.mlir
@@ -42,9 +42,9 @@
 
 // CHECK-LABEL: func.func @matmul_96x64x16_mfma
 
-// CHECK-DAG: %[[LHS:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED]]) {shared_memory_conversion}
-// CHECK-DAG: %[[RHS:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED1]]) {shared_memory_conversion}
-// CHECK-DAG: %[[ACC:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED2]])
+// CHECK-DAG: %[[LHS:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED]]) {mma_kind = #iree_gpu.mma_layout<MFMA_F32_32x32x8_F16>, shared_memory_conversion}
+// CHECK-DAG: %[[RHS:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED1]]) {mma_kind = #iree_gpu.mma_layout<MFMA_F32_32x32x8_F16>, shared_memory_conversion}
+// CHECK-DAG: %[[ACC:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED2]]) {mma_kind = #iree_gpu.mma_layout<MFMA_F32_32x32x8_F16>}
 // CHECK: linalg.generic
 // CHECK-SAME: ins(%[[LHS]], %[[RHS]]
 // CHECK-SAME: outs(%[[ACC]]
@@ -93,9 +93,9 @@
 
 // CHECK-LABEL: func.func @matmul_96x64x16_wmma
 
-// CHECK-DAG: %[[LHS:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED]]) {shared_memory_conversion}
-// CHECK-DAG: %[[RHS:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED1]]) {shared_memory_conversion}
-// CHECK-DAG: %[[ACC:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED2]])
+// CHECK-DAG: %[[LHS:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED]]) {mma_kind = #iree_gpu.mma_layout<WMMA_F32_16x16x16_F16>, shared_memory_conversion}
+// CHECK-DAG: %[[RHS:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED1]]) {mma_kind = #iree_gpu.mma_layout<WMMA_F32_16x16x16_F16>, shared_memory_conversion}
+// CHECK-DAG: %[[ACC:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED2]]) {mma_kind = #iree_gpu.mma_layout<WMMA_F32_16x16x16_F16>}
 // CHECK: linalg.generic
 // CHECK-SAME: ins(%[[LHS]], %[[RHS]]
 // CHECK-SAME: outs(%[[ACC]]
@@ -144,9 +144,9 @@
 
 // CHECK-LABEL: func.func @matmul_128x64x16_multi_subgroup
 
-// CHECK-DAG: %[[LHS:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED]]) {shared_memory_conversion}
-// CHECK-DAG: %[[RHS:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED1]]) {shared_memory_conversion}
-// CHECK-DAG: %[[ACC:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED2]])
+// CHECK-DAG: %[[LHS:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED]]) {mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>, shared_memory_conversion}
+// CHECK-DAG: %[[RHS:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED1]]) {mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>, shared_memory_conversion}
+// CHECK-DAG: %[[ACC:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED2]]) {mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>}
 // CHECK: linalg.generic
 // CHECK-SAME: ins(%[[LHS]], %[[RHS]]
 // CHECK-SAME: outs(%[[ACC]]
@@ -195,9 +195,9 @@
 // CHECK-DAG: #[[$NESTED2:.+]] = #iree_vector_ext.nested_layout<subgroup_tile = [2, 1, 2, 1], batch_tile = [4, 1, 4, 1], outer_tile = [1, 1, 1, 1], thread_tile = [1, 4, 1, 16], element_tile = [1, 4, 1, 1], subgroup_strides = [2, 0, 1, 0], thread_strides = [0, 16, 0, 1]>
 // CHECK-LABEL: func.func @packed_matmul_128x128x128
 
-// CHECK-DAG: %[[LHS:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED]]) {shared_memory_conversion}
-// CHECK-DAG: %[[RHS:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED1]]) {shared_memory_conversion}
-// CHECK-DAG: %[[ACC:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED2]])
+// CHECK-DAG: %[[LHS:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED]]) {mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>, shared_memory_conversion}
+// CHECK-DAG: %[[RHS:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED1]]) {mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>, shared_memory_conversion}
+// CHECK-DAG: %[[ACC:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED2]]) {mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>}
 // CHECK: linalg.generic
 // CHECK-SAME: ins(%[[LHS]], %[[RHS]]
 // CHECK-SAME: outs(%[[ACC]]