[Codegen][DMA] Fix unaligned swizzle offset computation in gather-to-lds lowering (#24241)

The inverse XOR swizzle applied to DMA source offsets was incorrect in
two cases:

1. **Subgroup base offset**: When a subgroup's transfer size is not a
multiple of the swizzle period, different subgroups sharing the same
local offsets but occupying different rows would get identical swizzled
addresses.

Fix: incorporate the subgroup's base offset within the full allocation
before swizzling.

2. **Access-width alignment**: When `elementsPerLane < accessWidth`, the
integer division inside `swizzleOffset` truncates offsets that differ
only within an access-width group to the same value.

Fix: strip the sub-accesswidth remainder before swizzling and restore it
after. This fix is applied directly in `swizzleOffset` for both XOR and
rotate_rows swizzles. While rotate_rows isn't currently used with DMA,
the access-width alignment issue affects both swizzle types.

Both issues caused numerical mismatches for BF16 batch matmuls using DMA
with XOR swizzle enabled.

Assisted-by: Cursor (Claude)

---------

Signed-off-by: Yu-Zhewen <zhewenyu@amd.com>
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/AMDGPULowerCoalescedDMAToGatherLDS.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/AMDGPULowerCoalescedDMAToGatherLDS.cpp
index 64fe2f4..c46fafd 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/AMDGPULowerCoalescedDMAToGatherLDS.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/AMDGPULowerCoalescedDMAToGatherLDS.cpp
@@ -218,16 +218,16 @@
 /// Trace a memref value through view-like ops to find a SwizzleHintOp.
 /// Returns the swizzle attribute if it is an XOR swizzle (which is
 /// self-inverse), std::nullopt otherwise.
-static std::optional<IREE::Codegen::SwizzleAttrInterface>
+static std::optional<IREE::Codegen::XORShuffleAttr>
 getDestSwizzleAttr(Value dest) {
   dest = getRootSource(dest);
   if (auto hintOp = dest.getDefiningOp<IREE::Codegen::SwizzleHintOp>()) {
-    auto swizzle = hintOp.getSwizzle();
     // Only XOR swizzle is self-inverse (swizzle(swizzle(x)) = x), so it can
     // be applied to source addresses as the inverse transformation. Other
     // swizzle types (e.g. rotate_rows) are not self-inverse.
-    if (isa<IREE::Codegen::XORShuffleAttr>(swizzle)) {
-      return swizzle;
+    if (auto xorSwizzle =
+            dyn_cast<IREE::Codegen::XORShuffleAttr>(hintOp.getSwizzle())) {
+      return xorSwizzle;
     }
   }
   return std::nullopt;
@@ -342,7 +342,7 @@
 
     Value source = dmaOp.getSource();
     Value dest = dmaOp.getInit();
-    std::optional<IREE::Codegen::SwizzleAttrInterface> destSwizzle =
+    std::optional<IREE::Codegen::XORShuffleAttr> destSwizzle =
         getDestSwizzleAttr(dest);
 
     auto sourceType = cast<MemRefType>(source.getType());
@@ -493,7 +493,7 @@
       ArrayRef<int64_t> destShape, int64_t numLinearDims, Type elementType,
       OperandRange indices, ArrayRef<TransferSegment> segments,
       ArrayRef<Value> segmentLaneOffsets, std::optional<ArrayAttr> inBoundsAttr,
-      std::optional<IREE::Codegen::SwizzleAttrInterface> destSwizzle) const {
+      std::optional<IREE::Codegen::XORShuffleAttr> destSwizzle) const {
     int64_t destRank = destShape.size();
     int64_t numOuterDims = destRank - numLinearDims;
     LDBG() << "Emitting transfers: " << numOuterDims << " outer dims, "
@@ -543,10 +543,8 @@
           // Apply inverse source swizzle when destination has XOR swizzle.
           // XOR swizzle is self-inverse, so swizzle(swizzle(x)) = x.
           if (destSwizzle) {
-            srcLinearOffset = getValueOrCreateConstantIndexOp(
-                rewriter, loc,
-                destSwizzle->swizzleOffset(rewriter, loc, srcLinearOffset,
-                                           dest));
+            srcLinearOffset = applyInverseXorSwizzleToDMASourceOffset(
+                rewriter, loc, srcLinearOffset, *destSwizzle, dest);
           }
           auto srcDelinearize = affine::AffineDelinearizeIndexOp::create(
               rewriter, loc, srcLinearOffset, basis, /*hasOuterBound=*/true);
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/amdgpu_lower_coalesced_dma_to_gather_lds.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/amdgpu_lower_coalesced_dma_to_gather_lds.mlir
index 4b77f00..0a032ba 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/amdgpu_lower_coalesced_dma_to_gather_lds.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/amdgpu_lower_coalesced_dma_to_gather_lds.mlir
@@ -1658,7 +1658,7 @@
 //
 // Shape: 4x128 f32 dest. With 64 lanes, 128-bit DMA = 4 elements/lane.
 // 256 elements/transfer, 512 total = 2 transfers.
-// XOR swizzle with row_width=128, access_width=16 permutes source offsets.
+// XOR swizzle with row_width=128, access_width=4 permutes source offsets.
 
 #executable_target_rocm_hsaco_fb_swizzle = #hal.executable.target<"rocm",
   "rocm-hsaco-fb", {iree_codegen.target_info = #iree_gpu.target<
@@ -1682,7 +1682,7 @@
     hal.executable.target = #executable_target_rocm_hsaco_fb_swizzle,
     translation_info = #translation_swizzle} {
   %alloc = memref.alloc() : memref<512xf32, #gpu.address_space<workgroup>>
-  %swizzled = iree_codegen.swizzle_hint %alloc[#iree_codegen.xor_shuffle<128, 16>]
+  %swizzled = iree_codegen.swizzle_hint %alloc[#iree_codegen.xor_shuffle<128, 4>]
       : memref<512xf32, #gpu.address_space<workgroup>>
   %dest = memref.expand_shape %swizzled [[0, 1]]
       output_shape [4, 128]
@@ -1696,12 +1696,15 @@
     // Transfer 1: linearOffset = 0, source gets XOR-swizzled offset.
     // CHECK: %[[C0:.+]] = arith.constant 0 : index
     // CHECK: %[[SRC_LIN0:.+]] = arith.addi %[[C0]], %[[LANE_OFFSET]]
-    // XOR swizzle: extractCol, extractRow, xori, updateCol, diff, add.
+    // Access-width alignment: strip remainder, swizzle, apply diff to original.
+    // CHECK: %[[REM0:.+]] = arith.remui %[[SRC_LIN0]], %[[C4]]
+    // CHECK: %[[ALIGNED0:.+]] = arith.subi %[[SRC_LIN0]], %[[REM0]]
+    // XOR swizzle: extractCol, extractRow, xori, updateCol.
     // CHECK: %[[COL0:.+]] = affine.apply
     // CHECK: %[[ROW0:.+]] = affine.apply
     // CHECK: %[[XOR0:.+]] = arith.xori %[[ROW0]], %[[COL0]]
     // CHECK: %[[UCOL0:.+]] = affine.apply
-    // CHECK: %[[DIFF0:.+]] = arith.subi %[[UCOL0]], %[[SRC_LIN0]]
+    // CHECK: %[[DIFF0:.+]] = arith.subi %[[UCOL0]], %[[ALIGNED0]]
     // CHECK: %[[SWIZZLED0:.+]] = arith.addi %[[SRC_LIN0]], %[[DIFF0]]
     // Source delinearized from swizzled offset.
     // CHECK: %[[SRC_DELIN0:.+]]:2 = affine.delinearize_index %[[SWIZZLED0]] into (4, 128)
@@ -1812,3 +1815,160 @@
   } {mapping = [#iree_gpu.lane_id<0>]}
   return
 }
+
+// -----
+
+// Test: Subgroup base offset adjustment for swizzled DMA.
+// The dest subview has offset 256 with 256 elements, which is less than
+// the swizzle period (64*64/8 = 512). The base offset must be added before
+// swizzling and subtracted after.
+
+#executable_target_base_offset = #hal.executable.target<"rocm",
+  "rocm-hsaco-fb", {iree_codegen.target_info = #iree_gpu.target<
+  arch = "gfx950", features = "", wgp = <
+    compute = fp32, storage = b32, subgroup = none, dot = none, mma = [],
+    subgroup_size_choices = [64, 64],
+    max_workgroup_sizes = [1024, 1024, 1024],
+    max_thread_count_per_workgroup = 1024,
+    max_workgroup_memory_bytes = 65536,
+    max_workgroup_counts = [2147483647, 2147483647, 2147483647],
+    max_load_instruction_bits = 128, simds_per_wgp = 4,
+    vgpr_space_bits = 8192, dma_sizes = [32]>>}>
+
+#translation_base_offset = #iree_codegen.translation_info<pipeline = #iree_gpu.pipeline<TileAndFuse> workgroup_size = [64, 1, 1] subgroup_size = 64>
+
+// CHECK-LABEL: func.func @lower_dma_swizzle_base_offset
+func.func @lower_dma_swizzle_base_offset(
+    %source: memref<4x64xbf16, #amdgpu.address_space<fat_raw_buffer>>)
+  attributes {
+    hal.executable.target = #executable_target_base_offset,
+    translation_info = #translation_base_offset} {
+  %alloc = memref.alloc() : memref<512xbf16, #gpu.address_space<workgroup>>
+  %swizzled = iree_codegen.swizzle_hint %alloc[#iree_codegen.xor_shuffle<64, 8>]
+      : memref<512xbf16, #gpu.address_space<workgroup>>
+  %expanded = memref.expand_shape %swizzled [[0, 1]]
+      output_shape [8, 64]
+      : memref<512xbf16, #gpu.address_space<workgroup>>
+      into memref<8x64xbf16, #gpu.address_space<workgroup>>
+  %dest = memref.subview %expanded[4, 0] [4, 64] [1, 1]
+      : memref<8x64xbf16, #gpu.address_space<workgroup>>
+      to memref<4x64xbf16, strided<[64, 1], offset: 256>, #gpu.address_space<workgroup>>
+  // CHECK: scf.forall (%[[LANE:.+]]) in (64)
+  scf.forall (%lane) in (64) {
+    // Base offset 256 added before swizzle, subtracted after.
+    // CHECK-DAG: %[[C256:.+]] = arith.constant 256 : index
+    // CHECK: arith.addi %[[C256]],
+    // CHECK: arith.xori
+    // CHECK: arith.subi {{.*}}, %[[C256]]
+    // CHECK: amdgpu.gather_to_lds
+    iree_gpu.coalesced_gather_dma %source into %dest lane(%lane)
+        : memref<4x64xbf16, #amdgpu.address_space<fat_raw_buffer>>,
+          memref<4x64xbf16, strided<[64, 1], offset: 256>, #gpu.address_space<workgroup>>, index
+  } {mapping = [#iree_gpu.lane_id<0>]}
+  return
+}
+
+// -----
+
+// Test: Access-width alignment for swizzled DMA.
+// With dma_sizes=[32] and bf16, elementsPerLane = 32/16 = 2.
+// accessWidth = 8 (from xor_shuffle<64, 8>), so 2 < 8 and the sub-group
+// remainder must be stripped before swizzling and restored after.
+
+#executable_target_access_align = #hal.executable.target<"rocm",
+  "rocm-hsaco-fb", {iree_codegen.target_info = #iree_gpu.target<
+  arch = "gfx950", features = "", wgp = <
+    compute = fp32, storage = b32, subgroup = none, dot = none, mma = [],
+    subgroup_size_choices = [64, 64],
+    max_workgroup_sizes = [1024, 1024, 1024],
+    max_thread_count_per_workgroup = 1024,
+    max_workgroup_memory_bytes = 65536,
+    max_workgroup_counts = [2147483647, 2147483647, 2147483647],
+    max_load_instruction_bits = 128, simds_per_wgp = 4,
+    vgpr_space_bits = 8192, dma_sizes = [32]>>}>
+
+#translation_access_align = #iree_codegen.translation_info<pipeline = #iree_gpu.pipeline<TileAndFuse> workgroup_size = [64, 1, 1] subgroup_size = 64>
+
+// CHECK-LABEL: func.func @lower_dma_swizzle_access_width_align
+func.func @lower_dma_swizzle_access_width_align(
+    %source: memref<8x64xbf16, #amdgpu.address_space<fat_raw_buffer>>)
+  attributes {
+    hal.executable.target = #executable_target_access_align,
+    translation_info = #translation_access_align} {
+  %alloc = memref.alloc() : memref<512xbf16, #gpu.address_space<workgroup>>
+  %swizzled = iree_codegen.swizzle_hint %alloc[#iree_codegen.xor_shuffle<64, 8>]
+      : memref<512xbf16, #gpu.address_space<workgroup>>
+  %dest = memref.expand_shape %swizzled [[0, 1]]
+      output_shape [8, 64]
+      : memref<512xbf16, #gpu.address_space<workgroup>>
+      into memref<8x64xbf16, #gpu.address_space<workgroup>>
+  // CHECK: scf.forall (%[[LANE:.+]]) in (64)
+  scf.forall (%lane) in (64) {
+    // elementsPerLane=2 < accessWidth=8: remainder stripped and restored.
+    // CHECK-DAG: %[[C8:.+]] = arith.constant 8 : index
+    // CHECK: %[[REM:.+]] = arith.remui %[[SRC_LIN:.+]], %[[C8]]
+    // CHECK: %[[ALIGNED:.+]] = arith.subi %[[SRC_LIN]], %[[REM]]
+    // CHECK: arith.xori
+    // CHECK: %[[DIFF:.+]] = arith.subi {{.*}}, %[[ALIGNED]]
+    // CHECK: arith.addi %[[SRC_LIN]], %[[DIFF]]
+    // CHECK: amdgpu.gather_to_lds
+    iree_gpu.coalesced_gather_dma %source into %dest lane(%lane)
+        : memref<8x64xbf16, #amdgpu.address_space<fat_raw_buffer>>,
+          memref<8x64xbf16, #gpu.address_space<workgroup>>, index
+  } {mapping = [#iree_gpu.lane_id<0>]}
+  return
+}
+
+// -----
+
+// Test: Combined subgroup base offset AND access-width alignment for swizzled DMA.
+
+#executable_target_combined = #hal.executable.target<"rocm",
+  "rocm-hsaco-fb", {iree_codegen.target_info = #iree_gpu.target<
+  arch = "gfx950", features = "", wgp = <
+    compute = fp32, storage = b32, subgroup = none, dot = none, mma = [],
+    subgroup_size_choices = [64, 64],
+    max_workgroup_sizes = [1024, 1024, 1024],
+    max_thread_count_per_workgroup = 1024,
+    max_workgroup_memory_bytes = 65536,
+    max_workgroup_counts = [2147483647, 2147483647, 2147483647],
+    max_load_instruction_bits = 128, simds_per_wgp = 4,
+    vgpr_space_bits = 8192, dma_sizes = [32]>>}>
+
+#translation_combined = #iree_codegen.translation_info<pipeline = #iree_gpu.pipeline<TileAndFuse> workgroup_size = [64, 1, 1] subgroup_size = 64>
+
+// CHECK-LABEL: func.func @lower_dma_swizzle_combined_base_and_access_width
+func.func @lower_dma_swizzle_combined_base_and_access_width(
+    %source: memref<4x64xbf16, #amdgpu.address_space<fat_raw_buffer>>)
+  attributes {
+    hal.executable.target = #executable_target_combined,
+    translation_info = #translation_combined} {
+  %alloc = memref.alloc() : memref<512xbf16, #gpu.address_space<workgroup>>
+  %swizzled = iree_codegen.swizzle_hint %alloc[#iree_codegen.xor_shuffle<64, 8>]
+      : memref<512xbf16, #gpu.address_space<workgroup>>
+  %expanded = memref.expand_shape %swizzled [[0, 1]]
+      output_shape [8, 64]
+      : memref<512xbf16, #gpu.address_space<workgroup>>
+      into memref<8x64xbf16, #gpu.address_space<workgroup>>
+  %dest = memref.subview %expanded[4, 0] [4, 64] [1, 1]
+      : memref<8x64xbf16, #gpu.address_space<workgroup>>
+      to memref<4x64xbf16, strided<[64, 1], offset: 256>, #gpu.address_space<workgroup>>
+  // CHECK: scf.forall (%[[LANE:.+]]) in (64)
+  scf.forall (%lane) in (64) {
+    // Base offset 256 added, then access-width remainder stripped/restored.
+    // CHECK: %[[C256:.+]] = arith.constant 256 : index
+    // CHECK: %[[SRC_LIN:.+]] = arith.addi %[[C256]],
+    // CHECK: %[[C8:.+]] = arith.constant 8 : index
+    // CHECK: %[[REM:.+]] = arith.remui %[[SRC_LIN]], %[[C8]]
+    // CHECK: %[[ALIGNED:.+]] = arith.subi %[[SRC_LIN]], %[[REM]]
+    // CHECK: arith.xori
+    // CHECK: %[[DIFF:.+]] = arith.subi {{.*}}, %[[ALIGNED]]
+    // CHECK: arith.addi %[[SRC_LIN]], %[[DIFF]]
+    // CHECK: arith.subi {{.*}}, %[[C256]]
+    // CHECK: amdgpu.gather_to_lds
+    iree_gpu.coalesced_gather_dma %source into %dest lane(%lane)
+        : memref<4x64xbf16, #amdgpu.address_space<fat_raw_buffer>>,
+          memref<4x64xbf16, strided<[64, 1], offset: 256>, #gpu.address_space<workgroup>>, index
+  } {mapping = [#iree_gpu.lane_id<0>]}
+  return
+}
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/resolve_swizzle_hints.mlir b/compiler/src/iree/compiler/Codegen/Common/test/resolve_swizzle_hints.mlir
index ee7a9d7..0df2999 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/resolve_swizzle_hints.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/resolve_swizzle_hints.mlir
@@ -141,14 +141,18 @@
 //   CHECK-DAG:   %[[ROW_WIDTH:.+]] = arith.constant 64 : index
 //   CHECK-DAG:   %[[GROUP_COUNT:.+]] = arith.constant 16 : index
 //   CHECK-DAG:   %[[GROUP_WIDTH:.+]] = arith.constant 4 : index
-//       CHECK:   %[[I:.+]] = arith.divui %[[OFFSET]], %[[ROW_WIDTH]] : index
-//       CHECK:   %[[JELEM:.+]] = arith.remui %[[OFFSET]], %[[ROW_WIDTH]] : index
+//       CHECK:   %[[REM:.+]] = arith.remui %[[OFFSET]], %[[GROUP_WIDTH]] : index
+//       CHECK:   %[[ALIGNED:.+]] = arith.subi %[[OFFSET]], %[[REM]] : index
+//       CHECK:   %[[I:.+]] = arith.divui %[[ALIGNED]], %[[ROW_WIDTH]] : index
+//       CHECK:   %[[JELEM:.+]] = arith.remui %[[ALIGNED]], %[[ROW_WIDTH]] : index
 //       CHECK:   %[[J:.+]] = arith.divui %[[JELEM]], %[[GROUP_WIDTH]] : index
 //       CHECK:   %[[ADD:.+]] = arith.addi %[[I]], %[[J]] : index
 //       CHECK:   %[[ROTATEJ:.+]] = arith.remui %[[ADD]], %[[GROUP_COUNT]] : index
 //       CHECK:   %[[ROTATEJELEM:.+]] = arith.muli %[[ROTATEJ]], %[[GROUP_WIDTH]] : index
 //       CHECK:   %[[IELEM:.+]] = arith.muli %[[I]], %[[ROW_WIDTH]] : index
-//       CHECK:   %[[SWOFF:.+]] = arith.addi %[[ROTATEJELEM]], %[[IELEM]] : index
+//       CHECK:   %[[SWIZZLED:.+]] = arith.addi %[[ROTATEJELEM]], %[[IELEM]] : index
+//       CHECK:   %[[DIFF:.+]] = arith.subi %[[SWIZZLED]], %[[ALIGNED]] : index
+//       CHECK:   %[[SWOFF:.+]] = arith.addi %[[OFFSET]], %[[DIFF]] : index
 
 // Make sure both the load and store get the same calculation.
 //       CHECK:   %[[VECTOR:.+]] = vector.load %[[SRC]][%[[SWOFF]]]
@@ -177,20 +181,23 @@
 //   CHECK-DAG:   %[[GROUP_WIDTH:.+]] = arith.constant 4 : index
 //   CHECK-DAG:   %[[C1040:.+]] = arith.constant 1040 : index
 //       CHECK:   %[[APPLY_BASE:.+]] = arith.addi %[[OFFSET]], %[[GROUP_COUNT]] overflow<nsw> : index
-//       CHECK:   %[[I:.+]] = arith.divui %[[APPLY_BASE]], %[[ROW_WIDTH]] : index
-//       CHECK:   %[[JELEM:.+]] = arith.remui %[[APPLY_BASE]], %[[ROW_WIDTH]] : index
+//       CHECK:   %[[REM:.+]] = arith.remui %[[APPLY_BASE]], %[[GROUP_WIDTH]] : index
+//       CHECK:   %[[ALIGNED:.+]] = arith.subi %[[APPLY_BASE]], %[[REM]] : index
+//       CHECK:   %[[I:.+]] = arith.divui %[[ALIGNED]], %[[ROW_WIDTH]] : index
+//       CHECK:   %[[JELEM:.+]] = arith.remui %[[ALIGNED]], %[[ROW_WIDTH]] : index
 //       CHECK:   %[[J:.+]] = arith.divui %[[JELEM]], %[[GROUP_WIDTH]] : index
 //       CHECK:   %[[ADD:.+]] = arith.addi %[[I]], %[[J]] : index
 //       CHECK:   %[[ROTATEJ:.+]] = arith.remui %[[ADD]], %[[GROUP_COUNT]] : index
 //       CHECK:   %[[ROTATEJELEM:.+]] = arith.muli %[[ROTATEJ]], %[[GROUP_WIDTH]] : index
 //       CHECK:   %[[IELEM:.+]] = arith.muli %[[I]], %[[ROW_WIDTH]] : index
-//       CHECK:   %[[SWOFF:.+]] = arith.addi %[[ROTATEJELEM]], %[[IELEM]] : index
+//       CHECK:   %[[SWIZZLED:.+]] = arith.addi %[[ROTATEJELEM]], %[[IELEM]] : index
+//       CHECK:   %[[DIFF:.+]] = arith.subi %[[SWIZZLED]], %[[ALIGNED]] : index
 
-//       CHECK:   %[[VECTOR:.+]] = vector.load %[[SRC]][%[[SWOFF]]]
+//       CHECK:   %[[LOAD_SWOFF:.+]] = arith.addi %[[APPLY_BASE]], %[[DIFF]] : index
+//       CHECK:   %[[VECTOR:.+]] = vector.load %[[SRC]][%[[LOAD_SWOFF]]]
 
 //       CHECK:   %[[STORE_BASE:.+]] = arith.addi %[[OFFSET]], %[[C1040]] overflow<nsw> : index
-//       CHECK:   %[[OFFSET_DIFF:.+]] = arith.subi %[[SWOFF]], %[[APPLY_BASE]] : index
-//       CHECK:   %[[STORE_SWOFF:.+]] = arith.addi %[[STORE_BASE]], %[[OFFSET_DIFF]] : index
+//       CHECK:   %[[STORE_SWOFF:.+]] = arith.addi %[[STORE_BASE]], %[[DIFF]] : index
 //       CHECK:   vector.store %[[VEC]], %[[SRC]][%[[STORE_SWOFF]]]
 //       CHECK:   return %[[VECTOR]]
 
@@ -232,6 +239,23 @@
 
 // -----
 
+func.func @swizzle_load_xor_unaligned(%src: memref<?xi8>) -> vector<16xi8> {
+  %0 = iree_codegen.swizzle_hint %src[#iree_codegen.xor_shuffle<128, 16>] : memref<?xi8>
+
+  // Offset 1955 = 1952 + 3, where 1952 swizzles to 2000, so result = 2000 + 3 = 2003.
+  %offset = arith.constant 1955 : index
+  %1 = vector.load %0[%offset] : memref<?xi8>, vector<16xi8>
+  return %1: vector<16xi8>
+}
+
+// CHECK-LABEL: func @swizzle_load_xor_unaligned
+//  CHECK-SAME:   %[[SRC:[A-Za-z0-9]+]]: memref<?xi8>
+//       CHECK:   %[[SWOFF:.+]] = arith.constant 2003 : index
+//       CHECK:   %[[VECTOR:.+]] = vector.load %[[SRC]][%[[SWOFF]]]
+//       CHECK:   return %[[VECTOR]]
+
+// -----
+
 func.func @swizzle_load_xor_phase2(%src: memref<?xi8>) -> vector<16xi8> {
   %0 = iree_codegen.swizzle_hint %src[#iree_codegen.xor_shuffle<128, 16, 128, 2>] : memref<?xi8>
 
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.cpp b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.cpp
index 46570c9..9b0bade 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.cpp
@@ -662,10 +662,8 @@
   // to get the simplest offset possible in case we are accessing values from
   // successive rows. This allows us to CSE the swizzling computation more
   // effectively.
-  int64_t rotationInvariant =
-      getRowWidth() * (getRowWidth() / getAccessWidth());
   OpFoldResult id =
-      getMinimumConstantOffsetValue(b, loc, offset, rotationInvariant);
+      getMinimumConstantOffsetValue(b, loc, offset, getSwizzlePeriod());
 
   // Number of elements per row.
   Value rowAlignmentVal = arith::ConstantIndexOp::create(b, loc, getRowWidth());
@@ -676,11 +674,16 @@
   Value accessWidthVal =
       arith::ConstantIndexOp::create(b, loc, getAccessWidth());
 
+  // Strip the sub-access_width remainder to align the offset, since the
+  // swizzle operates at access_width granularity.
   Value idVal = getValueOrCreateConstantIndexOp(b, loc, id);
+  Value remainder = arith::RemUIOp::create(b, loc, idVal, accessWidthVal);
+  Value alignedId = arith::SubIOp::create(b, loc, idVal, remainder);
+
   // i = row # = |offset| floordiv |Num elements per row|
-  Value i = arith::DivUIOp::create(b, loc, idVal, rowAlignmentVal);
+  Value i = arith::DivUIOp::create(b, loc, alignedId, rowAlignmentVal);
   // jByte = element column # = |offset| % |Num elements per row|
-  Value jElem = arith::RemUIOp::create(b, loc, idVal, rowAlignmentVal);
+  Value jElem = arith::RemUIOp::create(b, loc, alignedId, rowAlignmentVal);
   // j = group column # = jElem / |Num elements per group|
   Value j = arith::DivUIOp::create(b, loc, jElem, accessWidthVal);
 
@@ -700,7 +703,13 @@
   // This increases the chance of being able to CSE the offset calculation. When
   // multiple accesses to a memref only differ by a constant value (very common
   // when working with statically shaped memrefs like shared/scratch memory).
-  Value diff = arith::SubIOp::create(b, loc, swizzledId, idVal);
+  //
+  // The diff is computed from |alignedId| but applied to the original |offset|.
+  // This implicitly restores the sub-access_width remainder:
+  //   offset + (swizzle(alignedId) - alignedId)
+  //   = (alignedId + rem) + (swizzle(alignedId) - alignedId)
+  //   = swizzle(alignedId) + rem
+  Value diff = arith::SubIOp::create(b, loc, swizzledId, alignedId);
   return arith::AddIOp::create(
              b, loc, getValueOrCreateConstantIndexOp(b, loc, offset), diff)
       .getResult();
@@ -782,21 +791,23 @@
 OpFoldResult XORShuffleAttr::swizzleOffset(OpBuilder &b, Location loc,
                                            OpFoldResult offset,
                                            Value src) const {
-  int64_t rotationInvariant =
-      getRowWidth() * (getRowWidth() / getAccessWidth());
   int64_t rowStride =
       getRowStride() != int64_t() ? getRowStride() : getRowWidth();
   int64_t perPhase = getPerPhase() != int64_t() ? getPerPhase() : 1;
 
   OpFoldResult id =
-      getMinimumConstantOffsetValue(b, loc, offset, rotationInvariant);
+      getMinimumConstantOffsetValue(b, loc, offset, getSwizzlePeriod());
+
+  // Strip the sub-access_width remainder to align the offset, since the
+  // swizzle operates at access_width granularity.
   Value idVal = getValueOrCreateConstantIndexOp(b, loc, id);
+  Value accessWidthVal =
+      arith::ConstantIndexOp::create(b, loc, getAccessWidth());
+  Value remainder = arith::RemUIOp::create(b, loc, idVal, accessWidthVal);
+  Value alignedId = arith::SubIOp::create(b, loc, idVal, remainder);
 
   // Number of elements per row.
   Value rowAlignmentVal = arith::ConstantIndexOp::create(b, loc, getRowWidth());
-  // Number of elements per group.
-  Value accessWidthVal =
-      arith::ConstantIndexOp::create(b, loc, getAccessWidth());
   // Number of rows per phase.
   Value perPhaseVal = arith::ConstantIndexOp::create(b, loc, perPhase);
   // Buffer stride.
@@ -805,15 +816,20 @@
   Value rowAccessAlignmentVal =
       arith::ConstantIndexOp::create(b, loc, getRowWidth() / getAccessWidth());
 
-  Value colVal = extractCol(b, loc, idVal, rowAlignmentVal, accessWidthVal);
-  Value rowVal = extractRow(b, loc, idVal, rowStrideVal, perPhaseVal,
+  Value colVal = extractCol(b, loc, alignedId, rowAlignmentVal, accessWidthVal);
+  Value rowVal = extractRow(b, loc, alignedId, rowStrideVal, perPhaseVal,
                             rowAccessAlignmentVal);
   auto colSwizzled = arith::XOrIOp::create(b, loc, rowVal, colVal);
 
-  // Update colSwizzled to initial id
-  Value swizzledIdVal =
-      updateCol(b, loc, idVal, colSwizzled, rowAlignmentVal, accessWidthVal);
-  Value diff = arith::SubIOp::create(b, loc, swizzledIdVal, idVal);
+  // Update colSwizzled to alignedId
+  Value swizzledIdVal = updateCol(b, loc, alignedId, colSwizzled,
+                                  rowAlignmentVal, accessWidthVal);
+  // The diff is computed from |alignedId| but applied to the original |offset|.
+  // This implicitly restores the sub-access_width remainder:
+  //   offset + (swizzle(alignedId) - alignedId)
+  //   = (alignedId + rem) + (swizzle(alignedId) - alignedId)
+  //   = swizzle(alignedId) + rem
+  Value diff = arith::SubIOp::create(b, loc, swizzledIdVal, alignedId);
   return arith::AddIOp::create(
              b, loc, getValueOrCreateConstantIndexOp(b, loc, offset), diff)
       .getResult();
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.td b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.td
index 59176ea..f692a93 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.td
+++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.td
@@ -473,6 +473,11 @@
   let assemblyFormat = [{
     `<` $row_width `,` $access_width `>`
   }];
+  let extraClassDeclaration = [{
+    int64_t getSwizzlePeriod() const {
+      return getRowWidth() * (getRowWidth() / getAccessWidth());
+    }
+  }];
   let genVerifyDecl = 1;
 }
 
@@ -596,6 +601,11 @@
   let assemblyFormat = [{
     `<` $row_width `,` $access_width  (`,` $row_stride^)? (`,` $per_phase^)? `>`
   }];
+  let extraClassDeclaration = [{
+    int64_t getSwizzlePeriod() const {
+      return getRowWidth() * (getRowWidth() / getAccessWidth());
+    }
+  }];
   let genVerifyDecl = 1;
 }
 
diff --git a/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.cpp b/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.cpp
index 19603bc..e1adfa5 100644
--- a/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.cpp
+++ b/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.cpp
@@ -21,7 +21,10 @@
 #include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Matchers.h"
@@ -1367,4 +1370,61 @@
   return rootOps;
 }
 
+Value applyInverseXorSwizzleToDMASourceOffset(
+    OpBuilder &builder, Location loc, Value srcLinearOffset,
+    IREE::Codegen::XORShuffleAttr swizzle, Value dest) {
+  // Compute the subgroup's base offset within the full allocation by tracing
+  // through view-like ops and accumulating the linear element offset.
+  Value destTrace = dest;
+  OpFoldResult totalOffset = builder.getIndexAttr(0);
+  while (auto viewOp = destTrace.getDefiningOp<ViewLikeOpInterface>()) {
+    if (auto subviewOp = dyn_cast<memref::SubViewOp>(viewOp.getOperation())) {
+      auto parentType = cast<MemRefType>(subviewOp.getSource().getType());
+      SmallVector<int64_t> strides;
+      int64_t offset;
+      LogicalResult res = parentType.getStridesAndOffset(strides, offset);
+      assert(succeeded(res) && "expected strided layout on subview source");
+      (void)res;
+      SmallVector<OpFoldResult> strideOFRs =
+          getAsIndexOpFoldResult(builder.getContext(), strides);
+      auto &&[expr, values] = computeLinearIndex(totalOffset, strideOFRs,
+                                                 subviewOp.getMixedOffsets());
+      totalOffset =
+          affine::makeComposedFoldedAffineApply(builder, loc, expr, values);
+    } else {
+      assert((isa<memref::ExpandShapeOp, memref::CollapseShapeOp>(
+                 viewOp.getOperation())) &&
+             "unexpected view-like op in dest chain");
+    }
+    destTrace = viewOp.getViewSource();
+  }
+
+  // Only the offset unaligned to the swizzle period affects the XOR
+  // computation.
+  auto cst = getConstantIntValue(totalOffset);
+  Value swizzleBaseOffset;
+  int64_t swizzlePeriod = swizzle.getSwizzlePeriod();
+  if (!cst || (*cst % swizzlePeriod != 0)) {
+    swizzleBaseOffset =
+        getValueOrCreateConstantIndexOp(builder, loc, totalOffset);
+  }
+
+  // Add the subgroup's base offset within the full allocation.
+  Value swizzleInput = srcLinearOffset;
+  if (swizzleBaseOffset) {
+    swizzleInput =
+        arith::AddIOp::create(builder, loc, swizzleBaseOffset, srcLinearOffset);
+  }
+
+  // Apply the swizzle (handles access-width alignment internally).
+  Value swizzled = getValueOrCreateConstantIndexOp(
+      builder, loc, swizzle.swizzleOffset(builder, loc, swizzleInput, dest));
+
+  // Subtract the subgroup base offset.
+  if (swizzleBaseOffset) {
+    return arith::SubIOp::create(builder, loc, swizzled, swizzleBaseOffset);
+  }
+  return swizzled;
+}
+
 } // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.h b/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.h
index a39e095..631e020 100644
--- a/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.h
+++ b/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.h
@@ -282,6 +282,27 @@
                   IREE::Codegen::InnerTileDescAttrInterface intrinsic,
                   ArrayRef<int64_t> reductionTileSizes, int operandIndex,
                   bool skipUntunedFallback = false);
+
+/// Apply inverse XOR swizzle to a sub-tile-local source offset so that the
+/// DMA write-side permutation matches the read-side (ResolveSwizzleHints).
+///
+/// When the subgroup transfer size is not a multiple of the swizzle period,
+/// we add the subgroup's base offset within the full allocation before
+/// swizzling and subtract it after.
+///
+///    Example: xor_shuffle<64, 8>, period = 64*64/8 = 512 elements.
+///    Workgroup tile = 32x32, 4 subgroups of 8x32 = 256 each.
+///    256 < 512, so the fix is needed:
+///    ```
+///      BUG: swizzle(local)
+///           Subgroups 0 and 1 see the same local offsets but occupy
+///           different rows in the full allocation.
+///      FIX: swizzle(local + base) - base
+///    ```
+Value applyInverseXorSwizzleToDMASourceOffset(
+    OpBuilder &builder, Location loc, Value srcLinearOffset,
+    IREE::Codegen::XORShuffleAttr swizzle, Value dest);
+
 //===----------------------------------------------------------------------===//
 // GPU CodeGen op filter
 //===----------------------------------------------------------------------===//
diff --git a/tests/e2e/matmul/CMakeLists.txt b/tests/e2e/matmul/CMakeLists.txt
index 8b4e15c..e89d533 100644
--- a/tests/e2e/matmul/CMakeLists.txt
+++ b/tests/e2e/matmul/CMakeLists.txt
@@ -1627,6 +1627,34 @@
 
 iree_generated_e2e_runner_test(
   NAME
+    e2e_batch_matmul_${_CDNA_ARCH}_coalesced_dma_bf16
+  TEST_TYPE
+    matmul
+  GENERATOR
+    "generate_e2e_batch_matmul_tests.py"
+  GENERATOR_ARGS
+    "--lhs_rhs_type=bf16"
+    "--acc_type=f32"
+    "--shapes=block_static"
+  TEST_RUNNER
+    iree_tools_testing_e2e_iree-e2e-matmul-test
+  TARGET_BACKENDS
+    "rocm"
+  DRIVERS
+    "hip"
+  COMPILER_FLAGS
+    ${IREE_HIP_TEST_COMPILER_FLAGS}
+    "--iree-llvmgpu-use-direct-load"
+  LABELS
+    "noasan"
+    "nomsan"
+    "notsan"
+    "noubsan"
+    "requires-gpu-${_CDNA_ARCH}"
+)
+
+iree_generated_e2e_runner_test(
+  NAME
     e2e_matmul_${_CDNA_ARCH}_vecdistmfma_f16
   TEST_TYPE
     matmul
diff --git a/tests/e2e/matmul/generate_e2e_batch_matmul_tests.py b/tests/e2e/matmul/generate_e2e_batch_matmul_tests.py
index 08a3d88..80d9eb0 100644
--- a/tests/e2e/matmul/generate_e2e_batch_matmul_tests.py
+++ b/tests/e2e/matmul/generate_e2e_batch_matmul_tests.py
@@ -28,6 +28,7 @@
     "block_static": [
         BatchTestShape(batch=32, m=128, k=128, n=128),
         BatchTestShape(batch=32, m=256, k=64, n=256),
+        BatchTestShape(batch=32, m=96, k=96, n=96),
     ],
 }