[Codegen] Add pass to reinsert swizzle hints from alloc attributes (#24001)
Part 2/3 of enabling XOR swizzle with software pipelining (#23919).
**Overall plan:** `SwizzleHintOp` in the SSA chain blocks both
`memref::multiBuffer` and `scf::pipelineForLoop`. The fix absorbs the
hint into an alloc attribute before pipelining, preserves it through
multi-buffering, then re-inserts hints at leaf users afterward.
**This PR:** Adds `ReinsertSwizzleHintsPass`, which traces `vector.load`
and `vector.store` ops back to attributed `memref.alloc`s and inserts
`collapse_shape -> swizzle_hint -> expand_shape` at each use site.
`gather_to_lds` swizzle is handled separately in
`AMDGPULowerCoalescedDMAToGatherLDS`, which applies an inverse swizzle
on the source indices directly -- `swizzle_hint` cannot be used there
because the source is strided global memory. The pass is not yet wired
into any pipeline.
Assisted-by: Cursor (Claude)
---------
Signed-off-by: Yu-Zhewen <zhewenyu@amd.com>
diff --git a/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel
index 4a8911a..a62283d 100644
--- a/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel
@@ -152,6 +152,7 @@
"PropagateDispatchSizeBounds.cpp",
"PropagateReshapesByExpansion.cpp",
"ReconcileTranslationInfo.cpp",
+ "ReinsertSwizzleHints.cpp",
"RematerializeParallelOps.cpp",
"RemoveIndexHints.cpp",
"RemoveSingleIterationLoop.cpp",
diff --git a/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt
index fe04f75..0399d69 100644
--- a/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt
@@ -145,6 +145,7 @@
"PropagateDispatchSizeBounds.cpp"
"PropagateReshapesByExpansion.cpp"
"ReconcileTranslationInfo.cpp"
+ "ReinsertSwizzleHints.cpp"
"RematerializeParallelOps.cpp"
"RemoveIndexHints.cpp"
"RemoveSingleIterationLoop.cpp"
diff --git a/compiler/src/iree/compiler/Codegen/Common/Passes.td b/compiler/src/iree/compiler/Codegen/Common/Passes.td
index 17f5c62..dfcf98e 100644
--- a/compiler/src/iree/compiler/Codegen/Common/Passes.td
+++ b/compiler/src/iree/compiler/Codegen/Common/Passes.td
@@ -1023,6 +1023,15 @@
}];
}
+def ReinsertSwizzleHintsPass :
+ InterfacePass<"iree-codegen-reinsert-swizzle-hints", "mlir::FunctionOpInterface"> {
+ let summary = "Re-inserts swizzle_hint ops from alloc attributes at vector.load/store sites";
+ let dependentDialects = [
+ "IREE::Codegen::IREECodegenDialect",
+ "memref::MemRefDialect",
+ ];
+}
+
def RematerializeParallelOpsPass :
InterfacePass<"iree-codegen-rematerialize-parallel-ops", "mlir::FunctionOpInterface"> {
let summary = "Pass to rematerialize and merge parallel ops into consumers.";
diff --git a/compiler/src/iree/compiler/Codegen/Common/ReinsertSwizzleHints.cpp b/compiler/src/iree/compiler/Codegen/Common/ReinsertSwizzleHints.cpp
new file mode 100644
index 0000000..6a9872c
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/Common/ReinsertSwizzleHints.cpp
@@ -0,0 +1,146 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Codegen/Common/Passes.h"
+#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Interfaces/LoopLikeInterface.h"
+
+namespace mlir::iree_compiler {
+
+#define GEN_PASS_DEF_REINSERTSWIZZLEHINTSPASS
+#include "iree/compiler/Codegen/Common/Passes.h.inc"
+
+namespace {
+struct ReinsertSwizzleHintsPass final
+ : impl::ReinsertSwizzleHintsPassBase<ReinsertSwizzleHintsPass> {
+ using Base::Base;
+ void runOnOperation() override;
+};
+} // namespace
+
+/// Traces a memref value backward through defining ops and loop
+/// iter_args/results to find the root memref.alloc.
+static memref::AllocOp traceToAllocation(Value val) {
+ DenseSet<Value> visited;
+ SmallVector<Value> worklist = {val};
+ while (!worklist.empty()) {
+ Value current = worklist.pop_back_val();
+ if (!visited.insert(current).second) {
+ continue;
+ }
+ Operation *defOp = current.getDefiningOp();
+ if (!defOp) {
+ auto blockArg = cast<BlockArgument>(current);
+ auto loopOp =
+ dyn_cast<LoopLikeOpInterface>(blockArg.getOwner()->getParentOp());
+ if (!loopOp) {
+ continue;
+ }
+ if (OpOperand *init = loopOp.getTiedLoopInit(blockArg)) {
+ worklist.push_back(init->get());
+ }
+ } else if (auto allocOp = dyn_cast<memref::AllocOp>(defOp)) {
+ return allocOp;
+ } else if (auto loopOp = dyn_cast<LoopLikeOpInterface>(defOp)) {
+ if (OpOperand *init = loopOp.getTiedLoopInit(cast<OpResult>(current))) {
+ worklist.push_back(init->get());
+ }
+ } else {
+ for (Value operand : defOp->getOperands()) {
+ if (isa<MemRefType>(operand.getType())) {
+ worklist.push_back(operand);
+ }
+ }
+ }
+ }
+ return nullptr;
+}
+
+/// Returns the swizzle attribute on the alloc that val traces to, using
+/// cache to avoid repeated tracing.
+static IREE::Codegen::SwizzleAttrInterface
+lookupSwizzleAttr(Value val,
+ DenseMap<Value, IREE::Codegen::SwizzleAttrInterface> &cache) {
+ auto it = cache.find(val);
+ if (it != cache.end()) {
+ return it->second;
+ }
+ memref::AllocOp allocOp = traceToAllocation(val);
+ IREE::Codegen::SwizzleAttrInterface swizzle;
+ if (allocOp) {
+ swizzle = allocOp->getAttrOfType<IREE::Codegen::SwizzleAttrInterface>(
+ "iree_codegen.swizzle");
+ }
+ cache[val] = swizzle;
+ return swizzle;
+}
+
+/// Wraps |source| with collapse_shape -> swizzle_hint -> expand_shape so that
+/// downstream ResolveSwizzleHints can apply the XOR transform. For 1D memrefs,
+/// only the swizzle_hint is inserted.
+static Value insertSwizzleHint(IRRewriter &rewriter, Location loc, Value source,
+ IREE::Codegen::SwizzleAttrInterface swizzle) {
+ auto sourceType = cast<MemRefType>(source.getType());
+
+ Value hintInput = source;
+ SmallVector<ReassociationIndices> reassoc;
+
+ if (sourceType.getRank() > 1) {
+ reassoc.push_back(
+ llvm::to_vector(llvm::seq<int64_t>(0, sourceType.getRank())));
+ hintInput = memref::CollapseShapeOp::create(rewriter, loc, source, reassoc);
+ }
+
+ auto hintOp =
+ IREE::Codegen::SwizzleHintOp::create(rewriter, loc, hintInput, swizzle);
+
+ if (sourceType.getRank() > 1) {
+ return memref::ExpandShapeOp::create(rewriter, loc, sourceType.getShape(),
+ hintOp.getResult(), reassoc);
+ }
+ return hintOp.getResult();
+}
+
+void ReinsertSwizzleHintsPass::runOnOperation() {
+ FunctionOpInterface funcOp = getOperation();
+ IRRewriter rewriter(funcOp->getContext());
+ DenseMap<Value, IREE::Codegen::SwizzleAttrInterface> swizzleCache;
+
+ // For each vector.load/store whose base traces to a swizzled alloc, wrap the
+ // base with collapse_shape -> swizzle_hint -> expand_shape.
+ funcOp.walk([&](Operation *op) {
+ Value base;
+ if (auto loadOp = dyn_cast<vector::LoadOp>(op)) {
+ base = loadOp.getBase();
+ } else if (auto storeOp = dyn_cast<vector::StoreOp>(op)) {
+ base = storeOp.getBase();
+ } else {
+ return;
+ }
+ IREE::Codegen::SwizzleAttrInterface swizzle =
+ lookupSwizzleAttr(base, swizzleCache);
+ if (!swizzle) {
+ return;
+ }
+ rewriter.setInsertionPoint(op);
+ Value wrapped = insertSwizzleHint(rewriter, op->getLoc(), base, swizzle);
+ if (auto loadOp = dyn_cast<vector::LoadOp>(op)) {
+ loadOp.getBaseMutable().assign(wrapped);
+ } else if (auto storeOp = dyn_cast<vector::StoreOp>(op)) {
+ storeOp.getBaseMutable().assign(wrapped);
+ }
+ });
+
+ // Clean up the swizzle attributes from allocs now that hints are inserted.
+ funcOp.walk([](memref::AllocOp allocOp) {
+ allocOp->removeAttr("iree_codegen.swizzle");
+ });
+}
+
+} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel
index e6ddc62..24d17ee 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel
@@ -120,6 +120,7 @@
"reconcile_translation_info_linearize.mlir",
"reconcile_translation_info_pure.mlir",
"reductions.mlir",
+ "reinsert_swizzle_hints.mlir",
"rematerialize_parallel_ops.mlir",
"remove_dead_allocs.mlir",
"remove_index_hints.mlir",
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt
index 0d5889e..0098626 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt
@@ -115,6 +115,7 @@
"reconcile_translation_info_linearize.mlir"
"reconcile_translation_info_pure.mlir"
"reductions.mlir"
+ "reinsert_swizzle_hints.mlir"
"rematerialize_parallel_ops.mlir"
"remove_dead_allocs.mlir"
"remove_index_hints.mlir"
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/reinsert_swizzle_hints.mlir b/compiler/src/iree/compiler/Codegen/Common/test/reinsert_swizzle_hints.mlir
new file mode 100644
index 0000000..5fee9ea
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/Common/test/reinsert_swizzle_hints.mlir
@@ -0,0 +1,120 @@
+// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-codegen-reinsert-swizzle-hints))" \
+// RUN: --split-input-file --mlir-print-local-scope %s | FileCheck %s
+
+func.func @load_1d() -> vector<8xbf16> {
+ %alloc = memref.alloc() {iree_codegen.swizzle = #iree_codegen.xor_shuffle<128, 8>}
+ : memref<1024xbf16, #gpu.address_space<workgroup>>
+ %c0 = arith.constant 0 : index
+ %v = vector.load %alloc[%c0]
+ : memref<1024xbf16, #gpu.address_space<workgroup>>, vector<8xbf16>
+ return %v : vector<8xbf16>
+}
+
+// CHECK-LABEL: func @load_1d
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<1024xbf16, #gpu.address_space<workgroup>>
+// CHECK: %[[HINT:.+]] = iree_codegen.swizzle_hint %[[ALLOC]][#iree_codegen.xor_shuffle<128, 8>]
+// CHECK: vector.load %[[HINT]][%{{.+}}]
+
+// -----
+
+func.func @load_2d() -> vector<8xbf16> {
+ %alloc = memref.alloc() {iree_codegen.swizzle = #iree_codegen.xor_shuffle<128, 8>}
+ : memref<8x128xbf16, #gpu.address_space<workgroup>>
+ %c0 = arith.constant 0 : index
+ %v = vector.load %alloc[%c0, %c0]
+ : memref<8x128xbf16, #gpu.address_space<workgroup>>, vector<8xbf16>
+ return %v : vector<8xbf16>
+}
+
+// CHECK-LABEL: func @load_2d
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<8x128xbf16, #gpu.address_space<workgroup>>
+// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[ALLOC]] {{\[\[}}0, 1{{\]\]}}
+// CHECK: %[[HINT:.+]] = iree_codegen.swizzle_hint %[[COLLAPSED]][#iree_codegen.xor_shuffle<128, 8>]
+// CHECK: %[[EXPANDED:.+]] = memref.expand_shape %[[HINT]] {{\[\[}}0, 1{{\]\]}}
+// CHECK: vector.load %[[EXPANDED]][%{{.+}}, %{{.+}}]
+
+// -----
+
+func.func @store_2d(%v: vector<8xbf16>) {
+ %alloc = memref.alloc() {iree_codegen.swizzle = #iree_codegen.xor_shuffle<128, 8>}
+ : memref<8x128xbf16, #gpu.address_space<workgroup>>
+ %c0 = arith.constant 0 : index
+ vector.store %v, %alloc[%c0, %c0]
+ : memref<8x128xbf16, #gpu.address_space<workgroup>>, vector<8xbf16>
+ return
+}
+
+// CHECK-LABEL: func @store_2d
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<8x128xbf16, #gpu.address_space<workgroup>>
+// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[ALLOC]] {{\[\[}}0, 1{{\]\]}}
+// CHECK: %[[HINT:.+]] = iree_codegen.swizzle_hint %[[COLLAPSED]][#iree_codegen.xor_shuffle<128, 8>]
+// CHECK: %[[EXPANDED:.+]] = memref.expand_shape %[[HINT]] {{\[\[}}0, 1{{\]\]}}
+// CHECK: vector.store %{{.+}}, %[[EXPANDED]][%{{.+}}, %{{.+}}]
+
+// -----
+
+func.func @no_hint_without_swizzle() -> vector<8xbf16> {
+ %alloc = memref.alloc() : memref<1024xbf16, #gpu.address_space<workgroup>>
+ %c0 = arith.constant 0 : index
+ %v = vector.load %alloc[%c0]
+ : memref<1024xbf16, #gpu.address_space<workgroup>>, vector<8xbf16>
+ return %v : vector<8xbf16>
+}
+
+// CHECK-LABEL: func @no_hint_without_swizzle
+// CHECK: %[[ALLOC:.+]] = memref.alloc()
+// CHECK: vector.load %[[ALLOC]][%{{.+}}]
+// CHECK-NOT: swizzle_hint
+
+// -----
+
+func.func @trace_through_subview() -> vector<8xbf16> {
+ %alloc = memref.alloc() {iree_codegen.swizzle = #iree_codegen.xor_shuffle<128, 8>}
+ : memref<8x128xbf16, #gpu.address_space<workgroup>>
+ %c0 = arith.constant 0 : index
+ %subview = memref.subview %alloc[0, 0][4, 128][1, 1]
+ : memref<8x128xbf16, #gpu.address_space<workgroup>>
+ to memref<4x128xbf16, strided<[128, 1]>, #gpu.address_space<workgroup>>
+ %v = vector.load %subview[%c0, %c0]
+ : memref<4x128xbf16, strided<[128, 1]>, #gpu.address_space<workgroup>>, vector<8xbf16>
+ return %v : vector<8xbf16>
+}
+
+// CHECK-LABEL: func @trace_through_subview
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<8x128xbf16, #gpu.address_space<workgroup>>
+// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ALLOC]]
+// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[SUBVIEW]] {{\[\[}}0, 1{{\]\]}}
+// CHECK: %[[HINT:.+]] = iree_codegen.swizzle_hint %[[COLLAPSED]][#iree_codegen.xor_shuffle<128, 8>]
+// CHECK: %[[EXPANDED:.+]] = memref.expand_shape %[[HINT]] {{\[\[}}0, 1{{\]\]}}
+// CHECK: vector.load %[[EXPANDED]][%{{.+}}, %{{.+}}]
+
+// -----
+
+func.func @trace_through_scf_for() -> vector<8xbf16> {
+ %alloc = memref.alloc() {iree_codegen.swizzle = #iree_codegen.xor_shuffle<128, 8>}
+ : memref<8x128xbf16, #gpu.address_space<workgroup>>
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c4 = arith.constant 4 : index
+ %res = scf.for %iv = %c0 to %c4 step %c1
+ iter_args(%arg = %alloc) -> memref<8x128xbf16, #gpu.address_space<workgroup>> {
+ %v = vector.load %arg[%c0, %c0]
+ : memref<8x128xbf16, #gpu.address_space<workgroup>>, vector<8xbf16>
+ scf.yield %arg : memref<8x128xbf16, #gpu.address_space<workgroup>>
+ }
+ %epilogue = vector.load %res[%c0, %c0]
+ : memref<8x128xbf16, #gpu.address_space<workgroup>>, vector<8xbf16>
+ return %epilogue : vector<8xbf16>
+}
+
+// CHECK-LABEL: func @trace_through_scf_for
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<8x128xbf16, #gpu.address_space<workgroup>>
+// CHECK: scf.for
+// CHECK: %[[C1:.+]] = memref.collapse_shape %{{.+}} {{\[\[}}0, 1{{\]\]}}
+// CHECK: %[[H1:.+]] = iree_codegen.swizzle_hint %[[C1]][#iree_codegen.xor_shuffle<128, 8>]
+// CHECK: %[[E1:.+]] = memref.expand_shape %[[H1]] {{\[\[}}0, 1{{\]\]}}
+// CHECK: vector.load %[[E1]]
+// CHECK: %[[C2:.+]] = memref.collapse_shape %{{.+}} {{\[\[}}0, 1{{\]\]}}
+// CHECK: %[[H2:.+]] = iree_codegen.swizzle_hint %[[C2]][#iree_codegen.xor_shuffle<128, 8>]
+// CHECK: %[[E2:.+]] = memref.expand_shape %[[H2]] {{\[\[}}0, 1{{\]\]}}
+// CHECK: vector.load %[[E2]]