[LLVMGPU] Embed mma_intrinsic in to_layout and infer contraction's intrinsic from it. (#18842)
To enable faster flash attention, we'd like to be able to force
different vector widths => we'd like different contraction to
potentially have different intrinsics. This PR introduces a way to set
intrinsic information for individual contraction, and have it preserved
until vector distribution.
---------
Signed-off-by: Stanley Winata <stanley.winata@amd.com>
Co-authored-by: Kunwar Grover <groverkss@gmail.com>
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtOps.td b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtOps.td
index 496ee0f..4e40cd8 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtOps.td
+++ b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtOps.td
@@ -40,7 +40,10 @@
let arguments = (ins
AnyShaped:$input,
VectorLayoutInterface:$layout,
- DefaultValuedAttr<UnitAttr, "false">:$shared_memory_conversion
+ DefaultValuedAttr<UnitAttr, "false">:$shared_memory_conversion,
+ // TODO: Solve cmake IREEGPU and VectorExt cyclic dependency to
+ // change mma_Kind type to be of MMAInterfaceAttr.
+ OptionalAttr<AnyAttr>:$mma_kind
);
let results = (outs
AnyShaped:$output
@@ -48,13 +51,20 @@
let builders = [
OpBuilder<(ins "Value":$input,
"VectorLayoutInterface":$layout,
+ "Attribute":$mma_kind_attr,
CArg<"bool", "false">:$shared_memory_conversion), [{
+ UnitAttr defaultSharedMemoryConversion;
if (shared_memory_conversion) {
- build($_builder, $_state, input.getType(), input, layout, UnitAttr::get(input.getContext()));
- } else{
- build($_builder, $_state, input.getType(), input, layout);
+ defaultSharedMemoryConversion = UnitAttr::get(input.getContext());
}
- }]>
+ build($_builder, $_state, input.getType(), input, layout, defaultSharedMemoryConversion, mma_kind_attr);
+ }]>,
+ OpBuilder<(ins "Value":$input,
+ "VectorLayoutInterface":$layout), [{
+ UnitAttr defaultSharedMemoryConversion;
+ Attribute emptyIntrinsic;
+ build($_builder, $_state, input.getType(), input, layout, defaultSharedMemoryConversion, emptyIntrinsic);
+ }]>,
];
let extraClassDeclaration = [{
bool hasTensorSemantics() {
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/VectorExtFoldUnitExtentDims.cpp b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/VectorExtFoldUnitExtentDims.cpp
index 64edf00..fe1f425 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/VectorExtFoldUnitExtentDims.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/VectorExtFoldUnitExtentDims.cpp
@@ -61,9 +61,7 @@
Value rankReducedValue = rankReducingExtract.value();
auto newToLayoutOp = rewriter.create<IREE::VectorExt::ToLayoutOp>(
loc, rankReducedValue.getType(), rankReducedValue, newLayout,
- toLayoutOp.getSharedMemoryConversion());
- newToLayoutOp->setDiscardableAttrs(
- toLayoutOp->getDiscardableAttrDictionary());
+ toLayoutOp.getSharedMemoryConversion(), toLayoutOp.getMmaKindAttr());
// Expand to preserve output shape using insert_slice.
// Here, since the shape comes from the result of a to_layout op, it will
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/VectorizeIREEVectorExtOps.cpp b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/VectorizeIREEVectorExtOps.cpp
index e2c5c0c..7a1696e 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/VectorizeIREEVectorExtOps.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/VectorizeIREEVectorExtOps.cpp
@@ -48,7 +48,7 @@
// Create the toLayout operation but with vector types instead.
auto newLayoutOp = rewriter.create<IREE::VectorExt::ToLayoutOp>(
- loc, newInput, toLayoutOp.getLayout(),
+ loc, newInput, toLayoutOp.getLayout(), toLayoutOp.getMmaKindAttr(),
toLayoutOp.getSharedMemoryConversion());
// Create the write back to a tensor.
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUCastTypeToFitMMA.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUCastTypeToFitMMA.cpp
index 359c6ff..26fb949 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUCastTypeToFitMMA.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUCastTypeToFitMMA.cpp
@@ -6,8 +6,10 @@
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h"
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUInterfaces.h"
+#include "iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtOps.h"
#include "iree/compiler/Codegen/LLVMGPU/Passes.h"
#include "iree/compiler/Codegen/Utils/VectorOpUtils.h"
+#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/BuiltinTypes.h"
@@ -81,6 +83,34 @@
}
};
+static void inferMmaKind(vector::ContractionOp contract) {
+ SetVector<Operation *> slice;
+ getForwardSlice(contract.getResult(), &slice);
+
+ // Operations in slice are ordered in topological order, so the first
+ // to_layout operation we encounter is setting the layout.
+ IREE::VectorExt::ToLayoutOp toLayout;
+ for (Operation *op : slice) {
+ auto candidate = dyn_cast<IREE::VectorExt::ToLayoutOp>(op);
+ if (candidate) {
+ toLayout = candidate;
+ break;
+ }
+ }
+
+ if (!toLayout) {
+ return;
+ }
+
+ auto intrinsic =
+ dyn_cast_or_null<IREE::GPU::MmaInterfaceAttr>(toLayout.getMmaKindAttr());
+ if (!intrinsic) {
+ return;
+ }
+
+ contract->setAttr("iree.amdgpu.mma", intrinsic);
+}
+
struct LLVMGPUCastTypeToFitMMAPass final
: impl::LLVMGPUCastTypeToFitMMAPassBase<LLVMGPUCastTypeToFitMMAPass> {
void getDependentDialects(DialectRegistry ®istry) const override {
@@ -91,26 +121,15 @@
void runOnOperation() override {
auto func = getOperation();
- llvm::StringLiteral scheduleAttrName =
- IREE::GPU::MMAScheduleAttr::getMnemonic();
- auto scheduleAttr =
- func->getAttrOfType<IREE::GPU::MMAScheduleAttr>(scheduleAttrName);
- if (!scheduleAttr) {
- DictionaryAttr configDict = getTranslationInfo(func).getConfiguration();
- if (configDict) {
- scheduleAttr = dyn_cast_or_null<IREE::GPU::MMAScheduleAttr>(
- configDict.get(scheduleAttrName));
+ // Set MMA type from config embedded in toLayoutOp of contraction.
+ func.walk([&](vector::ContractionOp contract) {
+ inferMmaKind(contract);
+ if (!contract->hasAttr("iree.amdgpu.mma")) {
+ func.emitOpError("Failed to detect valid to_layout consumer of "
+ "vector.contract to infer MMA kind.");
+ return signalPassFailure();
}
- }
-
- // Import mma type from dispatch schedule attribute if present.
- if (scheduleAttr) {
- func.walk([&](vector::ContractionOp contract) {
- if (!contract->hasAttr("iree.amdgpu.mma")) {
- contract->setAttr("iree.amdgpu.mma", scheduleAttr.getIntrinsic());
- }
- });
- }
+ });
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUConfigureTensorLayouts.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUConfigureTensorLayouts.cpp
index 3f84454..22b570b 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUConfigureTensorLayouts.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUConfigureTensorLayouts.cpp
@@ -56,12 +56,12 @@
// Set layouts for lhs, rhs and acc.
rewriter.setInsertionPoint(contract);
- auto layoutedLhs =
- rewriter.create<IREE::VectorExt::ToLayoutOp>(loc, lhs, aLayout);
- auto layoutedRhs =
- rewriter.create<IREE::VectorExt::ToLayoutOp>(loc, rhs, bLayout);
- auto layoutedAcc =
- rewriter.create<IREE::VectorExt::ToLayoutOp>(loc, acc, cLayout);
+ auto layoutedLhs = rewriter.create<IREE::VectorExt::ToLayoutOp>(
+ loc, lhs, aLayout, schedule.getIntrinsic());
+ auto layoutedRhs = rewriter.create<IREE::VectorExt::ToLayoutOp>(
+ loc, rhs, bLayout, schedule.getIntrinsic());
+ auto layoutedAcc = rewriter.create<IREE::VectorExt::ToLayoutOp>(
+ loc, acc, cLayout, schedule.getIntrinsic());
// Promote matmul lhs and rhs.
// TODO: We should read this from the lowering_config on the operation.
@@ -82,7 +82,7 @@
// Set layout for result.
rewriter.setInsertionPointAfter(contract);
auto toLayout = rewriter.create<IREE::VectorExt::ToLayoutOp>(
- loc, contract->getResult(0), cLayout);
+ loc, contract->getResult(0), cLayout, schedule.getIntrinsic());
rewriter.replaceAllUsesExcept(contract->getResult(0), toLayout.getResult(),
toLayout);
@@ -140,11 +140,11 @@
// Set layouts for lhs, rhs and acc.
rewriter.setInsertionPoint(conv);
auto layoutedLhs = rewriter.create<IREE::VectorExt::ToLayoutOp>(
- loc, lhs.getType(), lhs, aLayout);
+ loc, lhs, aLayout, schedule.getIntrinsic());
auto layoutedRhs = rewriter.create<IREE::VectorExt::ToLayoutOp>(
- loc, rhs.getType(), rhs, bLayout);
+ loc, rhs, bLayout, schedule.getIntrinsic());
auto layoutedAcc = rewriter.create<IREE::VectorExt::ToLayoutOp>(
- loc, acc.getType(), acc, cLayout);
+ loc, acc, cLayout, schedule.getIntrinsic());
// Promote matmul lhs and rhs.
// TODO: We should read this from the lowering_config on the operation.
@@ -160,7 +160,7 @@
// Set layout for result.
rewriter.setInsertionPointAfter(conv);
auto toLayout = rewriter.create<IREE::VectorExt::ToLayoutOp>(
- loc, conv->getResult(0).getType(), conv->getResult(0), cLayout);
+ loc, conv->getResult(0), cLayout, schedule.getIntrinsic());
rewriter.replaceAllUsesExcept(conv->getResult(0), toLayout.getResult(),
toLayout);
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/cast_type_to_fit_mma.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/cast_type_to_fit_mma.mlir
index 4f97eb6..7da3e14 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/cast_type_to_fit_mma.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/cast_type_to_fit_mma.mlir
@@ -9,7 +9,11 @@
indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
%lhs, %rhs, %init : vector<96x16xf16>, vector<16x64xf16> into vector<96x64xf16>
- return %0 : vector<96x64xf16>
+ %1 = iree_vector_ext.to_layout %0 to layout(#iree_vector_ext.nested_layout<subgroup_tile = [1, 1], batch_tile = [3, 2],
+ outer_tile = [4, 1], thread_tile = [2, 32], element_tile = [4, 1],
+ subgroup_strides = [0, 0], thread_strides = [32, 1]>)
+ {mma_kind = #iree_gpu.mma_layout<MFMA_F32_32x32x8_F16>} : vector<96x64xf16>
+ return %1 : vector<96x64xf16>
}
// CHECK-LABEL: func.func @mfma_matmul_96x64x16_mm
@@ -21,7 +25,6 @@
// CHECK-SAME: %[[A]], %[[B]], %[[EXT]]
// CHECK-SAME: vector<96x16xf16>, vector<16x64xf16> into vector<96x64xf32>
// CHECK: %[[TRUNC:.+]] = arith.truncf %[[MM]] : vector<96x64xf32> to vector<96x64xf16>
-// CHECK: return %[[TRUNC]] : vector<96x64xf16>
// -----
@@ -34,7 +37,11 @@
indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
%lhs, %rhs, %init : vector<96x16xf16>, vector<64x16xf16> into vector<96x64xf16>
- return %0 : vector<96x64xf16>
+ %1 = iree_vector_ext.to_layout %0 to layout(#iree_vector_ext.nested_layout<subgroup_tile = [1, 1], batch_tile = [3, 2],
+ outer_tile = [4, 1], thread_tile = [2, 32], element_tile = [4, 1],
+ subgroup_strides = [0, 0], thread_strides = [32, 1]>)
+ {mma_kind = #iree_gpu.mma_layout<MFMA_F32_32x32x8_F16>} : vector<96x64xf16>
+ return %1 : vector<96x64xf16>
}
// CHECK-LABEL: func.func @mfma_matmul_96x64x16_mmt
@@ -55,7 +62,11 @@
indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
%lhs, %rhs, %init : vector<96x16xf16>, vector<16x64xf16> into vector<96x64xf64>
- return %0 : vector<96x64xf64>
+ %1 = iree_vector_ext.to_layout %0 to layout(#iree_vector_ext.nested_layout<subgroup_tile = [1, 1], batch_tile = [3, 2],
+ outer_tile = [4, 1], thread_tile = [2, 32], element_tile = [4, 1],
+ subgroup_strides = [0, 0], thread_strides = [32, 1]>)
+ {mma_kind = #iree_gpu.mma_layout<MFMA_F32_32x32x8_F16>} : vector<96x64xf64>
+ return %1 : vector<96x64xf64>
}
// CHECK-LABEL: func.func @mfma_matmul_96x64x16_mm_cannot_downcast
@@ -75,7 +86,11 @@
indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
%lhs, %rhs, %init : vector<48x32xf16>, vector<32x32xf16> into vector<48x32xf16>
- return %0 : vector<48x32xf16>
+ %1 = iree_vector_ext.to_layout %0 to layout(#iree_vector_ext.nested_layout<subgroup_tile = [1, 1], batch_tile = [3, 2],
+ outer_tile = [8, 1], thread_tile = [2, 16], element_tile = [1, 1],
+ subgroup_strides = [0, 0], thread_strides = [16, 1]>)
+ {mma_kind = #iree_gpu.mma_layout<WMMA_F32_16x16x16_F16>} : vector<48x32xf16>
+ return %1 : vector<48x32xf16>
}
// CHECK-LABEL: func.func @wmma_matmul_48x32x32_mm
@@ -87,7 +102,36 @@
// CHECK-SAME: %[[A]], %[[B]], %[[EXT]]
// CHECK-SAME: vector<48x32xf16>, vector<32x32xf16> into vector<48x32xf32>
// CHECK: %[[TRUNC:.+]] = arith.truncf %[[MM]] : vector<48x32xf32> to vector<48x32xf16>
-// CHECK: return %[[TRUNC]] : vector<48x32xf16>
+
+// -----
+
+// This tests cast_type_to_fit_mma works on contract where intrinsic is set by to_layout.
+// "iree.amdgpu.mma" will be generated from the "intrinsic" attribute of to_layout.
+// this also shows that we can overwrite default intrinsics if explicitly set.
+
+func.func @to_layout_config_matmul_96x64x16_mm(%lhs: vector<96x16xf16>, %rhs: vector<16x64xf16>, %init: vector<96x64xf16>) -> vector<96x64xf16> attributes {
+ mma_schedule = #iree_gpu.mma_schedule<
+ intrinsic = #iree_gpu.mma_layout<MFMA_F32_32x32x8_F16>,
+ subgroup_m_count = 1, subgroup_n_count = 1>,
+ workgroup_size = [64, 1, 1]} {
+ %0 = vector.contract {
+ indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+ %lhs, %rhs, %init : vector<96x16xf16>, vector<16x64xf16> into vector<96x64xf16>
+ %1 = iree_vector_ext.to_layout %0 to layout(#iree_vector_ext.nested_layout<subgroup_tile = [1, 1], batch_tile = [6, 4],
+ outer_tile = [1, 1], thread_tile = [16, 4], element_tile = [1, 4],
+ subgroup_strides = [0, 0], thread_strides = [1, 16]>)
+ {mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>} : vector<96x64xf16>
+ return %1 : vector<96x64xf16>
+}
+
+// CHECK-LABEL: func.func @to_layout_config_matmul_96x64x16_mm
+// CHECK-SAME: (%[[A:.+]]: vector<96x16xf16>, %[[B:.+]]: vector<16x64xf16>, %[[INIT:.+]]: vector<96x64xf16>)
+// CHECK: arith.extf
+// CHECK: vector.contract
+// CHECK-SAME: {iree.amdgpu.mma = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>}
+// CHECK-SAME: : vector<96x16xf16>, vector<16x64xf16> into vector<96x64xf32>
+// CHECK: arith.truncf
// -----
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/configure_tensor_layout.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/configure_tensor_layout.mlir
index b9ecc4a..8b761f8 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/configure_tensor_layout.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/configure_tensor_layout.mlir
@@ -42,9 +42,9 @@
// CHECK-LABEL: func.func @matmul_96x64x16_mfma
-// CHECK-DAG: %[[LHS:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED]]) {shared_memory_conversion}
-// CHECK-DAG: %[[RHS:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED1]]) {shared_memory_conversion}
-// CHECK-DAG: %[[ACC:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED2]])
+// CHECK-DAG: %[[LHS:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED]]) {mma_kind = #iree_gpu.mma_layout<MFMA_F32_32x32x8_F16>, shared_memory_conversion}
+// CHECK-DAG: %[[RHS:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED1]]) {mma_kind = #iree_gpu.mma_layout<MFMA_F32_32x32x8_F16>, shared_memory_conversion}
+// CHECK-DAG: %[[ACC:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED2]]) {mma_kind = #iree_gpu.mma_layout<MFMA_F32_32x32x8_F16>}
// CHECK: linalg.generic
// CHECK-SAME: ins(%[[LHS]], %[[RHS]]
// CHECK-SAME: outs(%[[ACC]]
@@ -93,9 +93,9 @@
// CHECK-LABEL: func.func @matmul_96x64x16_wmma
-// CHECK-DAG: %[[LHS:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED]]) {shared_memory_conversion}
-// CHECK-DAG: %[[RHS:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED1]]) {shared_memory_conversion}
-// CHECK-DAG: %[[ACC:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED2]])
+// CHECK-DAG: %[[LHS:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED]]) {mma_kind = #iree_gpu.mma_layout<WMMA_F32_16x16x16_F16>, shared_memory_conversion}
+// CHECK-DAG: %[[RHS:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED1]]) {mma_kind = #iree_gpu.mma_layout<WMMA_F32_16x16x16_F16>, shared_memory_conversion}
+// CHECK-DAG: %[[ACC:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED2]]) {mma_kind = #iree_gpu.mma_layout<WMMA_F32_16x16x16_F16>}
// CHECK: linalg.generic
// CHECK-SAME: ins(%[[LHS]], %[[RHS]]
// CHECK-SAME: outs(%[[ACC]]
@@ -144,9 +144,9 @@
// CHECK-LABEL: func.func @matmul_128x64x16_multi_subgroup
-// CHECK-DAG: %[[LHS:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED]]) {shared_memory_conversion}
-// CHECK-DAG: %[[RHS:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED1]]) {shared_memory_conversion}
-// CHECK-DAG: %[[ACC:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED2]])
+// CHECK-DAG: %[[LHS:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED]]) {mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>, shared_memory_conversion}
+// CHECK-DAG: %[[RHS:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED1]]) {mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>, shared_memory_conversion}
+// CHECK-DAG: %[[ACC:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED2]]) {mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>}
// CHECK: linalg.generic
// CHECK-SAME: ins(%[[LHS]], %[[RHS]]
// CHECK-SAME: outs(%[[ACC]]
@@ -195,9 +195,9 @@
// CHECK-DAG: #[[$NESTED2:.+]] = #iree_vector_ext.nested_layout<subgroup_tile = [2, 1, 2, 1], batch_tile = [4, 1, 4, 1], outer_tile = [1, 1, 1, 1], thread_tile = [1, 4, 1, 16], element_tile = [1, 4, 1, 1], subgroup_strides = [2, 0, 1, 0], thread_strides = [0, 16, 0, 1]>
// CHECK-LABEL: func.func @packed_matmul_128x128x128
-// CHECK-DAG: %[[LHS:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED]]) {shared_memory_conversion}
-// CHECK-DAG: %[[RHS:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED1]]) {shared_memory_conversion}
-// CHECK-DAG: %[[ACC:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED2]])
+// CHECK-DAG: %[[LHS:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED]]) {mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>, shared_memory_conversion}
+// CHECK-DAG: %[[RHS:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED1]]) {mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>, shared_memory_conversion}
+// CHECK-DAG: %[[ACC:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED2]]) {mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>}
// CHECK: linalg.generic
// CHECK-SAME: ins(%[[LHS]], %[[RHS]]
// CHECK-SAME: outs(%[[ACC]]