blob: b49f1e2698d457df89ab4ff9873c71abb1df6084 [file] [log] [blame]
// Copyright 2024 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//
// The content of this file is adapted from linalg's ElemenwiseOpFusion.cpp and
// modified to work with LinalgExt ops, specifically `LinalgExt::AttentionOp`.
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/PatternMatch.h"
#include <cstdint>
#include <optional>
namespace mlir::iree_compiler::IREE::LinalgExt {
static bool
isIdentityReassoc(const SmallVector<ReassociationIndices> &indices) {
for (auto &index : indices) {
if (index.size() != 1)
return false;
}
return true;
};
static SmallVector<ReassociationIndices>
computeReassocFromShapeMap(ArrayRef<SmallVector<int64_t>> shapeMap) {
SmallVector<ReassociationIndices> reassoc;
int64_t dimCount = 0;
for (auto &shape : shapeMap) {
reassoc.emplace_back(
llvm::to_vector(llvm::seq<int64_t>(dimCount, dimCount + shape.size())));
dimCount += shape.size();
}
return reassoc;
}
namespace {
/// Helper class that supports fusing reshapes with operands when not all of the
/// shape dims map to the iteration space.
struct ReshapeOperandInfo {
static constexpr int64_t kNoMapping = -1;
// Original shape of this operand.
ArrayRef<int64_t> originalShape;
// Similar to the results of the operand's `AffineMap` except `kNoMapping` if
// that dim doesn't map to the iteration space. For example, the indexed
// dimensions in a LinalgExt::ScatterOp.
SmallVector<int64_t> operandToIterationSpace;
};
/// Information needed to expand an operation to fold the reshape with
/// it.
class ExpansionInfo {
public:
// Computes the mapping from original dimensions of the op to the dimensions
// of the expanded op given the `indexingMap` of the fused operand/result of
// the op, the `reassocationMaps` of the reshape op and the shape of
// the expanded op.
LogicalResult compute(SmallVector<ReshapeOperandInfo> infos,
SmallVector<int64_t> loopRanges,
OpOperand *fusableOpOperand,
ArrayRef<ReassociationIndices> operandReassoc,
ArrayRef<int64_t> expandedShape);
std::optional<Value> getOrCreateExpanded(Location loc, OpOperand *operand,
RewriterBase &rewriter) {
auto shapeMap = this->getShapeMap(operand);
auto reassoc = computeReassocFromShapeMap(shapeMap);
if (isIdentityReassoc(reassoc)) {
return operand->get();
}
SmallVector<int64_t> flattenedArray;
for (auto &shape : shapeMap) {
flattenedArray.append(shape.begin(), shape.end());
}
auto oldType = cast<ShapedType>(operand->get().getType());
auto newType =
RankedTensorType::get(flattenedArray, oldType.getElementType());
if (failed(reshapeLikeShapesAreCompatible(
[&](const Twine &msg) {
return rewriter.notifyMatchFailure(loc, msg);
},
oldType.getShape(), newType.getShape(), reassoc,
/*isExpandingReshape=*/true))) {
return {};
}
return rewriter.create<tensor::ExpandShapeOp>(loc, newType, operand->get(),
reassoc);
};
/// Get the shape map for the operand.
SmallVector<SmallVector<int64_t>> getShapeMap(OpOperand *operand) const {
auto info = reshapeInfos[operand->getOperandNumber()];
SmallVector<SmallVector<int64_t>> shapeMap;
for (auto [operandIdx, loopIdx] :
llvm::enumerate(info.operandToIterationSpace)) {
if (loopIdx == ReshapeOperandInfo::kNoMapping) {
shapeMap.push_back(
SmallVector<int64_t>{info.originalShape[operandIdx]});
} else {
shapeMap.push_back(loopShapeMap[loopIdx]);
}
}
return shapeMap;
}
SmallVector<ReassociationIndices> getReassoc(OpOperand *operand) const {
auto shapeMap = this->getShapeMap(operand);
return computeReassocFromShapeMap(shapeMap);
}
unsigned getOrigNumLoops() const { return loopReassoc.size(); }
unsigned getExpandedNumLoops() const { return expandedOpNumDims; }
ReassociationIndicesRef getExpandedLoops(unsigned i) const {
return loopReassoc[i];
}
ArrayRef<int64_t> getExpandedShapeOfLoop(unsigned i) const {
return loopShapeMap[i];
}
private:
/// Extent of the iteration space in the original operation.
SmallVector<int64_t> loopRanges;
SmallVector<ReassociationIndices> loopReassoc;
/// Mapping from extent of loops in the original operation, to the extent of
/// loops in the expanded operation.
SmallVector<SmallVector<int64_t>> loopShapeMap;
unsigned expandedOpNumDims;
/// Info about the reassociation and original shape for each operand.
SmallVector<ReshapeOperandInfo> reshapeInfos;
};
class CollapsingInfo {
public:
LogicalResult initialize(unsigned origNumLoops,
ArrayRef<ReassociationIndices> foldedIterationDims);
/// Return mapping from collapsed loop domain to original loop domain.
ArrayRef<ReassociationIndices> getCollapsedOpToOrigOpMapping() const {
return collapsedOpToOrigOpIterationDim;
}
/// Return mapping from original loop domain to collapsed loop domain. The
/// mapping is a pair. First value is the dimension in the collapsed loop that
/// the original loop is mapped to. Second is the relative position in folded
/// list of this domain. For example if the original loop domain is 3D, and
/// the collapsed loop domain is folding all of it, i.e.
///
/// ```
/// collapsedOpToOrigOpMapping = [[0, 1, 2] [3, 4]]`
/// ```
///
/// then
///
/// ```
/// origOpToCollapsedOpMapping[0] = {0, 0};
/// origOpToCollapsedOpMapping[1] = {0, 1};
/// origOpToCollapsedOpMapping[2] = {0, 2};
/// origOpToCollapsedOpMapping[3] = {1, 0};
/// origOpToCollapsedOpMapping[4] = {1, 1};
/// ```
///
ArrayRef<std::pair<int64_t, unsigned>> getOrigOpToCollapsedOpMapping() const {
return origOpToCollapsedOpIterationDim;
}
/// Return the collapsed op iteration domain rank.
unsigned getCollapsedOpIterationRank() const {
return collapsedOpToOrigOpIterationDim.size();
}
private:
/// Map from the iteration domain index in collapsed op to the iteration
/// domain indices in the original op.
SmallVector<ReassociationIndices> collapsedOpToOrigOpIterationDim;
/// Map from iteration domain index in the original op to the iteration domain
/// index in the collapsed op.
SmallVector<std::pair<int64_t, unsigned>> origOpToCollapsedOpIterationDim;
};
} // namespace
LogicalResult ExpansionInfo::compute(
SmallVector<ReshapeOperandInfo> infos, SmallVector<int64_t> loopRanges,
OpOperand *fusableOpOperand, ArrayRef<ReassociationIndices> operandReassoc,
ArrayRef<int64_t> expandedShape) {
if (operandReassoc.empty())
return failure();
int64_t operandNum = fusableOpOperand->getOperandNumber();
ReshapeOperandInfo &fusionOperandInfo = infos[operandNum];
this->loopShapeMap.clear();
this->loopShapeMap.resize(loopRanges.size());
for (auto [operandIdx, loopIdx] :
llvm::enumerate(fusionOperandInfo.operandToIterationSpace)) {
if (loopIdx == ReshapeOperandInfo::kNoMapping) {
continue;
}
// Compute the shape map at element `loopIdx`
ReassociationIndicesRef indices = operandReassoc[operandIdx];
for (auto [dimIdx, shapeIdx] : llvm::enumerate(indices)) {
this->loopShapeMap[loopIdx].push_back(expandedShape[shapeIdx]);
}
}
// Fill in the remaining elements with `loopRanges`
this->expandedOpNumDims = 0;
for (const auto &[loopIdx, shapeMap] : llvm::enumerate(this->loopShapeMap)) {
if (shapeMap.empty()) {
this->loopShapeMap[loopIdx] = SmallVector<int64_t>{loopRanges[loopIdx]};
}
this->expandedOpNumDims += shapeMap.size();
}
if (llvm::all_of(this->loopShapeMap,
[&](auto vec) { return vec.size() == 1; })) {
return failure();
}
this->loopReassoc = computeReassocFromShapeMap(this->loopShapeMap);
this->reshapeInfos = std::move(infos);
this->loopRanges = std::move(loopRanges);
return success();
}
LogicalResult
CollapsingInfo::initialize(unsigned origNumLoops,
ArrayRef<ReassociationIndices> foldedIterationDims) {
llvm::SmallDenseSet<int64_t, 4> processedDims;
// Find all the dims that are folded.
for (ReassociationIndicesRef foldedIterationDim : foldedIterationDims) {
if (foldedIterationDim.empty())
continue;
// If the folded dims contain dims already folded, that's illegal
// specification. Repetition within a list is also illegal.
for (auto dim : foldedIterationDim) {
if (dim >= origNumLoops)
return failure();
if (processedDims.count(dim))
return failure();
processedDims.insert(dim);
}
collapsedOpToOrigOpIterationDim.emplace_back(foldedIterationDim.begin(),
foldedIterationDim.end());
}
if (processedDims.size() > origNumLoops)
return failure();
// Add all the preserved dims of the original op as single
// elements to `collapsedOpToOrigOpIterationDim`.
for (auto dim : llvm::seq<int64_t>(0, origNumLoops)) {
if (processedDims.count(dim))
continue;
collapsedOpToOrigOpIterationDim.emplace_back(ReassociationIndices{dim});
}
llvm::sort(collapsedOpToOrigOpIterationDim,
[&](ReassociationIndicesRef lhs, ReassociationIndicesRef rhs) {
return lhs[0] < rhs[0];
});
origOpToCollapsedOpIterationDim.resize(origNumLoops);
for (const auto &foldedDims :
llvm::enumerate(collapsedOpToOrigOpIterationDim)) {
for (const auto &dim : enumerate(foldedDims.value()))
origOpToCollapsedOpIterationDim[dim.value()] =
std::make_pair<int64_t, unsigned>(foldedDims.index(), dim.index());
}
return success();
}
static SmallVector<ReshapeOperandInfo>
getAttentionReshapeInfo(LinalgExt::AttentionOp attentionOp) {
return llvm::map_to_vector(
attentionOp->getOpOperands(), [&](OpOperand &opOperand) {
ReshapeOperandInfo operandInfo;
auto operandType = dyn_cast<ShapedType>(opOperand.get().getType());
if (!operandType) {
assert(
attentionOp.getMatchingIndexingMap(&opOperand).getNumResults() ==
0 &&
"expected non-shaped type to have no results in indexing map");
return operandInfo;
}
operandInfo.originalShape = operandType.getShape();
for (auto result :
attentionOp.getMatchingIndexingMap(&opOperand).getResults()) {
operandInfo.operandToIterationSpace.push_back(
cast<AffineDimExpr>(result).getPosition());
}
return operandInfo;
});
}
static SmallVector<ReshapeOperandInfo>
getScatterReshapeInfo(LinalgExt::ScatterOp scatterOp) {
SmallVector<ReshapeOperandInfo> infos;
auto rankOfContiguousSlice =
scatterOp.getOriginalType().getRank() - scatterOp.getIndexDepth();
auto updateRank = scatterOp.getUpdateType().getRank();
// Operand #0 Updates
{
ReshapeOperandInfo updateInfo;
updateInfo.originalShape = scatterOp.getUpdateType().getShape();
llvm::append_range(updateInfo.operandToIterationSpace,
llvm::seq<int64_t>(0, updateRank));
infos.push_back(std::move(updateInfo));
}
// Operand#1 Indices
{
ReshapeOperandInfo indicesInfo;
indicesInfo.originalShape = scatterOp.getIndicesType().getShape();
llvm::append_range(indicesInfo.operandToIterationSpace,
llvm::seq<int64_t>(0, scatterOp.getBatchRank()));
if (scatterOp.getBatchRank() != scatterOp.getIndicesType().getRank())
indicesInfo.operandToIterationSpace.push_back(
ReshapeOperandInfo::kNoMapping);
infos.push_back(std::move(indicesInfo));
}
// Operand #2 Original
{
ReshapeOperandInfo originalInfo;
originalInfo.originalShape = scatterOp.getOriginalType().getShape();
originalInfo.operandToIterationSpace.append(scatterOp.getIndexDepth(),
ReshapeOperandInfo::kNoMapping);
llvm::append_range(
originalInfo.operandToIterationSpace,
llvm::seq(updateRank - rankOfContiguousSlice, updateRank));
infos.push_back(std::move(originalInfo));
}
return infos;
};
static AffineMap
getIndexingMapInExpandedOp(OpBuilder &builder, AffineMap indexingMap,
const ExpansionInfo &expansionInfo) {
SmallVector<AffineExpr> newExprs;
for (AffineExpr expr : indexingMap.getResults()) {
unsigned pos = cast<AffineDimExpr>(expr).getPosition();
auto expandedExprs = llvm::to_vector_of<AffineExpr, 6>(
llvm::map_range(expansionInfo.getExpandedLoops(pos), [&](int64_t v) {
return builder.getAffineDimExpr(static_cast<unsigned>(v));
}));
newExprs.append(expandedExprs.begin(), expandedExprs.end());
}
return AffineMap::get(expansionInfo.getExpandedNumLoops(),
indexingMap.getNumSymbols(), newExprs,
builder.getContext());
}
static bool isFusableWithReshapeByDimExpansion(AttentionOp op,
OpOperand *fusableOpOperand) {
// Is fusable only if:
// - All the indexing maps for operands and results are projected
// permutations.
// - The fused tensor is not a scalar.
// - All the loops for the reshaped operand are parallel loops.
SmallVector<utils::IteratorType> iteratorTypes = op.getLoopIteratorTypes();
AffineMap operandMap = op.getMatchingIndexingMap(fusableOpOperand);
return operandMap && op.hasPureTensorSemantics() &&
llvm::all_of(op.getIndexingMapsArray(),
[](AffineMap map) {
return map && map.isProjectedPermutation();
}) &&
operandMap.getNumResults() > 0;
}
static std::optional<SmallVector<Value>> fuseAttentionWithReshapeByExpansion(
AttentionOp attentionOp, Operation *reshapeOp, OpOperand *fusableOpOperand,
PatternRewriter &rewriter) {
assert(isFusableWithReshapeByDimExpansion(attentionOp, fusableOpOperand) &&
"preconditions for fuse operation failed");
Location loc = attentionOp.getLoc();
// Check if reshape is expanding or collapsing.
auto expandingReshapeOp = dyn_cast<tensor::ExpandShapeOp>(*reshapeOp);
auto collapsingReshapeOp = dyn_cast<tensor::CollapseShapeOp>(*reshapeOp);
bool isExpanding = (expandingReshapeOp != nullptr);
RankedTensorType expandedType = isExpanding
? expandingReshapeOp.getResultType()
: collapsingReshapeOp.getSrcType();
ExpansionInfo expansionInfo;
if (failed(expansionInfo.compute(
getAttentionReshapeInfo(attentionOp),
attentionOp.getStaticLoopRanges().value(), fusableOpOperand,
isExpanding ? expandingReshapeOp.getReassociationIndices()
: collapsingReshapeOp.getReassociationIndices(),
expandedType.getShape())))
return std::nullopt;
auto expandedOpIndexingMaps = llvm::to_vector_of<AffineMap, 6>(
llvm::map_range(attentionOp.getIndexingMapsArray(), [&](AffineMap m) {
return getIndexingMapInExpandedOp(rewriter, m, expansionInfo);
}));
// Set insertion point to the attention op.
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(attentionOp);
SmallVector<Value> expandedOpOperands;
expandedOpOperands.reserve(attentionOp.getNumDpsInputs());
for (OpOperand *opOperand : attentionOp.getDpsInputOperands()) {
if (opOperand == fusableOpOperand) {
expandedOpOperands.push_back(isExpanding ? expandingReshapeOp.getSrc()
: collapsingReshapeOp.getSrc());
continue;
}
// Reshape the operand to get the right type.
std::optional<Value> expanded =
expansionInfo.getOrCreateExpanded(loc, opOperand, rewriter);
if (!expanded)
return std::nullopt;
expandedOpOperands.push_back(*expanded);
continue;
}
Value output;
OpOperand &outOperand = attentionOp.getOutputMutable();
std::optional<Value> maybeOutput =
expansionInfo.getOrCreateExpanded(loc, &outOperand, rewriter);
if (!maybeOutput)
return std::nullopt;
output = *maybeOutput;
Value maskOperand;
if (expandedOpOperands.size() > 4) {
maskOperand = expandedOpOperands[4];
}
// Create a new `AttentionOp` that has the computed operands/indexing maps.
TypeRange resultTypes = ValueRange(output).getTypes();
auto fusedOp = rewriter.create<AttentionOp>(
attentionOp.getLoc(), resultTypes, expandedOpOperands[0],
expandedOpOperands[1], expandedOpOperands[2], expandedOpOperands[3],
output, rewriter.getAffineMapArrayAttr(expandedOpIndexingMaps),
maskOperand);
rewriter.inlineRegionBefore(attentionOp.getRegion(), fusedOp.getRegion(),
fusedOp.getRegion().begin());
// Reshape the result values to their original shape if this is a collapsing
// reshape folded into its consumer.
SmallVector<Value> resultVals;
for (OpResult opResult : attentionOp->getOpResults()) {
int64_t resultNumber = opResult.getResultNumber();
if (resultTypes[resultNumber] != opResult.getType()) {
SmallVector<ReassociationIndices> reassociation =
expansionInfo.getReassoc(attentionOp.getTiedOpOperand(opResult));
resultVals.push_back(rewriter.create<tensor::CollapseShapeOp>(
attentionOp.getLoc(), opResult.getType(),
fusedOp->getResult(resultNumber), reassociation));
} else {
resultVals.push_back(fusedOp->getResult(resultNumber));
}
}
// Assuming a single result.
return resultVals;
}
static std::optional<Value>
fuseScatterWithReshapeByExpansion(ScatterOp scatterOp, Operation *reshapeOp,
OpOperand *fusableOpOperand,
PatternRewriter &rewriter) {
Location loc = scatterOp.getLoc();
// Check if reshape is expanding or collapsing.
auto expandingReshapeOp = dyn_cast<tensor::ExpandShapeOp>(*reshapeOp);
auto collapsingReshapeOp = dyn_cast<tensor::CollapseShapeOp>(*reshapeOp);
bool isExpanding = (expandingReshapeOp != nullptr);
RankedTensorType expandedType = isExpanding
? expandingReshapeOp.getResultType()
: collapsingReshapeOp.getSrcType();
ExpansionInfo info;
if (failed(info.compute(
getScatterReshapeInfo(scatterOp),
scatterOp.getStaticLoopRanges().value(), fusableOpOperand,
isExpanding ? expandingReshapeOp.getReassociationIndices()
: collapsingReshapeOp.getReassociationIndices(),
expandedType.getShape()))) {
return std::nullopt;
}
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(scatterOp);
OpOperand *update = scatterOp.getDpsInputOperand(0);
OpOperand *indices = scatterOp.getDpsInputOperand(1);
OpOperand *original = scatterOp.getDpsInitOperand(0);
auto newUpdates = info.getOrCreateExpanded(loc, update, rewriter).value();
auto newIndices = info.getOrCreateExpanded(loc, indices, rewriter).value();
auto newOriginal = info.getOrCreateExpanded(loc, original, rewriter).value();
auto newScatter = rewriter.create<ScatterOp>(
loc, newOriginal.getType(), ValueRange{newUpdates, newIndices},
ValueRange{newOriginal}, scatterOp.getDimensionMap(),
scatterOp.getUniqueIndices());
rewriter.inlineRegionBefore(scatterOp.getRegion(), newScatter.getRegion(),
newScatter.getRegion().begin());
auto originalShapeMap = info.getShapeMap(original);
SmallVector<ReassociationIndices> originalReassoc =
computeReassocFromShapeMap(originalShapeMap);
// Collapse back to original shape.
if (isIdentityReassoc(originalReassoc)) {
return {newScatter.getResult(0)};
}
auto newCollapse = rewriter.create<tensor::CollapseShapeOp>(
loc, scatterOp.getOriginalType(), newScatter.getResult(0),
originalReassoc);
return {newCollapse};
}
//===----------------------------------------------------------------------===//
// Fuse By Expansion Patterns
//===----------------------------------------------------------------------===//
namespace {
// Fold attention with its consumer expand_shape op.
struct FoldAttentionWithConsumerReshapeByExpansion
: public OpRewritePattern<tensor::ExpandShapeOp> {
FoldAttentionWithConsumerReshapeByExpansion(
MLIRContext *context, linalg::ControlFusionFn foldReshapes,
PatternBenefit benefit = 1)
: OpRewritePattern<tensor::ExpandShapeOp>(context, benefit),
controlFoldingReshapes(std::move(foldReshapes)) {}
LogicalResult matchAndRewrite(tensor::ExpandShapeOp reshapeOp,
PatternRewriter &rewriter) const override {
auto producerResult = dyn_cast<OpResult>(reshapeOp.getSrc());
if (!producerResult) {
return rewriter.notifyMatchFailure(reshapeOp,
"source not produced by an operation");
}
auto producer = dyn_cast<AttentionOp>(producerResult.getOwner());
if (!producer) {
return rewriter.notifyMatchFailure(reshapeOp,
"producer is not an attention op");
}
if (!controlFoldingReshapes(&reshapeOp.getSrcMutable())) {
return rewriter.notifyMatchFailure(reshapeOp,
"fusion blocked by control function");
}
// Note: expand_shape can always be fused with attention, it is not checked
// as a precondition. It is asserted in `fuseWithReshapeByExpansion`.
std::optional<SmallVector<Value>> replacementValues =
fuseAttentionWithReshapeByExpansion(
producer, reshapeOp,
producer.getDpsInitOperand(producerResult.getResultNumber()),
rewriter);
if (!replacementValues) {
return rewriter.notifyMatchFailure(reshapeOp,
"fusion by expansion failed");
}
Value reshapeReplacement =
(*replacementValues)[cast<OpResult>(reshapeOp.getSrc())
.getResultNumber()];
if (auto collapseOp =
reshapeReplacement.getDefiningOp<tensor::CollapseShapeOp>()) {
reshapeReplacement = collapseOp.getSrc();
}
rewriter.replaceOp(reshapeOp, reshapeReplacement);
rewriter.replaceOp(producer, *replacementValues);
return success();
}
linalg::ControlFusionFn controlFoldingReshapes;
};
// Fold a collapse_shape op with its consumer attention op.
// class FoldWith
struct FoldAttentionWithProducerReshapeByExpansion final
: public OpRewritePattern<AttentionOp> {
FoldAttentionWithProducerReshapeByExpansion(
MLIRContext *context, linalg::ControlFusionFn controlFoldingReshapes,
PatternBenefit benefit = 1)
: OpRewritePattern<AttentionOp>(context, benefit),
controlFoldingReshapes(std::move(controlFoldingReshapes)) {}
LogicalResult matchAndRewrite(AttentionOp attentionOp,
PatternRewriter &rewriter) const override {
for (OpOperand *opOperand : attentionOp.getDpsInputOperands()) {
tensor::CollapseShapeOp reshapeOp =
opOperand->get().getDefiningOp<tensor::CollapseShapeOp>();
if (!reshapeOp)
continue;
// Fold only if
// - The tensor reshape op is folding.
// - All constraints of fusing with reshape by expansion are met.
if (!isFusableWithReshapeByDimExpansion(attentionOp, opOperand) ||
(!controlFoldingReshapes(opOperand)))
continue;
std::optional<SmallVector<Value>> replacementValues =
fuseAttentionWithReshapeByExpansion(attentionOp, reshapeOp, opOperand,
rewriter);
if (!replacementValues)
return failure();
rewriter.replaceOp(attentionOp, *replacementValues);
return success();
}
return failure();
}
linalg::ControlFusionFn controlFoldingReshapes;
};
struct DropScatterUnitIndexDepth final : public OpRewritePattern<ScatterOp> {
using OpRewritePattern<ScatterOp>::OpRewritePattern;
LogicalResult matchAndRewrite(ScatterOp scatterOp,
PatternRewriter &rewriter) const override {
llvm::ArrayRef<int64_t> indicesShape =
scatterOp.getIndicesType().getShape();
int64_t batchRank = scatterOp.getBatchRank();
if (indicesShape.size() == batchRank || indicesShape.back() != 1) {
return failure();
}
SmallVector<ReassociationIndices> reassoc;
reassoc.reserve(indicesShape.size());
for (auto i : llvm::seq<int64_t>(0, batchRank - 1)) {
reassoc.emplace_back(1, i);
}
reassoc.push_back(ReassociationIndices{batchRank - 1, batchRank});
auto collapseOp = rewriter.create<tensor::CollapseShapeOp>(
scatterOp.getLoc(), scatterOp.getIndices(), reassoc);
rewriter.modifyOpInPlace(scatterOp, [&]() {
scatterOp.setOperand(ScatterOp::kIndicesOpNum, collapseOp.getResult());
});
return success();
}
};
struct FoldScatterWithProducerReshapeByExpansion final
: public OpRewritePattern<ScatterOp> {
FoldScatterWithProducerReshapeByExpansion(
MLIRContext *context, linalg::ControlFusionFn controlFoldingReshapes,
PatternBenefit benefit = 1)
: OpRewritePattern<ScatterOp>(context, benefit),
controlFoldingReshapes(std::move(controlFoldingReshapes)) {}
LogicalResult matchAndRewrite(ScatterOp scatterOp,
PatternRewriter &rewriter) const override {
for (OpOperand &opOperand : scatterOp->getOpOperands()) {
tensor::CollapseShapeOp reshapeOp =
opOperand.get().getDefiningOp<tensor::CollapseShapeOp>();
if (!reshapeOp)
continue;
if (!controlFoldingReshapes(&opOperand))
continue;
std::optional<Value> replacementValue = fuseScatterWithReshapeByExpansion(
scatterOp, reshapeOp, &opOperand, rewriter);
if (!replacementValue)
return failure();
rewriter.replaceOp(scatterOp, *replacementValue);
return success();
}
return failure();
}
linalg::ControlFusionFn controlFoldingReshapes;
};
struct FoldScatterWithConsumerReshapeByExpansion final
: public OpRewritePattern<tensor::ExpandShapeOp> {
FoldScatterWithConsumerReshapeByExpansion(
MLIRContext *context, linalg::ControlFusionFn controlFoldingReshapes,
PatternBenefit benefit = 1)
: OpRewritePattern<tensor::ExpandShapeOp>(context, benefit),
controlFoldingReshapes(std::move(controlFoldingReshapes)) {}
LogicalResult matchAndRewrite(tensor::ExpandShapeOp expandOp,
PatternRewriter &rewriter) const override {
auto producerResult = dyn_cast<OpResult>(expandOp.getSrc());
if (!producerResult) {
return rewriter.notifyMatchFailure(expandOp,
"source not produced by an operation");
}
auto scatterOp = producerResult.getDefiningOp<LinalgExt::ScatterOp>();
if (!scatterOp) {
return failure();
}
if (!controlFoldingReshapes(&expandOp.getSrcMutable())) {
return failure();
}
std::optional<Value> replacementValue = fuseScatterWithReshapeByExpansion(
scatterOp, expandOp, scatterOp.getTiedOpOperand(producerResult),
rewriter);
if (!replacementValue)
return failure();
rewriter.replaceOp(scatterOp, *replacementValue);
return success();
}
linalg::ControlFusionFn controlFoldingReshapes;
};
} // namespace
/// Return the `reassociation` indices to use to collapse the operand when the
/// iteration space of a generic op is collapsed.
static SmallVector<ReassociationIndices>
getOperandReassociation(AffineMap indexingMap,
const CollapsingInfo &collapsingInfo) {
unsigned counter = 0;
SmallVector<ReassociationIndices> operandReassociation;
auto origOpToCollapsedOpMapping =
collapsingInfo.getOrigOpToCollapsedOpMapping();
auto collapsedOpToOrigOpMapping =
collapsingInfo.getCollapsedOpToOrigOpMapping();
while (counter < indexingMap.getNumResults()) {
unsigned dim =
cast<AffineDimExpr>(indexingMap.getResult(counter)).getPosition();
// This is the start of a collapsed dimensions of the iteration that
// is gauranteed to be preserved in the indexing map. The number of folded
// dims is obtained from the collapsed op to original op mapping.
unsigned numFoldedDims =
collapsedOpToOrigOpMapping[origOpToCollapsedOpMapping[dim].first]
.size();
if (origOpToCollapsedOpMapping[dim].second == 0) {
auto range = llvm::seq<unsigned>(counter, counter + numFoldedDims);
operandReassociation.emplace_back(range.begin(), range.end());
}
counter += numFoldedDims;
}
return operandReassociation;
}
/// Get the new value to use for a given `OpOperand` in the collapsed operation.
static Value getCollapsedOpOperand(Location loc, AttentionOp op,
OpOperand *opOperand,
const CollapsingInfo &collapsingInfo,
OpBuilder &builder) {
AffineMap indexingMap = op.getMatchingIndexingMap(opOperand);
SmallVector<ReassociationIndices> operandReassociation =
getOperandReassociation(indexingMap, collapsingInfo);
// If the number of entries in the reassociation for the operand is same as
// the number of results of the indexing map, then nothing to do for this
// operand.
Value operand = opOperand->get();
if (operandReassociation.size() == indexingMap.getNumResults())
return operand;
// Insert a reshape to collapse the dimensions.
if (isa<MemRefType>(operand.getType())) {
return builder
.create<memref::CollapseShapeOp>(loc, operand, operandReassociation)
.getResult();
}
return builder
.create<tensor::CollapseShapeOp>(loc, operand, operandReassociation)
.getResult();
}
static void collapseOperandsAndResults(AttentionOp op,
const CollapsingInfo &collapsingInfo,
RewriterBase &rewriter,
SmallVectorImpl<Value> &inputOperands,
SmallVectorImpl<Value> &outputOperands,
SmallVectorImpl<Type> &resultTypes) {
Location loc = op->getLoc();
inputOperands =
llvm::map_to_vector(op.getDpsInputOperands(), [&](OpOperand *opOperand) {
return getCollapsedOpOperand(loc, op, opOperand, collapsingInfo,
rewriter);
});
// Get the output operands and result types.
resultTypes.reserve(op.getNumDpsInits());
outputOperands.reserve(op.getNumDpsInits());
for (OpOperand &output : op.getDpsInitsMutable()) {
Value newOutput =
getCollapsedOpOperand(loc, op, &output, collapsingInfo, rewriter);
outputOperands.push_back(newOutput);
// If the op has "buffer semantics", then the init operands are ranked
// memrefs and the op has no results.
if (!op.hasPureBufferSemantics())
resultTypes.push_back(newOutput.getType());
}
}
/// Compute the indexing map in the collapsed op that corresponds to the given
/// `indexingMap` of the original operation.
static AffineMap
getCollapsedOpIndexingMap(AffineMap indexingMap,
const CollapsingInfo &collapsingInfo) {
MLIRContext *context = indexingMap.getContext();
assert(indexingMap.isProjectedPermutation() &&
"expected indexing map to be projected permutation");
SmallVector<AffineExpr> resultExprs;
auto origOpToCollapsedOpMapping =
collapsingInfo.getOrigOpToCollapsedOpMapping();
for (auto expr : indexingMap.getResults()) {
unsigned dim = cast<AffineDimExpr>(expr).getPosition();
// If the dim is not the first of the collapsed dim, do nothing.
if (origOpToCollapsedOpMapping[dim].second != 0)
continue;
// The next n-dims are guaranteed to be collapsed. So just use the
// iteration dimension of the collapsed op.
resultExprs.push_back(
getAffineDimExpr(origOpToCollapsedOpMapping[dim].first, context));
}
return AffineMap::get(collapsingInfo.getCollapsedOpIterationRank(), 0,
resultExprs, context);
}
/// Get the iterator types for the collapsed operation given the original
/// iterator types and collapsed dimensions.
static SmallVector<utils::IteratorType>
getCollapsedOpIteratorTypes(ArrayRef<utils::IteratorType> iteratorTypes,
const CollapsingInfo &collapsingInfo) {
SmallVector<utils::IteratorType> collapsedIteratorTypes;
for (ReassociationIndicesRef foldedIterDims :
collapsingInfo.getCollapsedOpToOrigOpMapping()) {
assert(!foldedIterDims.empty() &&
"reassociation indices expected to have non-empty sets");
// Just pick the iterator type of the first folded dim. Pre-condition checks
// expected to have checked that iterator types of all folded dimensions are
// the same.
collapsedIteratorTypes.push_back(iteratorTypes[foldedIterDims[0]]);
}
return collapsedIteratorTypes;
}
/// Returns a copy of `attentionOp` with collapsed iteration dimensions.
static Operation *createCollapsedOp(AttentionOp origOp,
const CollapsingInfo &collapsingInfo,
RewriterBase &rewriter) {
SmallVector<Value> inputOperands, outputOperands;
SmallVector<Type> resultTypes;
collapseOperandsAndResults(origOp, collapsingInfo, rewriter, inputOperands,
outputOperands, resultTypes);
SmallVector<AffineMap> indexingMaps(
llvm::map_range(origOp.getIndexingMapsArray(), [&](AffineMap map) {
return getCollapsedOpIndexingMap(map, collapsingInfo);
}));
SmallVector<utils::IteratorType> iteratorTypes(getCollapsedOpIteratorTypes(
origOp.getLoopIteratorTypes(), collapsingInfo));
Value maskOperand;
if (inputOperands.size() > 4) {
maskOperand = inputOperands[4];
}
auto collapsedOp = rewriter.create<AttentionOp>(
origOp.getLoc(), resultTypes, inputOperands[0], inputOperands[1],
inputOperands[2], inputOperands[3], outputOperands[0],
rewriter.getAffineMapArrayAttr(indexingMaps), maskOperand);
rewriter.inlineRegionBefore(origOp.getRegion(), collapsedOp.getRegion(),
collapsedOp.getRegion().begin());
return collapsedOp;
}
FailureOr<CollapseResult>
collapseOpIterationDims(AttentionOp op,
ArrayRef<ReassociationIndices> foldedIterationDims,
RewriterBase &rewriter) {
if (op.getNumLoops() <= 1 || foldedIterationDims.empty() ||
llvm::all_of(foldedIterationDims, [](ReassociationIndicesRef foldedDims) {
return foldedDims.size() <= 1;
}))
return failure();
FailureOr<SmallVector<int64_t>> staticLoops = op.getStaticLoopRanges();
if (failed(staticLoops) ||
llvm::any_of(staticLoops.value(), ShapedType::isDynamic)) {
return failure();
}
CollapsingInfo collapsingInfo;
if (failed(
collapsingInfo.initialize(op.getNumLoops(), foldedIterationDims))) {
return rewriter.notifyMatchFailure(
op, "illegal to collapse specified dimensions");
}
Operation *collapsedOp = createCollapsedOp(op, collapsingInfo, rewriter);
auto loc = op.getLoc();
SmallVector<Value> results;
for (const auto &originalResult : llvm::enumerate(op->getResults())) {
Value collapsedOpResult = collapsedOp->getResult(originalResult.index());
auto originalResultType =
cast<ShapedType>(originalResult.value().getType());
auto collapsedOpResultType = cast<ShapedType>(collapsedOpResult.getType());
if (collapsedOpResultType.getRank() != originalResultType.getRank()) {
AffineMap indexingMap =
op.getIndexingMapMatchingResult(originalResult.value());
SmallVector<ReassociationIndices> reassociation =
getOperandReassociation(indexingMap, collapsingInfo);
Value result;
if (isa<MemRefType>(collapsedOpResult.getType())) {
MemRefType expandShapeResultType = MemRefType::get(
originalResultType.getShape(), originalResultType.getElementType());
result = rewriter.create<memref::ExpandShapeOp>(
loc, expandShapeResultType, collapsedOpResult, reassociation);
} else {
result = rewriter.create<tensor::ExpandShapeOp>(
loc, originalResultType, collapsedOpResult, reassociation);
}
results.push_back(result);
} else {
results.push_back(collapsedOpResult);
}
}
return CollapseResult{results, collapsedOp};
}
void populateFoldReshapeOpsByExpansionPatterns(
RewritePatternSet &patterns,
const linalg::ControlFusionFn &controlFoldingReshapes) {
patterns.add<FoldAttentionWithConsumerReshapeByExpansion>(
patterns.getContext(), controlFoldingReshapes);
patterns.add<FoldAttentionWithProducerReshapeByExpansion>(
patterns.getContext(), controlFoldingReshapes);
patterns.add<FoldScatterWithProducerReshapeByExpansion>(
patterns.getContext(), controlFoldingReshapes);
patterns.add<FoldScatterWithConsumerReshapeByExpansion>(
patterns.getContext(), controlFoldingReshapes);
}
SmallVector<unsigned> defaultControlDropUnitDims(Operation *op) {
auto fusionOp = cast<LinalgFusionOpInterface>(op);
return llvm::to_vector(llvm::seq<unsigned>(0, fusionOp.getNumLoops()));
}
void populateFoldUnitExtentDimsPatterns(
RewritePatternSet &patterns, const linalg::ControlDropUnitDims &options) {
patterns.add<DropScatterUnitIndexDepth>(patterns.getContext());
}
} // namespace mlir::iree_compiler::IREE::LinalgExt