[Codegen] Add XOR swizzle for BF16 matmul with DMA (#23932)
## Results: BF16 square matmul (`transposed_rhs`)
| Shape | Intrinsic | tile_k | DMA (no swizzle) Bank Conflicts | DMA
(`xor_shuffle<128,8>`) Bank Conflicts | DMA (`xor_shuffle<64,8>`) Bank
Conflicts | DMA (no swizzle) Time | DMA (`xor_shuffle<128,8>`) Time |
DMA(`xor_shuffle<64,8>`) Time |
| ---: | --- | ---: | ---: | ---: | ---: | ---: | ---: | ---: |
| 512 | 16x16x32 | 4 | 7.00 | **0.00** | 1.00 | 0.063ms | 0.058ms |
**0.055ms** |
| 1024 | 16x16x32 | 4 | 7.00 | **0.00** | 1.00 | 0.068ms | 0.062ms |
**0.057ms** |
| 2048 | 16x16x32 | 1 | 1.00 | 1.00 | **0.00** | 0.096ms | 0.092ms |
**0.086ms** |
| 4096 | 16x16x32 | 1 | 1.00 | 1.00 | **0.00** | 0.222ms | 0.214ms |
**0.210ms** |
| 8192 | 16x16x32 | 1 | 1.00 | 1.00 | **0.00** | 1.61ms | 1.61ms |
**1.47ms** |
| 16384 | 32x32x16 | 2 | 3.00 | **0.00** | 1.00 | 10.0ms | **9.35ms** |
9.42ms |
## Results: Sweep on 320 product shapes
Geometric mean speedup vs no-swizzle baseline:
| Config | Geomean Speedup vs Baseline (positive is better) |
|--------|--------------------------------------------------|
| DMA (`xor_shuffle<64,8>`) | **+7.4%** |
| DMA (`xor_shuffle<128,8>`) | +1.7% |
| Oracle (best xor per shape) | +8.4% |
| Baseline (no DMA) | +0.0% |
| DMA (no swizzle) | -5.3% |
The oracle picks the best config per shape, showing +1.0% additional
headroom. However, it is actually difficult to summarize a simple
compile-time heuristic. We default to `<64,8>` as it gives the best
single-config geomean.
Fixes: #23901
Assisted-by: Cursor (Claude)
Signed-off-by: Yu-Zhewen <zhewenyu@amd.com>
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp
index 665bd58..ab6673d 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp
@@ -942,8 +942,25 @@
// translation_info).
SmallVector<Attribute> promotionArray;
if (useDirectLoad) {
- Attribute useGlobalDma = IREE::GPU::UseGlobalLoadDMAAttr::get(context);
- promotionArray = {useGlobalDma, useGlobalDma};
+ Attribute lhsAttr = IREE::GPU::UseGlobalLoadDMAAttr::get(context);
+ Attribute rhsAttr = IREE::GPU::UseGlobalLoadDMAAttr::get(context);
+ // Apply XOR swizzle for BF16 DMA operands whose reduction dim is
+ // innermost (contiguous reads) to avoid LDS bank conflicts.
+ if (lhsElemType.isBF16() && !transposedLhs) {
+ FailureOr<Attribute> lhsSwizzleAttr = getXorShuffleAttr(
+ context, lhsAttr, target, kind, schedule->kTileSizes, kMMAOperandLhs);
+ if (succeeded(lhsSwizzleAttr)) {
+ lhsAttr = *lhsSwizzleAttr;
+ }
+ }
+ if (rhsElemType.isBF16() && transposedRhs) {
+ FailureOr<Attribute> rhsSwizzleAttr = getXorShuffleAttr(
+ context, rhsAttr, target, kind, schedule->kTileSizes, kMMAOperandRhs);
+ if (succeeded(rhsSwizzleAttr)) {
+ rhsAttr = *rhsSwizzleAttr;
+ }
+ }
+ promotionArray = {lhsAttr, rhsAttr};
}
SmallVector<int64_t> promotionList = {0, 1};
if (scaled) {
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse_gfx950.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse_gfx950.mlir
index dd830dd..a2941f8 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse_gfx950.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse_gfx950.mlir
@@ -567,3 +567,19 @@
// IGEMM-SAME: reduction = [0, 0, 0, 0, 0, 8]
// IGEMM-SAME: subgroup = [0, 1, 1, 1, 1, 0]
// IGEMM-SAME: workgroup = [4, 16, 1, 3, 16, 0]
+
+// -----
+
+// BF16 matmul with DMA. Both LHS and RHS are not transposed, so only LHS gets XOR swizzle.
+func.func @matmul_bf16(
+ %arg0: tensor<4096x4096xbf16>,
+ %arg1: tensor<4096x4096xbf16>,
+ %arg2: tensor<4096x4096xf32>) -> tensor<4096x4096xf32> {
+ %0 = linalg.matmul ins(%arg0, %arg1 : tensor<4096x4096xbf16>, tensor<4096x4096xbf16>)
+ outs(%arg2 : tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
+ return %0 : tensor<4096x4096xf32>
+}
+
+// CHECK-DIRECT-LOAD-LABEL: func.func @matmul_bf16
+// CHECK-DIRECT-LOAD: linalg.matmul {lowering_config = #iree_gpu.lowering_config
+// CHECK-DIRECT-LOAD-SAME: promotion_types = [#iree_gpu.swizzle_operand<copy_config = #iree_gpu.use_global_load_dma, swizzle = #iree_codegen.xor_shuffle<64, 8>>, #iree_gpu.use_global_load_dma]
diff --git a/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.cpp b/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.cpp
index 84dd7de..1df7230 100644
--- a/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.cpp
+++ b/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.cpp
@@ -814,6 +814,16 @@
return failure();
}
}
+ if (auto mma = dyn_cast<IREE::GPU::MMAAttr>(intrinsic)) {
+ switch (mma.getIntrinsic()) {
+ case IREE::GPU::MMAIntrinsic::MFMA_F32_16x16x32_BF16:
+ case IREE::GPU::MMAIntrinsic::MFMA_F32_32x32x16_BF16:
+ return XorShuffleParams({/*rowElems=*/64,
+ /*accessElems=*/8});
+ default:
+ return failure();
+ }
+ }
return failure();
}