blob: 48ee4d4d80c138096f8488d9a795e582fb9f4634 [file]
// Copyright 2024 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//===--- BubbleExpandShapes.cpp --- Pass to propagate expand shapes op up -===//
//
// This pass propagates expand_shape operations up the program (and conversely)
// sinks the collapse_shape operations down the program to get the elementwise
// operations into higher dimensionality to get better fusion.
//
//===----------------------------------------------------------------------===//
#include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h"
#include "iree/compiler/Dialect/LinalgExt/Utils/IndexingUtils.h"
#include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h"
#include "iree/compiler/DispatchCreation/FusionUtils.h"
#include "iree/compiler/DispatchCreation/Passes.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#define DEBUG_TYPE "iree-dispatch-creation-bubble-up-expand-shapes"
namespace mlir::iree_compiler::DispatchCreation {
#define GEN_PASS_DEF_BUBBLEUPEXPANDSHAPESPASS
#include "iree/compiler/DispatchCreation/Passes.h.inc"
namespace {
struct BubbleUpExpandShapesPass final
: public impl::BubbleUpExpandShapesPassBase<BubbleUpExpandShapesPass> {
void runOnOperation() override;
};
/// Bubbles a `tensor.expand_shape` op through a `tensor.extract_slice` op. This
/// pattern only gets applied when the `extract_slice` doesn't modify dimensions
/// that are expanded by the `expand_shape` and when the `extract_slice` is
/// completely static.
/// TODO: move this upstream with other tensor bubbling patterns.
struct BubbleExpandThroughExtract final
: public OpRewritePattern<tensor::ExpandShapeOp> {
using OpRewritePattern<tensor::ExpandShapeOp>::OpRewritePattern;
LogicalResult matchAndRewrite(tensor::ExpandShapeOp expandOp,
PatternRewriter &rewriter) const override {
auto extractOp = expandOp.getSrc().getDefiningOp<tensor::ExtractSliceOp>();
if (!extractOp) {
return failure();
}
auto srcType = extractOp.getSourceType();
auto extractedType = extractOp.getType();
auto expandedType = expandOp.getType();
if (srcType.getRank() != extractedType.getRank()) {
return rewriter.notifyMatchFailure(
extractOp, "Rank reducing extract_slice not supported");
}
if (!srcType.hasStaticShape() || !extractedType.hasStaticShape() ||
!expandedType.hasStaticShape()) {
return failure();
}
auto reassoc = expandOp.getReassociationIndices();
for (auto i : llvm::seq<uint64_t>(0, extractedType.getRank())) {
if (reassoc[i].size() == 1) {
continue;
}
if (srcType.getShape()[i] != extractedType.getShape()[i]) {
return rewriter.notifyMatchFailure(
extractOp, "Extract modifies the expanded dimension");
}
}
SmallVector<int64_t> newExpandShape;
SmallVector<int64_t> offsets;
SmallVector<int64_t> sizes;
SmallVector<int64_t> strides;
for (auto [inDim, outDims] : llvm::enumerate(reassoc)) {
if (outDims.size() == 1) {
newExpandShape.push_back(srcType.getShape()[inDim]);
offsets.push_back(extractOp.getStaticOffsets()[inDim]);
sizes.push_back(extractOp.getStaticSizes()[inDim]);
strides.push_back(extractOp.getStaticStrides()[inDim]);
} else {
for (auto outDim : outDims) {
newExpandShape.push_back(expandedType.getShape()[outDim]);
offsets.push_back(0);
sizes.push_back(expandedType.getShape()[outDim]);
strides.push_back(1);
}
}
}
Type newExpandType =
RankedTensorType::get(newExpandShape, expandedType.getElementType());
auto newExpand = rewriter.create<tensor::ExpandShapeOp>(
expandOp.getLoc(), newExpandType, extractOp.getSource(), reassoc);
rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(
expandOp, expandedType, newExpand, ValueRange{}, ValueRange{},
ValueRange{}, offsets, sizes, strides);
return success();
}
};
} // namespace
void BubbleUpExpandShapesPass::runOnOperation() {
MLIRContext *context = &getContext();
RewritePatternSet bubbleExpandShapePatterns(context);
linalg::ControlFusionFn bubbleUpExpansionControlFn =
[](OpOperand *fusedOperand) {
Operation *producer = fusedOperand->get().getDefiningOp();
Operation *consumer = fusedOperand->getOwner();
if (!IREE::Flow::isNonNullAndOutsideDispatch({producer, consumer})) {
return false;
}
// Do not fuse by expand if consumer is dequant.
if (IREE::LinalgExt::isBitExtendOp(consumer)) {
return false;
}
// Do not fuse producer generic op if it has more than one user
// or any reduction iterators.
if (auto producerGenericOp = dyn_cast<linalg::GenericOp>(producer)) {
return producerGenericOp->hasOneUse() &&
llvm::all_of(producerGenericOp.getIteratorTypesArray(),
linalg::isParallelIterator);
}
// Do not fuse with any producer linalg named ops for now.
if (isa<linalg::LinalgOp>(producer)) {
return false;
}
// Do not fuse with consumer linalg named ops or reductions.
if (auto consumerLinalgOp = dyn_cast<linalg::LinalgOp>(consumer)) {
return isa<linalg::GenericOp>(consumerLinalgOp) &&
llvm::all_of(consumerLinalgOp.getIteratorTypesArray(),
linalg::isParallelIterator);
}
// Fuse in all other cases.
return true;
};
linalg::populateFoldReshapeOpsByExpansionPatterns(bubbleExpandShapePatterns,
bubbleUpExpansionControlFn);
// TODO(#19263): Temporary fix to prevent compilation failures when the k2
// dims get expanded to unit dimensions. This adds the constraint to
// `bubbleUpExpansionControlFn` that the k2 dimensions cannot be expanded by
// the reshape fusion.
linalg::ControlFusionFn linalgExtExpansionFn = [&](OpOperand *fusedOperand) {
if (!bubbleUpExpansionControlFn(fusedOperand)) {
return false;
}
// There is no need to handle `expand_shape` ops because they would be the
// producer and therefore are unable to expand the k2 dims.
auto collapseOp =
dyn_cast<tensor::CollapseShapeOp>(fusedOperand->get().getDefiningOp());
auto attentionOp =
dyn_cast<IREE::LinalgExt::AttentionOp>(fusedOperand->getOwner());
if (!collapseOp || !attentionOp) {
return true;
}
SmallVector<ReassociationIndices> reassoc =
collapseOp.getReassociationIndices();
auto opDetail = IREE::LinalgExt::AttentionOpDetail::get(
attentionOp.getQueryMap(), attentionOp.getKeyMap(),
attentionOp.getValueMap(), attentionOp.getOutputMap());
// Don't sink the `collapse_shape` op if it is collapsing into any of the k2
// dimensions.
AffineMap operandMap = attentionOp.getMatchingIndexingMap(fusedOperand);
for (auto dim : opDetail->getK2Dims()) {
auto dimExpr = getAffineDimExpr(dim, operandMap.getContext());
if (std::optional<int64_t> maybeDim =
operandMap.getResultPosition(dimExpr);
maybeDim && reassoc[maybeDim.value()].size() > 1) {
return false;
}
}
return true;
};
IREE::LinalgExt::populateFoldReshapeOpsByExpansionPatterns(
bubbleExpandShapePatterns, linalgExtExpansionFn);
// Add patterns to do some additional cleanup (on top of canonicalizations
// that can be done later) of reshape ops.
tensor::populateFoldTensorEmptyPatterns(bubbleExpandShapePatterns);
bubbleExpandShapePatterns.insert<BubbleExpandThroughExtract>(context);
tensor::ExpandShapeOp::getCanonicalizationPatterns(bubbleExpandShapePatterns,
context);
GreedyRewriteConfig rewriteConfig;
rewriteConfig.maxIterations = GreedyRewriteConfig::kNoLimit;
if (failed(applyPatternsGreedily(getOperation(),
std::move(bubbleExpandShapePatterns),
rewriteConfig))) {
getOperation()->emitOpError("Failed to perform elementwise operations");
return signalPassFailure();
}
}
} // namespace mlir::iree_compiler::DispatchCreation