[VectorDistribute] Refactor VectorLayoutAnalysis into 2-phase forward/backward design (#23611)

Restructure the layout analysis from an interleaved forward+backward
worklist into a clean two-phase design:

Phase 1 (forward): Multi-candidate propagation from ToLayoutOp anchors
through uses. No IR mutation.

Resolve: Pick first candidate per value. This is a placeholder cost
model for now which matches the old analysis. Eventually, we will
consider coalescing, compute ops, mma layout, etc.

Phase 2 (backward fixup): Walk operations in reverse program order via
recursive fixupRegion/fixupOp. For each op, derive operand layouts from
resolved result layouts. Assign missing layouts, clone cheap ops
(constants, create_mask, step), or insert to_layout conversions on
conflict.

This naturally handles conflicts better in a predictable manner. // The
forward analysis is the main driver of the analysis. The reason for this
is that for a program to be well-formed for vector distribution, there
must be some way for the final store/return to get a layout. Otherwise,
there is not enough information in the program to determine how
distribution should be done. The forward analysis ensures that the final
return/store gets a layout in a well-formed program. The rest of the
program can get their layouts from backward propagation, everything in
the program must eventually reach the store/return.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.cpp
index 912d1e9..4bead73 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.cpp
@@ -360,10 +360,7 @@
   // Run the analysis and determine the layouts.
   LLVM_DEBUG(llvm::dbgs() << "Running Layout Analysis\n");
   llvm::MapVector<Value, VectorLayoutInterface> layouts;
-  if (failed(propagateVectorLayoutInfo(root, layouts))) {
-    LLVM_DEBUG(llvm::dbgs() << "Layout Analysis Failed\n");
-    return failure();
-  }
+  propagateVectorLayoutInfo(root, layouts);
   LLVM_DEBUG(llvm::dbgs() << "Layout Analysis Succeeded\n");
   LLVM_DEBUG(llvm::dbgs() << "\n\n");
 
diff --git a/compiler/src/iree/compiler/Codegen/Common/Passes.td b/compiler/src/iree/compiler/Codegen/Common/Passes.td
index fffe666..90958df 100644
--- a/compiler/src/iree/compiler/Codegen/Common/Passes.td
+++ b/compiler/src/iree/compiler/Codegen/Common/Passes.td
@@ -1069,7 +1069,7 @@
   let summary = "Test the PartitionableLoopsInterface";
 }
 
-def TestVectorLayoutAnalysisPass : Pass<"iree-codegen-test-vector-layout-analysis", ""> {
+def TestVectorLayoutAnalysisPass : InterfacePass<"iree-codegen-test-vector-layout-analysis", "mlir::FunctionOpInterface"> {
   let summary = "Test the vector layout analysis.";
   let description = [{
     Run VectorLayoutAnalysis on the root operation. The analysis emits remarks
diff --git a/compiler/src/iree/compiler/Codegen/Common/Transforms.h b/compiler/src/iree/compiler/Codegen/Common/Transforms.h
index ec17434..57a6804 100644
--- a/compiler/src/iree/compiler/Codegen/Common/Transforms.h
+++ b/compiler/src/iree/compiler/Codegen/Common/Transforms.h
@@ -81,69 +81,38 @@
 class VectorLayoutInterface;
 } // namespace IREE::VectorExt
 
-/// Analyzes the root op and its nested ops to propagate vector layouts
-/// originating from to_vector operations. Example:
+/// Propagates vector layouts through `root` and its nested operations.
+/// Populates `layouts` with the resolved layout for every vector-typed value.
 ///
-///    %root = vector.transfer_read
-///      |
-///      --> anchored to layout L (using a to_layout op)
-///    %root2 = vector.transfer_read
-///    %c = arith.mulf %root, %b
-///          |
-///          --> %root, %b and %c must have the same layout
-///    %e = arith.divf %b, %root2
-///          |
-///          --> %root2, %b and %e must have the same layout
+/// Layout anchors are `iree_vector_ext.to_layout` ops in the IR. These must
+/// be present before calling this function — they seed the analysis. On GPUs,
+/// anchors are typically placed on loads. IR without anchors is considered
+/// ill-formed for vector distribution.
 ///
-/// Here, the user provided an anchor point for %root, fixing its layout to L.
-/// The layout then uses its inference rules to find the layout of other
-/// values:
+/// Starting from anchors, layouts are inferred for the rest of the IR using
+/// op-specific rules (elementwise, transpose, contract, scf.for, etc.).
 ///
-///    %root = vector.transfer_read
-///     |
-///     --> inferred to layout L
-///    %root2 = vector.transfer_read
-///     |
-///     --> inferred to layout L
-///    %c = arith.mulf %root, %b
-///     |
-///     --> inferred to layout L
-///    %e = arith.divf %b, %root2
-///     |
-///     --> inferred to layout L
+/// Example — single anchor:
 ///
-/// If at any point, a value has a layout, but the user of that value requires
-/// a different layout, the analysis inserts a resolution operation. This
-/// resolution operation is `iree_vector_ext.to_layout`.
-/// For Example:
+///    %read = vector.transfer_read ...
+///    %anchored = iree_vector_ext.to_layout %read to layout(L)
+///    %c = arith.mulf %anchored, %b   --> inferred to layout L
+///    %e = arith.divf %b, %anchored   --> inferred to layout L
 ///
-/// %0 = vector.transfer_read
-///  |
-///  --> anchored to layout L
-/// %1 = vector.transfer_read
-///  |
-///  --> anchored to layout L'
-///  arith.addf %0, %1
-///     |
-///     --> %0 and %1 must have the same layout
+/// When a value receives conflicting layouts from different anchors, the
+/// analysis resolves the conflict by inserting a `to_layout` conversion.
+/// Cheap ops (constants, create_mask, step) are cloned per use site instead.
+/// The caller is responsible for lowering the inserted `to_layout` ops.
 ///
-/// To resolve the conflict, the analysis chooses one of the layouts, say
-/// L, and inserts a resolution operation to convert the other layout to L.
+/// Example — conflict resolution:
 ///
-/// %0 = vector.transfer_read
-///  |
-///  --> anchored to layout L
-/// %1 = vector.transfer_read
-///  |
-///  --> anchored to layout L'
-/// %resolved = iree_vector_ext.to_layout %1
-///  |
-///  --> inferred to layout L
-/// arith.addf %0, %resolved
-///
-/// The analysis itself will not try to resolve the conflict, but instead
-/// will leave it as a to_layout op, which can be rewritten by the caller.
-LogicalResult propagateVectorLayoutInfo(
+///    %a = ... to_layout ... to layout(L)
+///    %b = ... to_layout ... to layout(L')
+///    // %a and %b have different layouts but feed into the same op.
+///    // The analysis inserts a conversion on %b:
+///    %resolved = iree_vector_ext.to_layout %b to layout(L)
+///    arith.addf %a, %resolved   --> layout L
+void propagateVectorLayoutInfo(
     Operation *root,
     llvm::MapVector<Value, IREE::VectorExt::VectorLayoutInterface> &layouts);
 
diff --git a/compiler/src/iree/compiler/Codegen/Common/VectorLayoutAnalysis.cpp b/compiler/src/iree/compiler/Codegen/Common/VectorLayoutAnalysis.cpp
index 417f6e3..d4b539e 100644
--- a/compiler/src/iree/compiler/Codegen/Common/VectorLayoutAnalysis.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/VectorLayoutAnalysis.cpp
@@ -11,6 +11,7 @@
 #include <cassert>
 
 #include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SetVector.h"
 #include "llvm/Support/DebugLog.h"
 #include "llvm/Support/ErrorHandling.h"
 #include "llvm/Support/raw_ostream.h"
@@ -25,41 +26,140 @@
 
 using namespace IREE::VectorExt;
 
-struct LayoutInfo {
-  /// Given a value, propagate its layout information forward through its
-  /// users.
-  void propagateLayoutForward(Value val);
-  /// Given a value, propagate its layout information backward through its
-  /// defining operation.
-  void propagateLayoutBackward(Value val);
+/// Maximum number of candidate layouts tracked per value. Kept small to bound
+/// analysis cost; most values see 1-2 candidates in practice.
+static constexpr int kMaxCandidatesPerValue = 4;
 
-  void setLayoutIfUnset(Value val, VectorLayoutInterface layout) {
-    if (!isa<ShapedType>(val.getType())) {
-      // Don't set layouts on non-shaped types. This would anyway be an empty
-      // layout.
-      return;
-    }
-    if (hasLayout(val)) {
-      return;
-    }
-    layouts[val] = layout;
-    forward.push(val);
-    backward.push(val);
-  }
-  void setLayoutOrClone(OpOperand *val, VectorLayoutInterface layout);
-  VectorLayoutInterface getLayout(Value val) const {
-    return layouts.lookup(val);
-  }
-  bool hasLayout(Value val) const { return layouts.contains(val); }
+//===----------------------------------------------------------------------===//
+// Layout Analysis
+//
+// Phase 1: Forward propagation with multi-candidate tracking. Seeds from
+//   ToLayoutOp anchors, propagates forward through uses. Each value accumulates
+//   up to kMaxCandidatesPerValue candidate layouts. No IR mutation.
+//
+// Resolve: Pick first candidate for each value (first-wins). The multi-
+//   candidate data structure is ready for a cost model later.
+//
+// Phase 2: Backward fixup. Walks operations in reverse program order. For each
+//   op, determines operand layouts from resolved result/operand layouts.
+//   Assigns missing layouts, clones cheap ops, or inserts to_layout
+//   conversions.
+//
+// The forward analysis is the main driver of the analysis. The reason for this
+// is that for a program to be well-formed for vector distribution, there must
+// be some way for the final store/return to get a layout. Otherwise, there
+// is not enough information in the program to determine how distribution should
+// be done. The forward analysis ensures that the final return/store gets a
+// layout in a well-formed program. The rest of the program can get their
+// layouts from backward propagation, everything in the program must eventually
+// reach the store/return.
+//===----------------------------------------------------------------------===//
 
-  llvm::MapVector<Value, VectorLayoutInterface> layouts;
+namespace {
+
+struct LayoutAnalysis {
+  /// Multiple candidate layouts per value (Phase 1).
+  llvm::MapVector<Value, llvm::SmallSetVector<VectorLayoutInterface, 4>>
+      candidates;
+  /// Resolved layouts: single layout per value (after resolve, used by fixup).
+  llvm::MapVector<Value, VectorLayoutInterface> resolved;
+  /// Forward worklist (Phase 1 only).
   std::queue<Value> forward;
-  std::queue<Value> backward;
+
+  //===--- Phase 1: Forward propagation ---===//
+
+  bool addCandidate(Value val, VectorLayoutInterface layout);
+  VectorLayoutInterface getFirstCandidate(Value val) const;
+  void seed(Operation *root);
+  void propagateForward(Value val);
+  void propagateOneForward(Value val, VectorLayoutInterface layout);
+  void runForward();
+
+  //===--- Resolve ---===//
+
+  void resolve();
+
+  //===--- Phase 2: Backward fixup ---===//
+
+  VectorLayoutInterface getResolvedLayout(Value val) const {
+    return resolved.lookup(val);
+  }
+  bool hasResolvedLayout(Value val) const { return resolved.contains(val); }
+
+  void fixupRegion(Region &region);
+  void fixupOp(Operation *op);
+  void setLayoutOrClone(OpOperand *val, VectorLayoutInterface layout);
 };
 
-void LayoutInfo::propagateLayoutForward(Value val) {
-  LDBG() << "Propagating layout forward for value: " << val << "\n";
-  VectorLayoutInterface layout = getLayout(val);
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// Phase 1: Forward Propagation (no IR mutation)
+//===----------------------------------------------------------------------===//
+
+/// Add a candidate layout for a value. Returns true if a new candidate was
+/// added. Schedules the value for forward propagation.
+bool LayoutAnalysis::addCandidate(Value val, VectorLayoutInterface layout) {
+  if (!layout) {
+    return false;
+  }
+  if (!isa<ShapedType>(val.getType())) {
+    return false;
+  }
+  llvm::SmallSetVector<VectorLayoutInterface, 4> &set = candidates[val];
+  if (set.size() >= kMaxCandidatesPerValue) {
+    return false;
+  }
+  if (!set.insert(layout)) {
+    return false;
+  }
+  forward.push(val);
+  return true;
+}
+
+/// Return the first candidate layout for a value, or null.
+VectorLayoutInterface LayoutAnalysis::getFirstCandidate(Value val) const {
+  auto it = candidates.find(val);
+  if (it == candidates.end() || it->second.empty()) {
+    return {};
+  }
+  return it->second.front();
+}
+
+/// Seed anchors from ToLayoutOps.
+void LayoutAnalysis::seed(Operation *root) {
+  root->walk([&](ToLayoutOp toLayout) {
+    LDBG() << "Seeding layout from to_layout op: " << toLayout << "\n";
+    addCandidate(toLayout.getResult(), toLayout.getLayout());
+  });
+}
+
+/// Propagate all candidates for a value forward through its users.
+void LayoutAnalysis::propagateForward(Value val) {
+  LDBG() << "Propagating forward for value: " << val << "\n";
+  auto it = candidates.find(val);
+  if (it == candidates.end()) {
+    return;
+  }
+  for (VectorLayoutInterface layout : it->second) {
+    propagateOneForward(val, layout);
+  }
+}
+
+/// Run Phase 1: drain forward queue. Convergence is guaranteed because each
+/// value can contribute at most kMaxCandidatesPerValue new candidates, and
+/// addCandidate only enqueues when a genuinely new candidate is inserted.
+void LayoutAnalysis::runForward() {
+  while (!forward.empty()) {
+    Value val = forward.front();
+    forward.pop();
+    propagateForward(val);
+  }
+}
+
+/// Propagate a single layout forward through all users of a value.
+void LayoutAnalysis::propagateOneForward(Value val,
+                                         VectorLayoutInterface layout) {
   for (OpOperand &use : val.getUses()) {
     unsigned operandIdx = use.getOperandNumber();
     Operation *user = use.getOwner();
@@ -67,8 +167,8 @@
     if (auto forOp = dyn_cast<scf::ForOp>(user)) {
       Value arg = forOp.getTiedLoopRegionIterArg(&use);
       Value result = forOp.getTiedLoopResult(&use);
-      setLayoutIfUnset(arg, layout);
-      setLayoutIfUnset(result, layout);
+      addCandidate(arg, layout);
+      addCandidate(result, layout);
       continue;
     }
 
@@ -77,17 +177,17 @@
       if (auto forOp = dyn_cast<scf::ForOp>(parentOp)) {
         Value arg = forOp.getRegionIterArg(operandIdx);
         Value result = forOp->getResult(operandIdx);
-        setLayoutIfUnset(arg, layout);
-        setLayoutIfUnset(result, layout);
+        addCandidate(arg, layout);
+        addCandidate(result, layout);
         continue;
       }
       if (auto ifOp = dyn_cast<scf::IfOp>(parentOp)) {
         Value thenArg = ifOp.getThenRegion().getArgument(operandIdx);
         Value elseArg = ifOp.getElseRegion().getArgument(operandIdx);
         Value result = ifOp->getResult(operandIdx);
-        setLayoutIfUnset(thenArg, layout);
-        setLayoutIfUnset(elseArg, layout);
-        setLayoutIfUnset(result, layout);
+        addCandidate(thenArg, layout);
+        addCandidate(elseArg, layout);
+        addCandidate(result, layout);
         continue;
       }
     }
@@ -95,13 +195,13 @@
     if (auto yieldOp = dyn_cast<vector::YieldOp>(user)) {
       Operation *parentOp = cast<vector::MaskOp>(yieldOp->getParentOp());
       Value result = parentOp->getResult(operandIdx);
-      setLayoutIfUnset(result, layout);
+      addCandidate(result, layout);
       continue;
     }
 
     if (OpTrait::hasElementwiseMappableTraits(user)) {
       for (OpResult result : user->getOpResults()) {
-        setLayoutIfUnset(result, layout);
+        addCandidate(result, layout);
       }
       continue;
     }
@@ -110,139 +210,169 @@
       if (multiReduce.getSource() == val) {
         if (auto maskOp =
                 dyn_cast<vector::MaskOp>(multiReduce->getParentOp())) {
-          // We shouldn't have to do this... but vector.mask is badly designed
-          // and there is no mapping from the mask operand to the operation.
-          // TODO: Open vector.mask before vector distribute.
-          setLayoutOrClone(&maskOp.getMaskMutable(), layout);
+          addCandidate(maskOp.getMask(), layout);
         }
         SmallVector<bool> reductionMask = multiReduce.getReductionMask();
         VectorLayoutInterface reduceLayout = layout.project(reductionMask);
-        setLayoutIfUnset(multiReduce.getResult(), reduceLayout);
+        addCandidate(multiReduce.getResult(), reduceLayout);
         continue;
       }
       if (multiReduce.getAcc() == val) {
-        setLayoutIfUnset(multiReduce.getResult(), layout);
+        addCandidate(multiReduce.getResult(), layout);
         continue;
       }
     }
 
     if (auto transpose = dyn_cast<vector::TransposeOp>(user)) {
       if (transpose.getVector() == val) {
-        setLayoutIfUnset(transpose.getResult(),
-                         layout.permute(transpose.getPermutation()));
+        addCandidate(transpose.getResult(),
+                     layout.permute(transpose.getPermutation()));
         continue;
       }
     }
 
     if (auto contract = dyn_cast<vector::ContractionOp>(user)) {
       if (contract.getAcc() == val) {
-        setLayoutIfUnset(contract.getResult(), layout);
+        addCandidate(contract.getResult(), layout);
         continue;
       }
       if (contract.getLhs() == val || contract.getRhs() == val) {
-        if (contract->hasAttr("iree.gpu.mma")) {
-          // Intrinsic ops have fixed layouts, do not try to infer them through
-          // maps.
-          // TODO: Move to iree_gpu.multi_mma ops.
-          continue;
-        }
         if (auto maskOp = dyn_cast<vector::MaskOp>(contract->getParentOp())) {
-          // We shouldn't have to do this... but vector.mask is badly designed
-          // and there is no mapping from the mask operand to the operation.
-          // TODO: Open vector.mask before vector distribute.
           AffineMap map = contract.getMatchingIndexingMap(&use);
           if (map.isPermutation()) {
-            setLayoutOrClone(&maskOp.getMaskMutable(),
-                             layout.apply(inversePermutation(map)));
+            addCandidate(maskOp.getMask(),
+                         layout.apply(inversePermutation(map)));
           }
         }
-        // If lhs, rhs layout is known, infer result layout.
-        VectorLayoutInterface lhsLayout = getLayout(contract.getLhs());
-        VectorLayoutInterface rhsLayout = getLayout(contract.getRhs());
+        // Uses first candidate for each operand; first-wins avoids
+        // combinatorial explosion over candidate pairings.
+        // TODO: Consider all candidate combinations with a cost model.
+        VectorLayoutInterface lhsLayout = getFirstCandidate(contract.getLhs());
+        VectorLayoutInterface rhsLayout = getFirstCandidate(contract.getRhs());
         if (lhsLayout && rhsLayout) {
           AffineMap lhsMap = contract.getIndexingMapsArray()[0];
           AffineMap rhsMap = contract.getIndexingMapsArray()[1];
           AffineMap resMap = contract.getIndexingMapsArray()[2];
           VectorLayoutInterface resLayout = lhsLayout.getRecombinedLayout(
               {lhsLayout, rhsLayout}, {lhsMap, rhsMap}, resMap);
-          setLayoutIfUnset(contract.getResult(), resLayout);
+          addCandidate(contract.getResult(), resLayout);
         }
         continue;
       }
     }
 
     if (auto gather = dyn_cast<vector::GatherOp>(user)) {
-      setLayoutIfUnset(gather.getResult(), layout);
-      continue;
-    }
-
-    if (auto write = dyn_cast<vector::TransferWriteOp>(user)) {
-      if (!write.getMask()) {
-        continue;
-      }
-      OpOperand &mask = write.getMaskMutable()[0];
-      AffineMap maskMap =
-          inversePermutation(compressUnusedDims(write.getPermutationMap()));
-      setLayoutOrClone(&mask, layout.apply(maskMap));
+      addCandidate(gather.getResult(), layout);
       continue;
     }
 
     if (auto shapeCast = dyn_cast<vector::ShapeCastOp>(user)) {
-      setLayoutIfUnset(
-          shapeCast.getResult(),
-          layout.reshape(shapeCast.getResultVectorType().getShape()));
+      addCandidate(shapeCast.getResult(),
+                   layout.reshape(shapeCast.getResultVectorType().getShape()));
       continue;
     }
   }
 }
 
-void LayoutInfo::propagateLayoutBackward(Value val) {
-  LDBG() << "Propagating layout backward for value: " << val << "\n";
-  VectorLayoutInterface layout = getLayout(val);
-  if (auto blockArg = dyn_cast<BlockArgument>(val)) {
-    Operation *parent = val.getParentBlock()->getParentOp();
-    if (auto forOp = dyn_cast<scf::ForOp>(parent)) {
-      OpOperand *yielded = forOp.getTiedLoopYieldedValue(blockArg);
-      OpOperand *init = forOp.getTiedLoopInit(blockArg);
-      setLayoutOrClone(yielded, layout);
-      setLayoutOrClone(init, layout);
+//===----------------------------------------------------------------------===//
+// Resolve
+//===----------------------------------------------------------------------===//
+
+/// Pick first candidate for each value.
+void LayoutAnalysis::resolve() {
+  for (auto &[val, candidateSet] : candidates) {
+    if (!candidateSet.empty()) {
+      resolved[val] = candidateSet.front();
     }
+  }
+}
+
+//===----------------------------------------------------------------------===//
+// Phase 2: Backward Fixup (mutates IR)
+//===----------------------------------------------------------------------===//
+
+/// Walk operations in reverse within a region, fixing up operand layouts.
+/// Ops are collected upfront so that newly inserted to_layout ops (from
+/// setLayoutOrClone) are not visited by the walk.
+void LayoutAnalysis::fixupRegion(Region &region) {
+  for (Block &block : region.getBlocks()) {
+    SmallVector<Operation *> ops;
+    for (Operation &op : llvm::reverse(block.getOperations())) {
+      ops.push_back(&op);
+    }
+    for (Operation *op : ops) {
+      fixupOp(op);
+    }
+  }
+}
+
+/// Fix up operand layouts for a single operation. Result layouts are fixed
+/// (from resolve); this determines what operand layouts should be.
+void LayoutAnalysis::fixupOp(Operation *op) {
+  // transfer_write: vector operand layout -> derive mask layout.
+  if (auto write = dyn_cast<vector::TransferWriteOp>(op)) {
+    VectorLayoutInterface layout = getResolvedLayout(write.getVector());
+    if (!write.getMask()) {
+      return;
+    }
+    AffineMap maskMap =
+        inversePermutation(compressUnusedDims(write.getPermutationMap()));
+    setLayoutOrClone(&write.getMaskMutable()[0], layout.apply(maskMap));
     return;
   }
 
-  Operation *defOp = val.getDefiningOp();
-  if (OpTrait::hasElementwiseMappableTraits(defOp)) {
-    for (OpOperand &operand : defOp->getOpOperands()) {
+  // transfer_read: result layout -> derive mask layout.
+  if (auto read = dyn_cast<vector::TransferReadOp>(op)) {
+    VectorLayoutInterface layout = getResolvedLayout(read.getResult());
+    if (!read.getMask()) {
+      return;
+    }
+    AffineMap maskMap =
+        inversePermutation(compressUnusedDims(read.getPermutationMap()));
+    setLayoutOrClone(&read.getMaskMutable()[0], layout.apply(maskMap));
+    return;
+  }
+
+  // elementwise: result layout -> all operands get same layout.
+  if (OpTrait::hasElementwiseMappableTraits(op)) {
+    VectorLayoutInterface layout = getResolvedLayout(op->getResult(0));
+    for (OpOperand &operand : op->getOpOperands()) {
       setLayoutOrClone(&operand, layout);
     }
     return;
   }
 
-  if (auto toLayout = dyn_cast<ToLayoutOp>(defOp)) {
+  // to_layout: result layout -> input gets same layout.
+  if (auto toLayout = dyn_cast<ToLayoutOp>(op)) {
+    VectorLayoutInterface layout = getResolvedLayout(toLayout.getResult());
     setLayoutOrClone(&toLayout.getInputMutable(), layout);
     return;
   }
 
-  if (auto multiReduce = dyn_cast<vector::MultiDimReductionOp>(defOp)) {
+  // multi_dim_reduction: result layout -> acc gets same layout.
+  if (auto multiReduce = dyn_cast<vector::MultiDimReductionOp>(op)) {
+    VectorLayoutInterface layout = getResolvedLayout(multiReduce.getResult());
     setLayoutOrClone(&multiReduce.getAccMutable(), layout);
     return;
   }
 
-  if (auto transpose = dyn_cast<vector::TransposeOp>(defOp)) {
+  // transpose: result layout -> input gets inverse-permuted layout.
+  if (auto transpose = dyn_cast<vector::TransposeOp>(op)) {
+    VectorLayoutInterface layout = getResolvedLayout(transpose.getResult());
     setLayoutOrClone(
         &transpose.getVectorMutable(),
         layout.permute(invertPermutationVector(transpose.getPermutation())));
     return;
   }
 
-  if (auto broadcast = dyn_cast<vector::BroadcastOp>(defOp)) {
-    // Ensure that there are no broadcasted unit dims as we do not know how to
-    // handle them as of now.
-    assert(broadcast.computeBroadcastedUnitDims().empty() &&
-           "Stretching in broadcasting not implemented yet.");
+  // broadcast: result layout -> source gets projected layout.
+  if (auto broadcast = dyn_cast<vector::BroadcastOp>(op)) {
+    VectorLayoutInterface layout = getResolvedLayout(broadcast.getResult());
     if (!isa<VectorType>(broadcast.getSourceType())) {
       return;
     }
+    assert(broadcast.computeBroadcastedUnitDims().empty() &&
+           "Stretching in broadcasting not implemented yet.");
     int64_t numBroadcastedDims =
         broadcast.getResultVectorType().getRank() -
         cast<VectorType>(broadcast.getSourceType()).getRank();
@@ -254,39 +384,27 @@
     return;
   }
 
-  if (auto contract = dyn_cast<vector::ContractionOp>(defOp)) {
-    // TODO: We could determine lhs/rhs layout if we know one of them, but
-    // NYI for now.
+  // contract: result layout -> acc gets same layout.
+  if (auto contract = dyn_cast<vector::ContractionOp>(op)) {
+    VectorLayoutInterface layout = getResolvedLayout(contract.getResult());
     setLayoutOrClone(&contract.getAccMutable(), layout);
     return;
   }
 
-  if (auto gather = dyn_cast<vector::GatherOp>(defOp)) {
-    OpOperand &indices = gather.getIndicesMutable();
-    OpOperand &mask = gather.getMaskMutable();
-    OpOperand &passthru = gather.getPassThruMutable();
-    setLayoutOrClone(&indices, layout);
-    setLayoutOrClone(&mask, layout);
-    setLayoutOrClone(&passthru, layout);
+  // gather: result layout -> indices, mask, passthru get same layout.
+  if (auto gather = dyn_cast<vector::GatherOp>(op)) {
+    VectorLayoutInterface layout = getResolvedLayout(gather.getResult());
+    setLayoutOrClone(&gather.getIndicesMutable(), layout);
+    setLayoutOrClone(&gather.getMaskMutable(), layout);
+    setLayoutOrClone(&gather.getPassThruMutable(), layout);
     return;
   }
 
-  if (auto read = dyn_cast<vector::TransferReadOp>(defOp)) {
-    if (!read.getMask()) {
-      return;
-    }
-    OpOperand &mask = read.getMaskMutable()[0];
-    AffineMap maskMap =
-        inversePermutation(compressUnusedDims(read.getPermutationMap()));
-    setLayoutOrClone(&mask, layout.apply(maskMap));
-    return;
-  }
-
-  if (auto gather = dyn_cast<TransferGatherOp>(defOp)) {
+  // transfer_gather: result layout -> index vecs + mask get projected layouts.
+  if (auto gather = dyn_cast<TransferGatherOp>(op)) {
+    VectorLayoutInterface layout = getResolvedLayout(gather.getResult());
     SmallVector<AffineMap> maps = gather.getIndexingMapsArray();
     int64_t numIndexVecs = gather.getIndexVecs().size();
-    // Index vec maps are maps[1..numIndexVecs]. They only use dim exprs,
-    // so strip symbols before applying to layout.
     for (auto [i, operand] : llvm::enumerate(gather.getIndexVecsMutable())) {
       AffineMap indexVecMap = maps[1 + i];
       AffineMap projected =
@@ -304,28 +422,60 @@
     return;
   }
 
-  if (auto shapeCast = dyn_cast<vector::ShapeCastOp>(defOp)) {
+  // shape_cast: result layout -> source gets reshaped layout.
+  if (auto shapeCast = dyn_cast<vector::ShapeCastOp>(op)) {
+    VectorLayoutInterface layout = getResolvedLayout(shapeCast.getResult());
     setLayoutOrClone(
         &shapeCast.getSourceMutable(),
         layout.reshape(shapeCast.getSourceVectorType().getShape()));
     return;
   }
+
+  // scf.for: fix init_args/yield from result layouts, then recurse into body.
+  if (auto forOp = dyn_cast<scf::ForOp>(op)) {
+    auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
+    for (auto [i, result] : llvm::enumerate(forOp.getResults())) {
+      VectorLayoutInterface layout = getResolvedLayout(result);
+      setLayoutOrClone(&yieldOp->getOpOperand(i), layout);
+      setLayoutOrClone(&forOp.getInitArgsMutable()[i], layout);
+    }
+    fixupRegion(forOp.getBodyRegion());
+    return;
+  }
+
+  // Default: recurse into nested regions for ops we don't explicitly handle
+  // (e.g. scf.forall, scf.if, vector.mask).
+  for (Region &region : op->getRegions()) {
+    fixupRegion(region);
+  }
 }
 
-void LayoutInfo::setLayoutOrClone(OpOperand *val,
-                                  VectorLayoutInterface layout) {
+/// Assign a layout to an operand, cloning cheap ops or inserting conversions
+/// on conflict.
+void LayoutAnalysis::setLayoutOrClone(OpOperand *val,
+                                      VectorLayoutInterface layout) {
   if (!layout) {
-    // No layout to set.
     return;
   }
   if (!isa<ShapedType>(val->get().getType())) {
-    // Don't set layouts on non-shaped types. This would anyway be an empty
-    // layout.
     return;
   }
-  // Always clone constant like ops and set the layout on them.
+
+  // No layout yet -- assign.
+  if (!hasResolvedLayout(val->get())) {
+    resolved[val->get()] = layout;
+    return;
+  }
+
+  // Same layout -- nothing to do.
+  if (getResolvedLayout(val->get()) == layout) {
+    return;
+  }
+
+  // Different layout -- clone cheap ops or insert to_layout conversion.
   OpBuilder b(val->getOwner());
   if (Operation *defOp = val->get().getDefiningOp()) {
+    // Clone constant-like and duplicatable ops per use site.
     bool isConstantLike = defOp->hasTrait<OpTrait::ConstantLike>();
     bool isDuplicatable =
         isa<vector::StepOp, vector::CreateMaskOp, vector::ConstantMaskOp>(
@@ -334,54 +484,39 @@
       b.setInsertionPoint(defOp);
       Operation *cloned = b.clone(*defOp);
       val->set(cloned->getResult(0));
-      layouts[cloned->getResult(0)] = layout;
+      resolved[cloned->getResult(0)] = layout;
       return;
     }
   }
 
-  if (!hasLayout(val->get())) {
-    layouts[val->get()] = layout;
-    forward.push(val->get());
-    backward.push(val->get());
-    return;
-  }
-
-  if (getLayout(val->get()) != layout) {
-    // Create `to_layout` op to change layout if it's not the same as the
-    // existing.
-    Value v = val->get();
-    Value layourtedV = ToLayoutOp::create(b, v.getLoc(), v, layout);
-    val->set(layourtedV);
-    layouts[layourtedV] = layout;
-    return;
-  }
+  // Non-cheap op -- insert to_layout conversion.
+  Value v = val->get();
+  Value converted = ToLayoutOp::create(b, v.getLoc(), v, layout);
+  val->set(converted);
+  resolved[converted] = layout;
 }
 
-LogicalResult propagateVectorLayoutInfo(
+//===----------------------------------------------------------------------===//
+// Entry Point
+//===----------------------------------------------------------------------===//
+
+void propagateVectorLayoutInfo(
     Operation *root, llvm::MapVector<Value, VectorLayoutInterface> &layouts) {
-  LayoutInfo info;
-  // Initialize propagation info with to_layout operations;
-  root->walk([&](ToLayoutOp toLayout) {
-    LDBG() << "Initializing layout from to_layout op: " << toLayout << "\n";
-    info.setLayoutIfUnset(toLayout.getResult(), toLayout.getLayout());
-  });
-  // Propagate all layout information until fixpoint. Give priority to
-  // forward propagation and only do backward propagation when there is no
-  // forward propagation work left.
-  while (!info.forward.empty() || !info.backward.empty()) {
-    SmallVector<Value> changed;
-    if (!info.forward.empty()) {
-      Value val = info.forward.front();
-      info.forward.pop();
-      info.propagateLayoutForward(val);
-    } else {
-      Value val = info.backward.front();
-      info.backward.pop();
-      info.propagateLayoutBackward(val);
-    }
+  LayoutAnalysis analysis;
+
+  // Phase 1: Seed anchors and forward propagation (no IR mutation).
+  analysis.seed(root);
+  analysis.runForward();
+
+  // Resolve: pick first candidate for each value.
+  analysis.resolve();
+
+  // Phase 2: Backward fixup (mutates IR).
+  for (Region &region : root->getRegions()) {
+    analysis.fixupRegion(region);
   }
-  layouts = std::move(info.layouts);
-  return success();
+
+  layouts = std::move(analysis.resolved);
 }
 
 #define GEN_PASS_DEF_TESTVECTORLAYOUTANALYSISPASS
@@ -392,10 +527,7 @@
   void runOnOperation() override {
     Operation *root = getOperation();
     llvm::MapVector<Value, VectorLayoutInterface> layouts;
-    if (failed(propagateVectorLayoutInfo(root, layouts))) {
-      root->emitError("Layout Analysis Failed");
-      return signalPassFailure();
-    }
+    propagateVectorLayoutInfo(root, layouts);
 
     root->walk([&](Operation *op) {
       if (isa<ToLayoutOp>(op)) {
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/vector_layout_analysis.mlir b/compiler/src/iree/compiler/Codegen/Common/test/vector_layout_analysis.mlir
index af3e4fc..15f8732 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/vector_layout_analysis.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/vector_layout_analysis.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt -iree-codegen-test-vector-layout-analysis --split-input-file %s --verify-diagnostics
+// RUN: iree-opt --pass-pipeline="builtin.module(any(iree-codegen-test-vector-layout-analysis))" --split-input-file %s --verify-diagnostics
 
 #layout = #iree_vector_ext.nested_layout<
   subgroup_tile = [1, 1],
@@ -572,17 +572,20 @@
   thread_strides   = [0]
 >
 
+// The to_layout anchors are placed after the arithmetic so that forward
+// propagation seeds from them and backward fixup assigns the operand layouts.
 func.func @handle_multiuse_step(%lhs: vector<64xindex>, %rhs: vector<64xindex>) -> (vector<64xindex>, vector<64xindex>) {
-  %l_lhs = iree_vector_ext.to_layout %lhs to layout(#layoutA) : vector<64xindex>
-  %r_lhs = iree_vector_ext.to_layout %rhs to layout(#layoutB) : vector<64xindex>
+  // Two remarks: vector.step is cloned (one per use site with different layout).
   %cst = vector.step : vector<64xindex>
-  // expected-remark @above {{element_tile = [1]}}
   // expected-remark @above {{element_tile = [64]}}
+  // expected-remark @above {{element_tile = [1]}}
   %scaled_lhs = arith.muli %cst, %lhs : vector<64xindex>
   // expected-remark @above {{element_tile = [64]}}
+  %l = iree_vector_ext.to_layout %scaled_lhs to layout(#layoutA) : vector<64xindex>
   %scaled_rhs = arith.muli %cst, %rhs : vector<64xindex>
   // expected-remark @above {{element_tile = [1]}}
-  func.return %scaled_lhs, %scaled_rhs : vector<64xindex>, vector<64xindex>
+  %r = iree_vector_ext.to_layout %scaled_rhs to layout(#layoutB) : vector<64xindex>
+  func.return %l, %r : vector<64xindex>, vector<64xindex>
 }
 
 // -----
@@ -809,3 +812,122 @@
   // expected-remark @above {{subgroup_tile = [2], batch_tile = [4], outer_tile = [1], thread_tile = [4], element_tile = [4], subgroup_strides = [1], thread_strides = [1]}}
   func.return %reshape : vector<128xf16>
 }
+
+// -----
+
+// Test multi-candidate accumulation with first-wins resolution.
+// %c receives candidates from both %a (layoutA) and %b (layoutB).
+// First-wins means %c should get layoutA (first candidate in walk order).
+
+#layoutA = #iree_vector_ext.nested_layout<
+  subgroup_tile = [1, 1],
+  batch_tile = [1, 1],
+  outer_tile = [1, 1],
+  thread_tile = [1, 1],
+  element_tile = [16, 16],
+
+  subgroup_strides = [0, 0],
+  thread_strides   = [0, 0]
+>
+
+#layoutB = #iree_vector_ext.nested_layout<
+  subgroup_tile = [1, 1],
+  batch_tile = [2, 2],
+  outer_tile = [1, 1],
+  thread_tile = [1, 1],
+  element_tile = [8, 8],
+
+  subgroup_strides = [0, 0],
+  thread_strides   = [0, 0]
+>
+
+func.func @multi_anchor_first_wins(%input: vector<16x16xf16>) -> vector<16x16xf16> {
+  %a = iree_vector_ext.to_layout %input to layout(#layoutA) : vector<16x16xf16>
+  %b = iree_vector_ext.to_layout %input to layout(#layoutB) : vector<16x16xf16>
+  // %c gets layoutA from %a (first anchor in walk order) via forward propagation.
+  %c = arith.addf %a, %b : vector<16x16xf16>
+  // expected-remark @above {{element_tile = [16, 16]}}
+  func.return %c : vector<16x16xf16>
+}
+
+// -----
+
+// Test backward fixup for scf.for: to_layout anchor after the loop causes
+// result layout to flow to init_args and yield via Phase 2 backward fixup.
+
+#layout = #iree_vector_ext.nested_layout<
+  subgroup_tile = [1, 1],
+  batch_tile = [1, 1],
+  outer_tile = [1, 1],
+  thread_tile = [1, 1],
+  element_tile = [16, 16],
+
+  subgroup_strides = [0, 0],
+  thread_strides   = [0, 0]
+>
+
+func.func @scffor_backward_fixup(%arr: memref<16x16xf16>, %init: vector<16x16xf16>) -> vector<16x16xf16> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c10 = arith.constant 10 : index
+  %cst_0 = arith.constant 0.0 : f16
+  %out = scf.for %iv = %c0 to %c10 step %c1 iter_args(%arg1 = %init) -> (vector<16x16xf16>) {
+    // expected-remark @above {{element_tile = [16, 16]}}
+    %val = vector.transfer_read %arr[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16>
+    // expected-remark @above {{element_tile = [16, 16]}}
+    %sum = arith.addf %arg1, %val : vector<16x16xf16>
+    // expected-remark @above {{element_tile = [16, 16]}}
+    scf.yield %sum : vector<16x16xf16>
+  }
+  %outl = iree_vector_ext.to_layout %out to layout(#layout) : vector<16x16xf16>
+  func.return %outl : vector<16x16xf16>
+}
+
+// -----
+
+// Test that a shared create_mask gets cloned per use site when two
+// transfer_write ops have different data layouts (and thus need different
+// mask layouts). The backward fixup processes transfer_write ops in reverse
+// program order, deriving mask layouts from the vector operand's resolved
+// layout. The second write assigns the mask layout; the first write finds
+// a conflict and clones the create_mask.
+
+#layoutA = #iree_vector_ext.nested_layout<
+  subgroup_tile = [1, 1],
+  batch_tile = [1, 1],
+  outer_tile = [1, 1],
+  thread_tile = [1, 1],
+  element_tile = [16, 16],
+
+  subgroup_strides = [0, 0],
+  thread_strides   = [0, 0]
+>
+
+#layoutB = #iree_vector_ext.nested_layout<
+  subgroup_tile = [1, 1],
+  batch_tile = [2, 2],
+  outer_tile = [1, 1],
+  thread_tile = [1, 1],
+  element_tile = [8, 8],
+
+  subgroup_strides = [0, 0],
+  thread_strides   = [0, 0]
+>
+
+func.func @clone_shared_mask_on_layout_conflict(
+    %arr: memref<16x16xf16>,
+    %a: vector<16x16xf16>,
+    %b: vector<16x16xf16>) {
+  %c0 = arith.constant 0 : index
+  %c12 = arith.constant 12 : index
+  // The mask gets cloned per transfer_write use site, each with the
+  // correct layout derived from the written value's layout.
+  %mask = vector.create_mask %c12, %c12 : vector<16x16xi1>
+  // expected-remark @above {{element_tile = [8, 8]}}
+  // expected-remark @above {{element_tile = [16, 16]}}
+  %al = iree_vector_ext.to_layout %a to layout(#layoutA) : vector<16x16xf16>
+  %bl = iree_vector_ext.to_layout %b to layout(#layoutB) : vector<16x16xf16>
+  vector.transfer_write %al, %arr[%c0, %c0], %mask {in_bounds = [true, true]} : vector<16x16xf16>, memref<16x16xf16>
+  vector.transfer_write %bl, %arr[%c0, %c0], %mask {in_bounds = [true, true]} : vector<16x16xf16>, memref<16x16xf16>
+  func.return
+}