[LLVMGPU] Add Virtual MFMA layout that maximizes load through adjusted K-width (#18930)
The main use case for the virtual intrinsics are to change the layout of
intrinsics in K-dimension, such that we can coalesce reads from shared
memory to register.
Currently, the "native" intrinsics need to enforce the "native" layout
(i.e read 4 element per thread for MFMA_F32_16x16x16), however since we
know that K-dim is a reduction dimension which is associative, we can
read the data in non "native"/"correct" but "faster"/"more elements per
read" way but as long as we match the K-dim on both lhs and rhs we will
still get correct results (i.e read 8 contiguous element per thread from
shared memory along dimension K for and then slice them into two
MFMA_F32_16x16x16)).
an IR example for this is if we want to do a 16x16x32(MxNxK) matmul with
MFMA_F32_16x16x16_F16 intrinsics, on lane 0 we used to have something
like:
```
lhs_0 = read(lhs_shared_mem[0:4])
rhs_0 = read(rhs_shared_mem[0:4])
mma_0 = vector.contract(lhs_0, rhs_0)
(16 offset since MFMA_F32_16x16x16xF16 has intrinsic K size of 16)
lhs_1 = read(lhs_shared_mem[16 + 0: 16 + 4])
rhs_1 = read(rhs_shared_mem[16 + 0 : 16 + 4])
mma_1 = vector.contract(lhs_1, rhs_1, mma_0)
```
With this optimization, we will turn into something like:
```
lhs_reg = read(lhs_shared_mem[0:8])
rhs_reg = read(rhs_shared_mem[0:8])
lhs_0 = slice(lhs_reg, [0 : 4])
rhs_0 = slice(rhs_reg, [0 : 4])
mma_0 = vector.contract(lhs_0, rhs_0)
lhs_1 = slice(lhs_reg, [4 : 8])
rhs_1 = slice(rhs_reg, [4 : 8])
mma_1 = vector.contract(lhs_0, rhs_0, mma_0)
```
Currently, we are plumbing it in as MMA intrinsic enums for two variants
of unrolled k == 2 on the F16s(per discussion with @qedawkins and
@Groverkss ), as they are the easiest and non tangly way to
integrate/plumb through. all though in the future we can expose this
attribute as k-width for maximizing generability.
---------
Signed-off-by: Stanley Winata <stanley.winata@amd.com>
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_contract_amdgpu.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_contract_amdgpu.mlir
index 489b8a0..db39c0b 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_contract_amdgpu.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_contract_amdgpu.mlir
@@ -590,3 +590,96 @@
// CHECK: %[[B_OUT:.*]] = vector.broadcast %[[R_CAST]] : vector<8x1x1x1xf32> to vector<1x1x8x1x1x1xf32>
// CHECK: %[[R_SIMD:.+]] = iree_vector_ext.to_simd %[[B_OUT]] : vector<1x1x8x1x1x1xf32> -> vector<16x16xf32>
// CHECK: return {{.*}} %[[R_SIMD]]
+
+// -----
+
+// Non-native MFMA_F32_32x32x16_F16, i.e CDNA3 V_MFMA_F32_32x32x8_F16 with unrolled_k = 2.
+// This non native layout maximizes reads from shared memory to register.
+
+#map1 = affine_map<(m, n, k) -> (m, k)>
+#map2 = affine_map<(m, n, k) -> (k, n)>
+#map3 = affine_map<(m, n, k) -> (m, n)>
+
+// A: shape = 32x16, layout = layoutA
+#layout_a = #iree_vector_ext.nested_layout<
+ subgroup_tile = [1, 1],
+ batch_tile = [1, 1],
+ outer_tile = [1, 1],
+ thread_tile = [32, 2],
+ element_tile = [1, 8],
+
+ subgroup_strides = [1, 1],
+ thread_strides = [1, 32]
+>
+
+// B: shape = 16x32, layout = layoutB
+#layout_b = #iree_vector_ext.nested_layout<
+ subgroup_tile = [1, 1],
+ batch_tile = [1, 1],
+ outer_tile = [1, 1],
+ thread_tile = [2, 32],
+ element_tile = [8, 1],
+
+ subgroup_strides = [1, 1],
+ thread_strides = [32, 1]
+>
+
+// C: shape = 32x32, layout = layoutC
+#layout_c = #iree_vector_ext.nested_layout<
+ subgroup_tile = [1, 1],
+ batch_tile = [1, 1],
+ outer_tile = [4, 1],
+ thread_tile = [2, 32],
+ element_tile = [4, 1],
+
+ subgroup_strides = [1, 1],
+ thread_strides = [32, 1]
+>
+
+func.func @contract_to_vmfma_32x32x16_mm(%a : vector<32x16xf16>, %b : vector<16x32xf16>, %c : vector<32x32xf32>) -> vector<32x32xf32> {
+ %A = iree_vector_ext.to_layout %a to layout(#layout_a) : vector<32x16xf16>
+ %B = iree_vector_ext.to_layout %b to layout(#layout_b) : vector<16x32xf16>
+ %C = iree_vector_ext.to_layout %c to layout(#layout_c) : vector<32x32xf32>
+
+ %output = vector.contract {
+ indexing_maps = [#map1, #map2, #map3],
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>,
+ iree.amdgpu.mma = #iree_gpu.mma_layout<VMFMA_F32_32x32x16_F16>
+ } %A, %B, %C : vector<32x16xf16>, vector<16x32xf16> into vector<32x32xf32>
+
+ %O = iree_vector_ext.to_layout %output to layout(#layout_c) : vector<32x32xf32>
+ return %O : vector<32x32xf32>
+}
+
+builtin.module attributes { transform.with_named_sequence } {
+ transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
+ %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
+ transform.iree.test_gpu_vector_distribution %top_level_func : !transform.any_op
+ transform.yield
+ }
+}
+
+// Notable things to look out for:
+// 1. We are reading 8xf16 instead of 4xf16 for lhs,rhs operands.
+// 2. We slice the 8xf16 to 2 different 4xf16 per operand for use on 2 MMAs.
+// 3. Result of first mma becomes the second mma's accumulator.
+
+// CHECK-LABEL: func @contract_to_vmfma_32x32x16_mm
+// CHECK: %[[A_CAST:.+]] = vector.shape_cast %{{.+}} : vector<1x1x1x8xf16> to vector<8xf16>
+// CHECK: %[[B_CAST:.+]] = vector.shape_cast %{{.+}} : vector<1x1x8x1xf16> to vector<8xf16>
+// CHECK: %[[C_CAST:.+]] = vector.shape_cast %{{.+}} : vector<4x1x4x1xf32> to vector<16xf32>
+// CHECK: %[[A_SLICE_0:.+]] = vector.extract_strided_slice %[[A_CAST]] {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
+// CHECK: %[[B_SLICE_0:.+]] = vector.extract_strided_slice %[[B_CAST]] {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
+// CHECK: %[[MFMA_0:.*]] = amdgpu.mfma %[[A_SLICE_0]] * %[[B_SLICE_0]] + %[[C_CAST]]
+// CHECK-SAME: {blocks = 1 : i32, k = 8 : i32, m = 32 : i32, n = 32 : i32} blgp = none
+// CHECK-SAME: : vector<4xf16>, vector<4xf16>, vector<16xf32>
+// CHECK: %[[A_SLICE_1:.+]] = vector.extract_strided_slice %[[A_CAST]] {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
+// CHECK: %[[B_SLICE_1:.+]] = vector.extract_strided_slice %[[B_CAST]] {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
+// CHECK: %[[MFMA_1:.+]] = amdgpu.mfma %[[A_SLICE_1]] * %[[B_SLICE_1]] + %[[MFMA_0]]
+// CHECK-SAME: {blocks = 1 : i32, k = 8 : i32, m = 32 : i32, n = 32 : i32} blgp = none
+// CHECK-SAME: : vector<4xf16>, vector<4xf16>, vector<16xf32>
+// CHECK: %[[R_CAST:.+]] = vector.shape_cast %[[MFMA_1]] : vector<16xf32> to vector<4x1x4x1xf32>
+// CHECK: %[[B_OUT:.*]] = vector.broadcast %[[R_CAST]] : vector<4x1x4x1xf32> to vector<1x1x4x1x4x1xf32>
+// CHECK: %[[R_SIMD:.+]] = iree_vector_ext.to_simd %[[B_OUT]] : vector<1x1x4x1x4x1xf32> -> vector<32x32xf32>
+// CHECK: return {{.*}} %[[R_SIMD]]
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
index 93a2ca7..e53f915 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
@@ -263,6 +263,14 @@
case MMAIntrinsic::WMMA_I32_16x16x16_I8: {
return OpaqueMmaLayout{16, 16, 16, i8, i8, i32};
}
+ // V(Virtual)MFMA instructions which have 2 mfma instructions interleaved
+ // along the k dimension.
+ case MMAIntrinsic::VMFMA_F32_16x16x32_F16: {
+ return OpaqueMmaLayout{16, 16, 32, f16, f16, f32};
+ }
+ case MMAIntrinsic::VMFMA_F32_32x32x16_F16: {
+ return OpaqueMmaLayout{32, 32, 16, f16, f16, f32};
+ }
}
llvm_unreachable("unhandled mfma layout type");
return OpaqueMmaLayout{};
@@ -412,12 +420,14 @@
}
case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ:
case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ:
+ case MMAIntrinsic::VMFMA_F32_16x16x32_F16:
case MMAIntrinsic::MFMA_I32_16x16x32_I8: {
auto aType = VectorType::get({8}, getAType());
auto bType = VectorType::get({8}, getBType());
auto cType = VectorType::get({4}, getCType());
return std::make_tuple(aType, bType, cType);
}
+ case MMAIntrinsic::VMFMA_F32_32x32x16_F16:
case MMAIntrinsic::MFMA_I32_32x32x16_I8: {
auto aType = VectorType::get({8}, getAType());
auto bType = VectorType::get({8}, getBType());
@@ -461,7 +471,9 @@
case MMAIntrinsic::MFMA_I32_32x32x8_I8:
case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ:
case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ:
+ case MMAIntrinsic::VMFMA_F32_16x16x32_F16:
case MMAIntrinsic::MFMA_I32_16x16x32_I8:
+ case MMAIntrinsic::VMFMA_F32_32x32x16_F16:
case MMAIntrinsic::MFMA_I32_32x32x16_I8:
case MMAIntrinsic::WMMA_F16_16x16x16_F16:
case MMAIntrinsic::WMMA_F32_16x16x16_F16:
@@ -484,7 +496,9 @@
case MMAIntrinsic::MFMA_I32_32x32x8_I8:
case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ:
case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ:
+ case MMAIntrinsic::VMFMA_F32_16x16x32_F16:
case MMAIntrinsic::MFMA_I32_16x16x32_I8:
+ case MMAIntrinsic::VMFMA_F32_32x32x16_F16:
case MMAIntrinsic::MFMA_I32_32x32x16_I8: {
return 64;
}
@@ -549,6 +563,7 @@
return {/*outer=*/{4, 1}, /*thread=*/{2, 32}, /*tstrides=*/{32, 1},
/*element=*/{4, 1}};
}
+ case MMAIntrinsic::VMFMA_F32_16x16x32_F16:
case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ:
case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ:
case MMAIntrinsic::MFMA_I32_16x16x32_I8:
@@ -563,6 +578,7 @@
return {/*outer=*/{1, 1}, /*thread=*/{4, 16}, /*tstrides=*/{16, 1},
/*element=*/{4, 1}};
}
+ case MMAIntrinsic::VMFMA_F32_32x32x16_F16:
case MMAIntrinsic::MFMA_I32_32x32x16_I8:
switch (fragment) {
case MMAFragment::Lhs:
@@ -616,6 +632,19 @@
return getSingleSubgroupLayout(getIntrinsic().getValue(), MMAFragment::Acc);
}
+// Get virtual intrinsics that is composed/based on queried op.
+SmallVector<MMAIntrinsic> MMAAttr::getVirtualIntrinsics() const {
+ switch (getIntrinsic().getValue()) {
+ case MMAIntrinsic::MFMA_F32_16x16x16_F16:
+ return {MMAIntrinsic::VMFMA_F32_16x16x32_F16};
+ case MMAIntrinsic::MFMA_F32_32x32x8_F16:
+ return {MMAIntrinsic::VMFMA_F32_32x32x16_F16};
+ default:
+ return {};
+ }
+ return {};
+}
+
// Generates amdgpu.mfma/wmma operation on the given inputs for this attribute
// type.
FailureOr<Value> MMAAttr::buildMmaOperation(OpBuilder &builder, Location loc,
@@ -643,6 +672,37 @@
rhs, acc)
.getResult();
}
+ case MMAIntrinsic::VMFMA_F32_16x16x32_F16:
+ case MMAIntrinsic::VMFMA_F32_32x32x16_F16: {
+ // Generate mfma's for K with unrolled kernels.
+ const int64_t unrollKFactor = 2;
+ auto [m, n, k] = getMNKShape();
+ // Compute actual/native intrinsic's K size.
+ int64_t nativeKSize = k / unrollKFactor;
+
+ auto [aType, bType, cType] = getABCVectorTypes();
+ if (aType.getShape()[0] != bType.getShape()[0]) {
+ // Currently only support case where lhs and rhs
+ // has same vectorWidth.
+ return failure();
+ }
+ int64_t vectorWidth = aType.getShape()[0] / unrollKFactor;
+ for (int i = 0; i < unrollKFactor; i++) {
+ int64_t offset = vectorWidth * i;
+ Value sliced_lhs = builder.create<vector::ExtractStridedSliceOp>(
+ loc, lhs, ArrayRef<int64_t>{offset}, ArrayRef<int64_t>{vectorWidth},
+ ArrayRef<int64_t>{1});
+ Value sliced_rhs = builder.create<vector::ExtractStridedSliceOp>(
+ loc, rhs, ArrayRef<int64_t>{offset}, ArrayRef<int64_t>{vectorWidth},
+ ArrayRef<int64_t>{1});
+ acc = builder
+ .create<amdgpu::MFMAOp>(loc, resultType, m, n, nativeKSize,
+ getBlockSize(), sliced_lhs, sliced_rhs,
+ acc)
+ .getResult();
+ }
+ return acc;
+ }
case MMAIntrinsic::MFMA_I32_16x16x16_I8:
case MMAIntrinsic::MFMA_F32_16x16x16_F16:
case MMAIntrinsic::MFMA_F32_16x16x16_BF16:
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td
index d04e9fe..bbb7962 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td
@@ -216,6 +216,8 @@
MMASingleSubgroupLayout getASingleSubgroupLayout() const;
MMASingleSubgroupLayout getBSingleSubgroupLayout() const;
MMASingleSubgroupLayout getCSingleSubgroupLayout() const;
+
+ SmallVector<MMAIntrinsic> getVirtualIntrinsics() const;
}];
}
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.td b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.td
index 9d4ac2e..1afdf0d 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.td
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.td
@@ -98,7 +98,13 @@
let genSpecializedAttr = 0;
}
-// Format: <kind>_<output-type>_<M>x<N>x<K>_<input-type>
+// Format: <virtual><kind>_<output-type>_<M>x<N>x<K>_<input-type>
+//
+// "virtual": Prefixes intrinsic with "V" to represent Non native-MFMA
+// emulating a larger MMA with smaller ones. This is useful
+// to interleave reads in K-dim, S.T we can have wider reads
+// or align layouts between matmuls.
+//
// Values: 0xABCD where:
// * A = vendor:
// * 0 = AMD
@@ -121,6 +127,8 @@
def MFMA_F32_16x16x4_F32 : I32EnumAttrCase<"MFMA_F32_16x16x4_F32", 0x0900>;
def MFMA_F32_16x16x16_F16 : I32EnumAttrCase<"MFMA_F32_16x16x16_F16", 0x0910>;
def MFMA_F32_32x32x8_F16 : I32EnumAttrCase<"MFMA_F32_32x32x8_F16", 0x0911>;
+def VMFMA_F32_16x16x32_F16 : I32EnumAttrCase<"VMFMA_F32_16x16x32_F16", 0x0912>;
+def VMFMA_F32_32x32x16_F16 : I32EnumAttrCase<"VMFMA_F32_32x32x16_F16", 0x0913>;
def MFMA_F32_16x16x16_BF16 : I32EnumAttrCase<"MFMA_F32_16x16x16_BF16", 0x0920>;
def MFMA_F32_32x32x8_BF16 : I32EnumAttrCase<"MFMA_F32_32x32x8_BF16", 0x0921>;
def MFMA_F32_16x16x32_F8E5M2FNUZ : I32EnumAttrCase<"MFMA_F32_16x16x32_F8E5M2FNUZ", 0x0930>;
@@ -145,6 +153,8 @@
MFMA_F32_16x16x4_F32,
MFMA_F32_16x16x16_F16,
MFMA_F32_32x32x8_F16,
+ VMFMA_F32_16x16x32_F16,
+ VMFMA_F32_32x32x16_F16,
MFMA_F32_16x16x16_BF16,
MFMA_F32_32x32x8_BF16,
MFMA_F32_16x16x32_F8E4M3FNUZ,
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
index ede2d0b..b4567e3 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
@@ -309,15 +309,32 @@
GPUMatmulShapeType problem{bounds[mDim], bounds[nDim], bounds[kDim],
lhsElemType, rhsElemType, initElemType};
- SmallVector<GPUMatmulShapeType> intrinsics;
- intrinsics.reserve(target.getWgp().getMma().size());
- for (IREE::GPU::MMAAttr mma : target.getWgp().getMma()) {
+ // Helper fn to store mma information.
+ auto storeMmaInfo = [](IREE::GPU::MMAAttr mma,
+ SmallVector<GPUMatmulShapeType> &intrinsics,
+ SmallVector<IREE::GPU::MMAAttr> &mmaAttrs) {
auto [mSize, nSize, kSize] = mma.getMNKShape();
auto [aType, bType, cType] = mma.getABCElementTypes();
+ intrinsics.emplace_back(mSize, nSize, kSize, aType, bType, cType);
+ mmaAttrs.emplace_back(mma);
+ };
+
+ SmallVector<GPUMatmulShapeType> intrinsics;
+ intrinsics.reserve(target.getWgp().getMma().size());
+ SmallVector<IREE::GPU::MMAAttr> mmaAttrs;
+ MLIRContext *context = op.getContext();
+ for (IREE::GPU::MMAAttr mma : target.getWgp().getMma()) {
if (mma.getSubgroupSize() != targetSubgroupSize)
continue;
- intrinsics.emplace_back(mSize, nSize, kSize, aType, bType, cType);
+ storeMmaInfo(mma, intrinsics, mmaAttrs);
+ // Store info on virtual intrinsics based on current mma if any
+ for (IREE::GPU::MMAIntrinsic virtualIntrinsic :
+ mma.getVirtualIntrinsics()) {
+ auto virtualMma = IREE::GPU::MMAAttr::get(context, virtualIntrinsic);
+ storeMmaInfo(virtualMma, intrinsics, mmaAttrs);
+ }
}
+
if (intrinsics.empty())
return failure();
@@ -379,7 +396,6 @@
reductionTileSizes[filterDim] = 1;
}
- MLIRContext *context = op.getContext();
Builder b(context);
SmallVector<NamedAttribute, 2> attrs;
attrs.emplace_back(StringAttr::get(context, "workgroup"),
@@ -395,8 +411,8 @@
// for later access in the pipeline.
SmallVector<NamedAttribute, 1> pipelineAttrs;
auto scheduleAttr = IREE::GPU::MMAScheduleAttr::get(
- context, target.getWgp().getMma()[schedule->index],
- schedule->mSubgroupCounts[0], schedule->nSubgroupCounts[0]);
+ context, mmaAttrs[schedule->index], schedule->mSubgroupCounts[0],
+ schedule->nSubgroupCounts[0]);
pipelineAttrs.emplace_back(StringAttr::get(context, "mma_schedule"),
scheduleAttr);
@@ -506,15 +522,32 @@
GPUMatmulShapeType problem{bounds[mDim], bounds[nDim], bounds[kDim],
lhsElemType, rhsElemType, initElemType};
- SmallVector<GPUMatmulShapeType> intrinsics;
- intrinsics.reserve(target.getWgp().getMma().size());
- for (IREE::GPU::MMAAttr mma : target.getWgp().getMma()) {
+ // Helper fn to store mma information.
+ auto storeMmaInfo = [](IREE::GPU::MMAAttr mma,
+ SmallVector<GPUMatmulShapeType> &intrinsics,
+ SmallVector<IREE::GPU::MMAAttr> &mmaAttrs) {
auto [mSize, nSize, kSize] = mma.getMNKShape();
auto [aType, bType, cType] = mma.getABCElementTypes();
+ intrinsics.emplace_back(mSize, nSize, kSize, aType, bType, cType);
+ mmaAttrs.emplace_back(mma);
+ };
+
+ SmallVector<GPUMatmulShapeType> intrinsics;
+ intrinsics.reserve(target.getWgp().getMma().size());
+ SmallVector<IREE::GPU::MMAAttr> mmaAttrs;
+ MLIRContext *context = op.getContext();
+ for (IREE::GPU::MMAAttr mma : target.getWgp().getMma()) {
if (mma.getSubgroupSize() != targetSubgroupSize)
continue;
- intrinsics.emplace_back(mSize, nSize, kSize, aType, bType, cType);
+ storeMmaInfo(mma, intrinsics, mmaAttrs);
+ // Store info on virtual intrinsics based on current mma if any
+ for (IREE::GPU::MMAIntrinsic virtualIntrinsic :
+ mma.getVirtualIntrinsics()) {
+ auto virtualMma = IREE::GPU::MMAAttr::get(context, virtualIntrinsic);
+ storeMmaInfo(virtualMma, intrinsics, mmaAttrs);
+ }
}
+
if (intrinsics.empty())
return failure();
@@ -627,7 +660,6 @@
LLVM_DEBUG(debugPrintContractionInfo("Reduction tile sizes", op.getNumLoops(),
*contractionDims, reductionTileSizes));
- MLIRContext *context = op.getContext();
Builder b(context);
SmallVector<NamedAttribute, 2> attrs;
attrs.emplace_back(StringAttr::get(context, "workgroup"),
@@ -643,8 +675,8 @@
// for later access in the pipeline.
SmallVector<NamedAttribute, 1> pipelineAttrs;
auto scheduleAttr = IREE::GPU::MMAScheduleAttr::get(
- context, target.getWgp().getMma()[schedule->index],
- schedule->mSubgroupCounts[0], schedule->nSubgroupCounts[0]);
+ context, mmaAttrs[schedule->index], schedule->mSubgroupCounts[0],
+ schedule->nSubgroupCounts[0]);
pipelineAttrs.emplace_back(StringAttr::get(context, "mma_schedule"),
scheduleAttr);
@@ -709,15 +741,32 @@
Value kMatrix = op.getKey();
Value vMatrix = op.getValue();
- SmallVector<GPUMatmulShapeType> intrinsics;
- intrinsics.reserve(target.getWgp().getMma().size());
- for (IREE::GPU::MMAAttr mma : target.getWgp().getMma()) {
+ // Helper fn to store mma information.
+ auto storeMmaInfo = [](IREE::GPU::MMAAttr mma,
+ SmallVector<GPUMatmulShapeType> &intrinsics,
+ SmallVector<IREE::GPU::MMAAttr> &mmaAttrs) {
auto [mSize, nSize, kSize] = mma.getMNKShape();
auto [aType, bType, cType] = mma.getABCElementTypes();
+ intrinsics.emplace_back(mSize, nSize, kSize, aType, bType, cType);
+ mmaAttrs.emplace_back(mma);
+ };
+
+ SmallVector<GPUMatmulShapeType> intrinsics;
+ intrinsics.reserve(target.getWgp().getMma().size());
+ SmallVector<IREE::GPU::MMAAttr> mmaAttrs;
+ MLIRContext *context = op.getContext();
+ for (IREE::GPU::MMAAttr mma : target.getWgp().getMma()) {
if (mma.getSubgroupSize() != targetSubgroupSize)
continue;
- intrinsics.emplace_back(mSize, nSize, kSize, aType, bType, cType);
+ storeMmaInfo(mma, intrinsics, mmaAttrs);
+ // Store info on virtual intrinsics based on current mma if any
+ for (IREE::GPU::MMAIntrinsic virtualIntrinsic :
+ mma.getVirtualIntrinsics()) {
+ auto virtualMma = IREE::GPU::MMAAttr::get(context, virtualIntrinsic);
+ storeMmaInfo(virtualMma, intrinsics, mmaAttrs);
+ }
}
+
if (intrinsics.empty())
return failure();
@@ -826,7 +875,6 @@
reductionTileSizes[k2Dim] = schedule->kTileSizes[0] * schedule->kSize;
- MLIRContext *context = op.getContext();
SmallVector<NamedAttribute, 2> attrs;
attrs.emplace_back(StringAttr::get(context, "workgroup"),
b.getI64ArrayAttr(workgroupTileSizes));
@@ -878,8 +926,8 @@
// for later access in the pipeline.
SmallVector<NamedAttribute, 1> pipelineAttrs;
auto scheduleAttr = IREE::GPU::MMAScheduleAttr::get(
- context, target.getWgp().getMma()[schedule->index],
- schedule->mSubgroupCounts[0], schedule->nSubgroupCounts[0]);
+ context, mmaAttrs[schedule->index], schedule->mSubgroupCounts[0],
+ schedule->nSubgroupCounts[0]);
pipelineAttrs.emplace_back(StringAttr::get(context, "mma_schedule"),
scheduleAttr);
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx940.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx940.mlir
index 7e1ab62..cedec2d 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx940.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx940.mlir
@@ -655,6 +655,69 @@
// -----
+// This test ensures that we can generate and decompose the right instructions from V(Virtual) MFMAs.
+
+#config = #iree_gpu.lowering_config<{workgroup = [64, 64, 0], reduction = [0, 0, 128], promote_operands = [0, 1]}>
+#translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute workgroup_size = [256, 1, 1] subgroup_size = 64, {gpu_pipeline_options = #iree_gpu.pipeline_options<prefetch_shared_memory = true, no_reduce_shared_memory_bank_conflicts = false>, mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<VMFMA_F32_32x32x16_F16>, subgroup_m_count = 2, subgroup_n_count = 2>}>
+
+#pipeline_layout = #hal.pipeline.layout<bindings = [
+ #hal.pipeline.binding<storage_buffer>,
+ #hal.pipeline.binding<storage_buffer>,
+ #hal.pipeline.binding<storage_buffer>
+]>
+hal.executable @virtual_intrinsic_256x256x256_f16_f32 {
+hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb">) {
+ hal.executable.export @virtual_intrinsic_256x256x256_f16_f32 layout(#pipeline_layout) {
+ ^bb0(%arg0: !hal.device, %arg1: index, %arg2 : index):
+ %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2
+ hal.return %x, %y, %z : index, index, index
+ }
+ builtin.module {
+ func.func @virtual_intrinsic_256x256x256_f16_f32() attributes {translation_info = #translation} {
+ %cst = arith.constant 0.000000e+00 : f32
+ %c0 = arith.constant 0 : index
+ %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<256x256xf16>>
+ %1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<256x256xf16>>
+ %2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<256x256xf32>>
+ %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<256x256xf16>> -> tensor<256x256xf16>
+ %4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<256x256xf16>> -> tensor<256x256xf16>
+ %5 = tensor.empty() : tensor<256x256xf32>
+ %6 = linalg.fill ins(%cst : f32) outs(%5 : tensor<256x256xf32>) -> tensor<256x256xf32>
+ %7 = linalg.matmul {lowering_config = #config} ins(%3, %4 : tensor<256x256xf16>, tensor<256x256xf16>) outs(%6 : tensor<256x256xf32>) -> tensor<256x256xf32>
+ flow.dispatch.tensor.store %7, %2, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : tensor<256x256xf32> -> !flow.dispatch.tensor<writeonly:tensor<256x256xf32>>
+ return
+ }
+ }
+}
+}
+
+// CHECK-LABEL: func @virtual_intrinsic_256x256x256_f16_f32
+// CHECK: scf.for {{.*}} = %c0 to %c256 step %c128 iter_args(%[[ARG:.+]] = {{.*}}) -> (vector<1x1x4x1x4x1xf32>)
+
+// Validate that VMFMA is decomposed into coalesced read and 2 MFMAs:
+
+// CHECK: %[[A_CAST:.+]] = vector.shape_cast %{{.+}} : vector<1x1x1x8xf16> to vector<8xf16>
+// CHECK: %[[B_CAST:.+]] = vector.shape_cast %{{.+}} : vector<1x1x8x1xf16> to vector<8xf16>
+// CHECK: %[[C_CAST:.+]] = vector.shape_cast %{{.+}} : vector<4x1x4x1xf32> to vector<16xf32>
+// CHECK: %[[A_SLICE_0:.+]] = vector.extract_strided_slice %[[A_CAST]] {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
+// CHECK: %[[B_SLICE_0:.+]] = vector.extract_strided_slice %[[B_CAST]] {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
+// CHECK: %[[MFMA_0:.*]] = amdgpu.mfma %[[A_SLICE_0]] * %[[B_SLICE_0]] + %[[C_CAST]]
+// CHECK-SAME: {blocks = 1 : i32, k = 8 : i32, m = 32 : i32, n = 32 : i32} blgp = none
+// CHECK-SAME: : vector<4xf16>, vector<4xf16>, vector<16xf32>
+// CHECK: %[[A_SLICE_1:.+]] = vector.extract_strided_slice %[[A_CAST]] {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
+// CHECK: %[[B_SLICE_1:.+]] = vector.extract_strided_slice %[[B_CAST]] {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
+// CHECK: %[[MFMA_1:.+]] = amdgpu.mfma %[[A_SLICE_1]] * %[[B_SLICE_1]] + %[[MFMA_0]]
+// CHECK-SAME: {blocks = 1 : i32, k = 8 : i32, m = 32 : i32, n = 32 : i32} blgp = none
+// CHECK-SAME: : vector<4xf16>, vector<4xf16>, vector<16xf32>
+
+// Ensure right number of instructions are being generated.
+
+// CHECK-COUNT-14: vector.extract_strided_slice
+// CHECK-NEXT: amdgpu.mfma
+// CHECK: scf.yield
+
+// -----
+
#config = #iree_gpu.lowering_config<{workgroup = [1, 64, 0, 0, 64], reduction = [0, 0, 0, 64, 0], promote_operands = [0, 1, 2]}>
#translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute workgroup_size = [128, 1, 1] subgroup_size = 64, {mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>, subgroup_m_count = 2, subgroup_n_count = 1>}>
diff --git a/tests/e2e/matmul/generate_e2e_matmul_tests.py b/tests/e2e/matmul/generate_e2e_matmul_tests.py
index cd6f8eb..dd387f3 100644
--- a/tests/e2e/matmul/generate_e2e_matmul_tests.py
+++ b/tests/e2e/matmul/generate_e2e_matmul_tests.py
@@ -347,6 +347,10 @@
MMASchedule("MFMA_I32_32x32x16_I8", 2, 2, 1, 1, 2),
MMASchedule("MFMA_I32_32x32x16_I8", 4, 1, 1, 2, 2),
MMASchedule("MFMA_I32_32x32x16_I8", 4, 2, 2, 2, 2),
+ MMASchedule("VMFMA_F32_16x16x32_F16", 1, 1, 1, 1, 1),
+ MMASchedule("VMFMA_F32_16x16x32_F16", 4, 2, 1, 2, 4),
+ MMASchedule("VMFMA_F32_32x32x16_F16", 1, 1, 1, 1, 1),
+ MMASchedule("VMFMA_F32_32x32x16_F16", 4, 2, 1, 2, 4),
]
elif intrinsic == "WMMA":
schedules = [
@@ -393,13 +397,17 @@
wg_tile_n = schedule.n_count * schedule.n_tile_count * 32
wg_tile_k = schedule.k_tile_count * 8
elif (
- schedule.intrinsic == "MFMA_I32_16x16x32_I8"
+ schedule.intrinsic == "VMFMA_F32_16x16x32_F16"
+ or schedule.intrinsic == "MFMA_I32_16x16x32_I8"
or schedule.intrinsic == "MFMA_F32_16x16x32_F8E4M3FNUZ"
):
wg_tile_m = schedule.m_count * schedule.m_tile_count * 16
wg_tile_n = schedule.n_count * schedule.n_tile_count * 16
wg_tile_k = schedule.k_tile_count * 32
- elif schedule.intrinsic == "MFMA_I32_32x32x16_I8":
+ elif (
+ schedule.intrinsic == "VMFMA_F32_32x32x16_F16"
+ or schedule.intrinsic == "MFMA_I32_32x32x16_I8"
+ ):
wg_tile_m = schedule.m_count * schedule.m_tile_count * 32
wg_tile_n = schedule.n_count * schedule.n_tile_count * 32
wg_tile_k = schedule.k_tile_count * 16