[Codegen][DMA] Fix unaligned swizzle offset computation in gather-to-lds lowering (#24241)
The inverse XOR swizzle applied to DMA source offsets was incorrect in
two cases:
1. **Subgroup base offset**: When a subgroup's transfer size is not a
multiple of the swizzle period, different subgroups sharing the same
local offsets but occupying different rows would get identical swizzled
addresses.
Fix: incorporate the subgroup's base offset within the full allocation
before swizzling.
2. **Access-width alignment**: When `elementsPerLane < accessWidth`, the
integer division inside `swizzleOffset` truncates offsets that differ
only within an access-width group to the same value.
Fix: strip the sub-accesswidth remainder before swizzling and restore it
after. This fix is applied directly in `swizzleOffset` for both XOR and
rotate_rows swizzles. While rotate_rows isn't currently used with DMA,
the access-width alignment issue affects both swizzle types.
Both issues caused numerical mismatches for BF16 batch matmuls using DMA
with XOR swizzle enabled.
Assisted-by: Cursor (Claude)
---------
Signed-off-by: Yu-Zhewen <zhewenyu@amd.com>
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/AMDGPULowerCoalescedDMAToGatherLDS.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/AMDGPULowerCoalescedDMAToGatherLDS.cpp
index 64fe2f4..c46fafd 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/AMDGPULowerCoalescedDMAToGatherLDS.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/AMDGPULowerCoalescedDMAToGatherLDS.cpp
@@ -218,16 +218,16 @@
/// Trace a memref value through view-like ops to find a SwizzleHintOp.
/// Returns the swizzle attribute if it is an XOR swizzle (which is
/// self-inverse), std::nullopt otherwise.
-static std::optional<IREE::Codegen::SwizzleAttrInterface>
+static std::optional<IREE::Codegen::XORShuffleAttr>
getDestSwizzleAttr(Value dest) {
dest = getRootSource(dest);
if (auto hintOp = dest.getDefiningOp<IREE::Codegen::SwizzleHintOp>()) {
- auto swizzle = hintOp.getSwizzle();
// Only XOR swizzle is self-inverse (swizzle(swizzle(x)) = x), so it can
// be applied to source addresses as the inverse transformation. Other
// swizzle types (e.g. rotate_rows) are not self-inverse.
- if (isa<IREE::Codegen::XORShuffleAttr>(swizzle)) {
- return swizzle;
+ if (auto xorSwizzle =
+ dyn_cast<IREE::Codegen::XORShuffleAttr>(hintOp.getSwizzle())) {
+ return xorSwizzle;
}
}
return std::nullopt;
@@ -342,7 +342,7 @@
Value source = dmaOp.getSource();
Value dest = dmaOp.getInit();
- std::optional<IREE::Codegen::SwizzleAttrInterface> destSwizzle =
+ std::optional<IREE::Codegen::XORShuffleAttr> destSwizzle =
getDestSwizzleAttr(dest);
auto sourceType = cast<MemRefType>(source.getType());
@@ -493,7 +493,7 @@
ArrayRef<int64_t> destShape, int64_t numLinearDims, Type elementType,
OperandRange indices, ArrayRef<TransferSegment> segments,
ArrayRef<Value> segmentLaneOffsets, std::optional<ArrayAttr> inBoundsAttr,
- std::optional<IREE::Codegen::SwizzleAttrInterface> destSwizzle) const {
+ std::optional<IREE::Codegen::XORShuffleAttr> destSwizzle) const {
int64_t destRank = destShape.size();
int64_t numOuterDims = destRank - numLinearDims;
LDBG() << "Emitting transfers: " << numOuterDims << " outer dims, "
@@ -543,10 +543,8 @@
// Apply inverse source swizzle when destination has XOR swizzle.
// XOR swizzle is self-inverse, so swizzle(swizzle(x)) = x.
if (destSwizzle) {
- srcLinearOffset = getValueOrCreateConstantIndexOp(
- rewriter, loc,
- destSwizzle->swizzleOffset(rewriter, loc, srcLinearOffset,
- dest));
+ srcLinearOffset = applyInverseXorSwizzleToDMASourceOffset(
+ rewriter, loc, srcLinearOffset, *destSwizzle, dest);
}
auto srcDelinearize = affine::AffineDelinearizeIndexOp::create(
rewriter, loc, srcLinearOffset, basis, /*hasOuterBound=*/true);
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/amdgpu_lower_coalesced_dma_to_gather_lds.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/amdgpu_lower_coalesced_dma_to_gather_lds.mlir
index 4b77f00..0a032ba 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/amdgpu_lower_coalesced_dma_to_gather_lds.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/amdgpu_lower_coalesced_dma_to_gather_lds.mlir
@@ -1658,7 +1658,7 @@
//
// Shape: 4x128 f32 dest. With 64 lanes, 128-bit DMA = 4 elements/lane.
// 256 elements/transfer, 512 total = 2 transfers.
-// XOR swizzle with row_width=128, access_width=16 permutes source offsets.
+// XOR swizzle with row_width=128, access_width=4 permutes source offsets.
#executable_target_rocm_hsaco_fb_swizzle = #hal.executable.target<"rocm",
"rocm-hsaco-fb", {iree_codegen.target_info = #iree_gpu.target<
@@ -1682,7 +1682,7 @@
hal.executable.target = #executable_target_rocm_hsaco_fb_swizzle,
translation_info = #translation_swizzle} {
%alloc = memref.alloc() : memref<512xf32, #gpu.address_space<workgroup>>
- %swizzled = iree_codegen.swizzle_hint %alloc[#iree_codegen.xor_shuffle<128, 16>]
+ %swizzled = iree_codegen.swizzle_hint %alloc[#iree_codegen.xor_shuffle<128, 4>]
: memref<512xf32, #gpu.address_space<workgroup>>
%dest = memref.expand_shape %swizzled [[0, 1]]
output_shape [4, 128]
@@ -1696,12 +1696,15 @@
// Transfer 1: linearOffset = 0, source gets XOR-swizzled offset.
// CHECK: %[[C0:.+]] = arith.constant 0 : index
// CHECK: %[[SRC_LIN0:.+]] = arith.addi %[[C0]], %[[LANE_OFFSET]]
- // XOR swizzle: extractCol, extractRow, xori, updateCol, diff, add.
+ // Access-width alignment: strip remainder, swizzle, apply diff to original.
+ // CHECK: %[[REM0:.+]] = arith.remui %[[SRC_LIN0]], %[[C4]]
+ // CHECK: %[[ALIGNED0:.+]] = arith.subi %[[SRC_LIN0]], %[[REM0]]
+ // XOR swizzle: extractCol, extractRow, xori, updateCol.
// CHECK: %[[COL0:.+]] = affine.apply
// CHECK: %[[ROW0:.+]] = affine.apply
// CHECK: %[[XOR0:.+]] = arith.xori %[[ROW0]], %[[COL0]]
// CHECK: %[[UCOL0:.+]] = affine.apply
- // CHECK: %[[DIFF0:.+]] = arith.subi %[[UCOL0]], %[[SRC_LIN0]]
+ // CHECK: %[[DIFF0:.+]] = arith.subi %[[UCOL0]], %[[ALIGNED0]]
// CHECK: %[[SWIZZLED0:.+]] = arith.addi %[[SRC_LIN0]], %[[DIFF0]]
// Source delinearized from swizzled offset.
// CHECK: %[[SRC_DELIN0:.+]]:2 = affine.delinearize_index %[[SWIZZLED0]] into (4, 128)
@@ -1812,3 +1815,160 @@
} {mapping = [#iree_gpu.lane_id<0>]}
return
}
+
+// -----
+
+// Test: Subgroup base offset adjustment for swizzled DMA.
+// The dest subview has offset 256 with 256 elements, which is less than
+// the swizzle period (64*64/8 = 512). The base offset must be added before
+// swizzling and subtracted after.
+
+#executable_target_base_offset = #hal.executable.target<"rocm",
+ "rocm-hsaco-fb", {iree_codegen.target_info = #iree_gpu.target<
+ arch = "gfx950", features = "", wgp = <
+ compute = fp32, storage = b32, subgroup = none, dot = none, mma = [],
+ subgroup_size_choices = [64, 64],
+ max_workgroup_sizes = [1024, 1024, 1024],
+ max_thread_count_per_workgroup = 1024,
+ max_workgroup_memory_bytes = 65536,
+ max_workgroup_counts = [2147483647, 2147483647, 2147483647],
+ max_load_instruction_bits = 128, simds_per_wgp = 4,
+ vgpr_space_bits = 8192, dma_sizes = [32]>>}>
+
+#translation_base_offset = #iree_codegen.translation_info<pipeline = #iree_gpu.pipeline<TileAndFuse> workgroup_size = [64, 1, 1] subgroup_size = 64>
+
+// CHECK-LABEL: func.func @lower_dma_swizzle_base_offset
+func.func @lower_dma_swizzle_base_offset(
+ %source: memref<4x64xbf16, #amdgpu.address_space<fat_raw_buffer>>)
+ attributes {
+ hal.executable.target = #executable_target_base_offset,
+ translation_info = #translation_base_offset} {
+ %alloc = memref.alloc() : memref<512xbf16, #gpu.address_space<workgroup>>
+ %swizzled = iree_codegen.swizzle_hint %alloc[#iree_codegen.xor_shuffle<64, 8>]
+ : memref<512xbf16, #gpu.address_space<workgroup>>
+ %expanded = memref.expand_shape %swizzled [[0, 1]]
+ output_shape [8, 64]
+ : memref<512xbf16, #gpu.address_space<workgroup>>
+ into memref<8x64xbf16, #gpu.address_space<workgroup>>
+ %dest = memref.subview %expanded[4, 0] [4, 64] [1, 1]
+ : memref<8x64xbf16, #gpu.address_space<workgroup>>
+ to memref<4x64xbf16, strided<[64, 1], offset: 256>, #gpu.address_space<workgroup>>
+ // CHECK: scf.forall (%[[LANE:.+]]) in (64)
+ scf.forall (%lane) in (64) {
+ // Base offset 256 added before swizzle, subtracted after.
+ // CHECK-DAG: %[[C256:.+]] = arith.constant 256 : index
+ // CHECK: arith.addi %[[C256]],
+ // CHECK: arith.xori
+ // CHECK: arith.subi {{.*}}, %[[C256]]
+ // CHECK: amdgpu.gather_to_lds
+ iree_gpu.coalesced_gather_dma %source into %dest lane(%lane)
+ : memref<4x64xbf16, #amdgpu.address_space<fat_raw_buffer>>,
+ memref<4x64xbf16, strided<[64, 1], offset: 256>, #gpu.address_space<workgroup>>, index
+ } {mapping = [#iree_gpu.lane_id<0>]}
+ return
+}
+
+// -----
+
+// Test: Access-width alignment for swizzled DMA.
+// With dma_sizes=[32] and bf16, elementsPerLane = 32/16 = 2.
+// accessWidth = 8 (from xor_shuffle<64, 8>), so 2 < 8 and the sub-group
+// remainder must be stripped before swizzling and restored after.
+
+#executable_target_access_align = #hal.executable.target<"rocm",
+ "rocm-hsaco-fb", {iree_codegen.target_info = #iree_gpu.target<
+ arch = "gfx950", features = "", wgp = <
+ compute = fp32, storage = b32, subgroup = none, dot = none, mma = [],
+ subgroup_size_choices = [64, 64],
+ max_workgroup_sizes = [1024, 1024, 1024],
+ max_thread_count_per_workgroup = 1024,
+ max_workgroup_memory_bytes = 65536,
+ max_workgroup_counts = [2147483647, 2147483647, 2147483647],
+ max_load_instruction_bits = 128, simds_per_wgp = 4,
+ vgpr_space_bits = 8192, dma_sizes = [32]>>}>
+
+#translation_access_align = #iree_codegen.translation_info<pipeline = #iree_gpu.pipeline<TileAndFuse> workgroup_size = [64, 1, 1] subgroup_size = 64>
+
+// CHECK-LABEL: func.func @lower_dma_swizzle_access_width_align
+func.func @lower_dma_swizzle_access_width_align(
+ %source: memref<8x64xbf16, #amdgpu.address_space<fat_raw_buffer>>)
+ attributes {
+ hal.executable.target = #executable_target_access_align,
+ translation_info = #translation_access_align} {
+ %alloc = memref.alloc() : memref<512xbf16, #gpu.address_space<workgroup>>
+ %swizzled = iree_codegen.swizzle_hint %alloc[#iree_codegen.xor_shuffle<64, 8>]
+ : memref<512xbf16, #gpu.address_space<workgroup>>
+ %dest = memref.expand_shape %swizzled [[0, 1]]
+ output_shape [8, 64]
+ : memref<512xbf16, #gpu.address_space<workgroup>>
+ into memref<8x64xbf16, #gpu.address_space<workgroup>>
+ // CHECK: scf.forall (%[[LANE:.+]]) in (64)
+ scf.forall (%lane) in (64) {
+ // elementsPerLane=2 < accessWidth=8: remainder stripped and restored.
+ // CHECK-DAG: %[[C8:.+]] = arith.constant 8 : index
+ // CHECK: %[[REM:.+]] = arith.remui %[[SRC_LIN:.+]], %[[C8]]
+ // CHECK: %[[ALIGNED:.+]] = arith.subi %[[SRC_LIN]], %[[REM]]
+ // CHECK: arith.xori
+ // CHECK: %[[DIFF:.+]] = arith.subi {{.*}}, %[[ALIGNED]]
+ // CHECK: arith.addi %[[SRC_LIN]], %[[DIFF]]
+ // CHECK: amdgpu.gather_to_lds
+ iree_gpu.coalesced_gather_dma %source into %dest lane(%lane)
+ : memref<8x64xbf16, #amdgpu.address_space<fat_raw_buffer>>,
+ memref<8x64xbf16, #gpu.address_space<workgroup>>, index
+ } {mapping = [#iree_gpu.lane_id<0>]}
+ return
+}
+
+// -----
+
+// Test: Combined subgroup base offset AND access-width alignment for swizzled DMA.
+
+#executable_target_combined = #hal.executable.target<"rocm",
+ "rocm-hsaco-fb", {iree_codegen.target_info = #iree_gpu.target<
+ arch = "gfx950", features = "", wgp = <
+ compute = fp32, storage = b32, subgroup = none, dot = none, mma = [],
+ subgroup_size_choices = [64, 64],
+ max_workgroup_sizes = [1024, 1024, 1024],
+ max_thread_count_per_workgroup = 1024,
+ max_workgroup_memory_bytes = 65536,
+ max_workgroup_counts = [2147483647, 2147483647, 2147483647],
+ max_load_instruction_bits = 128, simds_per_wgp = 4,
+ vgpr_space_bits = 8192, dma_sizes = [32]>>}>
+
+#translation_combined = #iree_codegen.translation_info<pipeline = #iree_gpu.pipeline<TileAndFuse> workgroup_size = [64, 1, 1] subgroup_size = 64>
+
+// CHECK-LABEL: func.func @lower_dma_swizzle_combined_base_and_access_width
+func.func @lower_dma_swizzle_combined_base_and_access_width(
+ %source: memref<4x64xbf16, #amdgpu.address_space<fat_raw_buffer>>)
+ attributes {
+ hal.executable.target = #executable_target_combined,
+ translation_info = #translation_combined} {
+ %alloc = memref.alloc() : memref<512xbf16, #gpu.address_space<workgroup>>
+ %swizzled = iree_codegen.swizzle_hint %alloc[#iree_codegen.xor_shuffle<64, 8>]
+ : memref<512xbf16, #gpu.address_space<workgroup>>
+ %expanded = memref.expand_shape %swizzled [[0, 1]]
+ output_shape [8, 64]
+ : memref<512xbf16, #gpu.address_space<workgroup>>
+ into memref<8x64xbf16, #gpu.address_space<workgroup>>
+ %dest = memref.subview %expanded[4, 0] [4, 64] [1, 1]
+ : memref<8x64xbf16, #gpu.address_space<workgroup>>
+ to memref<4x64xbf16, strided<[64, 1], offset: 256>, #gpu.address_space<workgroup>>
+ // CHECK: scf.forall (%[[LANE:.+]]) in (64)
+ scf.forall (%lane) in (64) {
+ // Base offset 256 added, then access-width remainder stripped/restored.
+ // CHECK: %[[C256:.+]] = arith.constant 256 : index
+ // CHECK: %[[SRC_LIN:.+]] = arith.addi %[[C256]],
+ // CHECK: %[[C8:.+]] = arith.constant 8 : index
+ // CHECK: %[[REM:.+]] = arith.remui %[[SRC_LIN]], %[[C8]]
+ // CHECK: %[[ALIGNED:.+]] = arith.subi %[[SRC_LIN]], %[[REM]]
+ // CHECK: arith.xori
+ // CHECK: %[[DIFF:.+]] = arith.subi {{.*}}, %[[ALIGNED]]
+ // CHECK: arith.addi %[[SRC_LIN]], %[[DIFF]]
+ // CHECK: arith.subi {{.*}}, %[[C256]]
+ // CHECK: amdgpu.gather_to_lds
+ iree_gpu.coalesced_gather_dma %source into %dest lane(%lane)
+ : memref<4x64xbf16, #amdgpu.address_space<fat_raw_buffer>>,
+ memref<4x64xbf16, strided<[64, 1], offset: 256>, #gpu.address_space<workgroup>>, index
+ } {mapping = [#iree_gpu.lane_id<0>]}
+ return
+}
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/resolve_swizzle_hints.mlir b/compiler/src/iree/compiler/Codegen/Common/test/resolve_swizzle_hints.mlir
index ee7a9d7..0df2999 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/resolve_swizzle_hints.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/resolve_swizzle_hints.mlir
@@ -141,14 +141,18 @@
// CHECK-DAG: %[[ROW_WIDTH:.+]] = arith.constant 64 : index
// CHECK-DAG: %[[GROUP_COUNT:.+]] = arith.constant 16 : index
// CHECK-DAG: %[[GROUP_WIDTH:.+]] = arith.constant 4 : index
-// CHECK: %[[I:.+]] = arith.divui %[[OFFSET]], %[[ROW_WIDTH]] : index
-// CHECK: %[[JELEM:.+]] = arith.remui %[[OFFSET]], %[[ROW_WIDTH]] : index
+// CHECK: %[[REM:.+]] = arith.remui %[[OFFSET]], %[[GROUP_WIDTH]] : index
+// CHECK: %[[ALIGNED:.+]] = arith.subi %[[OFFSET]], %[[REM]] : index
+// CHECK: %[[I:.+]] = arith.divui %[[ALIGNED]], %[[ROW_WIDTH]] : index
+// CHECK: %[[JELEM:.+]] = arith.remui %[[ALIGNED]], %[[ROW_WIDTH]] : index
// CHECK: %[[J:.+]] = arith.divui %[[JELEM]], %[[GROUP_WIDTH]] : index
// CHECK: %[[ADD:.+]] = arith.addi %[[I]], %[[J]] : index
// CHECK: %[[ROTATEJ:.+]] = arith.remui %[[ADD]], %[[GROUP_COUNT]] : index
// CHECK: %[[ROTATEJELEM:.+]] = arith.muli %[[ROTATEJ]], %[[GROUP_WIDTH]] : index
// CHECK: %[[IELEM:.+]] = arith.muli %[[I]], %[[ROW_WIDTH]] : index
-// CHECK: %[[SWOFF:.+]] = arith.addi %[[ROTATEJELEM]], %[[IELEM]] : index
+// CHECK: %[[SWIZZLED:.+]] = arith.addi %[[ROTATEJELEM]], %[[IELEM]] : index
+// CHECK: %[[DIFF:.+]] = arith.subi %[[SWIZZLED]], %[[ALIGNED]] : index
+// CHECK: %[[SWOFF:.+]] = arith.addi %[[OFFSET]], %[[DIFF]] : index
// Make sure both the load and store get the same calculation.
// CHECK: %[[VECTOR:.+]] = vector.load %[[SRC]][%[[SWOFF]]]
@@ -177,20 +181,23 @@
// CHECK-DAG: %[[GROUP_WIDTH:.+]] = arith.constant 4 : index
// CHECK-DAG: %[[C1040:.+]] = arith.constant 1040 : index
// CHECK: %[[APPLY_BASE:.+]] = arith.addi %[[OFFSET]], %[[GROUP_COUNT]] overflow<nsw> : index
-// CHECK: %[[I:.+]] = arith.divui %[[APPLY_BASE]], %[[ROW_WIDTH]] : index
-// CHECK: %[[JELEM:.+]] = arith.remui %[[APPLY_BASE]], %[[ROW_WIDTH]] : index
+// CHECK: %[[REM:.+]] = arith.remui %[[APPLY_BASE]], %[[GROUP_WIDTH]] : index
+// CHECK: %[[ALIGNED:.+]] = arith.subi %[[APPLY_BASE]], %[[REM]] : index
+// CHECK: %[[I:.+]] = arith.divui %[[ALIGNED]], %[[ROW_WIDTH]] : index
+// CHECK: %[[JELEM:.+]] = arith.remui %[[ALIGNED]], %[[ROW_WIDTH]] : index
// CHECK: %[[J:.+]] = arith.divui %[[JELEM]], %[[GROUP_WIDTH]] : index
// CHECK: %[[ADD:.+]] = arith.addi %[[I]], %[[J]] : index
// CHECK: %[[ROTATEJ:.+]] = arith.remui %[[ADD]], %[[GROUP_COUNT]] : index
// CHECK: %[[ROTATEJELEM:.+]] = arith.muli %[[ROTATEJ]], %[[GROUP_WIDTH]] : index
// CHECK: %[[IELEM:.+]] = arith.muli %[[I]], %[[ROW_WIDTH]] : index
-// CHECK: %[[SWOFF:.+]] = arith.addi %[[ROTATEJELEM]], %[[IELEM]] : index
+// CHECK: %[[SWIZZLED:.+]] = arith.addi %[[ROTATEJELEM]], %[[IELEM]] : index
+// CHECK: %[[DIFF:.+]] = arith.subi %[[SWIZZLED]], %[[ALIGNED]] : index
-// CHECK: %[[VECTOR:.+]] = vector.load %[[SRC]][%[[SWOFF]]]
+// CHECK: %[[LOAD_SWOFF:.+]] = arith.addi %[[APPLY_BASE]], %[[DIFF]] : index
+// CHECK: %[[VECTOR:.+]] = vector.load %[[SRC]][%[[LOAD_SWOFF]]]
// CHECK: %[[STORE_BASE:.+]] = arith.addi %[[OFFSET]], %[[C1040]] overflow<nsw> : index
-// CHECK: %[[OFFSET_DIFF:.+]] = arith.subi %[[SWOFF]], %[[APPLY_BASE]] : index
-// CHECK: %[[STORE_SWOFF:.+]] = arith.addi %[[STORE_BASE]], %[[OFFSET_DIFF]] : index
+// CHECK: %[[STORE_SWOFF:.+]] = arith.addi %[[STORE_BASE]], %[[DIFF]] : index
// CHECK: vector.store %[[VEC]], %[[SRC]][%[[STORE_SWOFF]]]
// CHECK: return %[[VECTOR]]
@@ -232,6 +239,23 @@
// -----
+func.func @swizzle_load_xor_unaligned(%src: memref<?xi8>) -> vector<16xi8> {
+ %0 = iree_codegen.swizzle_hint %src[#iree_codegen.xor_shuffle<128, 16>] : memref<?xi8>
+
+ // Offset 1955 = 1952 + 3, where 1952 swizzles to 2000, so result = 2000 + 3 = 2003.
+ %offset = arith.constant 1955 : index
+ %1 = vector.load %0[%offset] : memref<?xi8>, vector<16xi8>
+ return %1: vector<16xi8>
+}
+
+// CHECK-LABEL: func @swizzle_load_xor_unaligned
+// CHECK-SAME: %[[SRC:[A-Za-z0-9]+]]: memref<?xi8>
+// CHECK: %[[SWOFF:.+]] = arith.constant 2003 : index
+// CHECK: %[[VECTOR:.+]] = vector.load %[[SRC]][%[[SWOFF]]]
+// CHECK: return %[[VECTOR]]
+
+// -----
+
func.func @swizzle_load_xor_phase2(%src: memref<?xi8>) -> vector<16xi8> {
%0 = iree_codegen.swizzle_hint %src[#iree_codegen.xor_shuffle<128, 16, 128, 2>] : memref<?xi8>
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.cpp b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.cpp
index 46570c9..9b0bade 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.cpp
@@ -662,10 +662,8 @@
// to get the simplest offset possible in case we are accessing values from
// successive rows. This allows us to CSE the swizzling computation more
// effectively.
- int64_t rotationInvariant =
- getRowWidth() * (getRowWidth() / getAccessWidth());
OpFoldResult id =
- getMinimumConstantOffsetValue(b, loc, offset, rotationInvariant);
+ getMinimumConstantOffsetValue(b, loc, offset, getSwizzlePeriod());
// Number of elements per row.
Value rowAlignmentVal = arith::ConstantIndexOp::create(b, loc, getRowWidth());
@@ -676,11 +674,16 @@
Value accessWidthVal =
arith::ConstantIndexOp::create(b, loc, getAccessWidth());
+ // Strip the sub-access_width remainder to align the offset, since the
+ // swizzle operates at access_width granularity.
Value idVal = getValueOrCreateConstantIndexOp(b, loc, id);
+ Value remainder = arith::RemUIOp::create(b, loc, idVal, accessWidthVal);
+ Value alignedId = arith::SubIOp::create(b, loc, idVal, remainder);
+
// i = row # = |offset| floordiv |Num elements per row|
- Value i = arith::DivUIOp::create(b, loc, idVal, rowAlignmentVal);
+ Value i = arith::DivUIOp::create(b, loc, alignedId, rowAlignmentVal);
// jByte = element column # = |offset| % |Num elements per row|
- Value jElem = arith::RemUIOp::create(b, loc, idVal, rowAlignmentVal);
+ Value jElem = arith::RemUIOp::create(b, loc, alignedId, rowAlignmentVal);
// j = group column # = jElem / |Num elements per group|
Value j = arith::DivUIOp::create(b, loc, jElem, accessWidthVal);
@@ -700,7 +703,13 @@
// This increases the chance of being able to CSE the offset calculation. When
// multiple accesses to a memref only differ by a constant value (very common
// when working with statically shaped memrefs like shared/scratch memory).
- Value diff = arith::SubIOp::create(b, loc, swizzledId, idVal);
+ //
+ // The diff is computed from |alignedId| but applied to the original |offset|.
+ // This implicitly restores the sub-access_width remainder:
+ // offset + (swizzle(alignedId) - alignedId)
+ // = (alignedId + rem) + (swizzle(alignedId) - alignedId)
+ // = swizzle(alignedId) + rem
+ Value diff = arith::SubIOp::create(b, loc, swizzledId, alignedId);
return arith::AddIOp::create(
b, loc, getValueOrCreateConstantIndexOp(b, loc, offset), diff)
.getResult();
@@ -782,21 +791,23 @@
OpFoldResult XORShuffleAttr::swizzleOffset(OpBuilder &b, Location loc,
OpFoldResult offset,
Value src) const {
- int64_t rotationInvariant =
- getRowWidth() * (getRowWidth() / getAccessWidth());
int64_t rowStride =
getRowStride() != int64_t() ? getRowStride() : getRowWidth();
int64_t perPhase = getPerPhase() != int64_t() ? getPerPhase() : 1;
OpFoldResult id =
- getMinimumConstantOffsetValue(b, loc, offset, rotationInvariant);
+ getMinimumConstantOffsetValue(b, loc, offset, getSwizzlePeriod());
+
+ // Strip the sub-access_width remainder to align the offset, since the
+ // swizzle operates at access_width granularity.
Value idVal = getValueOrCreateConstantIndexOp(b, loc, id);
+ Value accessWidthVal =
+ arith::ConstantIndexOp::create(b, loc, getAccessWidth());
+ Value remainder = arith::RemUIOp::create(b, loc, idVal, accessWidthVal);
+ Value alignedId = arith::SubIOp::create(b, loc, idVal, remainder);
// Number of elements per row.
Value rowAlignmentVal = arith::ConstantIndexOp::create(b, loc, getRowWidth());
- // Number of elements per group.
- Value accessWidthVal =
- arith::ConstantIndexOp::create(b, loc, getAccessWidth());
// Number of rows per phase.
Value perPhaseVal = arith::ConstantIndexOp::create(b, loc, perPhase);
// Buffer stride.
@@ -805,15 +816,20 @@
Value rowAccessAlignmentVal =
arith::ConstantIndexOp::create(b, loc, getRowWidth() / getAccessWidth());
- Value colVal = extractCol(b, loc, idVal, rowAlignmentVal, accessWidthVal);
- Value rowVal = extractRow(b, loc, idVal, rowStrideVal, perPhaseVal,
+ Value colVal = extractCol(b, loc, alignedId, rowAlignmentVal, accessWidthVal);
+ Value rowVal = extractRow(b, loc, alignedId, rowStrideVal, perPhaseVal,
rowAccessAlignmentVal);
auto colSwizzled = arith::XOrIOp::create(b, loc, rowVal, colVal);
- // Update colSwizzled to initial id
- Value swizzledIdVal =
- updateCol(b, loc, idVal, colSwizzled, rowAlignmentVal, accessWidthVal);
- Value diff = arith::SubIOp::create(b, loc, swizzledIdVal, idVal);
+ // Update colSwizzled to alignedId
+ Value swizzledIdVal = updateCol(b, loc, alignedId, colSwizzled,
+ rowAlignmentVal, accessWidthVal);
+ // The diff is computed from |alignedId| but applied to the original |offset|.
+ // This implicitly restores the sub-access_width remainder:
+ // offset + (swizzle(alignedId) - alignedId)
+ // = (alignedId + rem) + (swizzle(alignedId) - alignedId)
+ // = swizzle(alignedId) + rem
+ Value diff = arith::SubIOp::create(b, loc, swizzledIdVal, alignedId);
return arith::AddIOp::create(
b, loc, getValueOrCreateConstantIndexOp(b, loc, offset), diff)
.getResult();
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.td b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.td
index 59176ea..f692a93 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.td
+++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.td
@@ -473,6 +473,11 @@
let assemblyFormat = [{
`<` $row_width `,` $access_width `>`
}];
+ let extraClassDeclaration = [{
+ int64_t getSwizzlePeriod() const {
+ return getRowWidth() * (getRowWidth() / getAccessWidth());
+ }
+ }];
let genVerifyDecl = 1;
}
@@ -596,6 +601,11 @@
let assemblyFormat = [{
`<` $row_width `,` $access_width (`,` $row_stride^)? (`,` $per_phase^)? `>`
}];
+ let extraClassDeclaration = [{
+ int64_t getSwizzlePeriod() const {
+ return getRowWidth() * (getRowWidth() / getAccessWidth());
+ }
+ }];
let genVerifyDecl = 1;
}
diff --git a/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.cpp b/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.cpp
index 19603bc..e1adfa5 100644
--- a/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.cpp
+++ b/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.cpp
@@ -21,7 +21,10 @@
#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Matchers.h"
@@ -1367,4 +1370,61 @@
return rootOps;
}
+Value applyInverseXorSwizzleToDMASourceOffset(
+ OpBuilder &builder, Location loc, Value srcLinearOffset,
+ IREE::Codegen::XORShuffleAttr swizzle, Value dest) {
+ // Compute the subgroup's base offset within the full allocation by tracing
+ // through view-like ops and accumulating the linear element offset.
+ Value destTrace = dest;
+ OpFoldResult totalOffset = builder.getIndexAttr(0);
+ while (auto viewOp = destTrace.getDefiningOp<ViewLikeOpInterface>()) {
+ if (auto subviewOp = dyn_cast<memref::SubViewOp>(viewOp.getOperation())) {
+ auto parentType = cast<MemRefType>(subviewOp.getSource().getType());
+ SmallVector<int64_t> strides;
+ int64_t offset;
+ LogicalResult res = parentType.getStridesAndOffset(strides, offset);
+ assert(succeeded(res) && "expected strided layout on subview source");
+ (void)res;
+ SmallVector<OpFoldResult> strideOFRs =
+ getAsIndexOpFoldResult(builder.getContext(), strides);
+ auto &&[expr, values] = computeLinearIndex(totalOffset, strideOFRs,
+ subviewOp.getMixedOffsets());
+ totalOffset =
+ affine::makeComposedFoldedAffineApply(builder, loc, expr, values);
+ } else {
+ assert((isa<memref::ExpandShapeOp, memref::CollapseShapeOp>(
+ viewOp.getOperation())) &&
+ "unexpected view-like op in dest chain");
+ }
+ destTrace = viewOp.getViewSource();
+ }
+
+ // Only the offset unaligned to the swizzle period affects the XOR
+ // computation.
+ auto cst = getConstantIntValue(totalOffset);
+ Value swizzleBaseOffset;
+ int64_t swizzlePeriod = swizzle.getSwizzlePeriod();
+ if (!cst || (*cst % swizzlePeriod != 0)) {
+ swizzleBaseOffset =
+ getValueOrCreateConstantIndexOp(builder, loc, totalOffset);
+ }
+
+ // Add the subgroup's base offset within the full allocation.
+ Value swizzleInput = srcLinearOffset;
+ if (swizzleBaseOffset) {
+ swizzleInput =
+ arith::AddIOp::create(builder, loc, swizzleBaseOffset, srcLinearOffset);
+ }
+
+ // Apply the swizzle (handles access-width alignment internally).
+ Value swizzled = getValueOrCreateConstantIndexOp(
+ builder, loc, swizzle.swizzleOffset(builder, loc, swizzleInput, dest));
+
+ // Subtract the subgroup base offset.
+ if (swizzleBaseOffset) {
+ return arith::SubIOp::create(builder, loc, swizzled, swizzleBaseOffset);
+ }
+ return swizzled;
+}
+
} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.h b/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.h
index a39e095..631e020 100644
--- a/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.h
+++ b/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.h
@@ -282,6 +282,27 @@
IREE::Codegen::InnerTileDescAttrInterface intrinsic,
ArrayRef<int64_t> reductionTileSizes, int operandIndex,
bool skipUntunedFallback = false);
+
+/// Apply inverse XOR swizzle to a sub-tile-local source offset so that the
+/// DMA write-side permutation matches the read-side (ResolveSwizzleHints).
+///
+/// When the subgroup transfer size is not a multiple of the swizzle period,
+/// we add the subgroup's base offset within the full allocation before
+/// swizzling and subtract it after.
+///
+/// Example: xor_shuffle<64, 8>, period = 64*64/8 = 512 elements.
+/// Workgroup tile = 32x32, 4 subgroups of 8x32 = 256 each.
+/// 256 < 512, so the fix is needed:
+/// ```
+/// BUG: swizzle(local)
+/// Subgroups 0 and 1 see the same local offsets but occupy
+/// different rows in the full allocation.
+/// FIX: swizzle(local + base) - base
+/// ```
+Value applyInverseXorSwizzleToDMASourceOffset(
+ OpBuilder &builder, Location loc, Value srcLinearOffset,
+ IREE::Codegen::XORShuffleAttr swizzle, Value dest);
+
//===----------------------------------------------------------------------===//
// GPU CodeGen op filter
//===----------------------------------------------------------------------===//
diff --git a/tests/e2e/matmul/CMakeLists.txt b/tests/e2e/matmul/CMakeLists.txt
index 8b4e15c..e89d533 100644
--- a/tests/e2e/matmul/CMakeLists.txt
+++ b/tests/e2e/matmul/CMakeLists.txt
@@ -1627,6 +1627,34 @@
iree_generated_e2e_runner_test(
NAME
+ e2e_batch_matmul_${_CDNA_ARCH}_coalesced_dma_bf16
+ TEST_TYPE
+ matmul
+ GENERATOR
+ "generate_e2e_batch_matmul_tests.py"
+ GENERATOR_ARGS
+ "--lhs_rhs_type=bf16"
+ "--acc_type=f32"
+ "--shapes=block_static"
+ TEST_RUNNER
+ iree_tools_testing_e2e_iree-e2e-matmul-test
+ TARGET_BACKENDS
+ "rocm"
+ DRIVERS
+ "hip"
+ COMPILER_FLAGS
+ ${IREE_HIP_TEST_COMPILER_FLAGS}
+ "--iree-llvmgpu-use-direct-load"
+ LABELS
+ "noasan"
+ "nomsan"
+ "notsan"
+ "noubsan"
+ "requires-gpu-${_CDNA_ARCH}"
+)
+
+iree_generated_e2e_runner_test(
+ NAME
e2e_matmul_${_CDNA_ARCH}_vecdistmfma_f16
TEST_TYPE
matmul
diff --git a/tests/e2e/matmul/generate_e2e_batch_matmul_tests.py b/tests/e2e/matmul/generate_e2e_batch_matmul_tests.py
index 08a3d88..80d9eb0 100644
--- a/tests/e2e/matmul/generate_e2e_batch_matmul_tests.py
+++ b/tests/e2e/matmul/generate_e2e_batch_matmul_tests.py
@@ -28,6 +28,7 @@
"block_static": [
BatchTestShape(batch=32, m=128, k=128, n=128),
BatchTestShape(batch=32, m=256, k=64, n=256),
+ BatchTestShape(batch=32, m=96, k=96, n=96),
],
}