[LinalgExt] Scatter fusion by expansion 3/3 (#19588)
Implements fusion with reshapes by expansion for `LinalgExt::ScatterOp`.
See main issue https://github.com/iree-org/iree/issues/19091
---------
Signed-off-by: Ian Wood <ianwood2024@u.northwestern.edu>
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ReshapeFusion.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ReshapeFusion.cpp
index 901288b..6ad28e5 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ReshapeFusion.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ReshapeFusion.cpp
@@ -14,14 +14,50 @@
#include "iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/PatternMatch.h"
namespace mlir::iree_compiler::IREE::LinalgExt {
+static bool
+isIdentityReassoc(const SmallVector<ReassociationIndices> &indices) {
+ for (auto &index : indices) {
+ if (index.size() != 1)
+ return false;
+ }
+ return true;
+};
+
+static SmallVector<ReassociationIndices>
+computeReassocFromShapeMap(ArrayRef<SmallVector<int64_t>> shapeMap) {
+ SmallVector<ReassociationIndices> reassoc;
+ int64_t dimCount = 0;
+ for (auto &shape : shapeMap) {
+ reassoc.emplace_back(
+ llvm::to_vector(llvm::seq<int64_t>(dimCount, dimCount + shape.size())));
+ dimCount += shape.size();
+ }
+ return reassoc;
+}
+
namespace {
+/// Helper class that supports fusing reshapes with operands when not all of the
+/// shape dims map to the iteration space.
+struct ReshapeOperandInfo {
+ static constexpr int64_t kNoMapping = -1;
+
+ // Original shape of this operand.
+ ArrayRef<int64_t> originalShape;
+
+ // Similar to the results of the operand's `AffineMap` except `kNoMapping` if
+ // that dim doesn't map to the iteration space. For example, the indexed
+ // dimensions in a LinalgExt::ScatterOp.
+ SmallVector<int64_t> operandToIterationSpace;
+};
+
/// Information needed to expand an operation to fold the reshape with
/// it.
class ExpansionInfo {
@@ -30,32 +66,78 @@
// of the expanded op given the `indexingMap` of the fused operand/result of
// the op, the `reassocationMaps` of the reshape op and the shape of
// the expanded op.
- template <typename OpTy>
- LogicalResult compute(OpTy op, OpOperand *fusableOpOperand,
- ArrayRef<AffineMap> reassociationMaps,
- ArrayRef<int64_t> expandedShape,
- ArrayRef<int64_t> collapsedShape,
- PatternRewriter &rewriter);
- unsigned getOrigOpNumDims() const { return reassociation.size(); }
- unsigned getExpandedOpNumDims() const { return expandedOpNumDims; }
- ReassociationIndicesRef getExpandedDims(unsigned i) const {
- return reassociation[i];
+ LogicalResult compute(SmallVector<ReshapeOperandInfo> infos,
+ SmallVector<int64_t> loopRanges,
+ OpOperand *fusableOpOperand,
+ ArrayRef<ReassociationIndices> operandReassoc,
+ ArrayRef<int64_t> expandedShape);
+
+ std::optional<Value> getOrCreateExpanded(Location loc, OpOperand *operand,
+ RewriterBase &rewriter) {
+ auto shapeMap = this->getShapeMap(operand);
+ auto reassoc = computeReassocFromShapeMap(shapeMap);
+ if (isIdentityReassoc(reassoc)) {
+ return operand->get();
+ }
+ SmallVector<int64_t> flattenedArray;
+ for (auto &shape : shapeMap) {
+ flattenedArray.append(shape.begin(), shape.end());
+ }
+ auto oldType = cast<ShapedType>(operand->get().getType());
+ auto newType =
+ RankedTensorType::get(flattenedArray, oldType.getElementType());
+ if (failed(reshapeLikeShapesAreCompatible(
+ [&](const Twine &msg) {
+ return rewriter.notifyMatchFailure(loc, msg);
+ },
+ oldType.getShape(), newType.getShape(), reassoc,
+ /*isExpandingReshape=*/true))) {
+ return {};
+ }
+ return rewriter.create<tensor::ExpandShapeOp>(loc, newType, operand->get(),
+ reassoc);
+ };
+
+ /// Get the shape map for the operand.
+ SmallVector<SmallVector<int64_t>> getShapeMap(OpOperand *operand) const {
+ auto info = reshapeInfos[operand->getOperandNumber()];
+ SmallVector<SmallVector<int64_t>> shapeMap;
+ for (auto [operandIdx, loopIdx] :
+ llvm::enumerate(info.operandToIterationSpace)) {
+ if (loopIdx == ReshapeOperandInfo::kNoMapping) {
+ shapeMap.push_back(
+ SmallVector<int64_t>{info.originalShape[operandIdx]});
+ } else {
+ shapeMap.push_back(loopShapeMap[loopIdx]);
+ }
+ }
+ return shapeMap;
}
- ArrayRef<int64_t> getExpandedShapeOfDim(unsigned i) const {
- return expandedShapeMap[i];
+
+ SmallVector<ReassociationIndices> getReassoc(OpOperand *operand) const {
+ auto shapeMap = this->getShapeMap(operand);
+ return computeReassocFromShapeMap(shapeMap);
}
- ArrayRef<int64_t> getOriginalShape() const { return originalLoopExtent; }
+
+ unsigned getOrigNumLoops() const { return loopReassoc.size(); }
+ unsigned getExpandedNumLoops() const { return expandedOpNumDims; }
+ ReassociationIndicesRef getExpandedLoops(unsigned i) const {
+ return loopReassoc[i];
+ }
+ ArrayRef<int64_t> getExpandedShapeOfLoop(unsigned i) const {
+ return loopShapeMap[i];
+ }
private:
- /// Reassociation from the dimensions in the original operation to the
- /// dimension of the expanded operation.
- SmallVector<ReassociationIndices> reassociation;
+ /// Extent of the iteration space in the original operation.
+ SmallVector<int64_t> loopRanges;
+ SmallVector<ReassociationIndices> loopReassoc;
/// Mapping from extent of loops in the original operation, to the extent of
/// loops in the expanded operation.
- SmallVector<SmallVector<int64_t>> expandedShapeMap;
- /// Extent of the loop in the original operation.
- SmallVector<int64_t> originalLoopExtent;
+ SmallVector<SmallVector<int64_t>> loopShapeMap;
unsigned expandedOpNumDims;
+ /// Info about the reassociation and original shape for each operand.
+ SmallVector<ReshapeOperandInfo> reshapeInfos;
};
class CollapsingInfo {
@@ -109,50 +191,46 @@
} // namespace
-template <typename OpTy>
-LogicalResult ExpansionInfo::compute(OpTy op, OpOperand *fusableOpOperand,
- ArrayRef<AffineMap> reassociationMaps,
- ArrayRef<int64_t> expandedShape,
- ArrayRef<int64_t> collapsedShape,
- PatternRewriter &rewriter) {
- if (reassociationMaps.empty())
+LogicalResult ExpansionInfo::compute(
+ SmallVector<ReshapeOperandInfo> infos, SmallVector<int64_t> loopRanges,
+ OpOperand *fusableOpOperand, ArrayRef<ReassociationIndices> operandReassoc,
+ ArrayRef<int64_t> expandedShape) {
+ if (operandReassoc.empty())
return failure();
- AffineMap fusedIndexMap = op.getMatchingIndexingMap(fusableOpOperand);
- FailureOr<SmallVector<int64_t>> originalLoopRange = op.getStaticLoopRanges();
- if (failed(originalLoopRange)) {
+
+ int64_t operandNum = fusableOpOperand->getOperandNumber();
+ ReshapeOperandInfo &fusionOperandInfo = infos[operandNum];
+ this->loopShapeMap.clear();
+ this->loopShapeMap.resize(loopRanges.size());
+ for (auto [operandIdx, loopIdx] :
+ llvm::enumerate(fusionOperandInfo.operandToIterationSpace)) {
+ if (loopIdx == ReshapeOperandInfo::kNoMapping) {
+ continue;
+ }
+
+ // Compute the shape map at element `loopIdx`
+ ReassociationIndicesRef indices = operandReassoc[operandIdx];
+ for (auto [dimIdx, shapeIdx] : llvm::enumerate(indices)) {
+ this->loopShapeMap[loopIdx].push_back(expandedShape[shapeIdx]);
+ }
+ }
+
+ // Fill in the remaining elements with `loopRanges`
+ this->expandedOpNumDims = 0;
+ for (const auto &[loopIdx, shapeMap] : llvm::enumerate(this->loopShapeMap)) {
+ if (shapeMap.empty()) {
+ this->loopShapeMap[loopIdx] = SmallVector<int64_t>{loopRanges[loopIdx]};
+ }
+ this->expandedOpNumDims += shapeMap.size();
+ }
+
+ if (llvm::all_of(this->loopShapeMap,
+ [&](auto vec) { return vec.size() == 1; })) {
return failure();
}
- originalLoopExtent.assign(originalLoopRange->begin(),
- originalLoopRange->end());
-
- reassociation.clear();
- expandedShapeMap.clear();
- // Compute the number of dimension in the expanded op that correspond to each
- // dimension of the original op.
- SmallVector<unsigned> numExpandedDims(fusedIndexMap.getNumDims(), 1);
- expandedShapeMap.resize(fusedIndexMap.getNumDims());
- for (const auto &resultExpr : llvm::enumerate(fusedIndexMap.getResults())) {
- unsigned pos = cast<AffineDimExpr>(resultExpr.value()).getPosition();
- AffineMap foldedDims = reassociationMaps[resultExpr.index()];
- numExpandedDims[pos] = foldedDims.getNumResults();
- ArrayRef<int64_t> shape =
- expandedShape.slice(foldedDims.getDimPosition(0), numExpandedDims[pos]);
- expandedShapeMap[pos].assign(shape.begin(), shape.end());
- }
- // The remaining dimensions remain the same.
- for (unsigned i : llvm::seq<unsigned>(0, fusedIndexMap.getNumDims()))
- if (expandedShapeMap[i].empty())
- expandedShapeMap[i] = {originalLoopExtent[i]};
-
- // Compute reassociation map from the original op to the expanded op.
- unsigned sum = 0;
- reassociation.reserve(fusedIndexMap.getNumDims());
- for (const auto &numFoldedDim : llvm::enumerate(numExpandedDims)) {
- auto seq = llvm::seq<int64_t>(sum, sum + numFoldedDim.value());
- reassociation.emplace_back(seq.begin(), seq.end());
- sum += numFoldedDim.value();
- }
- expandedOpNumDims = sum;
+ this->loopReassoc = computeReassocFromShapeMap(this->loopShapeMap);
+ this->reshapeInfos = std::move(infos);
+ this->loopRanges = std::move(loopRanges);
return success();
}
@@ -201,6 +279,77 @@
return success();
}
+static SmallVector<ReshapeOperandInfo>
+getAttentionReshapeInfo(LinalgExt::AttentionOp attentionOp) {
+ return llvm::map_to_vector(
+ attentionOp->getOpOperands(), [&](OpOperand &opOperand) {
+ ReshapeOperandInfo operandInfo;
+ auto operandType = dyn_cast<ShapedType>(opOperand.get().getType());
+ if (!operandType) {
+ assert(
+ attentionOp.getMatchingIndexingMap(&opOperand).getNumResults() ==
+ 0 &&
+ "expected non-shaped type to have no results in indexing map");
+ return operandInfo;
+ }
+
+ operandInfo.originalShape = operandType.getShape();
+ for (auto result :
+ attentionOp.getMatchingIndexingMap(&opOperand).getResults()) {
+ operandInfo.operandToIterationSpace.push_back(
+ cast<AffineDimExpr>(result).getPosition());
+ }
+ return operandInfo;
+ });
+}
+
+static SmallVector<ReshapeOperandInfo>
+getScatterReshapeInfo(LinalgExt::ScatterOp scatterOp) {
+ SmallVector<ReshapeOperandInfo> infos;
+ auto rankOfContiguousSlice =
+ scatterOp.getOriginalType().getRank() - scatterOp.getIndexDepth();
+ auto updateRank = scatterOp.getUpdateType().getRank();
+
+ // Operand #0 Updates
+ {
+ ReshapeOperandInfo updateInfo;
+ updateInfo.originalShape = scatterOp.getUpdateType().getShape();
+ llvm::append_range(updateInfo.operandToIterationSpace,
+ llvm::seq<int64_t>(0, scatterOp.getBatchRank()));
+ updateInfo.operandToIterationSpace.append(
+ updateRank - (rankOfContiguousSlice + scatterOp.getBatchRank()),
+ ReshapeOperandInfo::kNoMapping);
+ llvm::append_range(
+ updateInfo.operandToIterationSpace,
+ llvm::seq(updateRank - rankOfContiguousSlice, updateRank));
+ infos.push_back(std::move(updateInfo));
+ }
+
+ // Operand#1 Indices
+ {
+ ReshapeOperandInfo indicesInfo;
+ indicesInfo.originalShape = scatterOp.getIndicesType().getShape();
+ llvm::append_range(indicesInfo.operandToIterationSpace,
+ llvm::seq<int64_t>(0, scatterOp.getBatchRank()));
+ indicesInfo.operandToIterationSpace.push_back(
+ ReshapeOperandInfo::kNoMapping);
+ infos.push_back(std::move(indicesInfo));
+ }
+
+ // Operand #2 Original
+ {
+ ReshapeOperandInfo originalInfo;
+ originalInfo.originalShape = scatterOp.getOriginalType().getShape();
+ originalInfo.operandToIterationSpace.append(scatterOp.getIndexDepth(),
+ ReshapeOperandInfo::kNoMapping);
+ llvm::append_range(
+ originalInfo.operandToIterationSpace,
+ llvm::seq(updateRank - rankOfContiguousSlice, updateRank));
+ infos.push_back(std::move(originalInfo));
+ }
+ return infos;
+};
+
static AffineMap
getIndexingMapInExpandedOp(OpBuilder &builder, AffineMap indexingMap,
const ExpansionInfo &expansionInfo) {
@@ -208,46 +357,17 @@
for (AffineExpr expr : indexingMap.getResults()) {
unsigned pos = cast<AffineDimExpr>(expr).getPosition();
auto expandedExprs = llvm::to_vector_of<AffineExpr, 6>(
- llvm::map_range(expansionInfo.getExpandedDims(pos), [&](int64_t v) {
+ llvm::map_range(expansionInfo.getExpandedLoops(pos), [&](int64_t v) {
return builder.getAffineDimExpr(static_cast<unsigned>(v));
}));
newExprs.append(expandedExprs.begin(), expandedExprs.end());
}
- return AffineMap::get(expansionInfo.getExpandedOpNumDims(),
+ return AffineMap::get(expansionInfo.getExpandedNumLoops(),
indexingMap.getNumSymbols(), newExprs,
builder.getContext());
}
-static RankedTensorType getExpandedType(RankedTensorType originalType,
- AffineMap indexingMap,
- const ExpansionInfo &expansionInfo) {
- SmallVector<int64_t> expandedShape;
- for (AffineExpr expr : indexingMap.getResults()) {
- unsigned dim = cast<AffineDimExpr>(expr).getPosition();
- auto dimExpansion = expansionInfo.getExpandedShapeOfDim(dim);
- expandedShape.append(dimExpansion.begin(), dimExpansion.end());
- }
- return RankedTensorType::get(expandedShape, originalType.getElementType());
-}
-
-static SmallVector<ReassociationIndices>
-getReassociationForExpansion(AffineMap indexingMap,
- const ExpansionInfo &expansionInfo) {
- SmallVector<ReassociationIndices> reassociation;
- unsigned numReshapeDims = 0;
- for (AffineExpr expr : indexingMap.getResults()) {
- unsigned dim = cast<AffineDimExpr>(expr).getPosition();
- auto numExpandedDims = expansionInfo.getExpandedDims(dim).size();
- SmallVector<int64_t, 2> indices = llvm::to_vector<2>(
- llvm::seq<int64_t>(numReshapeDims, numReshapeDims + numExpandedDims));
- reassociation.emplace_back(std::move(indices));
- numReshapeDims += numExpandedDims;
- }
- return reassociation;
-}
-
-template <typename OpTy>
-static bool isFusableWithReshapeByDimExpansion(OpTy op,
+static bool isFusableWithReshapeByDimExpansion(AttentionOp op,
OpOperand *fusableOpOperand) {
// Is fusable only if:
// - All the indexing maps for operands and results are projected
@@ -256,10 +376,11 @@
// - All the loops for the reshaped operand are parallel loops.
SmallVector<utils::IteratorType> iteratorTypes = op.getLoopIteratorTypes();
AffineMap operandMap = op.getMatchingIndexingMap(fusableOpOperand);
- return op.hasPureTensorSemantics() &&
- llvm::all_of(
- op.getIndexingMapsArray(),
- [](AffineMap map) { return map.isProjectedPermutation(); }) &&
+ return operandMap && op.hasPureTensorSemantics() &&
+ llvm::all_of(op.getIndexingMapsArray(),
+ [](AffineMap map) {
+ return map && map.isProjectedPermutation();
+ }) &&
operandMap.getNumResults() > 0;
}
@@ -277,16 +398,13 @@
RankedTensorType expandedType = isExpanding
? expandingReshapeOp.getResultType()
: collapsingReshapeOp.getSrcType();
- RankedTensorType collapsedType = isExpanding
- ? expandingReshapeOp.getSrcType()
- : collapsingReshapeOp.getResultType();
-
ExpansionInfo expansionInfo;
if (failed(expansionInfo.compute(
- attentionOp, fusableOpOperand,
- isExpanding ? expandingReshapeOp.getReassociationMaps()
- : collapsingReshapeOp.getReassociationMaps(),
- expandedType.getShape(), collapsedType.getShape(), rewriter)))
+ getAttentionReshapeInfo(attentionOp),
+ attentionOp.getStaticLoopRanges().value(), fusableOpOperand,
+ isExpanding ? expandingReshapeOp.getReassociationIndices()
+ : collapsingReshapeOp.getReassociationIndices(),
+ expandedType.getShape())))
return std::nullopt;
auto expandedOpIndexingMaps = llvm::to_vector_of<AffineMap, 6>(
llvm::map_range(attentionOp.getIndexingMapsArray(), [&](AffineMap m) {
@@ -305,54 +423,23 @@
: collapsingReshapeOp.getSrc());
continue;
}
- if (auto opOperandType =
- dyn_cast<RankedTensorType>(opOperand->get().getType())) {
- AffineMap indexingMap = attentionOp.getMatchingIndexingMap(opOperand);
- RankedTensorType expandedOperandType =
- getExpandedType(opOperandType, indexingMap, expansionInfo);
- if (expandedOperandType != opOperand->get().getType()) {
- // Reshape the operand to get the right type.
- SmallVector<ReassociationIndices> reassociation =
- getReassociationForExpansion(indexingMap, expansionInfo);
- if (failed(reshapeLikeShapesAreCompatible(
- [&](const Twine &msg) {
- return rewriter.notifyMatchFailure(attentionOp, msg);
- },
- opOperandType.getShape(), expandedOperandType.getShape(),
- reassociation,
- /*isExpandingReshape=*/true)))
- return std::nullopt;
- expandedOpOperands.push_back(rewriter.create<tensor::ExpandShapeOp>(
- loc, expandedOperandType, opOperand->get(), reassociation));
- continue;
- }
- }
- expandedOpOperands.push_back(opOperand->get());
+ // Reshape the operand to get the right type.
+ std::optional<Value> expanded =
+ expansionInfo.getOrCreateExpanded(loc, opOperand, rewriter);
+ if (!expanded)
+ return std::nullopt;
+ expandedOpOperands.push_back(*expanded);
+ continue;
}
Value output;
OpOperand &outOperand = attentionOp.getOutputMutable();
- AffineMap indexingMap = attentionOp.getMatchingIndexingMap(&outOperand);
- auto opOperandType = cast<RankedTensorType>(outOperand.get().getType());
- RankedTensorType expandedOutputType =
- getExpandedType(opOperandType, indexingMap, expansionInfo);
- if (expandedOutputType != outOperand.get().getType()) {
- SmallVector<ReassociationIndices> reassociation =
- getReassociationForExpansion(indexingMap, expansionInfo);
- if (failed(reshapeLikeShapesAreCompatible(
- [&](const Twine &msg) {
- return rewriter.notifyMatchFailure(attentionOp, msg);
- },
- opOperandType.getShape(), expandedOutputType.getShape(),
- reassociation,
- /*isExpandingReshape=*/true)))
- return std::nullopt;
- output = rewriter.create<tensor::ExpandShapeOp>(
- loc, expandedOutputType, outOperand.get(), reassociation);
- } else {
- output = outOperand.get();
- }
+ std::optional<Value> maybeOutput =
+ expansionInfo.getOrCreateExpanded(loc, &outOperand, rewriter);
+ if (!maybeOutput)
+ return std::nullopt;
+ output = *maybeOutput;
Value maskOperand;
if (expandedOpOperands.size() > 4) {
@@ -377,9 +464,7 @@
int64_t resultNumber = opResult.getResultNumber();
if (resultTypes[resultNumber] != opResult.getType()) {
SmallVector<ReassociationIndices> reassociation =
- getReassociationForExpansion(
- attentionOp.getIndexingMapsForResults()[resultNumber],
- expansionInfo);
+ expansionInfo.getReassoc(attentionOp.getTiedOpOperand(opResult));
resultVals.push_back(rewriter.create<tensor::CollapseShapeOp>(
attentionOp.getLoc(), opResult.getType(),
fusedOp->getResult(resultNumber), reassociation));
@@ -391,6 +476,64 @@
return resultVals;
}
+static std::optional<Value>
+fuseScatterWithReshapeByExpansion(ScatterOp scatterOp, Operation *reshapeOp,
+ OpOperand *fusableOpOperand,
+ PatternRewriter &rewriter) {
+ Location loc = scatterOp.getLoc();
+ // Check if reshape is expanding or collapsing.
+ auto expandingReshapeOp = dyn_cast<tensor::ExpandShapeOp>(*reshapeOp);
+ auto collapsingReshapeOp = dyn_cast<tensor::CollapseShapeOp>(*reshapeOp);
+ bool isExpanding = (expandingReshapeOp != nullptr);
+ RankedTensorType expandedType = isExpanding
+ ? expandingReshapeOp.getResultType()
+ : collapsingReshapeOp.getSrcType();
+ ExpansionInfo info;
+ if (failed(info.compute(
+ getScatterReshapeInfo(scatterOp),
+ scatterOp.getStaticLoopRanges().value(), fusableOpOperand,
+ isExpanding ? expandingReshapeOp.getReassociationIndices()
+ : collapsingReshapeOp.getReassociationIndices(),
+ expandedType.getShape()))) {
+ return std::nullopt;
+ }
+
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(scatterOp);
+
+ OpOperand *update = scatterOp.getDpsInputOperand(0);
+ OpOperand *indices = scatterOp.getDpsInputOperand(1);
+ OpOperand *original = scatterOp.getDpsInitOperand(0);
+ auto newUpdates = info.getOrCreateExpanded(loc, update, rewriter).value();
+ auto newIndices = info.getOrCreateExpanded(loc, indices, rewriter).value();
+ auto newOriginal = info.getOrCreateExpanded(loc, original, rewriter).value();
+
+ auto newScatter = rewriter.create<ScatterOp>(
+ loc, newOriginal.getType(), ValueRange{newUpdates, newIndices},
+ ValueRange{newOriginal}, scatterOp.getDimensionMap(),
+ scatterOp.getUniqueIndices());
+ rewriter.inlineRegionBefore(scatterOp.getRegion(), newScatter.getRegion(),
+ newScatter.getRegion().begin());
+
+ auto originalShapeMap = info.getShapeMap(original);
+ SmallVector<ReassociationIndices> originalReassoc =
+ computeReassocFromShapeMap(originalShapeMap);
+
+ // Collapse back to original shape.
+ if (isIdentityReassoc(originalReassoc)) {
+ return {newScatter.getResult(0)};
+ }
+ auto newCollapse = rewriter.create<tensor::CollapseShapeOp>(
+ loc, scatterOp.getOriginalType(), newScatter.getResult(0),
+ originalReassoc);
+
+ return {newCollapse};
+}
+
+//===----------------------------------------------------------------------===//
+// Fuse By Expansion Patterns
+//===----------------------------------------------------------------------===//
+
namespace {
// Fold attention with its consumer expand_shape op.
@@ -553,6 +696,74 @@
linalg::ControlDropUnitDims options;
};
+struct FoldScatterWithProducerReshapeByExpansion final
+ : public OpRewritePattern<ScatterOp> {
+ FoldScatterWithProducerReshapeByExpansion(
+ MLIRContext *context, linalg::ControlFusionFn controlFoldingReshapes,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<ScatterOp>(context, benefit),
+ controlFoldingReshapes(std::move(controlFoldingReshapes)) {}
+
+ LogicalResult matchAndRewrite(ScatterOp scatterOp,
+ PatternRewriter &rewriter) const override {
+ for (OpOperand &opOperand : scatterOp->getOpOperands()) {
+ tensor::CollapseShapeOp reshapeOp =
+ opOperand.get().getDefiningOp<tensor::CollapseShapeOp>();
+ if (!reshapeOp)
+ continue;
+ if (!controlFoldingReshapes(&opOperand))
+ continue;
+
+ std::optional<Value> replacementValue = fuseScatterWithReshapeByExpansion(
+ scatterOp, reshapeOp, &opOperand, rewriter);
+ if (!replacementValue)
+ return failure();
+ rewriter.replaceOp(scatterOp, *replacementValue);
+ return success();
+ }
+ return failure();
+ }
+
+ linalg::ControlFusionFn controlFoldingReshapes;
+};
+
+struct FoldScatterWithConsumerReshapeByExpansion final
+ : public OpRewritePattern<tensor::ExpandShapeOp> {
+ FoldScatterWithConsumerReshapeByExpansion(
+ MLIRContext *context, linalg::ControlFusionFn controlFoldingReshapes,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<tensor::ExpandShapeOp>(context, benefit),
+ controlFoldingReshapes(std::move(controlFoldingReshapes)) {}
+
+ LogicalResult matchAndRewrite(tensor::ExpandShapeOp expandOp,
+ PatternRewriter &rewriter) const override {
+ auto producerResult = dyn_cast<OpResult>(expandOp.getSrc());
+ if (!producerResult) {
+ return rewriter.notifyMatchFailure(expandOp,
+ "source not produced by an operation");
+ }
+
+ auto scatterOp = producerResult.getDefiningOp<LinalgExt::ScatterOp>();
+ if (!scatterOp) {
+ return failure();
+ }
+
+ if (!controlFoldingReshapes(&expandOp.getSrcMutable())) {
+ return failure();
+ }
+
+ std::optional<Value> replacementValue = fuseScatterWithReshapeByExpansion(
+ scatterOp, expandOp, scatterOp.getTiedOpOperand(producerResult),
+ rewriter);
+ if (!replacementValue)
+ return failure();
+ rewriter.replaceOp(scatterOp, *replacementValue);
+ return success();
+ }
+
+ linalg::ControlFusionFn controlFoldingReshapes;
+};
+
} // namespace
/// Return the `reassociation` indices to use to collapse the operand when the
@@ -773,6 +984,10 @@
patterns.getContext(), controlFoldingReshapes);
patterns.add<FoldAttentionWithProducerReshapeByExpansion>(
patterns.getContext(), controlFoldingReshapes);
+ patterns.add<FoldScatterWithProducerReshapeByExpansion>(
+ patterns.getContext(), controlFoldingReshapes);
+ patterns.add<FoldScatterWithConsumerReshapeByExpansion>(
+ patterns.getContext(), controlFoldingReshapes);
}
SmallVector<unsigned> defaultControlDropUnitDims(Operation *op) {
diff --git a/compiler/src/iree/compiler/DispatchCreation/test/attention_fuse_by_expansion.mlir b/compiler/src/iree/compiler/DispatchCreation/test/attention_fuse_by_expansion.mlir
index 25e1553..b3fdd90 100644
--- a/compiler/src/iree/compiler/DispatchCreation/test/attention_fuse_by_expansion.mlir
+++ b/compiler/src/iree/compiler/DispatchCreation/test/attention_fuse_by_expansion.mlir
@@ -481,3 +481,204 @@
// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>
// CHECK-SAME: ins(%[[ARG2]], %[[ARG1]], %[[COLLAPSED]], %[[ARG3]] :
// CHECK: util.return %[[ATTENTION]]
+
+
+// -----
+
+util.func @scatter_collapse_updates(%arg0: tensor<4x?x2x16x4x128xf16>, %arg1: tensor<?x1xi32>, %arg2: tensor<?x2x16x4x128xf16>) -> tensor<?x2x16x4x128xf16> {
+ %collapsed = tensor.collapse_shape %arg0[[0, 1], [2], [3], [4], [5]] : tensor<4x?x2x16x4x128xf16> into tensor<?x2x16x4x128xf16>
+ %1 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true) ins(%collapsed, %arg1 : tensor<?x2x16x4x128xf16>, tensor<?x1xi32>) outs(%arg2 : tensor<?x2x16x4x128xf16>) {
+ ^bb0(%arg7: f16, %arg8: f16):
+ iree_linalg_ext.yield %arg7 : f16
+ } -> tensor<?x2x16x4x128xf16>
+ util.return %1 : tensor<?x2x16x4x128xf16>
+}
+
+// CHECK-LABEL: util.func public @scatter_collapse_updates
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]:
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]:
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]:
+// CHECK: %[[INDICES:.+]] = tensor.expand_shape
+// CHECK-SAME: tensor<?x1xi32> into tensor<4x?x1xi32>
+// CHECK: %[[SCATTER:.+]] = iree_linalg_ext.scatter
+// CHECK-SAME: ins(%[[ARG0]], %[[INDICES]]
+// CHECK-SAME: outs(%[[ARG2]]
+// CHECK: util.return %[[SCATTER]]
+
+// -----
+
+util.func @scatter_collapse_updates_partial(%arg0: tensor<4x?x2x2x16x4x128xf16>, %arg1: tensor<?x1xi32>, %arg2: tensor<10x16x4x128xf16>) -> tensor<10x16x4x128xf16> {
+ %collapsed = tensor.collapse_shape %arg0[[0, 1], [2, 3], [4], [5], [6]] : tensor<4x?x2x2x16x4x128xf16> into tensor<?x4x16x4x128xf16>
+ %1 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true) ins(%collapsed, %arg1 : tensor<?x4x16x4x128xf16>, tensor<?x1xi32>) outs(%arg2 : tensor<10x16x4x128xf16>) {
+ ^bb0(%arg7: f16, %arg8: f16):
+ iree_linalg_ext.yield %arg7 : f16
+ } -> tensor<10x16x4x128xf16>
+ util.return %1 : tensor<10x16x4x128xf16>
+}
+
+// CHECK-LABEL: util.func public @scatter_collapse_updates_partial
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]:
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]:
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]:
+// CHECK-DAG: %[[INDICES:.+]] = tensor.expand_shape %[[ARG1]] {{.*}} tensor<?x1xi32> into tensor<4x?x1xi32>
+// CHECK-DAG: %[[UPDATES:.+]] = tensor.collapse_shape %[[ARG0]] {{.*}} tensor<4x?x2x2x16x4x128xf16> into tensor<4x?x4x16x4x128xf16>
+// CHECK: %[[SCATTER:.+]] = iree_linalg_ext.scatter
+// CHECK-SAME: ins(%[[UPDATES]], %[[INDICES]]
+// CHECK-SAME: outs(%[[ARG2]]
+// CHECK: util.return %[[SCATTER]]
+
+// -----
+
+util.func @scatter_collapse_original(%arg0: tensor<?x1x32x8x128xf16>, %arg1: tensor<?x1xi32>, %arg2: tensor<?x2x16x4x2x64x2xf16>, %arg3 : index) -> tensor<?x32x8x128xf16> {
+ %collapsed = tensor.collapse_shape %arg2 [[0], [1, 2], [3, 4], [5, 6]] : tensor<?x2x16x4x2x64x2xf16> into tensor<?x32x8x128xf16>
+ %1 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true) ins(%arg0, %arg1 : tensor<?x1x32x8x128xf16>, tensor<?x1xi32>) outs(%collapsed : tensor<?x32x8x128xf16>) {
+ ^bb0(%arg6: f16, %arg7: f16):
+ iree_linalg_ext.yield %arg6 : f16
+ } -> tensor<?x32x8x128xf16>
+ util.return %1 : tensor<?x32x8x128xf16>
+}
+
+// CHECK-LABEL: util.func public @scatter_collapse_original
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]:
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]:
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]:
+// CHECK: %[[UPDATES:.+]] = tensor.expand_shape %[[ARG0]]
+// CHECK-SAME: tensor<?x1x32x8x128xf16> into tensor<?x1x2x16x4x2x64x2xf16>
+// CHECK: %[[SCATTER:.+]] = iree_linalg_ext.scatter
+// CHECK-SAME: ins(%[[UPDATES]], %[[ARG1]]
+// CHECK-SAME: outs(%[[ARG2]]
+// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[SCATTER]]
+// CHECK: util.return %[[COLLAPSE]]
+
+// -----
+
+util.func @scatter_original_noop(%arg0: tensor<?x1x32x8x128xf16>, %arg1: tensor<?x1xi32>, %arg2: tensor<?x80x2x32x8x128xf16>, %arg3 : index) -> tensor<?x32x8x128xf16> {
+ %collapsed = tensor.collapse_shape %arg2 [[0, 1, 2], [3], [4], [5]] : tensor<?x80x2x32x8x128xf16> into tensor<?x32x8x128xf16>
+ %1 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true) ins(%arg0, %arg1 : tensor<?x1x32x8x128xf16>, tensor<?x1xi32>) outs(%collapsed : tensor<?x32x8x128xf16>) {
+ ^bb0(%arg6: f16, %arg7: f16):
+ iree_linalg_ext.yield %arg6 : f16
+ } -> tensor<?x32x8x128xf16>
+ util.return %1 : tensor<?x32x8x128xf16>
+}
+
+// CHECK-LABEL: util.func public @scatter_original_noop
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]:
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]:
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]:
+// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[ARG2]]
+// CHECK: %[[SCATTER:.+]] = iree_linalg_ext.scatter
+// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]]
+// CHECK-SAME: outs(%[[COLLAPSE]]
+// CHECK: util.return %[[SCATTER]]
+
+
+// -----
+
+util.func @scatter_collapse_original_partial(%arg0: tensor<?x1x32x8x128xf16>, %arg1: tensor<?x1xi32>, %arg2: tensor<5x?x2x16x4x2x64x2xf16>, %arg3 : index) -> tensor<?x32x8x128xf16> {
+ %collapsed = tensor.collapse_shape %arg2 [[0, 1], [2, 3], [4, 5], [6, 7]] : tensor<5x?x2x16x4x2x64x2xf16> into tensor<?x32x8x128xf16>
+ %1 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true) ins(%arg0, %arg1 : tensor<?x1x32x8x128xf16>, tensor<?x1xi32>) outs(%collapsed : tensor<?x32x8x128xf16>) {
+ ^bb0(%arg6: f16, %arg7: f16):
+ iree_linalg_ext.yield %arg6 : f16
+ } -> tensor<?x32x8x128xf16>
+ util.return %1 : tensor<?x32x8x128xf16>
+}
+
+// CHECK-LABEL: util.func public @scatter_collapse_original_partial
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]:
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]:
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]:
+// CHECK-DAG: %[[UPDATES:.+]] = tensor.expand_shape %[[ARG0]] {{.*}} tensor<?x1x32x8x128xf16> into tensor<?x1x2x16x4x2x64x2xf16>
+// TODO(IanWood1): fix this so the collapse folds with the expand
+// CHECK-DAG: %[[ORIGINAL:.+]] = tensor.expand_shape {{.*}} tensor<?x32x8x128xf16> into tensor<?x2x16x4x2x64x2xf16>
+// CHECK: %[[SCATTER:.+]] = iree_linalg_ext.scatter
+// CHECK-SAME: ins(%[[UPDATES]], %[[ARG1]]
+// CHECK-SAME: outs(%[[ORIGINAL]]
+// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[SCATTER]]
+// CHECK: util.return %[[COLLAPSE]]
+
+// -----
+
+util.func @scatter_collapse_indices(%arg0: tensor<?x2x16x4x128xf16>, %arg1: tensor<4x?x1xi32>, %arg2: tensor<?x2x16x4x128xf16>) -> tensor<?x2x16x4x128xf16> {
+ %collapsed = tensor.collapse_shape %arg1[[0, 1], [2]] : tensor<4x?x1xi32> into tensor<?x1xi32>
+ %1 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true) ins(%arg0, %collapsed: tensor<?x2x16x4x128xf16>, tensor<?x1xi32>) outs(%arg2 : tensor<?x2x16x4x128xf16>) {
+ ^bb0(%arg7: f16, %arg8: f16):
+ iree_linalg_ext.yield %arg7 : f16
+ } -> tensor<?x2x16x4x128xf16>
+ util.return %1 : tensor<?x2x16x4x128xf16>
+}
+
+// CHECK-LABEL: util.func public @scatter_collapse_indices
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]:
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]:
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]:
+// CHECK-DAG: %[[UPDATES:.+]] = tensor.expand_shape %[[ARG0]] {{.*}} tensor<?x2x16x4x128xf16> into tensor<4x?x2x16x4x128xf16>
+// CHECK: %[[SCATTER:.+]] = iree_linalg_ext.scatter
+// CHECK-SAME: ins(%[[UPDATES]], %[[ARG1]]
+// CHECK-SAME: outs(%[[ARG2]]
+// CHECK: util.return %[[SCATTER]]
+
+// -----
+
+util.func @scatter_collapse_indices_partial(%arg0: tensor<?x2x16x4x128xf16>, %arg1: tensor<4x?x1x1xi32>, %arg2: tensor<?x2x16x4x128xf16>) -> tensor<?x2x16x4x128xf16> {
+ %collapsed = tensor.collapse_shape %arg1[[0, 1], [2, 3]] : tensor<4x?x1x1xi32> into tensor<?x1xi32>
+ %1 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true) ins(%arg0, %collapsed: tensor<?x2x16x4x128xf16>, tensor<?x1xi32>) outs(%arg2 : tensor<?x2x16x4x128xf16>) {
+ ^bb0(%arg7: f16, %arg8: f16):
+ iree_linalg_ext.yield %arg7 : f16
+ } -> tensor<?x2x16x4x128xf16>
+ util.return %1 : tensor<?x2x16x4x128xf16>
+}
+
+// CHECK-LABEL: util.func public @scatter_collapse_indices_partial
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]:
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]:
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]:
+// CHECK-DAG: %[[UPDATES:.+]] = tensor.expand_shape {{.*}} tensor<?x2x16x4x128xf16> into tensor<4x?x2x16x4x128xf16>
+// CHECK-DAG: %[[ORIGINAL:.+]] = tensor.collapse_shape {{.*}} tensor<4x?x1x1xi32> into tensor<4x?x1xi32>
+// CHECK: %[[SCATTER:.+]] = iree_linalg_ext.scatter
+// CHECK-SAME: ins(%[[UPDATES]], %[[ORIGINAL]]
+// CHECK-SAME: outs(%[[ARG2]]
+// CHECK: util.return %[[SCATTER]]
+
+// -----
+
+util.func public @scatter_collapse(%arg0: tensor<?x2x16x4x128xf16>, %arg1: tensor<?x1xi32>, %arg2: tensor<?x2x16x4x128xf16>) -> tensor<?x2x16x4x4x32xf16> {
+ %c0 = arith.constant 0 : index
+ %0 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true) ins(%arg0, %arg1 : tensor<?x2x16x4x128xf16>, tensor<?x1xi32>) outs(%arg2 : tensor<?x2x16x4x128xf16>) {
+ ^bb0(%arg3: f16, %arg4: f16):
+ iree_linalg_ext.yield %arg3 : f16
+ } -> tensor<?x2x16x4x128xf16>
+ %dim = tensor.dim %arg0, %c0 : tensor<?x2x16x4x128xf16>
+ %expanded = tensor.expand_shape %0 [[0], [1], [2], [3], [4, 5]] output_shape [%dim, 2, 16, 4, 4, 32] : tensor<?x2x16x4x128xf16> into tensor<?x2x16x4x4x32xf16>
+ util.return %expanded : tensor<?x2x16x4x4x32xf16>
+}
+// CHECK-LABEL: util.func public @scatter_collapse
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]:
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]:
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]:
+// CHECK-DAG: %[[UPDATES:.+]] = tensor.expand_shape %[[ARG0]] {{.*}} tensor<?x2x16x4x128xf16> into tensor<?x2x16x4x4x32xf16>
+// CHECK-DAG: %[[ORIGINAL:.+]] = tensor.expand_shape %[[ARG2]] {{.*}} tensor<?x2x16x4x128xf16> into tensor<?x2x16x4x4x32xf16>
+// CHECK: %[[SCATTER:.+]] = iree_linalg_ext.scatter
+// CHECK-SAME: ins(%[[UPDATES]], %[[ARG1]]
+// CHECK-SAME: outs(%[[ORIGINAL]]
+// CHECK: util.return %[[SCATTER]]
+
+// -----
+
+util.func public @scatter_collapse_noop(%arg0: tensor<10xf16>, %arg1: tensor<10x1xi32>, %arg2: tensor<128xf16>) -> tensor<4x4x4x2xf16> {
+ %c0 = arith.constant 0 : index
+ %0 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true) ins(%arg0, %arg1 : tensor<10xf16>, tensor<10x1xi32>) outs(%arg2 : tensor<128xf16>) {
+ ^bb0(%arg3: f16, %arg4: f16):
+ iree_linalg_ext.yield %arg3 : f16
+ } -> tensor<128xf16>
+ %expanded = tensor.expand_shape %0 [[0, 1, 2, 3]] output_shape[4, 4, 4, 2] : tensor<128xf16> into tensor<4x4x4x2xf16>
+ util.return %expanded : tensor<4x4x4x2xf16>
+}
+// CHECK-LABEL: util.func public @scatter_collapse_noop
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]:
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]:
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]:
+// CHECK: %[[SCATTER:.+]] = iree_linalg_ext.scatter
+// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]]
+// CHECK-SAME: outs(%[[ARG2]]
+// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[SCATTER]]
+// CHECK: util.return %[[EXPANDED]]