[LLVMGPU] Add VMFMA for FP8 to align layouts between chained F8 contractions. (#19020)
This PR introduces virtual intrinsics on F8 MFMA that breaks apart a
single 8xF8 read into two interleaved 4xF8 read from shared memory.
This main motivation for this virtual intrinsic is to enable faster F8
attention/chained matmuls. The reason for that is by doing interleaved
reads on K-dimension, we can match the native F8 intrisic output layout
coming from the 1st matmul to the rhs read of the 2nd matmul(with
interleaved virtual MFMA layout).
Once the layout is aligned, we just need to handle it using to_layout
lowering that does reshape on the SIMT vector.
This PR has been tested on attention of shape:
[B: 1, M: 4096, K1: 64, K2: 4096, N: 64]
as seen in this IR:
[(link)](https://gist.githubusercontent.com/raikonenfnu/4d33b5addfa9c4ec9e76918704251e39/raw/5b20c0c359e3e2df7f8db4890d3cc0590352d18a/attention_f8_perf.mlir)
and using this spec to specify the VMFMA on 2nd matmul and regular MFMA
on 1st matmul:
([link](https://gist.githubusercontent.com/raikonenfnu/4d33b5addfa9c4ec9e76918704251e39/raw/5b20c0c359e3e2df7f8db4890d3cc0590352d18a/attn_config.mlir))
we were able to get perf of 1.63x speed up from (reference with same
config but using MFMA_16x16x32xF16 on both contractions. With
correct/same numerics.
Signed-off-by: Stanley Winata <stanley.winata@amd.com>
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
index 0bbb8d1..7d145f2 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
@@ -236,7 +236,8 @@
case MMAIntrinsic::MFMA_F32_32x32x8_BF16: {
return OpaqueMmaLayout{32, 32, 8, bf16, bf16, f32};
}
- case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ: {
+ case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ:
+ case MMAIntrinsic::VMFMA_F32_16x16x32_F8E4M3FNUZ: {
return OpaqueMmaLayout{16, 16, 32, f8E4M3FNUZ, f8E4M3FNUZ, f32};
}
case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ: {
@@ -420,6 +421,7 @@
}
case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ:
case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ:
+ case MMAIntrinsic::VMFMA_F32_16x16x32_F8E4M3FNUZ:
case MMAIntrinsic::VMFMA_F32_16x16x32_F16:
case MMAIntrinsic::MFMA_I32_16x16x32_I8: {
auto aType = VectorType::get({8}, getAType());
@@ -471,6 +473,7 @@
case MMAIntrinsic::MFMA_I32_32x32x8_I8:
case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ:
case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ:
+ case MMAIntrinsic::VMFMA_F32_16x16x32_F8E4M3FNUZ:
case MMAIntrinsic::VMFMA_F32_16x16x32_F16:
case MMAIntrinsic::MFMA_I32_16x16x32_I8:
case MMAIntrinsic::VMFMA_F32_32x32x16_F16:
@@ -496,6 +499,7 @@
case MMAIntrinsic::MFMA_I32_32x32x8_I8:
case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ:
case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ:
+ case MMAIntrinsic::VMFMA_F32_16x16x32_F8E4M3FNUZ:
case MMAIntrinsic::VMFMA_F32_16x16x32_F16:
case MMAIntrinsic::MFMA_I32_16x16x32_I8:
case MMAIntrinsic::VMFMA_F32_32x32x16_F16:
@@ -578,6 +582,18 @@
return {/*outer=*/{1, 1}, /*thread=*/{4, 16}, /*tstrides=*/{16, 1},
/*element=*/{4, 1}};
}
+ case MMAIntrinsic::VMFMA_F32_16x16x32_F8E4M3FNUZ:
+ switch (fragment) {
+ case MMAFragment::Lhs:
+ return {/*outer=*/{1, 2}, /*thread=*/{16, 4}, /*tstrides=*/{1, 16},
+ /*element=*/{1, 4}};
+ case MMAFragment::Rhs:
+ return {/*outer=*/{2, 1}, /*thread=*/{4, 16}, /*tstrides=*/{16, 1},
+ /*element=*/{4, 1}};
+ case MMAFragment::Acc:
+ return {/*outer=*/{1, 1}, /*thread=*/{4, 16}, /*tstrides=*/{16, 1},
+ /*element=*/{4, 1}};
+ }
case MMAIntrinsic::VMFMA_F32_32x32x16_F16:
case MMAIntrinsic::MFMA_I32_32x32x16_I8:
switch (fragment) {
@@ -711,6 +727,7 @@
case MMAIntrinsic::MFMA_F32_32x32x8_BF16:
case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ:
case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ:
+ case MMAIntrinsic::VMFMA_F32_16x16x32_F8E4M3FNUZ:
case MMAIntrinsic::MFMA_I32_16x16x32_I8:
case MMAIntrinsic::MFMA_I32_32x32x16_I8: {
auto [m, n, k] = getMNKShape();
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.td b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.td
index 1afdf0d..49f210e 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.td
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.td
@@ -133,6 +133,9 @@
def MFMA_F32_32x32x8_BF16 : I32EnumAttrCase<"MFMA_F32_32x32x8_BF16", 0x0921>;
def MFMA_F32_16x16x32_F8E5M2FNUZ : I32EnumAttrCase<"MFMA_F32_16x16x32_F8E5M2FNUZ", 0x0930>;
def MFMA_F32_16x16x32_F8E4M3FNUZ : I32EnumAttrCase<"MFMA_F32_16x16x32_F8E4M3FNUZ", 0x0940>;
+// V-Intrinsic below interleaves read from K-dim from one 8xF8 to two 4xF8.
+// (Useful in F8 chained-MM to align B-layout of 2nd MM to C-layout of 1st MM)
+def VMFMA_F32_16x16x32_F8E4M3FNUZ : I32EnumAttrCase<"VMFMA_F32_16x16x32_F8E4M3FNUZ", 0x0941>;
def MFMA_I32_16x16x32_I8 : I32EnumAttrCase<"MFMA_I32_16x16x32_I8", 0x0980>;
def MFMA_I32_32x32x16_I8 : I32EnumAttrCase<"MFMA_I32_32x32x16_I8", 0x0981>;
@@ -159,6 +162,7 @@
MFMA_F32_32x32x8_BF16,
MFMA_F32_16x16x32_F8E4M3FNUZ,
MFMA_F32_16x16x32_F8E5M2FNUZ,
+ VMFMA_F32_16x16x32_F8E4M3FNUZ,
MFMA_I32_16x16x32_I8,
MFMA_I32_32x32x16_I8,
MFMA_I32_16x16x16_I8,
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx940.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx940.mlir
index 1722b89..57ebf56 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx940.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx940.mlir
@@ -718,6 +718,76 @@
// -----
+// This test ensures we can generate correct instructions from V(Virtual) MFMAs that has only different read layouts.
+
+#config = #iree_gpu.lowering_config<{workgroup = [32, 32, 0], reduction = [0, 0, 128], promote_operands = [0, 1], mma_kind = #iree_gpu.mma_layout<VMFMA_F32_16x16x32_F8E4M3FNUZ>, subgroup_m_count = 2, subgroup_n_count = 2}>
+#translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute workgroup_size = [256, 1, 1] subgroup_size = 64, {gpu_pipeline_options = #iree_gpu.pipeline_options<prefetch_shared_memory = true, no_reduce_shared_memory_bank_conflicts = false>}>
+
+#pipeline_layout = #hal.pipeline.layout<bindings = [
+ #hal.pipeline.binding<storage_buffer>,
+ #hal.pipeline.binding<storage_buffer>,
+ #hal.pipeline.binding<storage_buffer>
+]>
+hal.executable @virtual_intrinsic_256x256x256_f8E4M3FNUZ_f32 {
+hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb">) {
+ hal.executable.export @virtual_intrinsic_256x256x256_f8E4M3FNUZ_f32 layout(#pipeline_layout) {
+ ^bb0(%arg0: !hal.device, %arg1: index, %arg2 : index):
+ %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2
+ hal.return %x, %y, %z : index, index, index
+ }
+ builtin.module {
+ func.func @virtual_intrinsic_256x256x256_f8E4M3FNUZ_f32() attributes {translation_info = #translation} {
+ %cst = arith.constant 0.000000e+00 : f32
+ %c0 = arith.constant 0 : index
+ %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<256x256xf8E4M3FNUZ>>
+ %1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<256x256xf8E4M3FNUZ>>
+ %2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<256x256xf32>>
+ %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<256x256xf8E4M3FNUZ>> -> tensor<256x256xf8E4M3FNUZ>
+ %4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<256x256xf8E4M3FNUZ>> -> tensor<256x256xf8E4M3FNUZ>
+ %5 = tensor.empty() : tensor<256x256xf32>
+ %6 = linalg.fill ins(%cst : f32) outs(%5 : tensor<256x256xf32>) -> tensor<256x256xf32>
+ %7 = linalg.matmul {lowering_config = #config} ins(%3, %4 : tensor<256x256xf8E4M3FNUZ>, tensor<256x256xf8E4M3FNUZ>) outs(%6 : tensor<256x256xf32>) -> tensor<256x256xf32>
+ flow.dispatch.tensor.store %7, %2, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : tensor<256x256xf32> -> !flow.dispatch.tensor<writeonly:tensor<256x256xf32>>
+ return
+ }
+ }
+}
+}
+
+// CHECK-LABEL: func @virtual_intrinsic_256x256x256_f8E4M3FNUZ_f32
+// CHECK-DAG: %[[ALLOC_LHS:.+]] = memref.alloc() : memref<32x136xf8E4M3FNUZ, #gpu.address_space<workgroup>>
+// CHECK-DAG: %[[ALLOC_RHS:.+]] = memref.alloc() : memref<128x40xf8E4M3FNUZ, #gpu.address_space<workgroup>>
+// CHECK: scf.for {{.*}} = %c0 to %c256 step %c128 iter_args(%[[ARG:.+]] = {{.*}}) -> (vector<1x1x1x1x4x1xf32>)
+
+// Validate that VMFMA do 2 interleaved reads, combine them for every MFMA:
+
+// CHECK-COUNT-6: vector.transfer_read %[[ALLOC_LHS]]
+// CHECK: %[[SLICE_LHS_0:.+]] = vector.transfer_read %[[ALLOC_LHS]]
+// CHECK: %[[VECTOR_LHS_0:.+]] = vector.insert_strided_slice %[[SLICE_LHS_0]], %{{.*}}
+// CHECK: %[[SLICE_LHS_1:.+]] = vector.transfer_read %[[ALLOC_LHS]]
+// CHECK: %[[VECTOR_LHS_1:.+]] = vector.insert_strided_slice %[[SLICE_LHS_1]], %[[VECTOR_LHS_0]] {{.*}} : vector<1x4xf8E4M3FNUZ> into vector<1x4x1x2x1x4xf8E4M3FNUZ>
+
+// CHECK-COUNT-6: vector.transfer_read %[[ALLOC_RHS]]
+// CHECK: %[[SLICE_RHS_0:.+]] = vector.transfer_read %[[ALLOC_RHS]]
+// CHECK: %[[VECTOR_RHS_0:.+]] = vector.insert_strided_slice %[[SLICE_RHS_0]], %{{.*}}
+// CHECK: %[[SLICE_RHS_1:.+]] = vector.transfer_read %[[ALLOC_RHS]]
+// CHECK: %[[VECTOR_RHS_1:.+]] = vector.insert_strided_slice %[[SLICE_RHS_1]], %[[VECTOR_RHS_0]] {{.*}} : vector<4x1xf8E4M3FNUZ> into vector<4x1x2x1x4x1xf8E4M3FNUZ>
+
+// CHECK: %[[EXTRACT_LHS:.+]] = vector.extract %[[VECTOR_LHS_1]][{{.*}}, {{.*}}] : vector<1x2x1x4xf8E4M3FNUZ> from vector<1x4x1x2x1x4xf8E4M3FNUZ>
+// CHECK: %[[EXTRACT_RHS:.+]] = vector.extract %[[VECTOR_RHS_1]][{{.*}}, {{.*}}] : vector<2x1x4x1xf8E4M3FNUZ> from vector<4x1x2x1x4x1xf8E4M3FNUZ>
+
+// CHECK: %[[LHS_CAST:.+]] = vector.shape_cast %[[EXTRACT_LHS]] : vector<1x2x1x4xf8E4M3FNUZ> to vector<8xf8E4M3FNUZ>
+// CHECK: %[[RHS_CAST:.+]] = vector.shape_cast %[[EXTRACT_RHS]] : vector<2x1x4x1xf8E4M3FNUZ> to vector<8xf8E4M3FNUZ>
+// CHECK: amdgpu.mfma %[[LHS_CAST]] * %[[RHS_CAST]] + %{{.*}} {blocks = 1 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32}
+// CHECK-SAME: : vector<8xf8E4M3FNUZ>, vector<8xf8E4M3FNUZ>, vector<4xf32>
+
+// Ensure right number of instructions are being generated.
+// CHECK-COUNT-3: amdgpu.mfma
+
+// CHECK: scf.yield
+
+// -----
+
#config = #iree_gpu.lowering_config<{workgroup = [1, 64, 0, 0, 64], reduction = [0, 0, 0, 64, 0], promote_operands = [0, 1, 2]}>
#translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute workgroup_size = [128, 1, 1] subgroup_size = 64>
diff --git a/tests/e2e/matmul/generate_e2e_matmul_tests.py b/tests/e2e/matmul/generate_e2e_matmul_tests.py
index 8761f1f..9ced22c 100644
--- a/tests/e2e/matmul/generate_e2e_matmul_tests.py
+++ b/tests/e2e/matmul/generate_e2e_matmul_tests.py
@@ -350,6 +350,8 @@
MMASchedule("VMFMA_F32_16x16x32_F16", 4, 2, 1, 2, 4),
MMASchedule("VMFMA_F32_32x32x16_F16", 1, 1, 1, 1, 1),
MMASchedule("VMFMA_F32_32x32x16_F16", 4, 2, 1, 2, 4),
+ MMASchedule("VMFMA_F32_16x16x32_F8E4M3FNUZ", 1, 1, 1, 1, 1),
+ MMASchedule("VMFMA_F32_16x16x32_F8E4M3FNUZ", 4, 1, 4, 1, 1),
]
elif intrinsic == "WMMA":
schedules = [
@@ -399,6 +401,7 @@
schedule.intrinsic == "VMFMA_F32_16x16x32_F16"
or schedule.intrinsic == "MFMA_I32_16x16x32_I8"
or schedule.intrinsic == "MFMA_F32_16x16x32_F8E4M3FNUZ"
+ or schedule.intrinsic == "VMFMA_F32_16x16x32_F8E4M3FNUZ"
):
wg_tile_m = schedule.m_count * schedule.m_tile_count * 16
wg_tile_n = schedule.n_count * schedule.n_tile_count * 16