[VectorDistribute] Refactor VectorLayoutAnalysis into 2-phase forward/backward design (#23611)
Restructure the layout analysis from an interleaved forward+backward
worklist into a clean two-phase design:
Phase 1 (forward): Multi-candidate propagation from ToLayoutOp anchors
through uses. No IR mutation.
Resolve: Pick first candidate per value. This is a placeholder cost
model for now which matches the old analysis. Eventually, we will
consider coalescing, compute ops, mma layout, etc.
Phase 2 (backward fixup): Walk operations in reverse program order via
recursive fixupRegion/fixupOp. For each op, derive operand layouts from
resolved result layouts. Assign missing layouts, clone cheap ops
(constants, create_mask, step), or insert to_layout conversions on
conflict.
This naturally handles conflicts better in a predictable manner. // The
forward analysis is the main driver of the analysis. The reason for this
is that for a program to be well-formed for vector distribution, there
must be some way for the final store/return to get a layout. Otherwise,
there is not enough information in the program to determine how
distribution should be done. The forward analysis ensures that the final
return/store gets a layout in a well-formed program. The rest of the
program can get their layouts from backward propagation, everything in
the program must eventually reach the store/return.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.cpp
index 912d1e9..4bead73 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.cpp
@@ -360,10 +360,7 @@
// Run the analysis and determine the layouts.
LLVM_DEBUG(llvm::dbgs() << "Running Layout Analysis\n");
llvm::MapVector<Value, VectorLayoutInterface> layouts;
- if (failed(propagateVectorLayoutInfo(root, layouts))) {
- LLVM_DEBUG(llvm::dbgs() << "Layout Analysis Failed\n");
- return failure();
- }
+ propagateVectorLayoutInfo(root, layouts);
LLVM_DEBUG(llvm::dbgs() << "Layout Analysis Succeeded\n");
LLVM_DEBUG(llvm::dbgs() << "\n\n");
diff --git a/compiler/src/iree/compiler/Codegen/Common/Passes.td b/compiler/src/iree/compiler/Codegen/Common/Passes.td
index fffe666..90958df 100644
--- a/compiler/src/iree/compiler/Codegen/Common/Passes.td
+++ b/compiler/src/iree/compiler/Codegen/Common/Passes.td
@@ -1069,7 +1069,7 @@
let summary = "Test the PartitionableLoopsInterface";
}
-def TestVectorLayoutAnalysisPass : Pass<"iree-codegen-test-vector-layout-analysis", ""> {
+def TestVectorLayoutAnalysisPass : InterfacePass<"iree-codegen-test-vector-layout-analysis", "mlir::FunctionOpInterface"> {
let summary = "Test the vector layout analysis.";
let description = [{
Run VectorLayoutAnalysis on the root operation. The analysis emits remarks
diff --git a/compiler/src/iree/compiler/Codegen/Common/Transforms.h b/compiler/src/iree/compiler/Codegen/Common/Transforms.h
index ec17434..57a6804 100644
--- a/compiler/src/iree/compiler/Codegen/Common/Transforms.h
+++ b/compiler/src/iree/compiler/Codegen/Common/Transforms.h
@@ -81,69 +81,38 @@
class VectorLayoutInterface;
} // namespace IREE::VectorExt
-/// Analyzes the root op and its nested ops to propagate vector layouts
-/// originating from to_vector operations. Example:
+/// Propagates vector layouts through `root` and its nested operations.
+/// Populates `layouts` with the resolved layout for every vector-typed value.
///
-/// %root = vector.transfer_read
-/// |
-/// --> anchored to layout L (using a to_layout op)
-/// %root2 = vector.transfer_read
-/// %c = arith.mulf %root, %b
-/// |
-/// --> %root, %b and %c must have the same layout
-/// %e = arith.divf %b, %root2
-/// |
-/// --> %root2, %b and %e must have the same layout
+/// Layout anchors are `iree_vector_ext.to_layout` ops in the IR. These must
+/// be present before calling this function — they seed the analysis. On GPUs,
+/// anchors are typically placed on loads. IR without anchors is considered
+/// ill-formed for vector distribution.
///
-/// Here, the user provided an anchor point for %root, fixing its layout to L.
-/// The layout then uses its inference rules to find the layout of other
-/// values:
+/// Starting from anchors, layouts are inferred for the rest of the IR using
+/// op-specific rules (elementwise, transpose, contract, scf.for, etc.).
///
-/// %root = vector.transfer_read
-/// |
-/// --> inferred to layout L
-/// %root2 = vector.transfer_read
-/// |
-/// --> inferred to layout L
-/// %c = arith.mulf %root, %b
-/// |
-/// --> inferred to layout L
-/// %e = arith.divf %b, %root2
-/// |
-/// --> inferred to layout L
+/// Example — single anchor:
///
-/// If at any point, a value has a layout, but the user of that value requires
-/// a different layout, the analysis inserts a resolution operation. This
-/// resolution operation is `iree_vector_ext.to_layout`.
-/// For Example:
+/// %read = vector.transfer_read ...
+/// %anchored = iree_vector_ext.to_layout %read to layout(L)
+/// %c = arith.mulf %anchored, %b --> inferred to layout L
+/// %e = arith.divf %b, %anchored --> inferred to layout L
///
-/// %0 = vector.transfer_read
-/// |
-/// --> anchored to layout L
-/// %1 = vector.transfer_read
-/// |
-/// --> anchored to layout L'
-/// arith.addf %0, %1
-/// |
-/// --> %0 and %1 must have the same layout
+/// When a value receives conflicting layouts from different anchors, the
+/// analysis resolves the conflict by inserting a `to_layout` conversion.
+/// Cheap ops (constants, create_mask, step) are cloned per use site instead.
+/// The caller is responsible for lowering the inserted `to_layout` ops.
///
-/// To resolve the conflict, the analysis chooses one of the layouts, say
-/// L, and inserts a resolution operation to convert the other layout to L.
+/// Example — conflict resolution:
///
-/// %0 = vector.transfer_read
-/// |
-/// --> anchored to layout L
-/// %1 = vector.transfer_read
-/// |
-/// --> anchored to layout L'
-/// %resolved = iree_vector_ext.to_layout %1
-/// |
-/// --> inferred to layout L
-/// arith.addf %0, %resolved
-///
-/// The analysis itself will not try to resolve the conflict, but instead
-/// will leave it as a to_layout op, which can be rewritten by the caller.
-LogicalResult propagateVectorLayoutInfo(
+/// %a = ... to_layout ... to layout(L)
+/// %b = ... to_layout ... to layout(L')
+/// // %a and %b have different layouts but feed into the same op.
+/// // The analysis inserts a conversion on %b:
+/// %resolved = iree_vector_ext.to_layout %b to layout(L)
+/// arith.addf %a, %resolved --> layout L
+void propagateVectorLayoutInfo(
Operation *root,
llvm::MapVector<Value, IREE::VectorExt::VectorLayoutInterface> &layouts);
diff --git a/compiler/src/iree/compiler/Codegen/Common/VectorLayoutAnalysis.cpp b/compiler/src/iree/compiler/Codegen/Common/VectorLayoutAnalysis.cpp
index 417f6e3..d4b539e 100644
--- a/compiler/src/iree/compiler/Codegen/Common/VectorLayoutAnalysis.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/VectorLayoutAnalysis.cpp
@@ -11,6 +11,7 @@
#include <cassert>
#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SetVector.h"
#include "llvm/Support/DebugLog.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/raw_ostream.h"
@@ -25,41 +26,140 @@
using namespace IREE::VectorExt;
-struct LayoutInfo {
- /// Given a value, propagate its layout information forward through its
- /// users.
- void propagateLayoutForward(Value val);
- /// Given a value, propagate its layout information backward through its
- /// defining operation.
- void propagateLayoutBackward(Value val);
+/// Maximum number of candidate layouts tracked per value. Kept small to bound
+/// analysis cost; most values see 1-2 candidates in practice.
+static constexpr int kMaxCandidatesPerValue = 4;
- void setLayoutIfUnset(Value val, VectorLayoutInterface layout) {
- if (!isa<ShapedType>(val.getType())) {
- // Don't set layouts on non-shaped types. This would anyway be an empty
- // layout.
- return;
- }
- if (hasLayout(val)) {
- return;
- }
- layouts[val] = layout;
- forward.push(val);
- backward.push(val);
- }
- void setLayoutOrClone(OpOperand *val, VectorLayoutInterface layout);
- VectorLayoutInterface getLayout(Value val) const {
- return layouts.lookup(val);
- }
- bool hasLayout(Value val) const { return layouts.contains(val); }
+//===----------------------------------------------------------------------===//
+// Layout Analysis
+//
+// Phase 1: Forward propagation with multi-candidate tracking. Seeds from
+// ToLayoutOp anchors, propagates forward through uses. Each value accumulates
+// up to kMaxCandidatesPerValue candidate layouts. No IR mutation.
+//
+// Resolve: Pick first candidate for each value (first-wins). The multi-
+// candidate data structure is ready for a cost model later.
+//
+// Phase 2: Backward fixup. Walks operations in reverse program order. For each
+// op, determines operand layouts from resolved result/operand layouts.
+// Assigns missing layouts, clones cheap ops, or inserts to_layout
+// conversions.
+//
+// The forward analysis is the main driver of the analysis. The reason for this
+// is that for a program to be well-formed for vector distribution, there must
+// be some way for the final store/return to get a layout. Otherwise, there
+// is not enough information in the program to determine how distribution should
+// be done. The forward analysis ensures that the final return/store gets a
+// layout in a well-formed program. The rest of the program can get their
+// layouts from backward propagation, everything in the program must eventually
+// reach the store/return.
+//===----------------------------------------------------------------------===//
- llvm::MapVector<Value, VectorLayoutInterface> layouts;
+namespace {
+
+struct LayoutAnalysis {
+ /// Multiple candidate layouts per value (Phase 1).
+ llvm::MapVector<Value, llvm::SmallSetVector<VectorLayoutInterface, 4>>
+ candidates;
+ /// Resolved layouts: single layout per value (after resolve, used by fixup).
+ llvm::MapVector<Value, VectorLayoutInterface> resolved;
+ /// Forward worklist (Phase 1 only).
std::queue<Value> forward;
- std::queue<Value> backward;
+
+ //===--- Phase 1: Forward propagation ---===//
+
+ bool addCandidate(Value val, VectorLayoutInterface layout);
+ VectorLayoutInterface getFirstCandidate(Value val) const;
+ void seed(Operation *root);
+ void propagateForward(Value val);
+ void propagateOneForward(Value val, VectorLayoutInterface layout);
+ void runForward();
+
+ //===--- Resolve ---===//
+
+ void resolve();
+
+ //===--- Phase 2: Backward fixup ---===//
+
+ VectorLayoutInterface getResolvedLayout(Value val) const {
+ return resolved.lookup(val);
+ }
+ bool hasResolvedLayout(Value val) const { return resolved.contains(val); }
+
+ void fixupRegion(Region ®ion);
+ void fixupOp(Operation *op);
+ void setLayoutOrClone(OpOperand *val, VectorLayoutInterface layout);
};
-void LayoutInfo::propagateLayoutForward(Value val) {
- LDBG() << "Propagating layout forward for value: " << val << "\n";
- VectorLayoutInterface layout = getLayout(val);
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// Phase 1: Forward Propagation (no IR mutation)
+//===----------------------------------------------------------------------===//
+
+/// Add a candidate layout for a value. Returns true if a new candidate was
+/// added. Schedules the value for forward propagation.
+bool LayoutAnalysis::addCandidate(Value val, VectorLayoutInterface layout) {
+ if (!layout) {
+ return false;
+ }
+ if (!isa<ShapedType>(val.getType())) {
+ return false;
+ }
+ llvm::SmallSetVector<VectorLayoutInterface, 4> &set = candidates[val];
+ if (set.size() >= kMaxCandidatesPerValue) {
+ return false;
+ }
+ if (!set.insert(layout)) {
+ return false;
+ }
+ forward.push(val);
+ return true;
+}
+
+/// Return the first candidate layout for a value, or null.
+VectorLayoutInterface LayoutAnalysis::getFirstCandidate(Value val) const {
+ auto it = candidates.find(val);
+ if (it == candidates.end() || it->second.empty()) {
+ return {};
+ }
+ return it->second.front();
+}
+
+/// Seed anchors from ToLayoutOps.
+void LayoutAnalysis::seed(Operation *root) {
+ root->walk([&](ToLayoutOp toLayout) {
+ LDBG() << "Seeding layout from to_layout op: " << toLayout << "\n";
+ addCandidate(toLayout.getResult(), toLayout.getLayout());
+ });
+}
+
+/// Propagate all candidates for a value forward through its users.
+void LayoutAnalysis::propagateForward(Value val) {
+ LDBG() << "Propagating forward for value: " << val << "\n";
+ auto it = candidates.find(val);
+ if (it == candidates.end()) {
+ return;
+ }
+ for (VectorLayoutInterface layout : it->second) {
+ propagateOneForward(val, layout);
+ }
+}
+
+/// Run Phase 1: drain forward queue. Convergence is guaranteed because each
+/// value can contribute at most kMaxCandidatesPerValue new candidates, and
+/// addCandidate only enqueues when a genuinely new candidate is inserted.
+void LayoutAnalysis::runForward() {
+ while (!forward.empty()) {
+ Value val = forward.front();
+ forward.pop();
+ propagateForward(val);
+ }
+}
+
+/// Propagate a single layout forward through all users of a value.
+void LayoutAnalysis::propagateOneForward(Value val,
+ VectorLayoutInterface layout) {
for (OpOperand &use : val.getUses()) {
unsigned operandIdx = use.getOperandNumber();
Operation *user = use.getOwner();
@@ -67,8 +167,8 @@
if (auto forOp = dyn_cast<scf::ForOp>(user)) {
Value arg = forOp.getTiedLoopRegionIterArg(&use);
Value result = forOp.getTiedLoopResult(&use);
- setLayoutIfUnset(arg, layout);
- setLayoutIfUnset(result, layout);
+ addCandidate(arg, layout);
+ addCandidate(result, layout);
continue;
}
@@ -77,17 +177,17 @@
if (auto forOp = dyn_cast<scf::ForOp>(parentOp)) {
Value arg = forOp.getRegionIterArg(operandIdx);
Value result = forOp->getResult(operandIdx);
- setLayoutIfUnset(arg, layout);
- setLayoutIfUnset(result, layout);
+ addCandidate(arg, layout);
+ addCandidate(result, layout);
continue;
}
if (auto ifOp = dyn_cast<scf::IfOp>(parentOp)) {
Value thenArg = ifOp.getThenRegion().getArgument(operandIdx);
Value elseArg = ifOp.getElseRegion().getArgument(operandIdx);
Value result = ifOp->getResult(operandIdx);
- setLayoutIfUnset(thenArg, layout);
- setLayoutIfUnset(elseArg, layout);
- setLayoutIfUnset(result, layout);
+ addCandidate(thenArg, layout);
+ addCandidate(elseArg, layout);
+ addCandidate(result, layout);
continue;
}
}
@@ -95,13 +195,13 @@
if (auto yieldOp = dyn_cast<vector::YieldOp>(user)) {
Operation *parentOp = cast<vector::MaskOp>(yieldOp->getParentOp());
Value result = parentOp->getResult(operandIdx);
- setLayoutIfUnset(result, layout);
+ addCandidate(result, layout);
continue;
}
if (OpTrait::hasElementwiseMappableTraits(user)) {
for (OpResult result : user->getOpResults()) {
- setLayoutIfUnset(result, layout);
+ addCandidate(result, layout);
}
continue;
}
@@ -110,139 +210,169 @@
if (multiReduce.getSource() == val) {
if (auto maskOp =
dyn_cast<vector::MaskOp>(multiReduce->getParentOp())) {
- // We shouldn't have to do this... but vector.mask is badly designed
- // and there is no mapping from the mask operand to the operation.
- // TODO: Open vector.mask before vector distribute.
- setLayoutOrClone(&maskOp.getMaskMutable(), layout);
+ addCandidate(maskOp.getMask(), layout);
}
SmallVector<bool> reductionMask = multiReduce.getReductionMask();
VectorLayoutInterface reduceLayout = layout.project(reductionMask);
- setLayoutIfUnset(multiReduce.getResult(), reduceLayout);
+ addCandidate(multiReduce.getResult(), reduceLayout);
continue;
}
if (multiReduce.getAcc() == val) {
- setLayoutIfUnset(multiReduce.getResult(), layout);
+ addCandidate(multiReduce.getResult(), layout);
continue;
}
}
if (auto transpose = dyn_cast<vector::TransposeOp>(user)) {
if (transpose.getVector() == val) {
- setLayoutIfUnset(transpose.getResult(),
- layout.permute(transpose.getPermutation()));
+ addCandidate(transpose.getResult(),
+ layout.permute(transpose.getPermutation()));
continue;
}
}
if (auto contract = dyn_cast<vector::ContractionOp>(user)) {
if (contract.getAcc() == val) {
- setLayoutIfUnset(contract.getResult(), layout);
+ addCandidate(contract.getResult(), layout);
continue;
}
if (contract.getLhs() == val || contract.getRhs() == val) {
- if (contract->hasAttr("iree.gpu.mma")) {
- // Intrinsic ops have fixed layouts, do not try to infer them through
- // maps.
- // TODO: Move to iree_gpu.multi_mma ops.
- continue;
- }
if (auto maskOp = dyn_cast<vector::MaskOp>(contract->getParentOp())) {
- // We shouldn't have to do this... but vector.mask is badly designed
- // and there is no mapping from the mask operand to the operation.
- // TODO: Open vector.mask before vector distribute.
AffineMap map = contract.getMatchingIndexingMap(&use);
if (map.isPermutation()) {
- setLayoutOrClone(&maskOp.getMaskMutable(),
- layout.apply(inversePermutation(map)));
+ addCandidate(maskOp.getMask(),
+ layout.apply(inversePermutation(map)));
}
}
- // If lhs, rhs layout is known, infer result layout.
- VectorLayoutInterface lhsLayout = getLayout(contract.getLhs());
- VectorLayoutInterface rhsLayout = getLayout(contract.getRhs());
+ // Uses first candidate for each operand; first-wins avoids
+ // combinatorial explosion over candidate pairings.
+ // TODO: Consider all candidate combinations with a cost model.
+ VectorLayoutInterface lhsLayout = getFirstCandidate(contract.getLhs());
+ VectorLayoutInterface rhsLayout = getFirstCandidate(contract.getRhs());
if (lhsLayout && rhsLayout) {
AffineMap lhsMap = contract.getIndexingMapsArray()[0];
AffineMap rhsMap = contract.getIndexingMapsArray()[1];
AffineMap resMap = contract.getIndexingMapsArray()[2];
VectorLayoutInterface resLayout = lhsLayout.getRecombinedLayout(
{lhsLayout, rhsLayout}, {lhsMap, rhsMap}, resMap);
- setLayoutIfUnset(contract.getResult(), resLayout);
+ addCandidate(contract.getResult(), resLayout);
}
continue;
}
}
if (auto gather = dyn_cast<vector::GatherOp>(user)) {
- setLayoutIfUnset(gather.getResult(), layout);
- continue;
- }
-
- if (auto write = dyn_cast<vector::TransferWriteOp>(user)) {
- if (!write.getMask()) {
- continue;
- }
- OpOperand &mask = write.getMaskMutable()[0];
- AffineMap maskMap =
- inversePermutation(compressUnusedDims(write.getPermutationMap()));
- setLayoutOrClone(&mask, layout.apply(maskMap));
+ addCandidate(gather.getResult(), layout);
continue;
}
if (auto shapeCast = dyn_cast<vector::ShapeCastOp>(user)) {
- setLayoutIfUnset(
- shapeCast.getResult(),
- layout.reshape(shapeCast.getResultVectorType().getShape()));
+ addCandidate(shapeCast.getResult(),
+ layout.reshape(shapeCast.getResultVectorType().getShape()));
continue;
}
}
}
-void LayoutInfo::propagateLayoutBackward(Value val) {
- LDBG() << "Propagating layout backward for value: " << val << "\n";
- VectorLayoutInterface layout = getLayout(val);
- if (auto blockArg = dyn_cast<BlockArgument>(val)) {
- Operation *parent = val.getParentBlock()->getParentOp();
- if (auto forOp = dyn_cast<scf::ForOp>(parent)) {
- OpOperand *yielded = forOp.getTiedLoopYieldedValue(blockArg);
- OpOperand *init = forOp.getTiedLoopInit(blockArg);
- setLayoutOrClone(yielded, layout);
- setLayoutOrClone(init, layout);
+//===----------------------------------------------------------------------===//
+// Resolve
+//===----------------------------------------------------------------------===//
+
+/// Pick first candidate for each value.
+void LayoutAnalysis::resolve() {
+ for (auto &[val, candidateSet] : candidates) {
+ if (!candidateSet.empty()) {
+ resolved[val] = candidateSet.front();
}
+ }
+}
+
+//===----------------------------------------------------------------------===//
+// Phase 2: Backward Fixup (mutates IR)
+//===----------------------------------------------------------------------===//
+
+/// Walk operations in reverse within a region, fixing up operand layouts.
+/// Ops are collected upfront so that newly inserted to_layout ops (from
+/// setLayoutOrClone) are not visited by the walk.
+void LayoutAnalysis::fixupRegion(Region ®ion) {
+ for (Block &block : region.getBlocks()) {
+ SmallVector<Operation *> ops;
+ for (Operation &op : llvm::reverse(block.getOperations())) {
+ ops.push_back(&op);
+ }
+ for (Operation *op : ops) {
+ fixupOp(op);
+ }
+ }
+}
+
+/// Fix up operand layouts for a single operation. Result layouts are fixed
+/// (from resolve); this determines what operand layouts should be.
+void LayoutAnalysis::fixupOp(Operation *op) {
+ // transfer_write: vector operand layout -> derive mask layout.
+ if (auto write = dyn_cast<vector::TransferWriteOp>(op)) {
+ VectorLayoutInterface layout = getResolvedLayout(write.getVector());
+ if (!write.getMask()) {
+ return;
+ }
+ AffineMap maskMap =
+ inversePermutation(compressUnusedDims(write.getPermutationMap()));
+ setLayoutOrClone(&write.getMaskMutable()[0], layout.apply(maskMap));
return;
}
- Operation *defOp = val.getDefiningOp();
- if (OpTrait::hasElementwiseMappableTraits(defOp)) {
- for (OpOperand &operand : defOp->getOpOperands()) {
+ // transfer_read: result layout -> derive mask layout.
+ if (auto read = dyn_cast<vector::TransferReadOp>(op)) {
+ VectorLayoutInterface layout = getResolvedLayout(read.getResult());
+ if (!read.getMask()) {
+ return;
+ }
+ AffineMap maskMap =
+ inversePermutation(compressUnusedDims(read.getPermutationMap()));
+ setLayoutOrClone(&read.getMaskMutable()[0], layout.apply(maskMap));
+ return;
+ }
+
+ // elementwise: result layout -> all operands get same layout.
+ if (OpTrait::hasElementwiseMappableTraits(op)) {
+ VectorLayoutInterface layout = getResolvedLayout(op->getResult(0));
+ for (OpOperand &operand : op->getOpOperands()) {
setLayoutOrClone(&operand, layout);
}
return;
}
- if (auto toLayout = dyn_cast<ToLayoutOp>(defOp)) {
+ // to_layout: result layout -> input gets same layout.
+ if (auto toLayout = dyn_cast<ToLayoutOp>(op)) {
+ VectorLayoutInterface layout = getResolvedLayout(toLayout.getResult());
setLayoutOrClone(&toLayout.getInputMutable(), layout);
return;
}
- if (auto multiReduce = dyn_cast<vector::MultiDimReductionOp>(defOp)) {
+ // multi_dim_reduction: result layout -> acc gets same layout.
+ if (auto multiReduce = dyn_cast<vector::MultiDimReductionOp>(op)) {
+ VectorLayoutInterface layout = getResolvedLayout(multiReduce.getResult());
setLayoutOrClone(&multiReduce.getAccMutable(), layout);
return;
}
- if (auto transpose = dyn_cast<vector::TransposeOp>(defOp)) {
+ // transpose: result layout -> input gets inverse-permuted layout.
+ if (auto transpose = dyn_cast<vector::TransposeOp>(op)) {
+ VectorLayoutInterface layout = getResolvedLayout(transpose.getResult());
setLayoutOrClone(
&transpose.getVectorMutable(),
layout.permute(invertPermutationVector(transpose.getPermutation())));
return;
}
- if (auto broadcast = dyn_cast<vector::BroadcastOp>(defOp)) {
- // Ensure that there are no broadcasted unit dims as we do not know how to
- // handle them as of now.
- assert(broadcast.computeBroadcastedUnitDims().empty() &&
- "Stretching in broadcasting not implemented yet.");
+ // broadcast: result layout -> source gets projected layout.
+ if (auto broadcast = dyn_cast<vector::BroadcastOp>(op)) {
+ VectorLayoutInterface layout = getResolvedLayout(broadcast.getResult());
if (!isa<VectorType>(broadcast.getSourceType())) {
return;
}
+ assert(broadcast.computeBroadcastedUnitDims().empty() &&
+ "Stretching in broadcasting not implemented yet.");
int64_t numBroadcastedDims =
broadcast.getResultVectorType().getRank() -
cast<VectorType>(broadcast.getSourceType()).getRank();
@@ -254,39 +384,27 @@
return;
}
- if (auto contract = dyn_cast<vector::ContractionOp>(defOp)) {
- // TODO: We could determine lhs/rhs layout if we know one of them, but
- // NYI for now.
+ // contract: result layout -> acc gets same layout.
+ if (auto contract = dyn_cast<vector::ContractionOp>(op)) {
+ VectorLayoutInterface layout = getResolvedLayout(contract.getResult());
setLayoutOrClone(&contract.getAccMutable(), layout);
return;
}
- if (auto gather = dyn_cast<vector::GatherOp>(defOp)) {
- OpOperand &indices = gather.getIndicesMutable();
- OpOperand &mask = gather.getMaskMutable();
- OpOperand &passthru = gather.getPassThruMutable();
- setLayoutOrClone(&indices, layout);
- setLayoutOrClone(&mask, layout);
- setLayoutOrClone(&passthru, layout);
+ // gather: result layout -> indices, mask, passthru get same layout.
+ if (auto gather = dyn_cast<vector::GatherOp>(op)) {
+ VectorLayoutInterface layout = getResolvedLayout(gather.getResult());
+ setLayoutOrClone(&gather.getIndicesMutable(), layout);
+ setLayoutOrClone(&gather.getMaskMutable(), layout);
+ setLayoutOrClone(&gather.getPassThruMutable(), layout);
return;
}
- if (auto read = dyn_cast<vector::TransferReadOp>(defOp)) {
- if (!read.getMask()) {
- return;
- }
- OpOperand &mask = read.getMaskMutable()[0];
- AffineMap maskMap =
- inversePermutation(compressUnusedDims(read.getPermutationMap()));
- setLayoutOrClone(&mask, layout.apply(maskMap));
- return;
- }
-
- if (auto gather = dyn_cast<TransferGatherOp>(defOp)) {
+ // transfer_gather: result layout -> index vecs + mask get projected layouts.
+ if (auto gather = dyn_cast<TransferGatherOp>(op)) {
+ VectorLayoutInterface layout = getResolvedLayout(gather.getResult());
SmallVector<AffineMap> maps = gather.getIndexingMapsArray();
int64_t numIndexVecs = gather.getIndexVecs().size();
- // Index vec maps are maps[1..numIndexVecs]. They only use dim exprs,
- // so strip symbols before applying to layout.
for (auto [i, operand] : llvm::enumerate(gather.getIndexVecsMutable())) {
AffineMap indexVecMap = maps[1 + i];
AffineMap projected =
@@ -304,28 +422,60 @@
return;
}
- if (auto shapeCast = dyn_cast<vector::ShapeCastOp>(defOp)) {
+ // shape_cast: result layout -> source gets reshaped layout.
+ if (auto shapeCast = dyn_cast<vector::ShapeCastOp>(op)) {
+ VectorLayoutInterface layout = getResolvedLayout(shapeCast.getResult());
setLayoutOrClone(
&shapeCast.getSourceMutable(),
layout.reshape(shapeCast.getSourceVectorType().getShape()));
return;
}
+
+ // scf.for: fix init_args/yield from result layouts, then recurse into body.
+ if (auto forOp = dyn_cast<scf::ForOp>(op)) {
+ auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
+ for (auto [i, result] : llvm::enumerate(forOp.getResults())) {
+ VectorLayoutInterface layout = getResolvedLayout(result);
+ setLayoutOrClone(&yieldOp->getOpOperand(i), layout);
+ setLayoutOrClone(&forOp.getInitArgsMutable()[i], layout);
+ }
+ fixupRegion(forOp.getBodyRegion());
+ return;
+ }
+
+ // Default: recurse into nested regions for ops we don't explicitly handle
+ // (e.g. scf.forall, scf.if, vector.mask).
+ for (Region ®ion : op->getRegions()) {
+ fixupRegion(region);
+ }
}
-void LayoutInfo::setLayoutOrClone(OpOperand *val,
- VectorLayoutInterface layout) {
+/// Assign a layout to an operand, cloning cheap ops or inserting conversions
+/// on conflict.
+void LayoutAnalysis::setLayoutOrClone(OpOperand *val,
+ VectorLayoutInterface layout) {
if (!layout) {
- // No layout to set.
return;
}
if (!isa<ShapedType>(val->get().getType())) {
- // Don't set layouts on non-shaped types. This would anyway be an empty
- // layout.
return;
}
- // Always clone constant like ops and set the layout on them.
+
+ // No layout yet -- assign.
+ if (!hasResolvedLayout(val->get())) {
+ resolved[val->get()] = layout;
+ return;
+ }
+
+ // Same layout -- nothing to do.
+ if (getResolvedLayout(val->get()) == layout) {
+ return;
+ }
+
+ // Different layout -- clone cheap ops or insert to_layout conversion.
OpBuilder b(val->getOwner());
if (Operation *defOp = val->get().getDefiningOp()) {
+ // Clone constant-like and duplicatable ops per use site.
bool isConstantLike = defOp->hasTrait<OpTrait::ConstantLike>();
bool isDuplicatable =
isa<vector::StepOp, vector::CreateMaskOp, vector::ConstantMaskOp>(
@@ -334,54 +484,39 @@
b.setInsertionPoint(defOp);
Operation *cloned = b.clone(*defOp);
val->set(cloned->getResult(0));
- layouts[cloned->getResult(0)] = layout;
+ resolved[cloned->getResult(0)] = layout;
return;
}
}
- if (!hasLayout(val->get())) {
- layouts[val->get()] = layout;
- forward.push(val->get());
- backward.push(val->get());
- return;
- }
-
- if (getLayout(val->get()) != layout) {
- // Create `to_layout` op to change layout if it's not the same as the
- // existing.
- Value v = val->get();
- Value layourtedV = ToLayoutOp::create(b, v.getLoc(), v, layout);
- val->set(layourtedV);
- layouts[layourtedV] = layout;
- return;
- }
+ // Non-cheap op -- insert to_layout conversion.
+ Value v = val->get();
+ Value converted = ToLayoutOp::create(b, v.getLoc(), v, layout);
+ val->set(converted);
+ resolved[converted] = layout;
}
-LogicalResult propagateVectorLayoutInfo(
+//===----------------------------------------------------------------------===//
+// Entry Point
+//===----------------------------------------------------------------------===//
+
+void propagateVectorLayoutInfo(
Operation *root, llvm::MapVector<Value, VectorLayoutInterface> &layouts) {
- LayoutInfo info;
- // Initialize propagation info with to_layout operations;
- root->walk([&](ToLayoutOp toLayout) {
- LDBG() << "Initializing layout from to_layout op: " << toLayout << "\n";
- info.setLayoutIfUnset(toLayout.getResult(), toLayout.getLayout());
- });
- // Propagate all layout information until fixpoint. Give priority to
- // forward propagation and only do backward propagation when there is no
- // forward propagation work left.
- while (!info.forward.empty() || !info.backward.empty()) {
- SmallVector<Value> changed;
- if (!info.forward.empty()) {
- Value val = info.forward.front();
- info.forward.pop();
- info.propagateLayoutForward(val);
- } else {
- Value val = info.backward.front();
- info.backward.pop();
- info.propagateLayoutBackward(val);
- }
+ LayoutAnalysis analysis;
+
+ // Phase 1: Seed anchors and forward propagation (no IR mutation).
+ analysis.seed(root);
+ analysis.runForward();
+
+ // Resolve: pick first candidate for each value.
+ analysis.resolve();
+
+ // Phase 2: Backward fixup (mutates IR).
+ for (Region ®ion : root->getRegions()) {
+ analysis.fixupRegion(region);
}
- layouts = std::move(info.layouts);
- return success();
+
+ layouts = std::move(analysis.resolved);
}
#define GEN_PASS_DEF_TESTVECTORLAYOUTANALYSISPASS
@@ -392,10 +527,7 @@
void runOnOperation() override {
Operation *root = getOperation();
llvm::MapVector<Value, VectorLayoutInterface> layouts;
- if (failed(propagateVectorLayoutInfo(root, layouts))) {
- root->emitError("Layout Analysis Failed");
- return signalPassFailure();
- }
+ propagateVectorLayoutInfo(root, layouts);
root->walk([&](Operation *op) {
if (isa<ToLayoutOp>(op)) {
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/vector_layout_analysis.mlir b/compiler/src/iree/compiler/Codegen/Common/test/vector_layout_analysis.mlir
index af3e4fc..15f8732 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/vector_layout_analysis.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/vector_layout_analysis.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt -iree-codegen-test-vector-layout-analysis --split-input-file %s --verify-diagnostics
+// RUN: iree-opt --pass-pipeline="builtin.module(any(iree-codegen-test-vector-layout-analysis))" --split-input-file %s --verify-diagnostics
#layout = #iree_vector_ext.nested_layout<
subgroup_tile = [1, 1],
@@ -572,17 +572,20 @@
thread_strides = [0]
>
+// The to_layout anchors are placed after the arithmetic so that forward
+// propagation seeds from them and backward fixup assigns the operand layouts.
func.func @handle_multiuse_step(%lhs: vector<64xindex>, %rhs: vector<64xindex>) -> (vector<64xindex>, vector<64xindex>) {
- %l_lhs = iree_vector_ext.to_layout %lhs to layout(#layoutA) : vector<64xindex>
- %r_lhs = iree_vector_ext.to_layout %rhs to layout(#layoutB) : vector<64xindex>
+ // Two remarks: vector.step is cloned (one per use site with different layout).
%cst = vector.step : vector<64xindex>
- // expected-remark @above {{element_tile = [1]}}
// expected-remark @above {{element_tile = [64]}}
+ // expected-remark @above {{element_tile = [1]}}
%scaled_lhs = arith.muli %cst, %lhs : vector<64xindex>
// expected-remark @above {{element_tile = [64]}}
+ %l = iree_vector_ext.to_layout %scaled_lhs to layout(#layoutA) : vector<64xindex>
%scaled_rhs = arith.muli %cst, %rhs : vector<64xindex>
// expected-remark @above {{element_tile = [1]}}
- func.return %scaled_lhs, %scaled_rhs : vector<64xindex>, vector<64xindex>
+ %r = iree_vector_ext.to_layout %scaled_rhs to layout(#layoutB) : vector<64xindex>
+ func.return %l, %r : vector<64xindex>, vector<64xindex>
}
// -----
@@ -809,3 +812,122 @@
// expected-remark @above {{subgroup_tile = [2], batch_tile = [4], outer_tile = [1], thread_tile = [4], element_tile = [4], subgroup_strides = [1], thread_strides = [1]}}
func.return %reshape : vector<128xf16>
}
+
+// -----
+
+// Test multi-candidate accumulation with first-wins resolution.
+// %c receives candidates from both %a (layoutA) and %b (layoutB).
+// First-wins means %c should get layoutA (first candidate in walk order).
+
+#layoutA = #iree_vector_ext.nested_layout<
+ subgroup_tile = [1, 1],
+ batch_tile = [1, 1],
+ outer_tile = [1, 1],
+ thread_tile = [1, 1],
+ element_tile = [16, 16],
+
+ subgroup_strides = [0, 0],
+ thread_strides = [0, 0]
+>
+
+#layoutB = #iree_vector_ext.nested_layout<
+ subgroup_tile = [1, 1],
+ batch_tile = [2, 2],
+ outer_tile = [1, 1],
+ thread_tile = [1, 1],
+ element_tile = [8, 8],
+
+ subgroup_strides = [0, 0],
+ thread_strides = [0, 0]
+>
+
+func.func @multi_anchor_first_wins(%input: vector<16x16xf16>) -> vector<16x16xf16> {
+ %a = iree_vector_ext.to_layout %input to layout(#layoutA) : vector<16x16xf16>
+ %b = iree_vector_ext.to_layout %input to layout(#layoutB) : vector<16x16xf16>
+ // %c gets layoutA from %a (first anchor in walk order) via forward propagation.
+ %c = arith.addf %a, %b : vector<16x16xf16>
+ // expected-remark @above {{element_tile = [16, 16]}}
+ func.return %c : vector<16x16xf16>
+}
+
+// -----
+
+// Test backward fixup for scf.for: to_layout anchor after the loop causes
+// result layout to flow to init_args and yield via Phase 2 backward fixup.
+
+#layout = #iree_vector_ext.nested_layout<
+ subgroup_tile = [1, 1],
+ batch_tile = [1, 1],
+ outer_tile = [1, 1],
+ thread_tile = [1, 1],
+ element_tile = [16, 16],
+
+ subgroup_strides = [0, 0],
+ thread_strides = [0, 0]
+>
+
+func.func @scffor_backward_fixup(%arr: memref<16x16xf16>, %init: vector<16x16xf16>) -> vector<16x16xf16> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c10 = arith.constant 10 : index
+ %cst_0 = arith.constant 0.0 : f16
+ %out = scf.for %iv = %c0 to %c10 step %c1 iter_args(%arg1 = %init) -> (vector<16x16xf16>) {
+ // expected-remark @above {{element_tile = [16, 16]}}
+ %val = vector.transfer_read %arr[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16>
+ // expected-remark @above {{element_tile = [16, 16]}}
+ %sum = arith.addf %arg1, %val : vector<16x16xf16>
+ // expected-remark @above {{element_tile = [16, 16]}}
+ scf.yield %sum : vector<16x16xf16>
+ }
+ %outl = iree_vector_ext.to_layout %out to layout(#layout) : vector<16x16xf16>
+ func.return %outl : vector<16x16xf16>
+}
+
+// -----
+
+// Test that a shared create_mask gets cloned per use site when two
+// transfer_write ops have different data layouts (and thus need different
+// mask layouts). The backward fixup processes transfer_write ops in reverse
+// program order, deriving mask layouts from the vector operand's resolved
+// layout. The second write assigns the mask layout; the first write finds
+// a conflict and clones the create_mask.
+
+#layoutA = #iree_vector_ext.nested_layout<
+ subgroup_tile = [1, 1],
+ batch_tile = [1, 1],
+ outer_tile = [1, 1],
+ thread_tile = [1, 1],
+ element_tile = [16, 16],
+
+ subgroup_strides = [0, 0],
+ thread_strides = [0, 0]
+>
+
+#layoutB = #iree_vector_ext.nested_layout<
+ subgroup_tile = [1, 1],
+ batch_tile = [2, 2],
+ outer_tile = [1, 1],
+ thread_tile = [1, 1],
+ element_tile = [8, 8],
+
+ subgroup_strides = [0, 0],
+ thread_strides = [0, 0]
+>
+
+func.func @clone_shared_mask_on_layout_conflict(
+ %arr: memref<16x16xf16>,
+ %a: vector<16x16xf16>,
+ %b: vector<16x16xf16>) {
+ %c0 = arith.constant 0 : index
+ %c12 = arith.constant 12 : index
+ // The mask gets cloned per transfer_write use site, each with the
+ // correct layout derived from the written value's layout.
+ %mask = vector.create_mask %c12, %c12 : vector<16x16xi1>
+ // expected-remark @above {{element_tile = [8, 8]}}
+ // expected-remark @above {{element_tile = [16, 16]}}
+ %al = iree_vector_ext.to_layout %a to layout(#layoutA) : vector<16x16xf16>
+ %bl = iree_vector_ext.to_layout %b to layout(#layoutB) : vector<16x16xf16>
+ vector.transfer_write %al, %arr[%c0, %c0], %mask {in_bounds = [true, true]} : vector<16x16xf16>, memref<16x16xf16>
+ vector.transfer_write %bl, %arr[%c0, %c0], %mask {in_bounds = [true, true]} : vector<16x16xf16>, memref<16x16xf16>
+ func.return
+}