[LLVMGPU] Add Virtual MFMA layout that maximizes load through adjusted K-width (#18930)

The main use case for the virtual intrinsics are to change the layout of
intrinsics in K-dimension, such that we can coalesce reads from shared
memory to register.

Currently, the "native" intrinsics need to enforce the "native" layout
(i.e read 4 element per thread for MFMA_F32_16x16x16), however since we
know that K-dim is a reduction dimension which is associative, we can
read the data in non "native"/"correct" but "faster"/"more elements per
read" way but as long as we match the K-dim on both lhs and rhs we will
still get correct results (i.e read 8 contiguous element per thread from
shared memory along dimension K for and then slice them into two
MFMA_F32_16x16x16)).

an IR example for this is if we want to do a 16x16x32(MxNxK) matmul with
MFMA_F32_16x16x16_F16 intrinsics, on lane 0 we used to have something
like:

```
lhs_0 = read(lhs_shared_mem[0:4])
rhs_0 = read(rhs_shared_mem[0:4])
mma_0 = vector.contract(lhs_0, rhs_0)

(16 offset since MFMA_F32_16x16x16xF16 has intrinsic K size of 16)
lhs_1 = read(lhs_shared_mem[16 + 0: 16 + 4])
rhs_1 = read(rhs_shared_mem[16 + 0 : 16 + 4])
mma_1 = vector.contract(lhs_1, rhs_1, mma_0)
```

With this optimization, we will turn into something like:

```
lhs_reg = read(lhs_shared_mem[0:8])
rhs_reg = read(rhs_shared_mem[0:8])

lhs_0 = slice(lhs_reg, [0 : 4])
rhs_0 = slice(rhs_reg, [0 : 4])
mma_0 = vector.contract(lhs_0, rhs_0)

lhs_1 = slice(lhs_reg, [4 : 8])
rhs_1 = slice(rhs_reg, [4 : 8])
mma_1 = vector.contract(lhs_0, rhs_0, mma_0)
```

Currently, we are plumbing it in as MMA intrinsic enums for two variants
of unrolled k == 2 on the F16s(per discussion with @qedawkins and
@Groverkss ), as they are the easiest and non tangly way to
integrate/plumb through. all though in the future we can expose this
attribute as k-width for maximizing generability.

---------

Signed-off-by: Stanley Winata <stanley.winata@amd.com>
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_contract_amdgpu.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_contract_amdgpu.mlir
index 489b8a0..db39c0b 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_contract_amdgpu.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_contract_amdgpu.mlir
@@ -590,3 +590,96 @@
 //       CHECK:   %[[B_OUT:.*]] = vector.broadcast %[[R_CAST]] : vector<8x1x1x1xf32> to vector<1x1x8x1x1x1xf32>
 //       CHECK:   %[[R_SIMD:.+]] = iree_vector_ext.to_simd %[[B_OUT]] : vector<1x1x8x1x1x1xf32> -> vector<16x16xf32>
 //       CHECK:   return {{.*}} %[[R_SIMD]]
+
+// -----
+
+// Non-native MFMA_F32_32x32x16_F16, i.e CDNA3 V_MFMA_F32_32x32x8_F16 with unrolled_k = 2.
+// This non native layout maximizes reads from shared memory to register.
+
+#map1 = affine_map<(m, n, k) -> (m, k)>
+#map2 = affine_map<(m, n, k) -> (k, n)>
+#map3 = affine_map<(m, n, k) -> (m, n)>
+
+// A: shape = 32x16, layout = layoutA
+#layout_a = #iree_vector_ext.nested_layout<
+  subgroup_tile = [1, 1],
+  batch_tile    = [1, 1],
+  outer_tile        = [1, 1],
+  thread_tile       = [32, 2],
+  element_tile     = [1, 8],
+
+  subgroup_strides        = [1, 1],
+  thread_strides          = [1, 32]
+>
+
+// B: shape = 16x32, layout = layoutB
+#layout_b = #iree_vector_ext.nested_layout<
+  subgroup_tile = [1, 1],
+  batch_tile    = [1, 1],
+  outer_tile        = [1, 1],
+  thread_tile       = [2, 32],
+  element_tile     = [8, 1],
+
+  subgroup_strides        = [1, 1],
+  thread_strides          = [32, 1]
+>
+
+// C: shape = 32x32, layout = layoutC
+#layout_c = #iree_vector_ext.nested_layout<
+  subgroup_tile = [1, 1],
+  batch_tile    = [1, 1],
+  outer_tile        = [4, 1],
+  thread_tile       = [2, 32],
+  element_tile     = [4, 1],
+
+  subgroup_strides        = [1, 1],
+  thread_strides          = [32, 1]
+>
+
+func.func @contract_to_vmfma_32x32x16_mm(%a : vector<32x16xf16>, %b : vector<16x32xf16>, %c : vector<32x32xf32>) -> vector<32x32xf32> {
+  %A = iree_vector_ext.to_layout %a to layout(#layout_a) : vector<32x16xf16>
+  %B = iree_vector_ext.to_layout %b to layout(#layout_b) : vector<16x32xf16>
+  %C = iree_vector_ext.to_layout %c to layout(#layout_c) : vector<32x32xf32>
+
+  %output = vector.contract {
+    indexing_maps = [#map1, #map2, #map3],
+    iterator_types = ["parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>,
+    iree.amdgpu.mma = #iree_gpu.mma_layout<VMFMA_F32_32x32x16_F16>
+  } %A, %B, %C : vector<32x16xf16>, vector<16x32xf16> into vector<32x32xf32>
+
+  %O = iree_vector_ext.to_layout %output to layout(#layout_c) : vector<32x32xf32>
+  return %O : vector<32x32xf32>
+}
+
+builtin.module attributes { transform.with_named_sequence } {
+  transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
+    %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
+    transform.iree.test_gpu_vector_distribution %top_level_func : !transform.any_op
+    transform.yield
+  }
+}
+
+// Notable things to look out for:
+// 1. We are reading 8xf16 instead of 4xf16 for lhs,rhs operands.
+// 2. We slice the 8xf16 to 2 different 4xf16 per operand for use on 2 MMAs.
+// 3. Result of first mma becomes the second mma's accumulator.
+
+// CHECK-LABEL: func @contract_to_vmfma_32x32x16_mm
+// CHECK:       %[[A_CAST:.+]] = vector.shape_cast %{{.+}} : vector<1x1x1x8xf16> to vector<8xf16>
+// CHECK:       %[[B_CAST:.+]] = vector.shape_cast %{{.+}} : vector<1x1x8x1xf16> to vector<8xf16>
+// CHECK:       %[[C_CAST:.+]] = vector.shape_cast %{{.+}} : vector<4x1x4x1xf32> to vector<16xf32>
+// CHECK:       %[[A_SLICE_0:.+]] = vector.extract_strided_slice %[[A_CAST]] {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
+// CHECK:       %[[B_SLICE_0:.+]] = vector.extract_strided_slice %[[B_CAST]] {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
+// CHECK:       %[[MFMA_0:.*]] = amdgpu.mfma %[[A_SLICE_0]] * %[[B_SLICE_0]] + %[[C_CAST]]
+// CHECK-SAME:     {blocks = 1 : i32, k = 8 : i32, m = 32 : i32, n = 32 : i32} blgp =  none
+// CHECK-SAME:     : vector<4xf16>, vector<4xf16>, vector<16xf32>
+// CHECK:       %[[A_SLICE_1:.+]] = vector.extract_strided_slice %[[A_CAST]] {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
+// CHECK:       %[[B_SLICE_1:.+]] = vector.extract_strided_slice %[[B_CAST]] {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
+// CHECK:       %[[MFMA_1:.+]] = amdgpu.mfma %[[A_SLICE_1]] * %[[B_SLICE_1]] + %[[MFMA_0]]
+// CHECK-SAME:     {blocks = 1 : i32, k = 8 : i32, m = 32 : i32, n = 32 : i32} blgp =  none
+// CHECK-SAME:     : vector<4xf16>, vector<4xf16>, vector<16xf32>
+// CHECK:       %[[R_CAST:.+]] = vector.shape_cast %[[MFMA_1]] : vector<16xf32> to vector<4x1x4x1xf32>
+// CHECK:       %[[B_OUT:.*]] = vector.broadcast %[[R_CAST]] : vector<4x1x4x1xf32> to vector<1x1x4x1x4x1xf32>
+// CHECK:       %[[R_SIMD:.+]] = iree_vector_ext.to_simd %[[B_OUT]] : vector<1x1x4x1x4x1xf32> -> vector<32x32xf32>
+// CHECK:       return {{.*}} %[[R_SIMD]]
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
index 93a2ca7..e53f915 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
@@ -263,6 +263,14 @@
   case MMAIntrinsic::WMMA_I32_16x16x16_I8: {
     return OpaqueMmaLayout{16, 16, 16, i8, i8, i32};
   }
+  // V(Virtual)MFMA instructions which have 2 mfma instructions interleaved
+  // along the k dimension.
+  case MMAIntrinsic::VMFMA_F32_16x16x32_F16: {
+    return OpaqueMmaLayout{16, 16, 32, f16, f16, f32};
+  }
+  case MMAIntrinsic::VMFMA_F32_32x32x16_F16: {
+    return OpaqueMmaLayout{32, 32, 16, f16, f16, f32};
+  }
   }
   llvm_unreachable("unhandled mfma layout type");
   return OpaqueMmaLayout{};
@@ -412,12 +420,14 @@
   }
   case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ:
   case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ:
+  case MMAIntrinsic::VMFMA_F32_16x16x32_F16:
   case MMAIntrinsic::MFMA_I32_16x16x32_I8: {
     auto aType = VectorType::get({8}, getAType());
     auto bType = VectorType::get({8}, getBType());
     auto cType = VectorType::get({4}, getCType());
     return std::make_tuple(aType, bType, cType);
   }
+  case MMAIntrinsic::VMFMA_F32_32x32x16_F16:
   case MMAIntrinsic::MFMA_I32_32x32x16_I8: {
     auto aType = VectorType::get({8}, getAType());
     auto bType = VectorType::get({8}, getBType());
@@ -461,7 +471,9 @@
   case MMAIntrinsic::MFMA_I32_32x32x8_I8:
   case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ:
   case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ:
+  case MMAIntrinsic::VMFMA_F32_16x16x32_F16:
   case MMAIntrinsic::MFMA_I32_16x16x32_I8:
+  case MMAIntrinsic::VMFMA_F32_32x32x16_F16:
   case MMAIntrinsic::MFMA_I32_32x32x16_I8:
   case MMAIntrinsic::WMMA_F16_16x16x16_F16:
   case MMAIntrinsic::WMMA_F32_16x16x16_F16:
@@ -484,7 +496,9 @@
   case MMAIntrinsic::MFMA_I32_32x32x8_I8:
   case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ:
   case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ:
+  case MMAIntrinsic::VMFMA_F32_16x16x32_F16:
   case MMAIntrinsic::MFMA_I32_16x16x32_I8:
+  case MMAIntrinsic::VMFMA_F32_32x32x16_F16:
   case MMAIntrinsic::MFMA_I32_32x32x16_I8: {
     return 64;
   }
@@ -549,6 +563,7 @@
       return {/*outer=*/{4, 1}, /*thread=*/{2, 32}, /*tstrides=*/{32, 1},
               /*element=*/{4, 1}};
     }
+  case MMAIntrinsic::VMFMA_F32_16x16x32_F16:
   case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ:
   case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ:
   case MMAIntrinsic::MFMA_I32_16x16x32_I8:
@@ -563,6 +578,7 @@
       return {/*outer=*/{1, 1}, /*thread=*/{4, 16}, /*tstrides=*/{16, 1},
               /*element=*/{4, 1}};
     }
+  case MMAIntrinsic::VMFMA_F32_32x32x16_F16:
   case MMAIntrinsic::MFMA_I32_32x32x16_I8:
     switch (fragment) {
     case MMAFragment::Lhs:
@@ -616,6 +632,19 @@
   return getSingleSubgroupLayout(getIntrinsic().getValue(), MMAFragment::Acc);
 }
 
+// Get virtual intrinsics that is composed/based on queried op.
+SmallVector<MMAIntrinsic> MMAAttr::getVirtualIntrinsics() const {
+  switch (getIntrinsic().getValue()) {
+  case MMAIntrinsic::MFMA_F32_16x16x16_F16:
+    return {MMAIntrinsic::VMFMA_F32_16x16x32_F16};
+  case MMAIntrinsic::MFMA_F32_32x32x8_F16:
+    return {MMAIntrinsic::VMFMA_F32_32x32x16_F16};
+  default:
+    return {};
+  }
+  return {};
+}
+
 // Generates amdgpu.mfma/wmma operation on the given inputs for this attribute
 // type.
 FailureOr<Value> MMAAttr::buildMmaOperation(OpBuilder &builder, Location loc,
@@ -643,6 +672,37 @@
                                 rhs, acc)
         .getResult();
   }
+  case MMAIntrinsic::VMFMA_F32_16x16x32_F16:
+  case MMAIntrinsic::VMFMA_F32_32x32x16_F16: {
+    // Generate mfma's for K with unrolled kernels.
+    const int64_t unrollKFactor = 2;
+    auto [m, n, k] = getMNKShape();
+    // Compute actual/native intrinsic's K size.
+    int64_t nativeKSize = k / unrollKFactor;
+
+    auto [aType, bType, cType] = getABCVectorTypes();
+    if (aType.getShape()[0] != bType.getShape()[0]) {
+      // Currently only support case where lhs and rhs
+      // has same vectorWidth.
+      return failure();
+    }
+    int64_t vectorWidth = aType.getShape()[0] / unrollKFactor;
+    for (int i = 0; i < unrollKFactor; i++) {
+      int64_t offset = vectorWidth * i;
+      Value sliced_lhs = builder.create<vector::ExtractStridedSliceOp>(
+          loc, lhs, ArrayRef<int64_t>{offset}, ArrayRef<int64_t>{vectorWidth},
+          ArrayRef<int64_t>{1});
+      Value sliced_rhs = builder.create<vector::ExtractStridedSliceOp>(
+          loc, rhs, ArrayRef<int64_t>{offset}, ArrayRef<int64_t>{vectorWidth},
+          ArrayRef<int64_t>{1});
+      acc = builder
+                .create<amdgpu::MFMAOp>(loc, resultType, m, n, nativeKSize,
+                                        getBlockSize(), sliced_lhs, sliced_rhs,
+                                        acc)
+                .getResult();
+    }
+    return acc;
+  }
   case MMAIntrinsic::MFMA_I32_16x16x16_I8:
   case MMAIntrinsic::MFMA_F32_16x16x16_F16:
   case MMAIntrinsic::MFMA_F32_16x16x16_BF16:
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td
index d04e9fe..bbb7962 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td
@@ -216,6 +216,8 @@
     MMASingleSubgroupLayout getASingleSubgroupLayout() const;
     MMASingleSubgroupLayout getBSingleSubgroupLayout() const;
     MMASingleSubgroupLayout getCSingleSubgroupLayout() const;
+
+    SmallVector<MMAIntrinsic> getVirtualIntrinsics() const;
   }];
 }
 
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.td b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.td
index 9d4ac2e..1afdf0d 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.td
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.td
@@ -98,7 +98,13 @@
   let genSpecializedAttr = 0;
 }
 
-// Format: <kind>_<output-type>_<M>x<N>x<K>_<input-type>
+// Format: <virtual><kind>_<output-type>_<M>x<N>x<K>_<input-type>
+//
+// "virtual": Prefixes intrinsic with "V" to represent Non native-MFMA
+//             emulating a larger MMA with smaller ones. This is useful
+//             to interleave reads in K-dim, S.T we can have wider reads
+//             or align layouts between matmuls.
+//
 // Values: 0xABCD where:
 // * A = vendor:
 //   * 0 = AMD
@@ -121,6 +127,8 @@
 def MFMA_F32_16x16x4_F32 : I32EnumAttrCase<"MFMA_F32_16x16x4_F32", 0x0900>;
 def MFMA_F32_16x16x16_F16 : I32EnumAttrCase<"MFMA_F32_16x16x16_F16", 0x0910>;
 def MFMA_F32_32x32x8_F16  : I32EnumAttrCase<"MFMA_F32_32x32x8_F16", 0x0911>;
+def VMFMA_F32_16x16x32_F16  : I32EnumAttrCase<"VMFMA_F32_16x16x32_F16", 0x0912>;
+def VMFMA_F32_32x32x16_F16  : I32EnumAttrCase<"VMFMA_F32_32x32x16_F16", 0x0913>;
 def MFMA_F32_16x16x16_BF16  : I32EnumAttrCase<"MFMA_F32_16x16x16_BF16", 0x0920>;
 def MFMA_F32_32x32x8_BF16  : I32EnumAttrCase<"MFMA_F32_32x32x8_BF16", 0x0921>;
 def MFMA_F32_16x16x32_F8E5M2FNUZ  : I32EnumAttrCase<"MFMA_F32_16x16x32_F8E5M2FNUZ", 0x0930>;
@@ -145,6 +153,8 @@
       MFMA_F32_16x16x4_F32,
       MFMA_F32_16x16x16_F16,
       MFMA_F32_32x32x8_F16,
+      VMFMA_F32_16x16x32_F16,
+      VMFMA_F32_32x32x16_F16,
       MFMA_F32_16x16x16_BF16,
       MFMA_F32_32x32x8_BF16,
       MFMA_F32_16x16x32_F8E4M3FNUZ,
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
index ede2d0b..b4567e3 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
@@ -309,15 +309,32 @@
   GPUMatmulShapeType problem{bounds[mDim], bounds[nDim], bounds[kDim],
                              lhsElemType,  rhsElemType,  initElemType};
 
-  SmallVector<GPUMatmulShapeType> intrinsics;
-  intrinsics.reserve(target.getWgp().getMma().size());
-  for (IREE::GPU::MMAAttr mma : target.getWgp().getMma()) {
+  // Helper fn to store mma information.
+  auto storeMmaInfo = [](IREE::GPU::MMAAttr mma,
+                         SmallVector<GPUMatmulShapeType> &intrinsics,
+                         SmallVector<IREE::GPU::MMAAttr> &mmaAttrs) {
     auto [mSize, nSize, kSize] = mma.getMNKShape();
     auto [aType, bType, cType] = mma.getABCElementTypes();
+    intrinsics.emplace_back(mSize, nSize, kSize, aType, bType, cType);
+    mmaAttrs.emplace_back(mma);
+  };
+
+  SmallVector<GPUMatmulShapeType> intrinsics;
+  intrinsics.reserve(target.getWgp().getMma().size());
+  SmallVector<IREE::GPU::MMAAttr> mmaAttrs;
+  MLIRContext *context = op.getContext();
+  for (IREE::GPU::MMAAttr mma : target.getWgp().getMma()) {
     if (mma.getSubgroupSize() != targetSubgroupSize)
       continue;
-    intrinsics.emplace_back(mSize, nSize, kSize, aType, bType, cType);
+    storeMmaInfo(mma, intrinsics, mmaAttrs);
+    // Store info on virtual intrinsics based on current mma if any
+    for (IREE::GPU::MMAIntrinsic virtualIntrinsic :
+         mma.getVirtualIntrinsics()) {
+      auto virtualMma = IREE::GPU::MMAAttr::get(context, virtualIntrinsic);
+      storeMmaInfo(virtualMma, intrinsics, mmaAttrs);
+    }
   }
+
   if (intrinsics.empty())
     return failure();
 
@@ -379,7 +396,6 @@
     reductionTileSizes[filterDim] = 1;
   }
 
-  MLIRContext *context = op.getContext();
   Builder b(context);
   SmallVector<NamedAttribute, 2> attrs;
   attrs.emplace_back(StringAttr::get(context, "workgroup"),
@@ -395,8 +411,8 @@
   // for later access in the pipeline.
   SmallVector<NamedAttribute, 1> pipelineAttrs;
   auto scheduleAttr = IREE::GPU::MMAScheduleAttr::get(
-      context, target.getWgp().getMma()[schedule->index],
-      schedule->mSubgroupCounts[0], schedule->nSubgroupCounts[0]);
+      context, mmaAttrs[schedule->index], schedule->mSubgroupCounts[0],
+      schedule->nSubgroupCounts[0]);
   pipelineAttrs.emplace_back(StringAttr::get(context, "mma_schedule"),
                              scheduleAttr);
 
@@ -506,15 +522,32 @@
   GPUMatmulShapeType problem{bounds[mDim], bounds[nDim], bounds[kDim],
                              lhsElemType,  rhsElemType,  initElemType};
 
-  SmallVector<GPUMatmulShapeType> intrinsics;
-  intrinsics.reserve(target.getWgp().getMma().size());
-  for (IREE::GPU::MMAAttr mma : target.getWgp().getMma()) {
+  // Helper fn to store mma information.
+  auto storeMmaInfo = [](IREE::GPU::MMAAttr mma,
+                         SmallVector<GPUMatmulShapeType> &intrinsics,
+                         SmallVector<IREE::GPU::MMAAttr> &mmaAttrs) {
     auto [mSize, nSize, kSize] = mma.getMNKShape();
     auto [aType, bType, cType] = mma.getABCElementTypes();
+    intrinsics.emplace_back(mSize, nSize, kSize, aType, bType, cType);
+    mmaAttrs.emplace_back(mma);
+  };
+
+  SmallVector<GPUMatmulShapeType> intrinsics;
+  intrinsics.reserve(target.getWgp().getMma().size());
+  SmallVector<IREE::GPU::MMAAttr> mmaAttrs;
+  MLIRContext *context = op.getContext();
+  for (IREE::GPU::MMAAttr mma : target.getWgp().getMma()) {
     if (mma.getSubgroupSize() != targetSubgroupSize)
       continue;
-    intrinsics.emplace_back(mSize, nSize, kSize, aType, bType, cType);
+    storeMmaInfo(mma, intrinsics, mmaAttrs);
+    // Store info on virtual intrinsics based on current mma if any
+    for (IREE::GPU::MMAIntrinsic virtualIntrinsic :
+         mma.getVirtualIntrinsics()) {
+      auto virtualMma = IREE::GPU::MMAAttr::get(context, virtualIntrinsic);
+      storeMmaInfo(virtualMma, intrinsics, mmaAttrs);
+    }
   }
+
   if (intrinsics.empty())
     return failure();
 
@@ -627,7 +660,6 @@
   LLVM_DEBUG(debugPrintContractionInfo("Reduction tile sizes", op.getNumLoops(),
                                        *contractionDims, reductionTileSizes));
 
-  MLIRContext *context = op.getContext();
   Builder b(context);
   SmallVector<NamedAttribute, 2> attrs;
   attrs.emplace_back(StringAttr::get(context, "workgroup"),
@@ -643,8 +675,8 @@
   // for later access in the pipeline.
   SmallVector<NamedAttribute, 1> pipelineAttrs;
   auto scheduleAttr = IREE::GPU::MMAScheduleAttr::get(
-      context, target.getWgp().getMma()[schedule->index],
-      schedule->mSubgroupCounts[0], schedule->nSubgroupCounts[0]);
+      context, mmaAttrs[schedule->index], schedule->mSubgroupCounts[0],
+      schedule->nSubgroupCounts[0]);
   pipelineAttrs.emplace_back(StringAttr::get(context, "mma_schedule"),
                              scheduleAttr);
 
@@ -709,15 +741,32 @@
   Value kMatrix = op.getKey();
   Value vMatrix = op.getValue();
 
-  SmallVector<GPUMatmulShapeType> intrinsics;
-  intrinsics.reserve(target.getWgp().getMma().size());
-  for (IREE::GPU::MMAAttr mma : target.getWgp().getMma()) {
+  // Helper fn to store mma information.
+  auto storeMmaInfo = [](IREE::GPU::MMAAttr mma,
+                         SmallVector<GPUMatmulShapeType> &intrinsics,
+                         SmallVector<IREE::GPU::MMAAttr> &mmaAttrs) {
     auto [mSize, nSize, kSize] = mma.getMNKShape();
     auto [aType, bType, cType] = mma.getABCElementTypes();
+    intrinsics.emplace_back(mSize, nSize, kSize, aType, bType, cType);
+    mmaAttrs.emplace_back(mma);
+  };
+
+  SmallVector<GPUMatmulShapeType> intrinsics;
+  intrinsics.reserve(target.getWgp().getMma().size());
+  SmallVector<IREE::GPU::MMAAttr> mmaAttrs;
+  MLIRContext *context = op.getContext();
+  for (IREE::GPU::MMAAttr mma : target.getWgp().getMma()) {
     if (mma.getSubgroupSize() != targetSubgroupSize)
       continue;
-    intrinsics.emplace_back(mSize, nSize, kSize, aType, bType, cType);
+    storeMmaInfo(mma, intrinsics, mmaAttrs);
+    // Store info on virtual intrinsics based on current mma if any
+    for (IREE::GPU::MMAIntrinsic virtualIntrinsic :
+         mma.getVirtualIntrinsics()) {
+      auto virtualMma = IREE::GPU::MMAAttr::get(context, virtualIntrinsic);
+      storeMmaInfo(virtualMma, intrinsics, mmaAttrs);
+    }
   }
+
   if (intrinsics.empty())
     return failure();
 
@@ -826,7 +875,6 @@
 
   reductionTileSizes[k2Dim] = schedule->kTileSizes[0] * schedule->kSize;
 
-  MLIRContext *context = op.getContext();
   SmallVector<NamedAttribute, 2> attrs;
   attrs.emplace_back(StringAttr::get(context, "workgroup"),
                      b.getI64ArrayAttr(workgroupTileSizes));
@@ -878,8 +926,8 @@
   // for later access in the pipeline.
   SmallVector<NamedAttribute, 1> pipelineAttrs;
   auto scheduleAttr = IREE::GPU::MMAScheduleAttr::get(
-      context, target.getWgp().getMma()[schedule->index],
-      schedule->mSubgroupCounts[0], schedule->nSubgroupCounts[0]);
+      context, mmaAttrs[schedule->index], schedule->mSubgroupCounts[0],
+      schedule->nSubgroupCounts[0]);
   pipelineAttrs.emplace_back(StringAttr::get(context, "mma_schedule"),
                              scheduleAttr);
 
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx940.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx940.mlir
index 7e1ab62..cedec2d 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx940.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx940.mlir
@@ -655,6 +655,69 @@
 
 // -----
 
+// This test ensures that we can generate and decompose the right instructions from V(Virtual) MFMAs.
+
+#config = #iree_gpu.lowering_config<{workgroup = [64, 64, 0], reduction = [0, 0, 128], promote_operands = [0, 1]}>
+#translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute workgroup_size = [256, 1, 1] subgroup_size = 64, {gpu_pipeline_options = #iree_gpu.pipeline_options<prefetch_shared_memory = true, no_reduce_shared_memory_bank_conflicts = false>, mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<VMFMA_F32_32x32x16_F16>, subgroup_m_count = 2, subgroup_n_count = 2>}>
+
+#pipeline_layout = #hal.pipeline.layout<bindings = [
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>
+]>
+hal.executable @virtual_intrinsic_256x256x256_f16_f32 {
+hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb">) {
+  hal.executable.export @virtual_intrinsic_256x256x256_f16_f32 layout(#pipeline_layout) {
+    ^bb0(%arg0: !hal.device, %arg1: index, %arg2 : index):
+      %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2
+      hal.return %x, %y, %z : index, index, index
+    }
+  builtin.module {
+    func.func @virtual_intrinsic_256x256x256_f16_f32() attributes {translation_info = #translation} {
+      %cst = arith.constant 0.000000e+00 : f32
+      %c0 = arith.constant 0 : index
+      %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<256x256xf16>>
+      %1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<256x256xf16>>
+      %2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<256x256xf32>>
+      %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<256x256xf16>> -> tensor<256x256xf16>
+      %4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<256x256xf16>> -> tensor<256x256xf16>
+      %5 = tensor.empty() : tensor<256x256xf32>
+      %6 = linalg.fill ins(%cst : f32) outs(%5 : tensor<256x256xf32>) -> tensor<256x256xf32>
+      %7 = linalg.matmul {lowering_config = #config} ins(%3, %4 : tensor<256x256xf16>, tensor<256x256xf16>) outs(%6 : tensor<256x256xf32>) -> tensor<256x256xf32>
+      flow.dispatch.tensor.store %7, %2, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : tensor<256x256xf32> -> !flow.dispatch.tensor<writeonly:tensor<256x256xf32>>
+      return
+    }
+  }
+}
+}
+
+// CHECK-LABEL: func @virtual_intrinsic_256x256x256_f16_f32
+// CHECK:   scf.for {{.*}} = %c0 to %c256 step %c128 iter_args(%[[ARG:.+]] = {{.*}}) -> (vector<1x1x4x1x4x1xf32>)
+
+// Validate that VMFMA is decomposed into coalesced read and 2 MFMAs:
+
+// CHECK:       %[[A_CAST:.+]] = vector.shape_cast %{{.+}} : vector<1x1x1x8xf16> to vector<8xf16>
+// CHECK:       %[[B_CAST:.+]] = vector.shape_cast %{{.+}} : vector<1x1x8x1xf16> to vector<8xf16>
+// CHECK:       %[[C_CAST:.+]] = vector.shape_cast %{{.+}} : vector<4x1x4x1xf32> to vector<16xf32>
+// CHECK:       %[[A_SLICE_0:.+]] = vector.extract_strided_slice %[[A_CAST]] {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
+// CHECK:       %[[B_SLICE_0:.+]] = vector.extract_strided_slice %[[B_CAST]] {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
+// CHECK:       %[[MFMA_0:.*]] = amdgpu.mfma %[[A_SLICE_0]] * %[[B_SLICE_0]] + %[[C_CAST]]
+// CHECK-SAME:     {blocks = 1 : i32, k = 8 : i32, m = 32 : i32, n = 32 : i32} blgp =  none
+// CHECK-SAME:     : vector<4xf16>, vector<4xf16>, vector<16xf32>
+// CHECK:       %[[A_SLICE_1:.+]] = vector.extract_strided_slice %[[A_CAST]] {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
+// CHECK:       %[[B_SLICE_1:.+]] = vector.extract_strided_slice %[[B_CAST]] {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
+// CHECK:       %[[MFMA_1:.+]] = amdgpu.mfma %[[A_SLICE_1]] * %[[B_SLICE_1]] + %[[MFMA_0]]
+// CHECK-SAME:     {blocks = 1 : i32, k = 8 : i32, m = 32 : i32, n = 32 : i32} blgp =  none
+// CHECK-SAME:     : vector<4xf16>, vector<4xf16>, vector<16xf32>
+
+// Ensure right number of instructions are being generated.
+
+// CHECK-COUNT-14: vector.extract_strided_slice
+// CHECK-NEXT: amdgpu.mfma
+// CHECK:     scf.yield
+
+// -----
+
 #config = #iree_gpu.lowering_config<{workgroup = [1, 64, 0, 0, 64], reduction = [0, 0, 0, 64, 0], promote_operands = [0, 1, 2]}>
 #translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute workgroup_size = [128, 1, 1] subgroup_size = 64, {mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>, subgroup_m_count = 2, subgroup_n_count = 1>}>
 
diff --git a/tests/e2e/matmul/generate_e2e_matmul_tests.py b/tests/e2e/matmul/generate_e2e_matmul_tests.py
index cd6f8eb..dd387f3 100644
--- a/tests/e2e/matmul/generate_e2e_matmul_tests.py
+++ b/tests/e2e/matmul/generate_e2e_matmul_tests.py
@@ -347,6 +347,10 @@
             MMASchedule("MFMA_I32_32x32x16_I8", 2, 2, 1, 1, 2),
             MMASchedule("MFMA_I32_32x32x16_I8", 4, 1, 1, 2, 2),
             MMASchedule("MFMA_I32_32x32x16_I8", 4, 2, 2, 2, 2),
+            MMASchedule("VMFMA_F32_16x16x32_F16", 1, 1, 1, 1, 1),
+            MMASchedule("VMFMA_F32_16x16x32_F16", 4, 2, 1, 2, 4),
+            MMASchedule("VMFMA_F32_32x32x16_F16", 1, 1, 1, 1, 1),
+            MMASchedule("VMFMA_F32_32x32x16_F16", 4, 2, 1, 2, 4),
         ]
     elif intrinsic == "WMMA":
         schedules = [
@@ -393,13 +397,17 @@
             wg_tile_n = schedule.n_count * schedule.n_tile_count * 32
             wg_tile_k = schedule.k_tile_count * 8
         elif (
-            schedule.intrinsic == "MFMA_I32_16x16x32_I8"
+            schedule.intrinsic == "VMFMA_F32_16x16x32_F16"
+            or schedule.intrinsic == "MFMA_I32_16x16x32_I8"
             or schedule.intrinsic == "MFMA_F32_16x16x32_F8E4M3FNUZ"
         ):
             wg_tile_m = schedule.m_count * schedule.m_tile_count * 16
             wg_tile_n = schedule.n_count * schedule.n_tile_count * 16
             wg_tile_k = schedule.k_tile_count * 32
-        elif schedule.intrinsic == "MFMA_I32_32x32x16_I8":
+        elif (
+            schedule.intrinsic == "VMFMA_F32_32x32x16_F16"
+            or schedule.intrinsic == "MFMA_I32_32x32x16_I8"
+        ):
             wg_tile_m = schedule.m_count * schedule.m_tile_count * 32
             wg_tile_n = schedule.n_count * schedule.n_tile_count * 32
             wg_tile_k = schedule.k_tile_count * 16