[LDS] Improve multiple transfers per lane (#22879)
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/AMDGPULowerCoalescedDMAToGatherLDS.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/AMDGPULowerCoalescedDMAToGatherLDS.cpp
index a284fd9..824416f 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/AMDGPULowerCoalescedDMAToGatherLDS.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/AMDGPULowerCoalescedDMAToGatherLDS.cpp
@@ -118,16 +118,39 @@
}
LDBG() << "Subgroup size: " << *subgroupSize;
- // Check that transfer size matches one of the target DMA sizes.
- int64_t transferSizePerLane = transferSizeBits / *subgroupSize;
- LDBG() << "Transfer size per lane: " << transferSizePerLane << " bits";
+ // Find a suitable DMA size that allows the innermost dimension to be
+ // evenly divided into N transfers. We prefer larger DMA sizes for
+ // efficiency. Sort DMA sizes in descending order to prefer larger sizes.
+ auto sortedDmaSizes = llvm::to_vector_of<int64_t>(targetDmaSizes);
+ llvm::sort(sortedDmaSizes, std::greater<>());
- if (!targetDmaSizes.empty() &&
- !llvm::is_contained(targetDmaSizes, transferSizePerLane)) {
- return rewriter.notifyMatchFailure(
- dmaOp, "transfer size does not match any target DMA size");
+ int64_t elementsPerLane = 0;
+ for (int64_t dmaSize : sortedDmaSizes) {
+ // Calculate elements per lane for this DMA size.
+ if (dmaSize % elementBits != 0)
+ continue;
+ int64_t candidateElementsPerLane = dmaSize / elementBits;
+
+ // Calculate total elements per transfer (all lanes combined).
+ int64_t totalElementsPerTransfer =
+ *subgroupSize * candidateElementsPerLane;
+
+ // Make sure it can evenly divide the innermost dimension.
+ if (innermostDimSize % totalElementsPerTransfer == 0) {
+ elementsPerLane = candidateElementsPerLane;
+ LDBG() << "Selected DMA size: " << dmaSize
+ << " bits, elementsPerLane: " << elementsPerLane
+ << ", numTransfers: "
+ << (innermostDimSize / totalElementsPerTransfer);
+ break;
+ }
}
- LDBG() << "Transfer size matches target DMA sizes";
+
+ if (elementsPerLane == 0) {
+ return rewriter.notifyMatchFailure(
+ dmaOp, "innermost dimension is not evenly divisible by any "
+ "supported DMA transfer size (subgroupSize * dmaSize)");
+ }
auto destType = cast<MemRefType>(dest.getType());
ArrayRef<int64_t> destShape = destType.getShape();
@@ -137,8 +160,7 @@
size_t numIndexDims = indices.size();
LDBG() << "Number of index dimensions: " << numIndexDims;
- int64_t elementsPerTransfer = innermostDimSize / *subgroupSize;
- auto transferType = VectorType::get({elementsPerTransfer}, elementType);
+ auto transferType = VectorType::get({elementsPerLane}, elementType);
// Actually create the GatherToLDS ops to perform the transfer.
rewriter.setInsertionPoint(dmaOp);
@@ -147,15 +169,17 @@
Location loc = dmaOp.getLoc();
Value laneOffset = arith::MulIOp::create(
rewriter, loc, laneId,
- arith::ConstantIndexOp::create(rewriter, loc, elementsPerTransfer));
+ arith::ConstantIndexOp::create(rewriter, loc, elementsPerLane));
- // Build tile sizes: [1, 1, ..., 1, subgroupSize * elementsPerTransfer].
+ // Build tile sizes: [1, 1, ..., 1, subgroupSize * elementsPerLane].
// This iterates over destination dimensions, with each lane handling
- // `elementsPerTransfer` contiguous elements in the innermost dimension.
+ // `elementsPerLane` contiguous elements in the innermost dimension.
// This approach uniformly handles 1D, 2D, and higher-dimensional cases,
// as well as both copy mode (no indices) and gather mode (with indices).
+ // When innermostDimSize > subgroupSize * elementsPerLane, multiple
+ // GatherToLDS ops are generated to cover the entire inner dimension.
SmallVector<int64_t> tileSizes(destShape.size(), 1);
- tileSizes.back() = *subgroupSize * elementsPerTransfer;
+ tileSizes.back() = *subgroupSize * elementsPerLane;
for (const SmallVector<int64_t> &offsets :
StaticTileOffsetRange(destShape, tileSizes)) {
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/amdgpu_lower_coalesced_dma_to_gather_lds.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/amdgpu_lower_coalesced_dma_to_gather_lds.mlir
index 580f6f8..8a9b4c9 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/amdgpu_lower_coalesced_dma_to_gather_lds.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/amdgpu_lower_coalesced_dma_to_gather_lds.mlir
@@ -13,7 +13,7 @@
max_workgroup_memory_bytes = 65536,
max_workgroup_counts = [2147483647, 2147483647, 2147483647],
max_load_instruction_bits = 128, simds_per_wgp = 4,
- vgpr_space_bits = 8192>>}>
+ vgpr_space_bits = 8192, dma_sizes = [32, 128]>>}>
#translation_64 = #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [64, 1, 1] subgroup_size = 32>
#translation_32 = #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [32, 1, 1] subgroup_size = 32>
@@ -79,7 +79,7 @@
max_workgroup_memory_bytes = 65536,
max_workgroup_counts = [2147483647, 2147483647, 2147483647],
max_load_instruction_bits = 128, simds_per_wgp = 4,
- vgpr_space_bits = 8192>>}>
+ vgpr_space_bits = 8192, dma_sizes = [32, 128]>>}>
#translation_32 = #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [32, 1, 1] subgroup_size = 32>
@@ -131,7 +131,7 @@
max_workgroup_memory_bytes = 65536,
max_workgroup_counts = [2147483647, 2147483647, 2147483647],
max_load_instruction_bits = 128, simds_per_wgp = 4,
- vgpr_space_bits = 8192>>}>
+ vgpr_space_bits = 8192, dma_sizes = [32, 128]>>}>
#translation_32 = #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [32, 1, 1] subgroup_size = 32>
@@ -175,7 +175,7 @@
max_workgroup_memory_bytes = 65536,
max_workgroup_counts = [2147483647, 2147483647, 2147483647],
max_load_instruction_bits = 128, simds_per_wgp = 4,
- vgpr_space_bits = 8192>>}>
+ vgpr_space_bits = 8192, dma_sizes = [32, 128]>>}>
#translation_32 = #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [32, 1, 1] subgroup_size = 32>
@@ -244,7 +244,7 @@
max_workgroup_memory_bytes = 65536,
max_workgroup_counts = [2147483647, 2147483647, 2147483647],
max_load_instruction_bits = 128, simds_per_wgp = 4,
- vgpr_space_bits = 8192>>}>
+ vgpr_space_bits = 8192, dma_sizes = [32, 128]>>}>
#translation_256_wide = #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [256, 1, 1] subgroup_size = 256>
@@ -297,7 +297,7 @@
max_workgroup_memory_bytes = 65536,
max_workgroup_counts = [2147483647, 2147483647, 2147483647],
max_load_instruction_bits = 128, simds_per_wgp = 4,
- vgpr_space_bits = 8192>>}>
+ vgpr_space_bits = 8192, dma_sizes = [32, 128]>>}>
#translation_32 = #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [32, 1, 1] subgroup_size = 32>
@@ -348,7 +348,7 @@
max_workgroup_memory_bytes = 65536,
max_workgroup_counts = [2147483647, 2147483647, 2147483647],
max_load_instruction_bits = 128, simds_per_wgp = 4,
- vgpr_space_bits = 8192>>}>
+ vgpr_space_bits = 8192, dma_sizes = [32, 128]>>}>
#translation_32 = #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [32, 1, 1] subgroup_size = 32>
@@ -398,3 +398,115 @@
} {mapping = [#gpu.thread<linear_dim_0>]}
return
}
+
+// -----
+
+#executable_target_rocm_hsaco_fb = #hal.executable.target<"rocm",
+ "rocm-hsaco-fb", {iree_codegen.target_info = #iree_gpu.target<
+ arch = "gfx950", features = "", wgp = <
+ compute = fp64|fp32|fp16|int64|int32|int16|int8,
+ storage = b64|b32|b16|b8, subgroup = shuffle|arithmetic,
+ dot = dp4xi8toi32, mma = [], subgroup_size_choices = [32, 32],
+ max_workgroup_sizes = [1024, 1024, 1024],
+ max_thread_count_per_workgroup = 1024,
+ max_workgroup_memory_bytes = 65536,
+ max_workgroup_counts = [2147483647, 2147483647, 2147483647],
+ max_load_instruction_bits = 128, simds_per_wgp = 4,
+ vgpr_space_bits = 8192, dma_sizes = [128]>>}>
+
+#translation_32 = #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [32, 1, 1] subgroup_size = 32>
+
+// Test: N-transfer mode for wide innermost dimension.
+// When innermostDimSize > subgroupSize * elementsPerLane, multiple
+// GatherToLDS ops are generated to cover the entire dimension.
+// - innermostDimSize = 256, subgroupSize = 32, dma_sizes = [128]
+// - elementsPerLane = 128 bits / 32 bits = 4 f32s per lane
+// - totalElementsPerTransfer = 32 * 4 = 128 elements
+// - numTransfers = 256 / 128 = 2 transfers
+//
+// CHECK-LABEL: func.func @lower_coalesced_dma_multiple_transfers_1d
+// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]: memref<256xf32, #amdgpu.address_space<fat_raw_buffer>>
+// CHECK-SAME: %[[DST:[a-zA-Z0-9]+]]: memref<256xf32, #gpu.address_space<workgroup>>
+func.func @lower_coalesced_dma_multiple_transfers_1d(
+ %source: memref<256xf32, #amdgpu.address_space<fat_raw_buffer>>,
+ %dest: memref<256xf32, #gpu.address_space<workgroup>>)
+ attributes {
+ hal.executable.target = #executable_target_rocm_hsaco_fb,
+ translation_info = #translation_32} {
+ // CHECK: scf.forall (%[[LANE_ID:[a-zA-Z0-9]+]]) in (32)
+ scf.forall (%arg6) in (32) {
+ // Each lane reads 4 elements per transfer.
+ // CHECK-DAG: %[[C4:[a-zA-Z0-9_]+]] = arith.constant 4
+ // CHECK-DAG: %[[LANE_OFFSET:[a-zA-Z0-9_]+]] = arith.muli %[[LANE_ID]], %[[C4]]
+ //
+ // Transfer 1: elements [0, 128), tile offset = 0
+ // CHECK-DAG: %[[TILE0:[a-zA-Z0-9_]+]] = arith.constant 0 : index
+ // CHECK: %[[SRC_IDX0:[a-zA-Z0-9_]+]] = arith.addi %[[TILE0]], %[[LANE_OFFSET]]
+ // CHECK: amdgpu.gather_to_lds %[[SRC]][%[[SRC_IDX0]]], %[[DST]][%[[TILE0]]] : vector<4xf32>
+ //
+ // Transfer 2: elements [128, 256), tile offset = 128
+ // CHECK: %[[TILE1:[a-zA-Z0-9_]+]] = arith.constant 128 : index
+ // CHECK: %[[SRC_IDX1:[a-zA-Z0-9_]+]] = arith.addi %[[TILE1]], %[[LANE_OFFSET]]
+ // CHECK: amdgpu.gather_to_lds %[[SRC]][%[[SRC_IDX1]]], %[[DST]][%[[TILE1]]] : vector<4xf32>
+ // CHECK-NOT: amdgpu.gather_to_lds
+ // CHECK-NOT: iree_gpu.coalesced_gather_dma
+ iree_gpu.coalesced_gather_dma %source into %dest lane(%arg6) : memref<256xf32, #amdgpu.address_space<fat_raw_buffer>>, memref<256xf32, #gpu.address_space<workgroup>>, index
+ } {mapping = [#gpu.thread<linear_dim_0>]}
+ return
+}
+
+// -----
+
+#executable_target_rocm_hsaco_fb = #hal.executable.target<"rocm",
+ "rocm-hsaco-fb", {iree_codegen.target_info = #iree_gpu.target<
+ arch = "gfx950", features = "", wgp = <
+ compute = fp64|fp32|fp16|int64|int32|int16|int8,
+ storage = b64|b32|b16|b8, subgroup = shuffle|arithmetic,
+ dot = dp4xi8toi32, mma = [], subgroup_size_choices = [32, 32],
+ max_workgroup_sizes = [1024, 1024, 1024],
+ max_thread_count_per_workgroup = 1024,
+ max_workgroup_memory_bytes = 65536,
+ max_workgroup_counts = [2147483647, 2147483647, 2147483647],
+ max_load_instruction_bits = 128, simds_per_wgp = 4,
+ vgpr_space_bits = 8192, dma_sizes = [128]>>}>
+
+#translation_32 = #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [32, 1, 1] subgroup_size = 32>
+
+// Test: N-transfer mode for 2D memref with wide innermost dimension.
+// - Shape: 2x256, innermostDimSize = 256
+// - subgroupSize = 32, dma_sizes = [128]
+// - elementsPerLane = 4, totalElementsPerTransfer = 128
+// - numTransfers per row = 256 / 128 = 2
+// - Total gather_to_lds ops = 2 rows * 2 transfers = 4
+//
+// CHECK-LABEL: func.func @lower_coalesced_dma_multiple_transfers_2d
+// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]: memref<2x256xf32, #amdgpu.address_space<fat_raw_buffer>>
+// CHECK-SAME: %[[DST:[a-zA-Z0-9]+]]: memref<2x256xf32, #gpu.address_space<workgroup>>
+func.func @lower_coalesced_dma_multiple_transfers_2d(
+ %source: memref<2x256xf32, #amdgpu.address_space<fat_raw_buffer>>,
+ %dest: memref<2x256xf32, #gpu.address_space<workgroup>>)
+ attributes {
+ hal.executable.target = #executable_target_rocm_hsaco_fb,
+ translation_info = #translation_32} {
+ // CHECK: scf.forall (%[[LANE_ID:[a-zA-Z0-9]+]]) in (32)
+ scf.forall (%arg6) in (32) {
+ // CHECK-DAG: %[[C4:[a-zA-Z0-9_]+]] = arith.constant 4
+ // CHECK-DAG: %[[LANE_OFFSET:[a-zA-Z0-9_]+]] = arith.muli %[[LANE_ID]], %[[C4]]
+ //
+ // Row 0, Transfer 1: [0, 0:128)
+ // CHECK: amdgpu.gather_to_lds %[[SRC]]{{.+}} : vector<4xf32>
+ //
+ // Row 0, Transfer 2: [0, 128:256)
+ // CHECK: amdgpu.gather_to_lds %[[SRC]]{{.+}} : vector<4xf32>
+ //
+ // Row 1, Transfer 1: [1, 0:128)
+ // CHECK: amdgpu.gather_to_lds %[[SRC]]{{.+}} : vector<4xf32>
+ //
+ // Row 1, Transfer 2: [1, 128:256)
+ // CHECK: amdgpu.gather_to_lds %[[SRC]]{{.+}} : vector<4xf32>
+ // CHECK-NOT: amdgpu.gather_to_lds
+ // CHECK-NOT: iree_gpu.coalesced_gather_dma
+ iree_gpu.coalesced_gather_dma %source into %dest lane(%arg6) : memref<2x256xf32, #amdgpu.address_space<fat_raw_buffer>>, memref<2x256xf32, #gpu.address_space<workgroup>>, index
+ } {mapping = [#gpu.thread<linear_dim_0>]}
+ return
+}
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_lower_coalesced_dma_to_global_loads.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_lower_coalesced_dma_to_global_loads.mlir
index 1290ec5..3235e1d 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_lower_coalesced_dma_to_global_loads.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_lower_coalesced_dma_to_global_loads.mlir
@@ -4,7 +4,7 @@
#executable_target_rocm_hsaco_fb = #hal.executable.target<"rocm",
"rocm-hsaco-fb", {iree_codegen.target_info = #iree_gpu.target<
- arch = "gfx1250", features = "", wgp = <
+ arch = "gfx950", features = "", wgp = <
compute = fp64|fp32|fp16|int64|int32|int16|int8,
storage = b64|b32|b16|b8, subgroup = shuffle|arithmetic,
dot = dp4xi8toi32, mma = [], subgroup_size_choices = [32, 32],
@@ -13,7 +13,7 @@
max_workgroup_memory_bytes = 65536,
max_workgroup_counts = [2147483647, 2147483647, 2147483647],
max_load_instruction_bits = 128, simds_per_wgp = 4,
- vgpr_space_bits = 8192>>}>
+ vgpr_space_bits = 8192, dma_sizes = [32, 128]>>}>
#translation = #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [64, 1, 1] subgroup_size = 32>
@@ -44,7 +44,7 @@
#executable_target_rocm_hsaco_fb = #hal.executable.target<"rocm",
"rocm-hsaco-fb", {iree_codegen.target_info = #iree_gpu.target<
- arch = "gfx1250", features = "", wgp = <
+ arch = "gfx950", features = "", wgp = <
compute = fp64|fp32|fp16|int64|int32|int16|int8,
storage = b64|b32|b16|b8, subgroup = shuffle|arithmetic,
dot = dp4xi8toi32, mma = [], subgroup_size_choices = [32, 32],
@@ -53,7 +53,7 @@
max_workgroup_memory_bytes = 65536,
max_workgroup_counts = [2147483647, 2147483647, 2147483647],
max_load_instruction_bits = 128, simds_per_wgp = 4,
- vgpr_space_bits = 8192>>}>
+ vgpr_space_bits = 8192, dma_sizes = [32, 128]>>}>
#translation = #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [32, 1, 1] subgroup_size = 32>
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/pipeline_coalesced_dma.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/pipeline_coalesced_dma.mlir
index 8b0ad21..6501fbc 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/pipeline_coalesced_dma.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/pipeline_coalesced_dma.mlir
@@ -189,3 +189,199 @@
}
}
}
+
+// -----
+
+// Test: Multiple DMA transfers per lane (N-transfer mode).
+// When innermost dimension > subgroupSize * elementsPerLane, multiple GatherToLDS
+// ops are generated to cover the entire dimension.
+//
+// With 4x128 f32 elements and 64 threads:
+// - innermost = 128, dma_sizes = [32, 128]
+// - dma_size=128: elementsPerLane=4, totalElementsPerTransfer=256, 128 % 256 != 0 -> skip
+// - dma_size=32: elementsPerLane=1, totalElementsPerTransfer=64, 128 % 64 = 0 -> 2 transfers
+// Each row requires 2 gather_to_lds ops (128/64 = 2 transfers per row).
+// 4 rows * 2 transfers = 8 total gather_to_lds ops.
+
+#executable_target_rocm_hsaco_fb = #hal.executable.target<"rocm",
+ "rocm-hsaco-fb", {iree_codegen.target_info = #iree_gpu.target<
+ arch = "gfx942", features = "", wgp = <
+ compute = fp64|fp32|fp16|int64|int32|int16|int8,
+ storage = b64|b32|b16|b8, subgroup = shuffle|arithmetic,
+ dot = dp4xi8toi32, mma = [], subgroup_size_choices = [64, 64],
+ max_workgroup_sizes = [1024, 1024, 1024],
+ max_thread_count_per_workgroup = 1024,
+ max_workgroup_memory_bytes = 65536,
+ max_workgroup_counts = [2147483647, 2147483647, 2147483647],
+ max_load_instruction_bits = 128, simds_per_wgp = 4,
+ vgpr_space_bits = 8192, dma_sizes = [32, 128]>>}>
+
+#pipeline_layout = #hal.pipeline.layout<bindings = [
+ #hal.pipeline.binding<storage_buffer>
+]>
+
+#translation = #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [64, 1, 1] subgroup_size = 64>
+
+// CHECK-LABEL: hal.executable public @coalesced_dma_multi_transfer
+hal.executable public @coalesced_dma_multi_transfer {
+ hal.executable.variant public @rocm_hsaco_fb target(#executable_target_rocm_hsaco_fb) {
+ hal.executable.export public @lower_coalesced_dma_multi_transfer ordinal(0) layout(#pipeline_layout) count(%arg0: !hal.device) -> (index, index, index) {
+ %x, %y, %z = iree_tensor_ext.dispatch.workgroup_count_from_slice()
+ hal.return %x, %y, %z : index, index, index
+ }
+ builtin.module {
+ // CHECK-LABEL: func.func @lower_coalesced_dma_multi_transfer
+ func.func @lower_coalesced_dma_multi_transfer()
+ attributes {
+ hal.executable.target = #executable_target_rocm_hsaco_fb,
+ translation_info = #translation} {
+ %c0 = arith.constant 0 : index
+ // CHECK: %[[SRC:.+]] = amdgpu.fat_raw_buffer_cast
+ %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : memref<4x128xf32, #hal.descriptor_type<storage_buffer>>
+ %assumed = memref.assume_alignment %0, 64 : memref<4x128xf32, #hal.descriptor_type<storage_buffer>>
+ %source = amdgpu.fat_raw_buffer_cast %assumed resetOffset : memref<4x128xf32, #hal.descriptor_type<storage_buffer>> to memref<4x128xf32, #amdgpu.address_space<fat_raw_buffer>>
+ // CHECK: %[[DST:.+]] = memref.alloc()
+ %dest = memref.alloc() : memref<4x128xf32, #gpu.address_space<workgroup>>
+ // CHECK: scf.forall (%[[LANE_ID:.+]]) in (64)
+ scf.forall (%arg6) in (64) {
+ // With 4x128 f32 elements and 64 threads:
+ // - innermost=128, can't use 128-bit (128 % 256 != 0), use 32-bit
+ // - elementsPerLane = 1, totalElementsPerTransfer = 64
+ // - Each row needs 128/64 = 2 transfers (at offsets 0 and 64)
+ // - 4 rows * 2 transfers = 8 gather_to_lds ops total
+ // CHECK-DAG: %[[C1:.+]] = arith.constant 1
+ // CHECK-DAG: %[[LANE_OFFSET:.+]] = arith.muli %[[LANE_ID]], %[[C1]]
+ //
+ // Row 0, Transfer 1: source[0, 0 + lane_offset], dest[0, 0]
+ // CHECK: %[[SRC_COL0_T0:.+]] = arith.addi %{{c0.+}}, %[[LANE_OFFSET]]
+ // CHECK: amdgpu.gather_to_lds %[[SRC]][%{{c0.+}}, %[[SRC_COL0_T0]]], %[[DST]][%{{c0.+}}, %{{c0.+}}] : vector<1xf32>
+ //
+ // Row 0, Transfer 2: source[0, 64 + lane_offset], dest[0, 64]
+ // CHECK: %[[C64:.+]] = arith.constant 64
+ // CHECK: %[[SRC_COL0_T1:.+]] = arith.addi %[[C64]], %[[LANE_OFFSET]]
+ // CHECK: amdgpu.gather_to_lds %[[SRC]][%{{c0.+}}, %[[SRC_COL0_T1]]], %[[DST]][%{{c0.+}}, %[[C64]]] : vector<1xf32>
+ //
+ // Row 1, Transfer 1: source[1, 0 + lane_offset], dest[1, 0]
+ // CHECK: %[[ROW1:.+]] = arith.constant 1
+ // CHECK: %[[SRC_COL1_T0:.+]] = arith.addi %{{c0.+}}, %[[LANE_OFFSET]]
+ // CHECK: amdgpu.gather_to_lds %[[SRC]][%[[ROW1]], %[[SRC_COL1_T0]]], %[[DST]][%[[ROW1]], %{{c0.+}}] : vector<1xf32>
+ //
+ // Row 1, Transfer 2: source[1, 64 + lane_offset], dest[1, 64]
+ // CHECK: %[[SRC_COL1_T1:.+]] = arith.addi %{{c64.+}}, %[[LANE_OFFSET]]
+ // CHECK: amdgpu.gather_to_lds %[[SRC]][%{{.+}}, %[[SRC_COL1_T1]]], %[[DST]][%{{.+}}, %{{c64.+}}] : vector<1xf32>
+ //
+ // Row 2, Transfer 1: source[2, 0 + lane_offset], dest[2, 0]
+ // CHECK: %[[ROW2:.+]] = arith.constant 2
+ // CHECK: %[[SRC_COL2_T0:.+]] = arith.addi %{{c0.+}}, %[[LANE_OFFSET]]
+ // CHECK: amdgpu.gather_to_lds %[[SRC]][%[[ROW2]], %[[SRC_COL2_T0]]], %[[DST]][%[[ROW2]], %{{c0.+}}] : vector<1xf32>
+ //
+ // Row 2, Transfer 2: source[2, 64 + lane_offset], dest[2, 64]
+ // CHECK: %[[SRC_COL2_T1:.+]] = arith.addi %{{c64.+}}, %[[LANE_OFFSET]]
+ // CHECK: amdgpu.gather_to_lds %[[SRC]][%{{.+}}, %[[SRC_COL2_T1]]], %[[DST]][%{{.+}}, %{{c64.+}}] : vector<1xf32>
+ //
+ // Row 3, Transfer 1: source[3, 0 + lane_offset], dest[3, 0]
+ // CHECK: %[[ROW3:.+]] = arith.constant 3
+ // CHECK: %[[SRC_COL3_T0:.+]] = arith.addi %{{c0.+}}, %[[LANE_OFFSET]]
+ // CHECK: amdgpu.gather_to_lds %[[SRC]][%[[ROW3]], %[[SRC_COL3_T0]]], %[[DST]][%[[ROW3]], %{{c0.+}}] : vector<1xf32>
+ //
+ // Row 3, Transfer 2: source[3, 64 + lane_offset], dest[3, 64]
+ // CHECK: %[[SRC_COL3_T1:.+]] = arith.addi %{{c64.+}}, %[[LANE_OFFSET]]
+ // CHECK: amdgpu.gather_to_lds %[[SRC]][%{{.+}}, %[[SRC_COL3_T1]]], %[[DST]][%{{.+}}, %{{c64.+}}] : vector<1xf32>
+ // CHECK-NOT: amdgpu.gather_to_lds
+ // CHECK-NOT: iree_gpu.coalesced_gather_dma
+ iree_gpu.coalesced_gather_dma %source into %dest lane(%arg6) :
+ memref<4x128xf32, #amdgpu.address_space<fat_raw_buffer>>,
+ memref<4x128xf32, #gpu.address_space<workgroup>>, index
+ } {mapping = [#gpu.thread<linear_dim_0>]}
+ return
+ }
+ }
+ }
+}
+
+// -----
+
+// Test: Multiple DMA transfers with 128-bit DMA size.
+// With 2x512 f32 elements and 64 threads:
+// - innermost = 512, dma_sizes = [32, 128]
+// - dma_size=128: elementsPerLane=4, totalElementsPerTransfer=256, 512 % 256 = 0 -> 2 transfers
+// Each row requires 2 gather_to_lds ops using 128-bit (4 f32) transfers.
+// 2 rows * 2 transfers = 4 total gather_to_lds ops.
+
+#executable_target_rocm_hsaco_fb = #hal.executable.target<"rocm",
+ "rocm-hsaco-fb", {iree_codegen.target_info = #iree_gpu.target<
+ arch = "gfx942", features = "", wgp = <
+ compute = fp64|fp32|fp16|int64|int32|int16|int8,
+ storage = b64|b32|b16|b8, subgroup = shuffle|arithmetic,
+ dot = dp4xi8toi32, mma = [], subgroup_size_choices = [64, 64],
+ max_workgroup_sizes = [1024, 1024, 1024],
+ max_thread_count_per_workgroup = 1024,
+ max_workgroup_memory_bytes = 65536,
+ max_workgroup_counts = [2147483647, 2147483647, 2147483647],
+ max_load_instruction_bits = 128, simds_per_wgp = 4,
+ vgpr_space_bits = 8192, dma_sizes = [32, 128]>>}>
+
+#pipeline_layout = #hal.pipeline.layout<bindings = [
+ #hal.pipeline.binding<storage_buffer>
+]>
+
+#translation = #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [64, 1, 1] subgroup_size = 64>
+
+// CHECK-LABEL: hal.executable public @coalesced_dma_multi_transfer_128bit
+hal.executable public @coalesced_dma_multi_transfer_128bit {
+ hal.executable.variant public @rocm_hsaco_fb target(#executable_target_rocm_hsaco_fb) {
+ hal.executable.export public @lower_coalesced_dma_multi_transfer_128bit ordinal(0) layout(#pipeline_layout) count(%arg0: !hal.device) -> (index, index, index) {
+ %x, %y, %z = iree_tensor_ext.dispatch.workgroup_count_from_slice()
+ hal.return %x, %y, %z : index, index, index
+ }
+ builtin.module {
+ // CHECK-LABEL: func.func @lower_coalesced_dma_multi_transfer_128bit
+ func.func @lower_coalesced_dma_multi_transfer_128bit()
+ attributes {
+ hal.executable.target = #executable_target_rocm_hsaco_fb,
+ translation_info = #translation} {
+ %c0 = arith.constant 0 : index
+ // CHECK: %[[SRC:.+]] = amdgpu.fat_raw_buffer_cast
+ %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : memref<2x512xf32, #hal.descriptor_type<storage_buffer>>
+ %assumed = memref.assume_alignment %0, 64 : memref<2x512xf32, #hal.descriptor_type<storage_buffer>>
+ %source = amdgpu.fat_raw_buffer_cast %assumed resetOffset : memref<2x512xf32, #hal.descriptor_type<storage_buffer>> to memref<2x512xf32, #amdgpu.address_space<fat_raw_buffer>>
+ // CHECK: %[[DST:.+]] = memref.alloc()
+ %dest = memref.alloc() : memref<2x512xf32, #gpu.address_space<workgroup>>
+ // CHECK: scf.forall (%[[LANE_ID:.+]]) in (64)
+ scf.forall (%arg6) in (64) {
+ // With 2x512 f32 elements and 64 threads:
+ // - innermost=512, use 128-bit (512 % 256 = 0)
+ // - elementsPerLane = 4, totalElementsPerTransfer = 256
+ // - Each row needs 512/256 = 2 transfers (at offsets 0 and 256)
+ // - 2 rows * 2 transfers = 4 gather_to_lds ops total
+ // CHECK-DAG: %[[C4:.+]] = arith.constant 4
+ // CHECK-DAG: %[[LANE_OFFSET:.+]] = arith.muli %[[LANE_ID]], %[[C4]]
+ //
+ // Row 0, Transfer 1: source[0, 0 + lane_offset], dest[0, 0]
+ // CHECK: %[[SRC_COL0_T0:.+]] = arith.addi %{{c0.+}}, %[[LANE_OFFSET]]
+ // CHECK: amdgpu.gather_to_lds %[[SRC]][%{{c0.+}}, %[[SRC_COL0_T0]]], %[[DST]][%{{c0.+}}, %{{c0.+}}] : vector<4xf32>
+ //
+ // Row 0, Transfer 2: source[0, 256 + lane_offset], dest[0, 256]
+ // CHECK: %[[C256:.+]] = arith.constant 256
+ // CHECK: %[[SRC_COL0_T1:.+]] = arith.addi %[[C256]], %[[LANE_OFFSET]]
+ // CHECK: amdgpu.gather_to_lds %[[SRC]][%{{c0.+}}, %[[SRC_COL0_T1]]], %[[DST]][%{{c0.+}}, %[[C256]]] : vector<4xf32>
+ //
+ // Row 1, Transfer 1: source[1, 0 + lane_offset], dest[1, 0]
+ // CHECK: %[[ROW1:.+]] = arith.constant 1
+ // CHECK: %[[SRC_COL1_T0:.+]] = arith.addi %{{c0.+}}, %[[LANE_OFFSET]]
+ // CHECK: amdgpu.gather_to_lds %[[SRC]][%[[ROW1]], %[[SRC_COL1_T0]]], %[[DST]][%[[ROW1]], %{{c0.+}}] : vector<4xf32>
+ //
+ // Row 1, Transfer 2: source[1, 256 + lane_offset], dest[1, 256]
+ // CHECK: %[[SRC_COL1_T1:.+]] = arith.addi %{{c256.+}}, %[[LANE_OFFSET]]
+ // CHECK: amdgpu.gather_to_lds %[[SRC]][%{{.+}}, %[[SRC_COL1_T1]]], %[[DST]][%{{.+}}, %{{c256.+}}] : vector<4xf32>
+ // CHECK-NOT: amdgpu.gather_to_lds
+ // CHECK-NOT: iree_gpu.coalesced_gather_dma
+ iree_gpu.coalesced_gather_dma %source into %dest lane(%arg6) :
+ memref<2x512xf32, #amdgpu.address_space<fat_raw_buffer>>,
+ memref<2x512xf32, #gpu.address_space<workgroup>>, index
+ } {mapping = [#gpu.thread<linear_dim_0>]}
+ return
+ }
+ }
+ }
+}
diff --git a/tests/e2e/matmul/BUILD.bazel b/tests/e2e/matmul/BUILD.bazel
index e40eb9c..3f42c1d 100644
--- a/tests/e2e/matmul/BUILD.bazel
+++ b/tests/e2e/matmul/BUILD.bazel
@@ -482,3 +482,60 @@
test_runner = "//tools/testing/e2e:iree-e2e-matmul-test",
test_type = "matmul",
)
+
+# Third coalesced DMA test with wide K/N to trigger multiple DMA transfers per lane.
+# This tests the N-transfer mode where innermostDimSize > subgroupSize * elementsPerLane.
+iree_generated_e2e_runner_test(
+ name = "e2e_matmul_cdna3_coalesced_dma_f32_multi_transfer",
+ compiler_flags = [
+ "--iree-hip-target=gfx942",
+ "--iree-llvmgpu-use-direct-load",
+ ],
+ generator = ":generate_e2e_matmul_tests",
+ generator_args = [
+ "--lhs_rhs_type=f32",
+ "--acc_type=f32",
+ "--shapes=custom_mnk",
+ "--mnk=32,128,128",
+ ],
+ tags = [
+ "noasan",
+ "nomsan",
+ "notsan",
+ "noubsan",
+ "requires-gpu-cdna3",
+ ],
+ target_backends_and_drivers = [
+ ("rocm", "hip"),
+ ],
+ test_runner = "//tools/testing/e2e:iree-e2e-matmul-test",
+ test_type = "matmul",
+)
+
+# Tests coalesced DMA with larger dimensions requiring multiple transfers.
+iree_generated_e2e_runner_test(
+ name = "e2e_matmul_cdna3_coalesced_dma_f32_large",
+ compiler_flags = [
+ "--iree-hip-target=gfx942",
+ "--iree-llvmgpu-use-direct-load",
+ ],
+ generator = ":generate_e2e_matmul_tests",
+ generator_args = [
+ "--lhs_rhs_type=f32",
+ "--acc_type=f32",
+ "--shapes=custom_mnk",
+ "--mnk=128,256,512",
+ ],
+ tags = [
+ "noasan",
+ "nomsan",
+ "notsan",
+ "noubsan",
+ "requires-gpu-cdna3",
+ ],
+ target_backends_and_drivers = [
+ ("rocm", "hip"),
+ ],
+ test_runner = "//tools/testing/e2e:iree-e2e-matmul-test",
+ test_type = "matmul",
+)
diff --git a/tests/e2e/matmul/CMakeLists.txt b/tests/e2e/matmul/CMakeLists.txt
index e8e872b..7e0418d 100644
--- a/tests/e2e/matmul/CMakeLists.txt
+++ b/tests/e2e/matmul/CMakeLists.txt
@@ -1200,6 +1200,64 @@
"requires-gpu-cdna3"
)
+iree_generated_e2e_runner_test(
+ NAME
+ e2e_matmul_cdna3_coalesced_dma_f32_multi_transfer
+ TEST_TYPE
+ matmul
+ GENERATOR
+ "generate_e2e_matmul_tests.py"
+ GENERATOR_ARGS
+ "--lhs_rhs_type=f32"
+ "--acc_type=f32"
+ "--shapes=custom_mnk"
+ "--mnk=32,128,128"
+ TEST_RUNNER
+ iree_tools_testing_e2e_iree-e2e-matmul-test
+ TARGET_BACKENDS
+ "rocm"
+ DRIVERS
+ "hip"
+ COMPILER_FLAGS
+ "--iree-hip-target=gfx942"
+ "--iree-llvmgpu-use-direct-load"
+ LABELS
+ "noasan"
+ "nomsan"
+ "notsan"
+ "noubsan"
+ "requires-gpu-cdna3"
+)
+
+iree_generated_e2e_runner_test(
+ NAME
+ e2e_matmul_cdna3_coalesced_dma_f32_large
+ TEST_TYPE
+ matmul
+ GENERATOR
+ "generate_e2e_matmul_tests.py"
+ GENERATOR_ARGS
+ "--lhs_rhs_type=f32"
+ "--acc_type=f32"
+ "--shapes=custom_mnk"
+ "--mnk=128,256,512"
+ TEST_RUNNER
+ iree_tools_testing_e2e_iree-e2e-matmul-test
+ TARGET_BACKENDS
+ "rocm"
+ DRIVERS
+ "hip"
+ COMPILER_FLAGS
+ "--iree-hip-target=gfx942"
+ "--iree-llvmgpu-use-direct-load"
+ LABELS
+ "noasan"
+ "nomsan"
+ "notsan"
+ "noubsan"
+ "requires-gpu-cdna3"
+)
+
### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
# To distinguish between CDNA(gfx9), RDNA3(gfx11), and RDNA4(gfx12)