[Codegen] Limit async scope in pipelining (#24350)

Currently, the pipelining assumes that all `amdgpu.gather_to_lds`
operations are part of the loop that is being pipelined and marks them
as async, with waits inserted in the pipelined loop.

This assumption was fine for the workloads so far, but with work
underway to enable pipelining for attention, this no longer holds, as
the load of e.g. the `Q` matrix can be outside of the loop. Operations
outside the loop can also be marked async, but that requires inserting
marks and waits before the loop.

This PR fixes this by recording a marker in the block before the loop
and later on inserting marks and waits for async operations before the
marker.

This is part of https://github.com/iree-org/iree/issues/23782.

Assisted-by: Claude Code and Codex

---------

Signed-off-by: Lukas Sommer <lukas.sommer@amd.com>
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/ROCDLPrefetchSharedMemoryCopy.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/ROCDLPrefetchSharedMemoryCopy.cpp
index eb4d47b..b1ad158 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/ROCDLPrefetchSharedMemoryCopy.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/ROCDLPrefetchSharedMemoryCopy.cpp
@@ -1075,11 +1075,48 @@
   return nullptr;
 }
 
-/// Sets the async flag on all gather_to_lds ops in the parent block so they
+/// Sets the async flag on gather_to_lds ops from `begin` to `end` so they
 /// lower to rocdl.load.async.to.lds instead of rocdl.load.to.lds.
-static void enableAsyncOnGatherOps(Block *parentBlock) {
-  parentBlock->walk(
-      [](amdgpu::GatherToLDSOp gatherOp) { gatherOp.setAsync(true); });
+static void enableAsyncOnGatherOps(Block::iterator begin, Block::iterator end) {
+  for (auto it = begin; it != end; ++it) {
+    it->walk([](amdgpu::GatherToLDSOp gatherOp) { gatherOp.setAsync(true); });
+  }
+}
+
+/// Inserts an asyncmark/wait.asyncmark 0 pair before `insertPt`.
+static void insertAsyncDrainBefore(RewriterBase &rewriter, Location loc,
+                                   Operation *insertPt) {
+  OpBuilder::InsertionGuard guard(rewriter);
+  rewriter.setInsertionPoint(insertPt);
+  ROCDL::AsyncmarkOp::create(rewriter, loc);
+  ROCDL::WaitAsyncmarkOp::create(rewriter, loc, rewriter.getI16IntegerAttr(0));
+}
+
+/// Converts direct pre-existing gather_to_lds ops before the pipelined loop to
+/// async mode and inserts explicit waits that preserve their original
+/// synchronous behavior.
+static void insertPreLoopAsyncMarkers(RewriterBase &rewriter, Location loc,
+                                      Block::iterator preLoopStart,
+                                      Block::iterator preLoopEnd) {
+  bool hasPendingAsyncGather = false;
+  for (auto it = preLoopStart; it != preLoopEnd;) {
+    Operation *op = &*it++;
+
+    if (hasPendingAsyncGather &&
+        (isa<gpu::BarrierOp>(op) || hasNestedSharedRead(op))) {
+      insertAsyncDrainBefore(rewriter, loc, op);
+      hasPendingAsyncGather = false;
+    }
+
+    if (auto gatherOp = dyn_cast<amdgpu::GatherToLDSOp>(op)) {
+      gatherOp.setAsync(true);
+      hasPendingAsyncGather = true;
+    }
+  }
+
+  if (hasPendingAsyncGather) {
+    insertAsyncDrainBefore(rewriter, loc, &*preLoopEnd);
+  }
 }
 
 /// Inserts asyncmark ops in the prologue to delineate DMA groups.
@@ -1088,11 +1125,11 @@
 /// Each iteration group gets one asyncmark after its last gather_to_lds.
 /// Groups are identified by evenly dividing the total prologue gather ops.
 static void insertPrologueAsyncMarks(RewriterBase &rewriter, Location loc,
-                                     Block *parentBlock,
+                                     Block::iterator prologueStart,
                                      Block::iterator loopStart,
                                      unsigned numStages) {
   SmallVector<Operation *> prologueGathers;
-  for (auto it = parentBlock->begin(); it != loopStart; ++it) {
+  for (auto it = prologueStart; it != loopStart; ++it) {
     if (isa<amdgpu::GatherToLDSOp>(&*it)) {
       prologueGathers.push_back(&*it);
     } else if (it->getNumRegions() > 0 && containsNestedGatherToLDS(&*it)) {
@@ -1175,15 +1212,16 @@
 /// new DMA group we wait until only (N-1) groups are in flight, ensuring the
 /// oldest group's data is ready for reading.
 static void insertExplicitAsyncMarkers(RewriterBase &rewriter,
-                                       scf::ForOp newForOp,
-                                       unsigned numStages) {
+                                       scf::ForOp newForOp, unsigned numStages,
+                                       Block::iterator prologueStart) {
   Block *parentBlock = newForOp->getBlock();
   Location loc = newForOp.getLoc();
   int16_t waitCount = static_cast<int16_t>(numStages - 1);
 
-  enableAsyncOnGatherOps(parentBlock);
-  insertPrologueAsyncMarks(rewriter, loc, parentBlock, newForOp->getIterator(),
-                           numStages);
+  insertPreLoopAsyncMarkers(rewriter, loc, parentBlock->begin(), prologueStart);
+  enableAsyncOnGatherOps(prologueStart, parentBlock->end());
+  insertPrologueAsyncMarks(rewriter, loc, prologueStart,
+                           newForOp->getIterator(), numStages);
   insertLoopBodyAsyncMarkers(rewriter, loc, newForOp, waitCount);
   insertEpilogueAsyncWait(rewriter, loc, std::next(newForOp->getIterator()),
                           parentBlock->end());
@@ -1278,11 +1316,17 @@
     return forOp;
   }
 
+  Operation *opBeforeLoop = nullptr;
   if (mode == PipelineMode::AsyncCopy) {
     // Apply multi-buffering: numStages buffers for N-stage pipelining.
     if (failed(multiBufferLDSAllocations(forOp, /*numBuffers=*/numStages))) {
       return failure();
     }
+    // Record the operation before the original loop. After pipelining, the next
+    // operation after this marker is the start of the generated prologue. Ops
+    // before that boundary are pre-existing parent-block ops and are explicitly
+    // drained separately from the pipelined async groups.
+    opBeforeLoop = forOp->getPrevNode();
   }
 
   // Re-run classification on the (potentially multi-buffered) IR to capture
@@ -1334,7 +1378,12 @@
   // allowing DMA writes to a new buffer slot to overlap with ds_reads from the
   // previous slot.
   if (mode == PipelineMode::AsyncCopy) {
-    insertExplicitAsyncMarkers(rewriter, newForOp, numStages);
+    // Compute the start of the pipelining prologue, skipping any pre-existing
+    // ops that were before the original loop.
+    Block::iterator prologueStart = opBeforeLoop
+                                        ? std::next(opBeforeLoop->getIterator())
+                                        : newForOp->getBlock()->begin();
+    insertExplicitAsyncMarkers(rewriter, newForOp, numStages, prologueStart);
   }
 
   return newForOp;
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/prefetch_shared_memory.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/prefetch_shared_memory.mlir
index 019c0ab..bda838f 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/prefetch_shared_memory.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/prefetch_shared_memory.mlir
@@ -799,3 +799,95 @@
   vector.transfer_write %result, %output[%c0] {in_bounds = [true]} : vector<1xf32>, memref<128xf32>
   return
 }
+
+// -----
+
+// Test that gather_to_lds ops before the loop (e.g., Q loads in attention) are
+// converted to async with explicit synchronization. This is intentionally two
+// pre-loop gathers to verify that they are grouped and drained before use.
+// CHECK-LABEL: @gather_to_lds_pre_loop_gets_explicit_wait
+func.func @gather_to_lds_pre_loop_gets_explicit_wait(
+    %A_global: memref<128x128xf32>,
+    %B_global: memref<128x128xf32>,
+    %C_global: memref<128xf32>) {
+  %cst_0 = arith.constant 0.000000e+00 : f32
+  %c128 = arith.constant 128 : index
+  %c1 = arith.constant 1 : index
+  %c0 = arith.constant 0 : index
+
+  %Q_lds = memref.alloc() : memref<2xf32, #gpu.address_space<workgroup>>
+  %K_lds = memref.alloc() : memref<1xf32, #gpu.address_space<workgroup>>
+
+  // Pre-loop gather_to_lds ops (e.g., loading Q) - async with explicit drain.
+  // CHECK: amdgpu.gather_to_lds async
+  amdgpu.gather_to_lds %A_global[%c0, %c0], %Q_lds[%c0] : vector<1xf32>, memref<128x128xf32>, memref<2xf32, #gpu.address_space<workgroup>>
+  // CHECK: amdgpu.gather_to_lds async
+  amdgpu.gather_to_lds %A_global[%c0, %c1], %Q_lds[%c1] : vector<1xf32>, memref<128x128xf32>, memref<2xf32, #gpu.address_space<workgroup>>
+  // CHECK: rocdl.asyncmark
+  // CHECK: rocdl.wait.asyncmark 0
+  // CHECK: gpu.barrier
+  // CHECK: vector.transfer_read
+
+  // Barrier + read Q from LDS before the loop, mirroring the real attention
+  // pattern where Q is loaded, synchronized, and consumed before K/V pipelining.
+  gpu.barrier
+  %q_val = vector.transfer_read %Q_lds[%c0], %cst_0 : memref<2xf32, #gpu.address_space<workgroup>>, vector<2xf32>
+
+  // Pipelined loop with gather_to_lds - these should be async.
+  // CHECK: amdgpu.gather_to_lds async
+  // CHECK: rocdl.asyncmark
+  // CHECK: scf.for
+  %result = scf.for %k = %c0 to %c128 step %c1 iter_args(%acc = %q_val) -> (vector<2xf32>) {
+    amdgpu.gather_to_lds %B_global[%c0, %k], %K_lds[%c0] : vector<1xf32>, memref<128x128xf32>, memref<1xf32, #gpu.address_space<workgroup>>
+    %k_val = vector.transfer_read %K_lds[%c0], %cst_0 : memref<1xf32, #gpu.address_space<workgroup>>, vector<1xf32>
+    %k_splat = vector.broadcast %k_val : vector<1xf32> to vector<2xf32>
+    %prod = arith.mulf %k_splat, %acc : vector<2xf32>
+    scf.yield %prod : vector<2xf32>
+  }
+
+  vector.transfer_write %result, %C_global[%c0] {in_bounds = [true]} : vector<2xf32>, memref<128xf32>
+  return
+}
+
+// -----
+
+// Test that the pre-loop async drain is inserted before the first barrier even
+// when the shared read is not immediately adjacent to that barrier.
+// CHECK-LABEL: @gather_to_lds_pre_loop_wait_before_non_adjacent_read
+func.func @gather_to_lds_pre_loop_wait_before_non_adjacent_read(
+    %A_global: memref<128x128xf32>,
+    %B_global: memref<128x128xf32>,
+    %C_global: memref<128xf32>,
+    %offset: index) {
+  %cst_0 = arith.constant 0.000000e+00 : f32
+  %c128 = arith.constant 128 : index
+  %c1 = arith.constant 1 : index
+  %c0 = arith.constant 0 : index
+
+  %Q_lds = memref.alloc() : memref<1xf32, #gpu.address_space<workgroup>>
+  %K_lds = memref.alloc() : memref<1xf32, #gpu.address_space<workgroup>>
+
+  // CHECK: amdgpu.gather_to_lds async
+  amdgpu.gather_to_lds %A_global[%c0, %c0], %Q_lds[%c0] : vector<1xf32>, memref<128x128xf32>, memref<1xf32, #gpu.address_space<workgroup>>
+  // CHECK: rocdl.asyncmark
+  // CHECK: rocdl.wait.asyncmark 0
+  // CHECK: gpu.barrier
+  // CHECK: arith.addi
+  // CHECK: vector.transfer_read
+  gpu.barrier
+  %write_idx = arith.addi %offset, %c1 : index
+  %q_val = vector.transfer_read %Q_lds[%c0], %cst_0 : memref<1xf32, #gpu.address_space<workgroup>>, vector<1xf32>
+
+  // CHECK: amdgpu.gather_to_lds async
+  // CHECK: rocdl.asyncmark
+  // CHECK: scf.for
+  %result = scf.for %k = %c0 to %c128 step %c1 iter_args(%acc = %q_val) -> (vector<1xf32>) {
+    amdgpu.gather_to_lds %B_global[%c0, %k], %K_lds[%c0] : vector<1xf32>, memref<128x128xf32>, memref<1xf32, #gpu.address_space<workgroup>>
+    %k_val = vector.transfer_read %K_lds[%c0], %cst_0 : memref<1xf32, #gpu.address_space<workgroup>>, vector<1xf32>
+    %prod = arith.mulf %k_val, %acc : vector<1xf32>
+    scf.yield %prod : vector<1xf32>
+  }
+
+  vector.transfer_write %result, %C_global[%write_idx] {in_bounds = [true]} : vector<1xf32>, memref<128xf32>
+  return
+}