blob: 4a629c91b68c5507928a29d1bc66fe9f3b6b9273 [file]
// Copyright 2023 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//===- PropagateLinalgTranspose.cpp - Pass to propagate transposes ---------==//
//
// The pass is to propagate linalg.transpose operations through a restricted
// set of operations based on a set of local propagation decisions.
//
//===----------------------------------------------------------------------===//
#include "iree/compiler/Dialect/Flow/Conversion/TensorToFlow/Utils.h"
#include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
#include "iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h"
#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
#include "iree/compiler/GlobalOptimization/Passes.h"
#include "llvm/Support/Debug.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#define DEBUG_TYPE "iree-global-opt-propagate-linalg-transpose"
namespace mlir::iree_compiler::GlobalOptimization {
#define GEN_PASS_DEF_PROPAGATELINALGTRANSPOSEPASS
#include "iree/compiler/GlobalOptimization/Passes.h.inc"
//===----------------------------------------------------------------------===//
// Transpose permutation helpers
//===----------------------------------------------------------------------===//
static bool isIdentityPermutation(ArrayRef<int64_t> perm) {
for (auto [index, dim] : llvm::enumerate(perm)) {
if (index != dim) {
return false;
}
}
return true;
}
// Constructs a transpose of the given tensor and permutation.
static Value createTransposeInit(OpBuilder &builder, Value source,
ArrayRef<int64_t> perm) {
SmallVector<OpFoldResult> mixedSizes =
tensor::getMixedSizes(builder, source.getLoc(), source);
applyPermutationToVector(mixedSizes, perm);
Type elemType = cast<RankedTensorType>(source.getType()).getElementType();
Value empty =
builder.create<tensor::EmptyOp>(source.getLoc(), mixedSizes, elemType)
.getResult();
return empty;
}
// Constructs a transpose of the given tensor and permutation,
// or produces a transposed version of the producing tensor.empty op.
static Value createTranspose(OpBuilder &builder, Value source,
ArrayRef<int64_t> perm) {
if (auto empty = source.getDefiningOp<tensor::EmptyOp>()) {
Type elementType = empty.getType().getElementType();
SmallVector<OpFoldResult> mixedSizes = empty.getMixedSizes();
applyPermutationToVector(mixedSizes, perm);
return builder.create<tensor::EmptyOp>(empty.getLoc(), mixedSizes,
elementType);
}
Value empty = createTransposeInit(builder, source, perm);
return builder
.create<linalg::TransposeOp>(source.getLoc(), source, empty, perm)
->getResult(0);
}
static RankedTensorType getPermutedTensorType(RankedTensorType type,
SmallVector<int64_t> perm) {
SmallVector<int64_t> permutedShape = applyPermutation(type.getShape(), perm);
return RankedTensorType::get(permutedShape, type.getElementType());
}
//===----------------------------------------------------------------------===//
// Transpose specialization
//===----------------------------------------------------------------------===//
// Indicates whether the given linalg op represents a transpose. In particular,
// it requires a single input where the indexing maps are full permutations and
// non-equal.
static bool isaTransposeOpInterface(linalg::LinalgOp linalgOp) {
if (linalgOp.getNumParallelLoops() != linalgOp.getNumLoops())
return false;
if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1)
return false;
auto mapRange = linalgOp.getIndexingMapsArray();
if (mapRange.size() != 2 || !mapRange.front().isPermutation() ||
!mapRange.back().isPermutation() || mapRange.front() == mapRange.back()) {
return false;
}
return llvm::hasSingleElement(linalgOp.getBlock()->getOperations());
}
// Specializes linalg.generic op to linalg.transpose if it is transposing a
// single input.
static void specializeGenericTransposeOp(RewriterBase &rewriter,
linalg::GenericOp genericOp) {
if (!mlir::iree_compiler::GlobalOptimization::isaTransposeOpInterface(
genericOp)) {
return;
}
auto mapRange = genericOp.getIndexingMapsArray();
AffineMap outMap = mapRange.back();
AffineMap inMap = mapRange.front();
SmallVector<int64_t> perm;
// To get the permutation, look at each output index and find which
// dimension in the input we're reading from for that index.
for (AffineExpr expr : outMap.getResults()) {
perm.push_back(*inMap.getResultPosition(expr));
}
rewriter.replaceOpWithNewOp<linalg::TransposeOp>(
genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0], perm);
}
//===----------------------------------------------------------------------===//
// Other pattern helpers
//===----------------------------------------------------------------------===//
/// If the `op` is a ContractionOpInterface, return the generalized op if
/// generalizing is allowed. Otherwise if the `op` is a linalg::GenericOp,
/// then just return the generic op.
static FailureOr<linalg::GenericOp>
getGenericOpOrGeneralizeContraction(RewriterBase &rewriter, Operation *op,
bool allowGeneralizing) {
auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
if (!linalgOp) {
return failure();
}
// TODO: Right now this is restricted to contractions due to fragility around
// handling of convolutions.
if (!isa<linalg::GenericOp>(linalgOp) &&
!(allowGeneralizing && linalg::isaContractionOpInterface(linalgOp))) {
return failure();
}
auto genericOp = dyn_cast<linalg::GenericOp>(op);
if (genericOp) {
return genericOp;
}
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(linalgOp);
return linalg::generalizeNamedOp(rewriter, linalgOp);
}
//===----------------------------------------------------------------------===//
// Transpose Bubbling Patterns
//===----------------------------------------------------------------------===//
namespace {
// Fuses a transpose with the init of a linalg.generic op or contraction op.
// Contraction ops are generalized and then treated as a generic. For example,
//
// linalg.generic {
// indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0, d2)>,
// affine_map<(d0, d1, d2) -> (d0, d1, d2)>]
// ins(%0 : tensor<2x7x5>) outs(%1 : tensor<7x2x5>)
//
// %2 = linalg.transpose ... permutation = [0, 2, 1] :
// tensor<7x2x5> -> tensor<7x5x2>
// Becomes
//
// linalg.generic {
// indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d0, d1)>,
// affine_map<(d0, d1, d2) -> (d0, d1, d2)>]
// ins(%0 : tensor<2x7x5>) outs(%3 : tensor<7x5x2>)
class FuseTransposeWithProducerLinalgOp
: public OpRewritePattern<linalg::TransposeOp> {
public:
using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
FuseTransposeWithProducerLinalgOp(MLIRContext *ctx, bool aggressiveProp,
PatternBenefit b = 1)
: OpRewritePattern<linalg::TransposeOp>(ctx, b),
allowGeneralizing(aggressiveProp) {}
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
PatternRewriter &rewriter) const override {
if (!IREE::Flow::isNonNullAndOutsideDispatch(transposeOp)) {
return failure();
}
OpResult result = dyn_cast<OpResult>(transposeOp.getInput());
if (!result) {
return rewriter.notifyMatchFailure(
transposeOp, "transpose input defined by block argument");
}
if (!result.hasOneUse()) {
return rewriter.notifyMatchFailure(transposeOp,
"multi use transpose input");
}
auto linalgOp = dyn_cast<linalg::LinalgOp>(result.getOwner());
if (!linalgOp) {
return rewriter.notifyMatchFailure(
transposeOp, "non-linalg op producer for transpose input");
}
int64_t resultIndex = result.getResultNumber();
auto maybeGenericOp = getGenericOpOrGeneralizeContraction(
rewriter, result.getOwner(), allowGeneralizing);
if (failed(maybeGenericOp)) {
return rewriter.notifyMatchFailure(
transposeOp, "linalg op producer is not generic or contraction");
}
auto genericOp = maybeGenericOp.value();
result = genericOp->getOpResult(resultIndex);
ArrayRef<int64_t> perm = transposeOp.getPermutation();
auto invPerm = invertPermutationVector(perm);
// 1. Get the transposed init of the generic.
Value init = genericOp.getDpsInits()[resultIndex];
SmallVector<Value> inits = genericOp.getDpsInits();
Value newInit = createTranspose(rewriter, init, perm);
inits[resultIndex] = newInit;
SmallVector<Type> resultTypes(genericOp->getResultTypes());
resultTypes[resultIndex] = newInit.getType();
// 2. Update the indexing map of the transposed init operand by permuting
// the results of the map.
SmallVector<AffineMap> newIndexingMaps = genericOp.getIndexingMapsArray();
AffineMap resultMap =
newIndexingMaps[genericOp.getNumDpsInputs() + resultIndex];
SmallVector<AffineExpr> newExprs =
applyPermutation(resultMap.getResults(), perm);
AffineMap transposedMap =
AffineMap::get(resultMap.getNumDims(), resultMap.getNumSymbols(),
newExprs, rewriter.getContext());
newIndexingMaps[genericOp.getNumDpsInputs() + resultIndex] = transposedMap;
// 3. Create the new generic with the same iteration order.
auto newGenericOp = rewriter.create<linalg::GenericOp>(
genericOp.getLoc(), resultTypes, genericOp.getDpsInputs(), newInit,
newIndexingMaps, genericOp.getIteratorTypesArray(),
/*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(genericOp));
rewriter.cloneRegionBefore(genericOp.getRegion(), newGenericOp.getRegion(),
newGenericOp.getRegion().begin());
// 4. Remap iteration space of the generic to match the dimension order of
// the output.
if (newGenericOp.getNumResults() == 1) {
SmallVector<unsigned int> interchange;
int64_t permIdx = 0;
for (int i = 0, e = transposedMap.getNumDims(); i < e; ++i) {
if (transposedMap.isFunctionOfDim(i)) {
interchange.push_back(
llvm::cast<AffineDimExpr>(transposedMap.getResult(permIdx))
.getPosition());
permIdx++;
continue;
}
interchange.push_back(i);
}
auto interchangedGenericOp =
linalg::interchangeGenericOp(rewriter, newGenericOp, interchange);
// Interchange only fails if interchangeGenericOpPrecondition fails, which
// only fails if the interchange vector is not invertible or doesn't match
// the number of loops in the generic, both of which are guaranteed by
// the fact that the output map must be a projection in the above
// construction.
assert(succeeded(interchangedGenericOp) &&
"failed to interchange transposed generic");
newGenericOp = *interchangedGenericOp;
}
// 5. Replace the result of the transpose with the transposed init.
rewriter.replaceOp(transposeOp, newGenericOp->getResult(resultIndex));
for (auto [oldRes, newRes] :
llvm::zip_equal(genericOp.getResults(), newGenericOp->getResults())) {
if (oldRes.getResultNumber() == resultIndex)
continue;
rewriter.replaceAllUsesWith(oldRes, newRes);
}
return success();
}
private:
bool allowGeneralizing = false;
};
// Bubbles a transpose through a tensor.collapse_shape.
class BubbleTransposeThroughCollapseShape
: public OpRewritePattern<linalg::TransposeOp> {
public:
using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
PatternRewriter &rewriter) const override {
if (!IREE::Flow::isNonNullAndOutsideDispatch(transposeOp)) {
return failure();
}
Value source = transposeOp.getDpsInputOperand(0)->get();
auto collapseOp = source.getDefiningOp<tensor::CollapseShapeOp>();
// Do not propagate through reshapes if the transpose has multiple users, as
// this could end up duplicating the transposes. We should only propagate
// through reshape when it is free to do so.
if (!collapseOp || !collapseOp->hasOneUse()) {
return rewriter.notifyMatchFailure(
transposeOp, "transpose input is not a single-use collapse shape");
}
SmallVector<ReassociationIndices> reassociations =
collapseOp.getReassociationIndices();
// Because we are doing transpose(collapse_shape), all expanded groups are
// transposed together. As a result, to get the permutation of the new
// transpose, we can just flatten the transposed reassociation indices.
// For example,
//
// reassociation_map = [[0, 1, 2], [3], [4, 5]]
// permutation = [1, 2, 0]
//
// Becomes
//
// permutation = [3, 4, 5, 0, 1, 2]
// reassociation_map = [[0], [1, 2], [3, 4, 5]]
applyPermutationToVector(reassociations, transposeOp.getPermutation());
SmallVector<int64_t> newPerm;
SmallVector<ReassociationIndices> newReassociations;
int64_t expandedDim = 0;
for (auto reassoc : reassociations) {
ReassociationIndices newReassoc;
for (auto dim : reassoc) {
newPerm.push_back(dim);
newReassoc.push_back(expandedDim++);
}
newReassociations.push_back(newReassoc);
}
Value newTranspose =
createTranspose(rewriter, collapseOp.getSrc(), newPerm);
Value newReshape = rewriter.create<tensor::CollapseShapeOp>(
collapseOp.getLoc(), transposeOp.getResultTypes()[0], newTranspose,
newReassociations);
rewriter.replaceOp(transposeOp, newReshape);
return success();
}
};
} // namespace
//===----------------------------------------------------------------------===//
// Transpose Sinking Patterns
//===----------------------------------------------------------------------===//
namespace {
// Combines two transposes into one. This shouldn't be strictly necessary as
// fusion should cancel inverse transposes, but doing this here can open up
// new propagation opportunities and eases the analysis in fusion/later passes.
class ComposeTransposes : public OpRewritePattern<linalg::TransposeOp> {
public:
using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
LogicalResult matchAndRewrite(linalg::TransposeOp consumer,
PatternRewriter &rewriter) const override {
if (!IREE::Flow::isNonNullAndOutsideDispatch(consumer)) {
return failure();
}
Value input = consumer.getInput();
auto producer = input.getDefiningOp<linalg::TransposeOp>();
if (!producer) {
return failure();
}
ArrayRef<int64_t> producerPerm = producer.getPermutation();
ArrayRef<int64_t> consumerPerm = consumer.getPermutation();
SmallVector<int64_t> composedPerm =
applyPermutation(producerPerm, consumerPerm);
Value transposedSource = producer.getInput();
if (!isIdentityPermutation(composedPerm)) {
transposedSource =
createTranspose(rewriter, transposedSource, composedPerm);
}
rewriter.replaceOp(consumer, transposedSource);
return success();
}
};
// Sinks a transpose through a tensor.extract_slice iff the transpose turns
// the extracted slice into a contiguous slice.
class SinkTransposeThroughExtractSlice
: public OpRewritePattern<tensor::ExtractSliceOp> {
public:
using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
LogicalResult matchAndRewrite(tensor::ExtractSliceOp extractOp,
PatternRewriter &rewriter) const override {
if (!IREE::Flow::isNonNullAndOutsideDispatch(extractOp)) {
return failure();
}
Value source = extractOp.getSource();
auto transposeOp = source.getDefiningOp<linalg::TransposeOp>();
if (!transposeOp) {
return failure();
}
// Applying `perm` takes a list from the pre-transpose ordering to the
// post-transpose ordering.
ArrayRef<int64_t> perm = transposeOp.getPermutation();
// Applying `invPerm` takes a list from the post-transpose ordering to the
// pre-transpose ordering. Sinking a transpose through an op is largely a
// matter of rewriting it in pre-transpose space, and thus just applying
// the inverse permutation.
auto invPerm = invertPermutationVector(perm);
SmallVector<OpFoldResult> offsets = extractOp.getMixedOffsets();
SmallVector<OpFoldResult> sizes = extractOp.getMixedSizes();
SmallVector<OpFoldResult> strides = extractOp.getMixedStrides();
ArrayRef<int64_t> srcShape = extractOp.getSourceType().getShape();
// Permute the offsets, sizes, and strides to pre-transpose ordering.
applyPermutationToVector(offsets, invPerm);
applyPermutationToVector(sizes, invPerm);
applyPermutationToVector(strides, invPerm);
SmallVector<int64_t> baseShape = applyPermutation(srcShape, invPerm);
// Check if the resulting offsets, sizes, and strides correspond to a
// contiguous slice and can thus be mappable to a `flow.tensor.update` op.
// This should always be worth doing because this can remove a dispatch for
// the slice, and the transpose is on the slice rather than the full tensor.
if (!IREE::Flow::isOffsetSizeAndStrideMappableToFlow(offsets, sizes,
strides, baseShape)) {
return rewriter.notifyMatchFailure(
extractOp, "transposed slice not mappable to flow ops");
}
ArrayRef<int64_t> staticSizes = extractOp.getStaticSizes();
ArrayRef<int64_t> sliceShape = extractOp.getResultType().getShape();
std::optional<llvm::SmallDenseSet<unsigned>> maybeRankReducingMask =
mlir::computeRankReductionMask(staticSizes, sliceShape);
if (!maybeRankReducingMask) {
return rewriter.notifyMatchFailure(
extractOp, "failed to compute rank reducing mask");
}
llvm::SmallDenseSet<unsigned> rankReducingMask = *maybeRankReducingMask;
// Find rank reducing map in the pre-transposed domain.
int64_t dim = 0;
llvm::SmallDenseMap<int64_t, int64_t> rankReducedMap;
// Since `dim` is in the pre-transposed domain, and is incrementing each
// iteration, `idx` must also be in the pre-transposed domain.
for (int64_t idx = 0, e = perm.size(); idx < e; ++idx) {
// Get index in the transposed domain, since `rankReducingMask` is in
// the transposed domain.
if (!rankReducingMask.contains(perm[idx])) {
// Domain of `rankReducedMap` is in pre-transposed domain.
rankReducedMap[idx] = dim++;
}
}
// Compute the new permutation by dropping all rank-reduced dimensions.
SmallVector<int64_t> rankReducedPerm;
for (int64_t i : perm) {
if (!rankReducingMask.contains(i)) {
rankReducedPerm.push_back(rankReducedMap[i]);
}
}
auto rankReducedInvPerm = invertPermutationVector(rankReducedPerm);
RankedTensorType sliceType = getPermutedTensorType(
cast<RankedTensorType>(extractOp.getType()), rankReducedInvPerm);
Value slice = rewriter.create<tensor::ExtractSliceOp>(
extractOp.getLoc(), sliceType, transposeOp.getInput(), offsets, sizes,
strides);
// Transpose back to the original slice.
if (!isIdentityPermutation(rankReducedPerm)) {
slice = createTranspose(rewriter, slice, rankReducedPerm);
}
rewriter.replaceOp(extractOp, slice);
return success();
}
};
// Sinks a transpose through a tensor.expand_shape.
class SinkTransposeThroughExpandShape
: public OpRewritePattern<tensor::ExpandShapeOp> {
public:
using OpRewritePattern<tensor::ExpandShapeOp>::OpRewritePattern;
LogicalResult matchAndRewrite(tensor::ExpandShapeOp expandOp,
PatternRewriter &rewriter) const override {
if (!IREE::Flow::isNonNullAndOutsideDispatch(expandOp)) {
return failure();
}
Value source = expandOp.getSrc();
auto transposeOp = source.getDefiningOp<linalg::TransposeOp>();
// Do not propagate through reshapes if the transpose has multiple users, as
// this could end up duplicating the transposes. We should only propagate
// through reshape when it is free to do so.
if (!transposeOp || !transposeOp->hasOneUse()) {
return rewriter.notifyMatchFailure(
expandOp, "expand shape input is not a single-use transpose");
}
auto invPerm = invertPermutationVector(transposeOp.getPermutation());
SmallVector<ReassociationIndices> reassociations =
expandOp.getReassociationIndices();
// Because we are doing expand_shape(transpose), all expanded groups are
// transposed together. As a result, to get the permutation of the new
// transpose, we can just flatten the transposed reassociation indices.
// For example,
//
// permutation = [0, 2, 1]
// reassociation_map = [[0, 1, 2], [3], [4, 5]]
//
// Becomes
//
// reassociation_map = [[0, 1, 2], [3, 4], [5]]
// permutation = [0, 1, 2, 4, 5, 3]
applyPermutationToVector(reassociations, invPerm);
SmallVector<int64_t> newInvPerm;
SmallVector<ReassociationIndices> newReassociations;
int64_t expandedDim = 0;
for (auto reassoc : reassociations) {
ReassociationIndices newReassoc;
for (auto dim : reassoc) {
newInvPerm.push_back(dim);
newReassoc.push_back(expandedDim++);
}
newReassociations.push_back(newReassoc);
}
auto newPerm = invertPermutationVector(newInvPerm);
RankedTensorType expandedType = getPermutedTensorType(
cast<RankedTensorType>(expandOp.getType()), newInvPerm);
Value transposedReshape = rewriter.create<tensor::ExpandShapeOp>(
expandOp.getLoc(), expandedType, transposeOp.getInput(),
newReassociations);
Value originalReshape =
createTranspose(rewriter, transposedReshape, newPerm);
rewriter.replaceOp(expandOp, originalReshape);
return success();
}
};
// Fuses a transpose with the input of a linalg.generic op or contraction op.
// Contraction ops are generalized and then treated as a generic. For example,
//
// %0 = linalg.transpose ... permutation = [0, 2, 1] :
// tensor<2x5x7> -> tensor<2x7x5>
// linalg.generic {
// indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0, d2)>,
// affine_map<(d0, d1, d2) -> (d0, d1, d2)>]
// ins(%0 : tensor<2x7x5>)
//
// Becomes
//
// linalg.generic {
// indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2, d0)>,
// affine_map<(d0, d1, d2) -> (d0, d1, d2)>]
// ins(%0 : tensor<2x5x7>)
//
// This is considered just one way to model transpose propagation to generics,
// another option would be to interpret the transpose on the iterators of the
// generic, thus producing a transpose on the output and any other inputs to
// the generic. This has the potential introduce more transposes/data movement
// and isn't the way this pass is modeled. Global data layout transformations
// like that are better suited for pack/unpack propagation rooted on specific
// operations.
//
// TODO: Rewrite this to use elementwise op fusion patterns.
class FuseTransposeWithLinalgOpConsumer
: public OpInterfaceRewritePattern<linalg::LinalgOp> {
public:
using OpInterfaceRewritePattern<linalg::LinalgOp>::OpInterfaceRewritePattern;
FuseTransposeWithLinalgOpConsumer(MLIRContext *ctx, bool aggressiveProp,
PatternBenefit b = 1)
: OpInterfaceRewritePattern<linalg::LinalgOp>(ctx, b),
allowGeneralizing(aggressiveProp) {}
LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp,
PatternRewriter &rewriter) const override {
if (!IREE::Flow::isNonNullAndOutsideDispatch(linalgOp)) {
return failure();
}
OpOperand *transposeOperand = nullptr;
linalg::TransposeOp transposeOp;
for (OpOperand *input : linalgOp.getDpsInputOperands()) {
auto maybeTransposeOp = input->get().getDefiningOp<linalg::TransposeOp>();
if (maybeTransposeOp && maybeTransposeOp->hasOneUse()) {
transposeOp = maybeTransposeOp;
transposeOperand = input;
break;
}
}
if (!transposeOperand) {
return rewriter.notifyMatchFailure(linalgOp, "no transpose operand");
}
int64_t inputIndex = transposeOperand->getOperandNumber();
ArrayRef<int64_t> perm = transposeOp.getPermutation();
auto invPerm = invertPermutationVector(perm);
// To do the fusion, we can simply apply the permutation of the transpose
// to the results of the associated input's indexing map, and then forward
// the input to the transpose to the consumer generic.
auto maybeGenericOp = getGenericOpOrGeneralizeContraction(
rewriter, linalgOp, allowGeneralizing);
if (failed(maybeGenericOp)) {
return failure();
}
auto genericOp = maybeGenericOp.value();
transposeOperand = genericOp.getDpsInputOperand(inputIndex);
rewriter.startOpModification(genericOp);
SmallVector<AffineMap> newIndexingMaps = genericOp.getIndexingMapsArray();
AffineMap inputMap = genericOp.getMatchingIndexingMap(transposeOperand);
SmallVector<AffineExpr> newExprs =
applyPermutation(inputMap.getResults(), invPerm);
AffineMap transposedMap =
AffineMap::get(inputMap.getNumDims(), inputMap.getNumSymbols(),
newExprs, rewriter.getContext());
newIndexingMaps[inputIndex] = transposedMap;
genericOp.setIndexingMapsAttr(
rewriter.getAffineMapArrayAttr(newIndexingMaps));
genericOp.setOperand(inputIndex, transposeOp.getInput());
rewriter.finalizeOpModification(genericOp);
return success();
}
private:
bool allowGeneralizing = false;
};
static bool isIndexingMapAffectedByTransposeMap(
AffineMap indexingMap, ArrayRef<int64_t> iterationSpacePermutation) {
int64_t prevIdx = -1;
for (auto result : indexingMap.getResults()) {
int64_t idx =
iterationSpacePermutation[cast<AffineDimExpr>(result).getPosition()];
// Verify that the relative ordering of indices in the map remain the same.
// If not, then the transposition affects the access order for the given
// map (and associated operand).
if (idx <= prevIdx) {
return true;
}
prevIdx = idx;
}
return false;
}
// Finds a single DPS input operand of the given |genericOp| that is affected by
// the |iterationSpacePermutation|. In other words, the permutation changes the
// relative ordering of any of the dimensions of that input operand.
//
// For example, with permutation [1, 0, 2], affine map (d0, d1, d2) -> (d0, d1)
// is affected by the permutation because the first two dimensions are iterated
// in a different order while (d0, d1, d2) -> (d0, d2) is unaffected.
//
// If no such operand is found or there is more than one such operation, nullptr
// is returned.
static OpOperand *
getSingleTransposedInputOperand(linalg::GenericOp genericOp,
ArrayRef<int64_t> iterationSpacePermutation) {
OpOperand *operand = nullptr;
for (auto input : genericOp.getDpsInputOperands()) {
if (!isIndexingMapAffectedByTransposeMap(
genericOp.getMatchingIndexingMap(input),
iterationSpacePermutation)) {
continue;
}
if (operand) {
return nullptr;
}
operand = input;
}
return operand;
}
// Returns a new list of indexing maps that composes the iteration space
// permutation map |transposeMap| with all indexing maps of |genericOp| except
// for the |transposedInputIdx|'th operand. The unchanged operand is expected
// to have an explicit `linalg.transpose` op constructed for it so its map does
// not need to be updated.
static SmallVector<AffineMap>
getTransposedIndexingMaps(linalg::GenericOp genericOp,
int64_t transposedInputIdx, AffineMap transposeMap) {
SmallVector<AffineMap> indexingMaps = genericOp.getIndexingMapsArray();
for (unsigned i = 0, e = genericOp.getNumDpsInputs(); i < e; ++i) {
if (i == transposedInputIdx) {
continue;
}
indexingMaps[i] = indexingMaps[i].compose(transposeMap);
}
return indexingMaps;
}
// Sinks a transpose through the input of a elementwise operation where the
// transposition of the iteration space only affects a single input operand.
class SinkTransposeThroughUnaryElementwiseInput
: public OpRewritePattern<linalg::GenericOp> {
public:
using OpRewritePattern<linalg::GenericOp>::OpRewritePattern;
LogicalResult matchAndRewrite(linalg::GenericOp genericOp,
PatternRewriter &rewriter) const override {
if (!IREE::Flow::isNonNullAndOutsideDispatch(genericOp)) {
return rewriter.notifyMatchFailure(genericOp, "pre-formed dispatch");
}
if (!linalg::isElementwise(genericOp)) {
return rewriter.notifyMatchFailure(genericOp, "non-elementwise generic");
}
if (genericOp.getNumDpsInits() != 1) {
return rewriter.notifyMatchFailure(genericOp,
"unimplemented: multiple results");
}
AffineMap resultMap =
genericOp.getMatchingIndexingMap(genericOp.getDpsInitOperand(0));
if (!resultMap.isIdentity()) {
return rewriter.notifyMatchFailure(
genericOp, "unimplemented: non-identity result map");
}
linalg::TransposeOp transposeOp;
OpOperand *inputOperand;
for (auto input : genericOp.getDpsInputOperands()) {
// Skip broadcasted operands and transposed operands. If the input is
// broadcasted then we would not want to propagate because that would
// do the transpose on larger data, and if transposed we would rather
// simply compose the transposes (handled in a separate pattern).
if (genericOp.getMatchingIndexingMap(input) != resultMap) {
continue;
}
auto maybeTransposeOp = input->get().getDefiningOp<linalg::TransposeOp>();
// Skip multi-use transposes.
if (!maybeTransposeOp || !maybeTransposeOp->hasOneUse()) {
continue;
}
auto transposableInputOperand = getSingleTransposedInputOperand(
genericOp, maybeTransposeOp.getPermutation());
// Skip if more than one operand is affected by the transpose.
if (transposableInputOperand != input) {
continue;
}
transposeOp = maybeTransposeOp;
inputOperand = transposableInputOperand;
break;
}
if (!transposeOp) {
return rewriter.notifyMatchFailure(genericOp,
"no single use transpose operand");
}
ArrayRef<int64_t> perm = transposeOp.getPermutation();
auto invPerm = invertPermutationVector(perm);
// Create a new empty init for the transposed generic.
Value newInit =
createTransposeInit(rewriter, genericOp.getDpsInits()[0], invPerm);
// We do not need to update iterator types because this is an elementwise
// op. We just need to update the indexing maps of all other input operands
// by composing the transpose map.
AffineMap transposeMap =
AffineMap::getPermutationMap(perm, rewriter.getContext());
SmallVector<AffineMap> indexingMaps = getTransposedIndexingMaps(
genericOp, inputOperand->getOperandNumber(), transposeMap);
SmallVector<Value> newOperands = genericOp->getOperands();
newOperands[inputOperand->getOperandNumber()] = transposeOp.getInput();
newOperands[genericOp.getDpsInitOperand(0)->getOperandNumber()] = newInit;
auto newGenericOp =
mlir::clone(rewriter, genericOp, newInit.getType(), newOperands);
newGenericOp.setIndexingMapsAttr(
rewriter.getAffineMapArrayAttr(indexingMaps));
rewriter.replaceOp(
genericOp, createTranspose(rewriter, newGenericOp->getResult(0), perm));
return success();
}
};
// Bubbles a transpose through the init of a elementwise operation where the
// transposition of the iteration space only affects a single input operand.
class BubbleTransposeThroughUnaryElementwiseDpsInit
: public OpRewritePattern<linalg::TransposeOp> {
public:
using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
PatternRewriter &rewriter) const override {
auto genericOp = transposeOp.getInput().getDefiningOp<linalg::GenericOp>();
if (!genericOp) {
return rewriter.notifyMatchFailure(transposeOp, "non-generic producer");
}
if (genericOp.getNumDpsInits() != 1) {
return rewriter.notifyMatchFailure(transposeOp,
"unimplemented: multiple results");
}
if (!IREE::Flow::isNonNullAndOutsideDispatch({genericOp, transposeOp})) {
return failure();
}
if (!linalg::isElementwise(genericOp) ||
!genericOp.getMatchingIndexingMap(genericOp.getDpsInitOperand(0))
.isIdentity()) {
return rewriter.notifyMatchFailure(transposeOp, "not elementwise");
}
if (!genericOp->hasOneUse()) {
return rewriter.notifyMatchFailure(transposeOp, "not single user");
}
ArrayRef<int64_t> perm = transposeOp.getPermutation();
auto invPerm = invertPermutationVector(perm);
auto inputOperand = getSingleTransposedInputOperand(genericOp, invPerm);
if (!inputOperand ||
!genericOp.getMatchingIndexingMap(inputOperand).isIdentity()) {
return rewriter.notifyMatchFailure(
genericOp, "no single transposable input operand");
}
Value newTranspose = createTranspose(rewriter, inputOperand->get(), perm);
// Create a new empty init for the transposed generic.
Value newInit =
createTransposeInit(rewriter, genericOp.getDpsInits()[0], perm);
SmallVector<Value> newOperands = genericOp->getOperands();
newOperands[inputOperand->getOperandNumber()] = newTranspose;
newOperands[genericOp.getDpsInitOperand(0)->getOperandNumber()] = newInit;
AffineMap transposeMap =
AffineMap::getPermutationMap(invPerm, rewriter.getContext());
// We do not need to update iterator types because this is an elementwise
// op. We just need to update the indexing maps of all other input operands
// by composing the transpose map.
SmallVector<AffineMap> indexingMaps = getTransposedIndexingMaps(
genericOp, inputOperand->getOperandNumber(), transposeMap);
// We do not need to update indexing maps because this is a unary
// elementwise op where the input and output maps are the same. Just
// replace the operands with transposed variants.
auto newGenericOp =
mlir::clone(rewriter, genericOp, newInit.getType(), newOperands);
newGenericOp.setIndexingMapsAttr(
rewriter.getAffineMapArrayAttr(indexingMaps));
rewriter.replaceOp(transposeOp, newGenericOp);
return success();
}
};
} // namespace
//===----------------------------------------------------------------------===//
// Linalg Named Op -> Named Op Conversions
//===----------------------------------------------------------------------===//
namespace {
template <typename OpTy, typename ReplTy, int64_t inputIdx>
class NamedOpConversion : public OpRewritePattern<OpTy> {
public:
using OpRewritePattern<OpTy>::OpRewritePattern;
NamedOpConversion(MLIRContext *ctx, SmallVector<int64_t> perm,
PatternBenefit b = 1)
: OpRewritePattern<OpTy>(ctx, b), permutation(perm) {}
LogicalResult matchAndRewrite(OpTy namedOp,
PatternRewriter &rewriter) const override {
if (!IREE::Flow::isNonNullAndOutsideDispatch(namedOp)) {
return failure();
}
Value input = namedOp.getInputs()[inputIdx];
auto transpose = input.getDefiningOp<linalg::TransposeOp>();
if (!transpose) {
return failure();
}
SmallVector<int64_t> transPerm(transpose.getPermutation());
if (transPerm != permutation) {
return rewriter.notifyMatchFailure(
namedOp, "transpose permutation does not match target permutation");
}
SmallVector<NamedAttribute> attrs = getPrunedAttributeList(namedOp);
SmallVector<Value> newInputs = namedOp.getInputs();
newInputs[inputIdx] = transpose.getInput();
rewriter.replaceOpWithNewOp<ReplTy>(namedOp, newInputs,
namedOp.getDpsInits(), attrs);
return success();
}
private:
// Non-type literal array template parameters are a C++20 feature, so instead
// all the named op patterns pass their permutation explicitly as a
// SmallVector.
SmallVector<int64_t> permutation;
};
} // namespace
//===----------------------------------------------------------------------===//
// Pass definition
//===----------------------------------------------------------------------===//
namespace {
struct PropagateLinalgTransposePass
: public impl::PropagateLinalgTransposePassBase<
PropagateLinalgTransposePass> {
using impl::PropagateLinalgTransposePassBase<
PropagateLinalgTransposePass>::PropagateLinalgTransposePassBase;
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<linalg::LinalgDialect, tensor::TensorDialect>();
}
explicit PropagateLinalgTransposePass(bool enableAggressivePropagation) {
this->enableAggressivePropagation = enableAggressivePropagation;
}
void runOnOperation() override;
};
} // namespace
static void populateNamedOpSinkingPatterns(MLIRContext *context,
RewritePatternSet &sinkingPatterns) {
sinkingPatterns
.insert<NamedOpConversion</*OpType=*/linalg::MatmulOp,
/*ReplacementType=*/linalg::MatmulTransposeBOp,
/*inputIdx=*/1>>(context,
SmallVector<int64_t>{1, 0});
sinkingPatterns
.insert<NamedOpConversion</*OpType=*/linalg::MatmulOp,
/*ReplacementType=*/linalg::MatmulTransposeAOp,
/*inputIdx=*/0>>(context,
SmallVector<int64_t>{1, 0});
sinkingPatterns
.insert<NamedOpConversion</*OpType=*/linalg::MatmulTransposeBOp,
/*ReplacementType=*/linalg::MatmulOp,
/*inputIdx=*/1>>(context,
SmallVector<int64_t>{1, 0});
sinkingPatterns
.insert<NamedOpConversion</*OpType=*/linalg::MatmulTransposeAOp,
/*ReplacementType=*/linalg::MatmulOp,
/*inputIdx=*/0>>(context,
SmallVector<int64_t>{1, 0});
sinkingPatterns.insert<
NamedOpConversion</*OpType=*/linalg::BatchMatmulOp,
/*ReplacementType=*/linalg::BatchMatmulTransposeBOp,
/*inputIdx=*/1>>(context,
SmallVector<int64_t>{0, 2, 1});
sinkingPatterns.insert<
NamedOpConversion</*OpType=*/linalg::BatchMatmulOp,
/*ReplacementType=*/linalg::BatchMatmulTransposeAOp,
/*inputIdx=*/0>>(context,
SmallVector<int64_t>{0, 2, 1});
sinkingPatterns
.insert<NamedOpConversion</*OpType=*/linalg::BatchMatmulTransposeBOp,
/*ReplacementType=*/linalg::BatchMatmulOp,
/*inputIdx=*/1>>(context,
SmallVector<int64_t>{0, 2, 1});
sinkingPatterns
.insert<NamedOpConversion</*OpType=*/linalg::BatchMatmulTransposeAOp,
/*ReplacementType=*/linalg::BatchMatmulOp,
/*inputIdx=*/0>>(context,
SmallVector<int64_t>{0, 2, 1});
}
static void
populateCommonCanonicalizationPatterns(MLIRContext *context,
RewritePatternSet &patterns) {
linalg::FillOp::getCanonicalizationPatterns(patterns, context);
tensor::EmptyOp::getCanonicalizationPatterns(patterns, context);
tensor::ExpandShapeOp::getCanonicalizationPatterns(patterns, context);
tensor::CollapseShapeOp::getCanonicalizationPatterns(patterns, context);
tensor::populateFoldTensorEmptyPatterns(patterns,
/*foldSingleUseOnly=*/false);
}
void PropagateLinalgTransposePass::runOnOperation() {
MLIRContext *context = &getContext();
auto funcOp = getOperation();
// First, specialize all transposes to `linalg.transpose`. This dramatically
// simplifies all subsequent propagation patterns, both in matching and
// rewriting.
{
SmallVector<linalg::GenericOp> genericCandidates;
funcOp.walk([&](linalg::GenericOp genericOp) {
if (IREE::Flow::isNonNullAndOutsideDispatch(genericOp)) {
genericCandidates.push_back(genericOp);
}
});
IRRewriter rewriter(&getContext());
for (auto genericOp : genericCandidates) {
rewriter.setInsertionPoint(genericOp);
specializeGenericTransposeOp(rewriter, genericOp);
}
}
LLVM_DEBUG({
llvm::dbgs() << "\n--- After specializing transpose ops ---\n";
funcOp->print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
llvm::dbgs() << "\n\n";
});
// First try to fuse transposes with some consumer linalg named ops before
// any reshape propagation. Some transposes may be adjacent to named ops,
// and it is more canonical if we can fuse the ops into a new named op.
if (!testBubblingOnly) {
RewritePatternSet sinkingPatterns(context);
sinkingPatterns.insert<SinkTransposeThroughExtractSlice>(context);
sinkingPatterns.insert<SinkTransposeThroughExpandShape>(context);
populateNamedOpSinkingPatterns(context, sinkingPatterns);
populateCommonCanonicalizationPatterns(context, sinkingPatterns);
sinkingPatterns.add<SinkTransposeThroughUnaryElementwiseInput>(
context, /*benefit=*/2);
if (failed(applyPatternsGreedily(funcOp, std::move(sinkingPatterns)))) {
funcOp.emitError("Transpose initial sinking patterns failed");
return signalPassFailure();
}
}
LLVM_DEBUG({
llvm::dbgs() << "\n--- After canonicalizing transpose in place ---\n";
funcOp->print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
llvm::dbgs() << "\n\n";
});
// Propagate transposes upwards, and fuse with any producer generic ops. Also
// propagate reshapes upwards to open up more transpose fusion opportunities.
if (!testSinkingOnly) {
linalg::ControlFusionFn reshapePropagationFn =
[&](OpOperand *fusedOperand) {
Operation *producer = fusedOperand->get().getDefiningOp();
Operation *consumer = fusedOperand->getOwner();
if (!IREE::Flow::isNonNullAndOutsideDispatch({producer, consumer})) {
return false;
}
// Do not reshape producer linalg op if it has more than one user.
auto producerLinalgOp = dyn_cast<linalg::LinalgOp>(producer);
if (!producerLinalgOp || !producerLinalgOp->hasOneUse()) {
return false;
}
// Only reshape generic ops, or any op if aggressive propagation is
// enabled.
if (!enableAggressivePropagation &&
!isa<linalg::GenericOp>(producerLinalgOp)) {
return false;
}
// Only propagate expand_shape ops up through producers because it
// is always possible to bubble a transpose through an collapse_shape
// and thus is handled separately.
if (!isa<tensor::ExpandShapeOp>(consumer)) {
return false;
}
// Only propagate if the immediate consumer of the reshape is a
// transpose.
return consumer->hasOneUse() &&
llvm::isa<linalg::TransposeOp>(*(consumer->user_begin()));
};
RewritePatternSet bubblingPatterns(context);
linalg::populateFoldReshapeOpsByExpansionPatterns(bubblingPatterns,
reshapePropagationFn);
linalg::FillOp::getCanonicalizationPatterns(bubblingPatterns, context);
linalg::ControlFusionFn bubbleTransposeControlFn =
[](OpOperand *fusedOperand) {
Operation *producer = fusedOperand->get().getDefiningOp();
Operation *consumer = fusedOperand->getOwner();
return IREE::Flow::isNonNullAndOutsideDispatch({producer, consumer});
};
IREE::LinalgExt::populateBubbleTransposeFromLinalgExtOps(
bubblingPatterns, bubbleTransposeControlFn);
bubblingPatterns.insert<FuseTransposeWithProducerLinalgOp>(
context, enableAggressivePropagation);
bubblingPatterns.insert<BubbleTransposeThroughCollapseShape>(context);
bubblingPatterns.add<BubbleTransposeThroughUnaryElementwiseDpsInit>(
context, /*benefit=*/2);
bubblingPatterns.insert<ComposeTransposes>(context);
populateCommonCanonicalizationPatterns(context, bubblingPatterns);
GreedyRewriteConfig config;
config.maxIterations = GreedyRewriteConfig::kNoLimit;
if (failed(applyPatternsGreedily(funcOp, std::move(bubblingPatterns),
config))) {
funcOp.emitError("Transpose bubbling patterns failed");
return signalPassFailure();
}
}
LLVM_DEBUG({
llvm::dbgs() << "\n--- After bubbling transpose ops up ---\n";
funcOp->print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
llvm::dbgs() << "\n\n";
});
// Propagate transposes downwards, and fuse with any non-unary generic ops
// or linalg named ops. Also propagate reshapes downwards to open up more
// transpose fusion opportunities.
if (!testBubblingOnly) {
RewritePatternSet sinkingPatterns(context);
linalg::ControlFusionFn reshapePropagationFn =
[&](OpOperand *fusedOperand) {
Operation *producer = fusedOperand->get().getDefiningOp();
Operation *consumer = fusedOperand->getOwner();
if (!IREE::Flow::isNonNullAndOutsideDispatch({producer, consumer})) {
return false;
}
auto consumerLinalgOp = dyn_cast<linalg::LinalgOp>(consumer);
if (!consumerLinalgOp) {
return false;
}
// Only reshape generic ops.
if (!enableAggressivePropagation &&
!isa<linalg::GenericOp>(consumerLinalgOp)) {
return false;
}
// Only propagate collapse_shape ops down through consumers because it
// is always possible to sink a transpose through an expand_shape and
// thus is handled separately.
if (!isa<tensor::CollapseShapeOp>(producer)) {
return false;
}
// Require that the immediate producer of the reshape is a transpose.
return isa_and_nonnull<linalg::TransposeOp>(
producer->getOperand(0).getDefiningOp());
};
linalg::populateFoldReshapeOpsByExpansionPatterns(sinkingPatterns,
reshapePropagationFn);
sinkingPatterns.insert<SinkTransposeThroughExtractSlice>(context);
sinkingPatterns.insert<SinkTransposeThroughExpandShape>(context);
sinkingPatterns.insert<FuseTransposeWithLinalgOpConsumer>(
context, enableAggressivePropagation);
sinkingPatterns.insert<ComposeTransposes>(context);
populateNamedOpSinkingPatterns(context, sinkingPatterns);
populateCommonCanonicalizationPatterns(context, sinkingPatterns);
sinkingPatterns.add<SinkTransposeThroughUnaryElementwiseInput>(
context, /*benefit=*/2);
GreedyRewriteConfig config;
// TODO: This is inefficient. Consider rewriting this pass to use a
// worklist of just the transpose operations.
config.maxIterations = GreedyRewriteConfig::kNoLimit;
if (failed(applyPatternsGreedily(funcOp, std::move(sinkingPatterns),
config))) {
funcOp.emitError("Transpose sinking patterns failed");
return signalPassFailure();
}
}
LLVM_DEBUG({
llvm::dbgs() << "\n--- After sinking transpose ops down ---\n";
funcOp->print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
llvm::dbgs() << "\n\n";
});
}
std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
createPropagateLinalgTransposePass(bool enableAggressivePropagation) {
return std::make_unique<PropagateLinalgTransposePass>(
enableAggressivePropagation);
}
} // namespace mlir::iree_compiler::GlobalOptimization