[LLVMGPU][ROCDL] Add pass to group global loads for better instruction scheduling (#24247)
Adds LLVMGPUGroupGlobalLoadsPass which moves global loads in the same
block to be adjacent to each other when they are separated by pure
address computation ops. The pass moves each load along with its
transitive dependency chain to be right after the preceding global load.
This improves performance in situations where LLVM is not able to
convert address computation into a single base + constant offset. In
such cases, instruction scheduling can become pessimistic and each
global load needs to be waited on before the next is issued. With this
instruction reordering, all global loads are issued together after
address computation is completed.
Based on benchmarks with this change alone, we don't have any cases in
our suite of kernels that runs into this issue today. However, some
convolution shapes run into the issue after the changes in
https://github.com/iree-org/iree/pull/24245, and this PR prevents such
regressions.
This is only enabled for ROCDL in this PR, because we don't have any
data points to support adding it to other pipelines yet.
---------
Signed-off-by: Max Dawkins <max.dawkins@gmail.com>
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel
index f17fa96..b304f6b 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel
@@ -101,6 +101,7 @@
"LLVMGPUCastAddressSpaceFunction.cpp",
"LLVMGPUConfigureTensorLayouts.cpp",
"LLVMGPUConstraintGenerator.cpp",
+ "LLVMGPUGroupGlobalLoads.cpp",
"LLVMGPULegalizeNDVectors.cpp",
"LLVMGPULinkExecutables.cpp",
"LLVMGPULowerExecutableTarget.cpp",
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt
index cc8c125..bc5b5fa 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt
@@ -81,6 +81,7 @@
"LLVMGPUCastAddressSpaceFunction.cpp"
"LLVMGPUConfigureTensorLayouts.cpp"
"LLVMGPUConstraintGenerator.cpp"
+ "LLVMGPUGroupGlobalLoads.cpp"
"LLVMGPULegalizeNDVectors.cpp"
"LLVMGPULinkExecutables.cpp"
"LLVMGPULowerExecutableTarget.cpp"
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUGroupGlobalLoads.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUGroupGlobalLoads.cpp
new file mode 100644
index 0000000..d6c295f
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUGroupGlobalLoads.cpp
@@ -0,0 +1,211 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Codegen/LLVMGPU/Passes.h"
+#include "iree/compiler/Codegen/Utils/GPUUtils.h"
+#include "llvm/ADT/SetVector.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
+
+#define DEBUG_TYPE "iree-codegen-llvmgpu-group-global-loads"
+
+namespace mlir::iree_compiler {
+
+#define GEN_PASS_DEF_LLVMGPUGROUPGLOBALLOADSPASS
+#include "iree/compiler/Codegen/LLVMGPU/Passes.h.inc"
+
+namespace {
+
+/// Returns true if the operation is a load from global memory.
+static bool isGlobalLoad(Operation *op) {
+ Type memrefType;
+ if (auto loadOp = dyn_cast<vector::LoadOp>(op)) {
+ memrefType = loadOp.getBase().getType();
+ } else if (auto loadOp = dyn_cast<memref::LoadOp>(op)) {
+ memrefType = loadOp.getMemref().getType();
+ } else {
+ return false;
+ }
+ auto memref = dyn_cast<MemRefType>(memrefType);
+ return memref && hasGlobalMemoryAddressSpace(memref);
+}
+
+/// Collects all ops in the same block that `op` transitively depends on
+/// and that are strictly between `boundary` and `op`. These are the ops
+/// that must be moved along with `op` if it is hoisted above `boundary`.
+static void collectDepsInRange(Operation *op, Operation *boundary,
+ llvm::SetVector<Operation *> &deps) {
+ for (Value operand : op->getOperands()) {
+ Operation *defOp = operand.getDefiningOp();
+ if (!defOp || defOp->getBlock() != op->getBlock()) {
+ continue;
+ }
+ if (!boundary->isBeforeInBlock(defOp)) {
+ continue;
+ }
+ if (deps.insert(defOp)) {
+ collectDepsInRange(defOp, boundary, deps);
+ }
+ }
+}
+
+/// Returns true if `op` can be moved before `boundary` while preserving SSA
+/// dominance. `movedDeps` contains earlier dependencies that will also be
+/// moved before `boundary`.
+static bool
+canMoveBeforeBoundary(Operation *op, Operation *boundary,
+ const llvm::SetVector<Operation *> &movedDeps) {
+ for (Value operand : op->getOperands()) {
+ Operation *defOp = operand.getDefiningOp();
+ if (!defOp || defOp->getBlock() != op->getBlock()) {
+ continue;
+ }
+ if (defOp->isBeforeInBlock(boundary) || movedDeps.contains(defOp)) {
+ continue;
+ }
+ return false;
+ }
+ return true;
+}
+
+/// Returns true if `op` writes to global memory. Used to decide
+/// whether a non-dependent op left between the previous global load and the
+/// load being hoisted would invalidate the load by changing observable
+/// memory.
+static bool writesToGlobalMemory(Operation *op) {
+ if (isPure(op)) {
+ return false;
+ }
+ auto effectOp = dyn_cast<MemoryEffectOpInterface>(op);
+ if (!effectOp) {
+ // Non-pure op without the memory-effects interface; conservatively assume
+ // it could write to buffer or global memory.
+ return true;
+ }
+ SmallVector<MemoryEffects::EffectInstance> effects;
+ effectOp.getEffects(effects);
+ for (const MemoryEffects::EffectInstance &effect : effects) {
+ if (!isa<MemoryEffects::Write>(effect.getEffect())) {
+ continue;
+ }
+ Value value = effect.getValue();
+ if (!value) {
+ // Write to an unknown resource; be conservative.
+ return true;
+ }
+ auto memrefType = dyn_cast<MemRefType>(value.getType());
+ if (!memrefType) {
+ continue;
+ }
+ if (hasGlobalMemoryAddressSpace(memrefType)) {
+ return true;
+ }
+ }
+ return false;
+}
+
+/// Groups global loads within each block by hoisting each load (along with
+/// its pure address-computation dependencies) to be adjacent to the
+/// preceding global load.
+static void groupGlobalLoadsInBlock(Block &block) {
+ SmallVector<Operation *> globalLoads;
+ for (Operation &op : block) {
+ if (isGlobalLoad(&op)) {
+ globalLoads.push_back(&op);
+ }
+ }
+
+ Operation *prevGlobalLoad = nullptr;
+ for (Operation *load : globalLoads) {
+ if (!prevGlobalLoad) {
+ prevGlobalLoad = load;
+ continue;
+ }
+
+ if (!prevGlobalLoad->isBeforeInBlock(load)) {
+ prevGlobalLoad = load;
+ continue;
+ }
+
+ // Collect the ops between `prevGlobalLoad` and `load` that the load
+ // transitively depends on and would need to be hoisted alongside it.
+ llvm::SetVector<Operation *> deps;
+ collectDepsInRange(load, prevGlobalLoad, deps);
+
+ // The dependencies move with the load, but they must be pure so that
+ // hoisting them above unrelated ops in the range is safe.
+ if (!llvm::all_of(deps, [](Operation *op) { return isPure(op); })) {
+ prevGlobalLoad = load;
+ continue;
+ }
+
+ // Non-dependent ops in (prevGlobalLoad, load) are left in place after the
+ // load is hoisted. If any of them writes to global memory the
+ // hoisted load could observe stale memory, so the move is unsafe.
+ bool unsafe = false;
+ for (Operation *cur = prevGlobalLoad->getNextNode(); cur && cur != load;
+ cur = cur->getNextNode()) {
+ if (deps.contains(cur)) {
+ continue;
+ }
+ if (writesToGlobalMemory(cur)) {
+ unsafe = true;
+ break;
+ }
+ }
+ if (unsafe) {
+ prevGlobalLoad = load;
+ continue;
+ }
+
+ // Move deps in topological order, then the load itself.
+ // Since deps are collected via DFS, we need to sort them by their
+ // original position in the block to maintain valid SSA ordering.
+ SmallVector<Operation *> sortedDeps(deps.begin(), deps.end());
+ llvm::sort(sortedDeps, [](Operation *a, Operation *b) {
+ return a->isBeforeInBlock(b);
+ });
+
+ // If an address-computation dependency does not depend on the previous
+ // global load, move it before that load. That lets the global loads become
+ // adjacent while preserving the dependency order.
+ llvm::SetVector<Operation *> depsBeforePrevLoad;
+ SmallVector<Operation *> depsAfterPrevLoad;
+ for (Operation *dep : sortedDeps) {
+ if (canMoveBeforeBoundary(dep, prevGlobalLoad, depsBeforePrevLoad)) {
+ depsBeforePrevLoad.insert(dep);
+ continue;
+ }
+ depsAfterPrevLoad.push_back(dep);
+ }
+
+ for (Operation *dep : depsBeforePrevLoad) {
+ dep->moveBefore(prevGlobalLoad);
+ }
+
+ Operation *insertAfter = prevGlobalLoad;
+ for (Operation *dep : depsAfterPrevLoad) {
+ dep->moveAfter(insertAfter);
+ insertAfter = dep;
+ }
+ load->moveAfter(insertAfter);
+
+ prevGlobalLoad = load;
+ }
+}
+
+struct LLVMGPUGroupGlobalLoadsPass final
+ : impl::LLVMGPUGroupGlobalLoadsPassBase<LLVMGPUGroupGlobalLoadsPass> {
+ void runOnOperation() override {
+ FunctionOpInterface funcOp = getOperation();
+ funcOp.walk([](Block *block) { groupGlobalLoadsInBlock(*block); });
+ }
+};
+
+} // namespace
+
+} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
index d68d07b..b8cf60c 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
@@ -1049,6 +1049,13 @@
clLLVMGPUEnableSmallFloatEmulation});
});
+ // Group global loads together to improve AMDGPU instruction scheduling.
+ // The transformation is target-agnostic, but currently only enabled for
+ // ROCDL targets until there is data to support that it benefits other
+ // targets.
+ funcPassManager.addPredicatedPass(forROCDL,
+ createLLVMGPUGroupGlobalLoadsPass);
+
// Commit the func-level adaptor before adding module-level passes.
funcPassManager.commitPass();
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.td b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.td
index a0217e9..82bd866 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.td
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.td
@@ -48,6 +48,18 @@
let summary = "Pass to set layouts on tensors for later vector distribution";
}
+def LLVMGPUGroupGlobalLoadsPass :
+ InterfacePass<"iree-llvmgpu-group-global-loads", "mlir::FunctionOpInterface"> {
+ let summary = "Group adjacent global loads to improve GPU instruction scheduling";
+ let description = [{
+ Moves vector.load and memref.load operations from global-memory memrefs next
+ to each other when they are separated only by operations that do not depend
+ on the preceding load's result. This enables the GPU backend to issue
+ multiple global loads before waiting, instead of serializing each load
+ behind its own waitcount.
+ }];
+}
+
def LLVMGPU1DVectorCanonicalizationsPass :
InterfacePass<"iree-llvmgpu-1d-vector-canonicalizations", "mlir::FunctionOpInterface"> {
let summary = "Canonicalization patterns for 1-D vectors after legalization.";
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel
index b531490..9f6f8e4 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel
@@ -51,6 +51,7 @@
"linalg_transform.mlir",
"link_executables.mlir",
"llvmgpu_bufferize.mlir",
+ "llvmgpu_group_global_loads.mlir",
"nvvm_pipeline_test.mlir",
"pack_shared_memory_alloc.mlir",
"pipeline_coalesced_dma.mlir",
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt
index fc70e25..8b60996 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt
@@ -46,6 +46,7 @@
"linalg_transform.mlir"
"link_executables.mlir"
"llvmgpu_bufferize.mlir"
+ "llvmgpu_group_global_loads.mlir"
"nvvm_pipeline_test.mlir"
"pack_shared_memory_alloc.mlir"
"pipeline_coalesced_dma.mlir"
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/llvmgpu_group_global_loads.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/llvmgpu_group_global_loads.mlir
new file mode 100644
index 0000000..ee2f380
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/llvmgpu_group_global_loads.mlir
@@ -0,0 +1,208 @@
+// RUN: iree-opt --split-input-file --pass-pipeline='builtin.module(func.func(iree-llvmgpu-group-global-loads))' %s | FileCheck %s
+
+// Two global loads separated by a pure (arith.addf) op that doesn't depend on
+// either load. The pass hoists the second load to sit immediately after the
+// first load.
+// CHECK-LABEL: func.func @group_two_loads
+// CHECK: %[[L0:.+]] = vector.load %{{.+}}[%{{.+}}]
+// CHECK-NEXT: %[[L1:.+]] = vector.load %{{.+}}[%{{.+}}]
+// CHECK-NEXT: arith.addf
+// CHECK-NEXT: return %[[L0]], %[[L1]]
+func.func @group_two_loads(%a: memref<256xf32, #amdgpu.address_space<fat_raw_buffer>>,
+ %b: memref<256xf32, #amdgpu.address_space<fat_raw_buffer>>,
+ %x: f32, %y: f32) -> (vector<4xf32>, vector<4xf32>) {
+ %c0 = arith.constant 0 : index
+ %c4 = arith.constant 4 : index
+ %v0 = vector.load %a[%c0] : memref<256xf32, #amdgpu.address_space<fat_raw_buffer>>, vector<4xf32>
+ %sum = arith.addf %x, %y : f32
+ %v1 = vector.load %b[%c4] : memref<256xf32, #amdgpu.address_space<fat_raw_buffer>>, vector<4xf32>
+ return %v0, %v1 : vector<4xf32>, vector<4xf32>
+}
+
+// -----
+
+// The second load's index is computed in the gap, but it only depends on
+// values available before the first load. The pass hoists that address
+// computation before the first load so the global loads can be adjacent.
+// CHECK-LABEL: func.func @hoists_independent_index_before_load_group
+// CHECK: %[[OFF1:.+]] = arith.addi
+// CHECK-NEXT: %[[L0:.+]] = vector.load %{{.+}}[%{{.+}}]
+// CHECK-NEXT: %[[L1:.+]] = vector.load %{{.+}}[%[[OFF1]]]
+// CHECK-NEXT: return %[[L0]], %[[L1]]
+func.func @hoists_independent_index_before_load_group(%a: memref<256xf32, #amdgpu.address_space<fat_raw_buffer>>,
+ %off: index,
+ %stride: index) -> (vector<4xf32>, vector<4xf32>) {
+ %v0 = vector.load %a[%off] : memref<256xf32, #amdgpu.address_space<fat_raw_buffer>>, vector<4xf32>
+ %off1 = arith.addi %off, %stride : index
+ %v1 = vector.load %a[%off1] : memref<256xf32, #amdgpu.address_space<fat_raw_buffer>>, vector<4xf32>
+ return %v0, %v1 : vector<4xf32>, vector<4xf32>
+}
+
+// -----
+
+// The second load reads from the same fat raw buffer that an intervening
+// vector.store writes to. Hoisting the load above the store would
+// change observable memory, so the pass must leave the loads in place.
+// CHECK-LABEL: func.func @blocked_by_buffer_write
+// CHECK: vector.load
+// CHECK-NEXT: vector.store
+// CHECK-NEXT: vector.load
+func.func @blocked_by_buffer_write(%a: memref<256xf32, #amdgpu.address_space<fat_raw_buffer>>,
+ %v: vector<4xf32>) -> (vector<4xf32>, vector<4xf32>) {
+ %c0 = arith.constant 0 : index
+ %c4 = arith.constant 4 : index
+ %v0 = vector.load %a[%c0] : memref<256xf32, #amdgpu.address_space<fat_raw_buffer>>, vector<4xf32>
+ vector.store %v, %a[%c4] : memref<256xf32, #amdgpu.address_space<fat_raw_buffer>>, vector<4xf32>
+ %v1 = vector.load %a[%c4] : memref<256xf32, #amdgpu.address_space<fat_raw_buffer>>, vector<4xf32>
+ return %v0, %v1 : vector<4xf32>, vector<4xf32>
+}
+
+// -----
+
+// Plain global memref store between two global loads is a global-memory write,
+// so hoisting must be blocked.
+// CHECK-LABEL: func.func @blocked_by_global_write
+// CHECK: vector.load
+// CHECK-NEXT: vector.store
+// CHECK-NEXT: vector.load
+func.func @blocked_by_global_write(%a: memref<256xf32, #amdgpu.address_space<fat_raw_buffer>>,
+ %g: memref<256xf32>,
+ %v: vector<4xf32>) -> (vector<4xf32>, vector<4xf32>) {
+ %c0 = arith.constant 0 : index
+ %c4 = arith.constant 4 : index
+ %v0 = vector.load %a[%c0] : memref<256xf32, #amdgpu.address_space<fat_raw_buffer>>, vector<4xf32>
+ vector.store %v, %g[%c4] : memref<256xf32>, vector<4xf32>
+ %v1 = vector.load %a[%c4] : memref<256xf32, #amdgpu.address_space<fat_raw_buffer>>, vector<4xf32>
+ return %v0, %v1 : vector<4xf32>, vector<4xf32>
+}
+
+// -----
+
+// A write to workgroup (shared) memory between the loads is not on the global
+// address space, so hoisting is allowed.
+// CHECK-LABEL: func.func @allowed_through_shared_write
+// CHECK: %[[L0:.+]] = vector.load %{{.+}}[%{{.+}}]
+// CHECK-NEXT: %[[L1:.+]] = vector.load %{{.+}}[%{{.+}}]
+// CHECK-NEXT: vector.store
+// CHECK-NEXT: return %[[L0]], %[[L1]]
+func.func @allowed_through_shared_write(%a: memref<256xf32, #amdgpu.address_space<fat_raw_buffer>>,
+ %s: memref<256xf32, #gpu.address_space<workgroup>>,
+ %v: vector<4xf32>) -> (vector<4xf32>, vector<4xf32>) {
+ %c0 = arith.constant 0 : index
+ %c4 = arith.constant 4 : index
+ %v0 = vector.load %a[%c0] : memref<256xf32, #amdgpu.address_space<fat_raw_buffer>>, vector<4xf32>
+ vector.store %v, %s[%c4] : memref<256xf32, #gpu.address_space<workgroup>>, vector<4xf32>
+ %v1 = vector.load %a[%c4] : memref<256xf32, #amdgpu.address_space<fat_raw_buffer>>, vector<4xf32>
+ return %v0, %v1 : vector<4xf32>, vector<4xf32>
+}
+
+// -----
+
+// The second load's index is computed in the gap from a value the first load
+// produces. The pure address-computation ops (vector.extract, arith.index_cast,
+// arith.muli) are transitive deps that must be hoisted along with the second
+// load.
+// CHECK-LABEL: func.func @hoists_index_dependencies
+// CHECK: %[[L0:.+]] = vector.load %{{.+}}[%{{.+}}]
+// CHECK-NEXT: vector.extract
+// CHECK-NEXT: arith.index_cast
+// CHECK-NEXT: arith.muli
+// CHECK-NEXT: %[[L1:.+]] = vector.load
+// CHECK-NEXT: return %[[L0]], %[[L1]]
+func.func @hoists_index_dependencies(%a: memref<256xf32, #amdgpu.address_space<fat_raw_buffer>>,
+ %b: memref<256xi32, #amdgpu.address_space<fat_raw_buffer>>)
+ -> (vector<1xi32>, vector<4xf32>) {
+ %c0 = arith.constant 0 : index
+ %c4 = arith.constant 4 : index
+ %v0 = vector.load %b[%c0] : memref<256xi32, #amdgpu.address_space<fat_raw_buffer>>, vector<1xi32>
+ %scalar = vector.extract %v0[0] : i32 from vector<1xi32>
+ %idx = arith.index_cast %scalar : i32 to index
+ %off = arith.muli %idx, %c4 : index
+ %v1 = vector.load %a[%off] : memref<256xf32, #amdgpu.address_space<fat_raw_buffer>>, vector<4xf32>
+ return %v0, %v1 : vector<1xi32>, vector<4xf32>
+}
+
+// -----
+
+// vector.load from a plain memref is also a global load, so the pass groups
+// these loads just like fat raw buffer loads.
+// CHECK-LABEL: func.func @groups_plain_global_loads
+// CHECK: %[[L0:.+]] = vector.load %{{.+}}[%{{.+}}]
+// CHECK-NEXT: %[[L1:.+]] = vector.load %{{.+}}[%{{.+}}]
+// CHECK-NEXT: arith.addf
+// CHECK-NEXT: return %[[L0]], %[[L1]]
+func.func @groups_plain_global_loads(%a: memref<256xf32>,
+ %b: memref<256xf32>,
+ %x: f32, %y: f32) -> (vector<4xf32>, vector<4xf32>) {
+ %c0 = arith.constant 0 : index
+ %c4 = arith.constant 4 : index
+ %v0 = vector.load %a[%c0] : memref<256xf32>, vector<4xf32>
+ %sum = arith.addf %x, %y : f32
+ %v1 = vector.load %b[%c4] : memref<256xf32>, vector<4xf32>
+ return %v0, %v1 : vector<4xf32>, vector<4xf32>
+}
+
+// -----
+
+// memref.load is also a global load, so the pass groups it with other global
+// loads.
+// CHECK-LABEL: func.func @groups_memref_loads
+// CHECK: %[[L0:.+]] = memref.load %{{.+}}[%{{.+}}]
+// CHECK-NEXT: %[[L1:.+]] = memref.load %{{.+}}[%{{.+}}]
+// CHECK-NEXT: arith.addf
+// CHECK-NEXT: return %[[L0]], %[[L1]]
+func.func @groups_memref_loads(%a: memref<256xf32>,
+ %b: memref<256xf32>,
+ %x: f32, %y: f32) -> (f32, f32) {
+ %c0 = arith.constant 0 : index
+ %c4 = arith.constant 4 : index
+ %v0 = memref.load %a[%c0] : memref<256xf32>
+ %sum = arith.addf %x, %y : f32
+ %v1 = memref.load %b[%c4] : memref<256xf32>
+ return %v0, %v1 : f32, f32
+}
+
+// -----
+
+// Two adjacent global loads (no ops between them) — the pass should be a
+// no-op and not perturb the IR.
+// CHECK-LABEL: func.func @already_adjacent
+// CHECK: %[[L0:.+]] = vector.load %{{.+}}[%{{.+}}]
+// CHECK-NEXT: %[[L1:.+]] = vector.load %{{.+}}[%{{.+}}]
+// CHECK-NEXT: return %[[L0]], %[[L1]]
+func.func @already_adjacent(%a: memref<256xf32, #amdgpu.address_space<fat_raw_buffer>>,
+ %b: memref<256xf32, #amdgpu.address_space<fat_raw_buffer>>)
+ -> (vector<4xf32>, vector<4xf32>) {
+ %c0 = arith.constant 0 : index
+ %c4 = arith.constant 4 : index
+ %v0 = vector.load %a[%c0] : memref<256xf32, #amdgpu.address_space<fat_raw_buffer>>, vector<4xf32>
+ %v1 = vector.load %b[%c4] : memref<256xf32, #amdgpu.address_space<fat_raw_buffer>>, vector<4xf32>
+ return %v0, %v1 : vector<4xf32>, vector<4xf32>
+}
+
+// -----
+
+// Three global loads with pure ops scattered between them. All three should
+// end up grouped at the position of the first load.
+// CHECK-LABEL: func.func @group_three_loads
+// CHECK: %[[L0:.+]] = vector.load %{{.+}}[%{{.+}}]
+// CHECK-NEXT: %[[L1:.+]] = vector.load %{{.+}}[%{{.+}}]
+// CHECK-NEXT: %[[L2:.+]] = vector.load %{{.+}}[%{{.+}}]
+// CHECK-NEXT: arith.addf
+// CHECK-NEXT: arith.mulf
+// CHECK-NEXT: return %[[L0]], %[[L1]], %[[L2]]
+func.func @group_three_loads(%a: memref<256xf32, #amdgpu.address_space<fat_raw_buffer>>,
+ %b: memref<256xf32, #amdgpu.address_space<fat_raw_buffer>>,
+ %c: memref<256xf32, #amdgpu.address_space<fat_raw_buffer>>,
+ %x: f32, %y: f32)
+ -> (vector<4xf32>, vector<4xf32>, vector<4xf32>) {
+ %c0 = arith.constant 0 : index
+ %c4 = arith.constant 4 : index
+ %c8 = arith.constant 8 : index
+ %v0 = vector.load %a[%c0] : memref<256xf32, #amdgpu.address_space<fat_raw_buffer>>, vector<4xf32>
+ %sum = arith.addf %x, %y : f32
+ %v1 = vector.load %b[%c4] : memref<256xf32, #amdgpu.address_space<fat_raw_buffer>>, vector<4xf32>
+ %prod = arith.mulf %sum, %y : f32
+ %v2 = vector.load %c[%c8] : memref<256xf32, #amdgpu.address_space<fat_raw_buffer>>, vector<4xf32>
+ return %v0, %v1, %v2 : vector<4xf32>, vector<4xf32>, vector<4xf32>
+}