[Codegen] Add pass to reinsert swizzle hints from alloc attributes (#24001)

Part 2/3 of enabling XOR swizzle with software pipelining (#23919).

**Overall plan:** `SwizzleHintOp` in the SSA chain blocks both
`memref::multiBuffer` and `scf::pipelineForLoop`. The fix absorbs the
hint into an alloc attribute before pipelining, preserves it through
multi-buffering, then re-inserts hints at leaf users afterward.

**This PR:** Adds `ReinsertSwizzleHintsPass`, which traces `vector.load`
and `vector.store` ops back to attributed `memref.alloc`s and inserts
`collapse_shape -> swizzle_hint -> expand_shape` at each use site.
`gather_to_lds` swizzle is handled separately in
`AMDGPULowerCoalescedDMAToGatherLDS`, which applies an inverse swizzle
on the source indices directly -- `swizzle_hint` cannot be used there
because the source is strided global memory. The pass is not yet wired
into any pipeline.

Assisted-by: Cursor (Claude)

---------

Signed-off-by: Yu-Zhewen <zhewenyu@amd.com>
diff --git a/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel
index 4a8911a..a62283d 100644
--- a/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel
@@ -152,6 +152,7 @@
         "PropagateDispatchSizeBounds.cpp",
         "PropagateReshapesByExpansion.cpp",
         "ReconcileTranslationInfo.cpp",
+        "ReinsertSwizzleHints.cpp",
         "RematerializeParallelOps.cpp",
         "RemoveIndexHints.cpp",
         "RemoveSingleIterationLoop.cpp",
diff --git a/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt
index fe04f75..0399d69 100644
--- a/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt
@@ -145,6 +145,7 @@
     "PropagateDispatchSizeBounds.cpp"
     "PropagateReshapesByExpansion.cpp"
     "ReconcileTranslationInfo.cpp"
+    "ReinsertSwizzleHints.cpp"
     "RematerializeParallelOps.cpp"
     "RemoveIndexHints.cpp"
     "RemoveSingleIterationLoop.cpp"
diff --git a/compiler/src/iree/compiler/Codegen/Common/Passes.td b/compiler/src/iree/compiler/Codegen/Common/Passes.td
index 17f5c62..dfcf98e 100644
--- a/compiler/src/iree/compiler/Codegen/Common/Passes.td
+++ b/compiler/src/iree/compiler/Codegen/Common/Passes.td
@@ -1023,6 +1023,15 @@
   }];
 }
 
+def ReinsertSwizzleHintsPass :
+    InterfacePass<"iree-codegen-reinsert-swizzle-hints", "mlir::FunctionOpInterface"> {
+  let summary = "Re-inserts swizzle_hint ops from alloc attributes at vector.load/store sites";
+  let dependentDialects = [
+    "IREE::Codegen::IREECodegenDialect",
+    "memref::MemRefDialect",
+  ];
+}
+
 def RematerializeParallelOpsPass :
     InterfacePass<"iree-codegen-rematerialize-parallel-ops", "mlir::FunctionOpInterface"> {
   let summary = "Pass to rematerialize and merge parallel ops into consumers.";
diff --git a/compiler/src/iree/compiler/Codegen/Common/ReinsertSwizzleHints.cpp b/compiler/src/iree/compiler/Codegen/Common/ReinsertSwizzleHints.cpp
new file mode 100644
index 0000000..6a9872c
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/Common/ReinsertSwizzleHints.cpp
@@ -0,0 +1,146 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Codegen/Common/Passes.h"
+#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Interfaces/LoopLikeInterface.h"
+
+namespace mlir::iree_compiler {
+
+#define GEN_PASS_DEF_REINSERTSWIZZLEHINTSPASS
+#include "iree/compiler/Codegen/Common/Passes.h.inc"
+
+namespace {
+struct ReinsertSwizzleHintsPass final
+    : impl::ReinsertSwizzleHintsPassBase<ReinsertSwizzleHintsPass> {
+  using Base::Base;
+  void runOnOperation() override;
+};
+} // namespace
+
+/// Traces a memref value backward through defining ops and loop
+/// iter_args/results to find the root memref.alloc.
+static memref::AllocOp traceToAllocation(Value val) {
+  DenseSet<Value> visited;
+  SmallVector<Value> worklist = {val};
+  while (!worklist.empty()) {
+    Value current = worklist.pop_back_val();
+    if (!visited.insert(current).second) {
+      continue;
+    }
+    Operation *defOp = current.getDefiningOp();
+    if (!defOp) {
+      auto blockArg = cast<BlockArgument>(current);
+      auto loopOp =
+          dyn_cast<LoopLikeOpInterface>(blockArg.getOwner()->getParentOp());
+      if (!loopOp) {
+        continue;
+      }
+      if (OpOperand *init = loopOp.getTiedLoopInit(blockArg)) {
+        worklist.push_back(init->get());
+      }
+    } else if (auto allocOp = dyn_cast<memref::AllocOp>(defOp)) {
+      return allocOp;
+    } else if (auto loopOp = dyn_cast<LoopLikeOpInterface>(defOp)) {
+      if (OpOperand *init = loopOp.getTiedLoopInit(cast<OpResult>(current))) {
+        worklist.push_back(init->get());
+      }
+    } else {
+      for (Value operand : defOp->getOperands()) {
+        if (isa<MemRefType>(operand.getType())) {
+          worklist.push_back(operand);
+        }
+      }
+    }
+  }
+  return nullptr;
+}
+
+/// Returns the swizzle attribute on the alloc that val traces to, using
+/// cache to avoid repeated tracing.
+static IREE::Codegen::SwizzleAttrInterface
+lookupSwizzleAttr(Value val,
+                  DenseMap<Value, IREE::Codegen::SwizzleAttrInterface> &cache) {
+  auto it = cache.find(val);
+  if (it != cache.end()) {
+    return it->second;
+  }
+  memref::AllocOp allocOp = traceToAllocation(val);
+  IREE::Codegen::SwizzleAttrInterface swizzle;
+  if (allocOp) {
+    swizzle = allocOp->getAttrOfType<IREE::Codegen::SwizzleAttrInterface>(
+        "iree_codegen.swizzle");
+  }
+  cache[val] = swizzle;
+  return swizzle;
+}
+
+/// Wraps |source| with collapse_shape -> swizzle_hint -> expand_shape so that
+/// downstream ResolveSwizzleHints can apply the XOR transform. For 1D memrefs,
+/// only the swizzle_hint is inserted.
+static Value insertSwizzleHint(IRRewriter &rewriter, Location loc, Value source,
+                               IREE::Codegen::SwizzleAttrInterface swizzle) {
+  auto sourceType = cast<MemRefType>(source.getType());
+
+  Value hintInput = source;
+  SmallVector<ReassociationIndices> reassoc;
+
+  if (sourceType.getRank() > 1) {
+    reassoc.push_back(
+        llvm::to_vector(llvm::seq<int64_t>(0, sourceType.getRank())));
+    hintInput = memref::CollapseShapeOp::create(rewriter, loc, source, reassoc);
+  }
+
+  auto hintOp =
+      IREE::Codegen::SwizzleHintOp::create(rewriter, loc, hintInput, swizzle);
+
+  if (sourceType.getRank() > 1) {
+    return memref::ExpandShapeOp::create(rewriter, loc, sourceType.getShape(),
+                                         hintOp.getResult(), reassoc);
+  }
+  return hintOp.getResult();
+}
+
+void ReinsertSwizzleHintsPass::runOnOperation() {
+  FunctionOpInterface funcOp = getOperation();
+  IRRewriter rewriter(funcOp->getContext());
+  DenseMap<Value, IREE::Codegen::SwizzleAttrInterface> swizzleCache;
+
+  // For each vector.load/store whose base traces to a swizzled alloc, wrap the
+  // base with collapse_shape -> swizzle_hint -> expand_shape.
+  funcOp.walk([&](Operation *op) {
+    Value base;
+    if (auto loadOp = dyn_cast<vector::LoadOp>(op)) {
+      base = loadOp.getBase();
+    } else if (auto storeOp = dyn_cast<vector::StoreOp>(op)) {
+      base = storeOp.getBase();
+    } else {
+      return;
+    }
+    IREE::Codegen::SwizzleAttrInterface swizzle =
+        lookupSwizzleAttr(base, swizzleCache);
+    if (!swizzle) {
+      return;
+    }
+    rewriter.setInsertionPoint(op);
+    Value wrapped = insertSwizzleHint(rewriter, op->getLoc(), base, swizzle);
+    if (auto loadOp = dyn_cast<vector::LoadOp>(op)) {
+      loadOp.getBaseMutable().assign(wrapped);
+    } else if (auto storeOp = dyn_cast<vector::StoreOp>(op)) {
+      storeOp.getBaseMutable().assign(wrapped);
+    }
+  });
+
+  // Clean up the swizzle attributes from allocs now that hints are inserted.
+  funcOp.walk([](memref::AllocOp allocOp) {
+    allocOp->removeAttr("iree_codegen.swizzle");
+  });
+}
+
+} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel
index e6ddc62..24d17ee 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel
@@ -120,6 +120,7 @@
             "reconcile_translation_info_linearize.mlir",
             "reconcile_translation_info_pure.mlir",
             "reductions.mlir",
+            "reinsert_swizzle_hints.mlir",
             "rematerialize_parallel_ops.mlir",
             "remove_dead_allocs.mlir",
             "remove_index_hints.mlir",
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt
index 0d5889e..0098626 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt
@@ -115,6 +115,7 @@
     "reconcile_translation_info_linearize.mlir"
     "reconcile_translation_info_pure.mlir"
     "reductions.mlir"
+    "reinsert_swizzle_hints.mlir"
     "rematerialize_parallel_ops.mlir"
     "remove_dead_allocs.mlir"
     "remove_index_hints.mlir"
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/reinsert_swizzle_hints.mlir b/compiler/src/iree/compiler/Codegen/Common/test/reinsert_swizzle_hints.mlir
new file mode 100644
index 0000000..5fee9ea
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/Common/test/reinsert_swizzle_hints.mlir
@@ -0,0 +1,120 @@
+// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-codegen-reinsert-swizzle-hints))" \
+// RUN:   --split-input-file --mlir-print-local-scope %s | FileCheck %s
+
+func.func @load_1d() -> vector<8xbf16> {
+  %alloc = memref.alloc() {iree_codegen.swizzle = #iree_codegen.xor_shuffle<128, 8>}
+    : memref<1024xbf16, #gpu.address_space<workgroup>>
+  %c0 = arith.constant 0 : index
+  %v = vector.load %alloc[%c0]
+    : memref<1024xbf16, #gpu.address_space<workgroup>>, vector<8xbf16>
+  return %v : vector<8xbf16>
+}
+
+// CHECK-LABEL: func @load_1d
+//       CHECK:   %[[ALLOC:.+]] = memref.alloc() : memref<1024xbf16, #gpu.address_space<workgroup>>
+//       CHECK:   %[[HINT:.+]] = iree_codegen.swizzle_hint %[[ALLOC]][#iree_codegen.xor_shuffle<128, 8>]
+//       CHECK:   vector.load %[[HINT]][%{{.+}}]
+
+// -----
+
+func.func @load_2d() -> vector<8xbf16> {
+  %alloc = memref.alloc() {iree_codegen.swizzle = #iree_codegen.xor_shuffle<128, 8>}
+    : memref<8x128xbf16, #gpu.address_space<workgroup>>
+  %c0 = arith.constant 0 : index
+  %v = vector.load %alloc[%c0, %c0]
+    : memref<8x128xbf16, #gpu.address_space<workgroup>>, vector<8xbf16>
+  return %v : vector<8xbf16>
+}
+
+// CHECK-LABEL: func @load_2d
+//       CHECK:   %[[ALLOC:.+]] = memref.alloc() : memref<8x128xbf16, #gpu.address_space<workgroup>>
+//       CHECK:   %[[COLLAPSED:.+]] = memref.collapse_shape %[[ALLOC]] {{\[\[}}0, 1{{\]\]}}
+//       CHECK:   %[[HINT:.+]] = iree_codegen.swizzle_hint %[[COLLAPSED]][#iree_codegen.xor_shuffle<128, 8>]
+//       CHECK:   %[[EXPANDED:.+]] = memref.expand_shape %[[HINT]] {{\[\[}}0, 1{{\]\]}}
+//       CHECK:   vector.load %[[EXPANDED]][%{{.+}}, %{{.+}}]
+
+// -----
+
+func.func @store_2d(%v: vector<8xbf16>) {
+  %alloc = memref.alloc() {iree_codegen.swizzle = #iree_codegen.xor_shuffle<128, 8>}
+    : memref<8x128xbf16, #gpu.address_space<workgroup>>
+  %c0 = arith.constant 0 : index
+  vector.store %v, %alloc[%c0, %c0]
+    : memref<8x128xbf16, #gpu.address_space<workgroup>>, vector<8xbf16>
+  return
+}
+
+// CHECK-LABEL: func @store_2d
+//       CHECK:   %[[ALLOC:.+]] = memref.alloc() : memref<8x128xbf16, #gpu.address_space<workgroup>>
+//       CHECK:   %[[COLLAPSED:.+]] = memref.collapse_shape %[[ALLOC]] {{\[\[}}0, 1{{\]\]}}
+//       CHECK:   %[[HINT:.+]] = iree_codegen.swizzle_hint %[[COLLAPSED]][#iree_codegen.xor_shuffle<128, 8>]
+//       CHECK:   %[[EXPANDED:.+]] = memref.expand_shape %[[HINT]] {{\[\[}}0, 1{{\]\]}}
+//       CHECK:   vector.store %{{.+}}, %[[EXPANDED]][%{{.+}}, %{{.+}}]
+
+// -----
+
+func.func @no_hint_without_swizzle() -> vector<8xbf16> {
+  %alloc = memref.alloc() : memref<1024xbf16, #gpu.address_space<workgroup>>
+  %c0 = arith.constant 0 : index
+  %v = vector.load %alloc[%c0]
+    : memref<1024xbf16, #gpu.address_space<workgroup>>, vector<8xbf16>
+  return %v : vector<8xbf16>
+}
+
+// CHECK-LABEL: func @no_hint_without_swizzle
+//       CHECK:   %[[ALLOC:.+]] = memref.alloc()
+//       CHECK:   vector.load %[[ALLOC]][%{{.+}}]
+//   CHECK-NOT:   swizzle_hint
+
+// -----
+
+func.func @trace_through_subview() -> vector<8xbf16> {
+  %alloc = memref.alloc() {iree_codegen.swizzle = #iree_codegen.xor_shuffle<128, 8>}
+    : memref<8x128xbf16, #gpu.address_space<workgroup>>
+  %c0 = arith.constant 0 : index
+  %subview = memref.subview %alloc[0, 0][4, 128][1, 1]
+    : memref<8x128xbf16, #gpu.address_space<workgroup>>
+    to memref<4x128xbf16, strided<[128, 1]>, #gpu.address_space<workgroup>>
+  %v = vector.load %subview[%c0, %c0]
+    : memref<4x128xbf16, strided<[128, 1]>, #gpu.address_space<workgroup>>, vector<8xbf16>
+  return %v : vector<8xbf16>
+}
+
+// CHECK-LABEL: func @trace_through_subview
+//       CHECK:   %[[ALLOC:.+]] = memref.alloc() : memref<8x128xbf16, #gpu.address_space<workgroup>>
+//       CHECK:   %[[SUBVIEW:.+]] = memref.subview %[[ALLOC]]
+//       CHECK:   %[[COLLAPSED:.+]] = memref.collapse_shape %[[SUBVIEW]] {{\[\[}}0, 1{{\]\]}}
+//       CHECK:   %[[HINT:.+]] = iree_codegen.swizzle_hint %[[COLLAPSED]][#iree_codegen.xor_shuffle<128, 8>]
+//       CHECK:   %[[EXPANDED:.+]] = memref.expand_shape %[[HINT]] {{\[\[}}0, 1{{\]\]}}
+//       CHECK:   vector.load %[[EXPANDED]][%{{.+}}, %{{.+}}]
+
+// -----
+
+func.func @trace_through_scf_for() -> vector<8xbf16> {
+  %alloc = memref.alloc() {iree_codegen.swizzle = #iree_codegen.xor_shuffle<128, 8>}
+    : memref<8x128xbf16, #gpu.address_space<workgroup>>
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c4 = arith.constant 4 : index
+  %res = scf.for %iv = %c0 to %c4 step %c1
+      iter_args(%arg = %alloc) -> memref<8x128xbf16, #gpu.address_space<workgroup>> {
+    %v = vector.load %arg[%c0, %c0]
+      : memref<8x128xbf16, #gpu.address_space<workgroup>>, vector<8xbf16>
+    scf.yield %arg : memref<8x128xbf16, #gpu.address_space<workgroup>>
+  }
+  %epilogue = vector.load %res[%c0, %c0]
+    : memref<8x128xbf16, #gpu.address_space<workgroup>>, vector<8xbf16>
+  return %epilogue : vector<8xbf16>
+}
+
+// CHECK-LABEL: func @trace_through_scf_for
+//       CHECK:   %[[ALLOC:.+]] = memref.alloc() : memref<8x128xbf16, #gpu.address_space<workgroup>>
+//       CHECK:   scf.for
+//       CHECK:     %[[C1:.+]] = memref.collapse_shape %{{.+}} {{\[\[}}0, 1{{\]\]}}
+//       CHECK:     %[[H1:.+]] = iree_codegen.swizzle_hint %[[C1]][#iree_codegen.xor_shuffle<128, 8>]
+//       CHECK:     %[[E1:.+]] = memref.expand_shape %[[H1]] {{\[\[}}0, 1{{\]\]}}
+//       CHECK:     vector.load %[[E1]]
+//       CHECK:   %[[C2:.+]] = memref.collapse_shape %{{.+}} {{\[\[}}0, 1{{\]\]}}
+//       CHECK:   %[[H2:.+]] = iree_codegen.swizzle_hint %[[C2]][#iree_codegen.xor_shuffle<128, 8>]
+//       CHECK:   %[[E2:.+]] = memref.expand_shape %[[H2]] {{\[\[}}0, 1{{\]\]}}
+//       CHECK:   vector.load %[[E2]]