| // Copyright 2024 The IREE Authors |
| // |
| // Licensed under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| |
| #include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h" |
| #include "llvm/ADT/STLExtras.h" |
| #include "mlir/Dialect/Affine/IR/AffineOps.h" |
| #include "mlir/Dialect/Linalg/IR/Linalg.h" |
| #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" |
| #include "mlir/Dialect/Linalg/Transforms/Transforms.h" |
| #include "mlir/Dialect/Linalg/Utils/Utils.h" |
| #include "mlir/Dialect/Tensor/IR/Tensor.h" |
| #include "mlir/Dialect/Tensor/Transforms/Transforms.h" |
| #include "mlir/Dialect/Utils/ReshapeOpsUtils.h" |
| #include "mlir/Pass/Pass.h" |
| #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| |
| namespace mlir::iree_compiler::DispatchCreation { |
| |
| #define GEN_PASS_DEF_BUBBLEUPEXTRACTSLICESPASS |
| #include "iree/compiler/DispatchCreation/Passes.h.inc" |
| |
| namespace { |
| |
| // Convert extract_slice(dequant) to dequant(extract_slice) |
| // |
| // Because `extract_slice` ops and dequantize-like ops get cloned into regions |
| // later, it's okay to bubble up through multi-use dequant ops. |
| struct BubbleUpExtract : OpRewritePattern<tensor::ExtractSliceOp> { |
| using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp, |
| PatternRewriter &rewriter) const final { |
| Value source = sliceOp.getSource(); |
| auto genericOp = source.getDefiningOp<linalg::GenericOp>(); |
| if (!genericOp || genericOp->getNumResults() != 1) { |
| return rewriter.notifyMatchFailure( |
| sliceOp, "expected source to implement `linalg::LinalgOp` and have a " |
| "single result"); |
| } |
| |
| if (!IREE::LinalgExt::isBitExtendOp(genericOp) && !genericOp->hasOneUse()) { |
| return rewriter.notifyMatchFailure( |
| sliceOp, |
| "expected source to be dequantize-like op or have a single use"); |
| } |
| |
| if (!sliceOp.hasUnitStride()) { |
| return rewriter.notifyMatchFailure(sliceOp, "expected unit stride"); |
| } |
| |
| if (!llvm::all_of(genericOp.getIndexingMapsArray(), [](AffineMap map) { |
| return map.isProjectedPermutation(); |
| })) { |
| return rewriter.notifyMatchFailure( |
| genericOp, |
| "expected generic op to have all projected permutation maps"); |
| } |
| |
| if (genericOp.hasIndexSemantics()) { |
| return rewriter.notifyMatchFailure( |
| genericOp, "pattern doesn't support index semantics"); |
| } |
| |
| Value replacement; |
| linalg::GenericOp swappedOp; |
| { |
| FailureOr<TilingResult> tilingResult = |
| tensor::replaceExtractSliceWithTiledProducer(rewriter, sliceOp, |
| genericOp->getResult(0)); |
| assert(succeeded(tilingResult) && "failed to swap extract_slice with op"); |
| assert(tilingResult->tiledOps.size() == 1); |
| replacement = tilingResult->tiledValues[0]; |
| swappedOp = cast<linalg::GenericOp>(tilingResult->tiledOps[0]); |
| } |
| |
| // Check if this is a rank-reducing slice, if so we need to fold the unit |
| // dimensions of the op. |
| // This is necessary because `replaceExtractSliceWithTiledProducer` does not |
| // take into account the `extract_slice`'s implicit rank reduction. The |
| // operations generated by that function will have any unit dims that were |
| // removed by the original `extract_slice`. Folding them away ensures that |
| // the types match. |
| if (sliceOp.getSourceType().getRank() != |
| sliceOp.getResultType().getRank()) { |
| |
| llvm::SmallBitVector droppedDims = sliceOp.getDroppedDims(); |
| // Get the indexing map for the result. |
| AffineMap resultMap = |
| swappedOp.getIndexingMapMatchingResult(swappedOp->getResult(0)); |
| linalg::ControlDropUnitDims options; |
| options.rankReductionStrategy = linalg::ControlDropUnitDims:: |
| RankReductionStrategy::ExtractInsertSlice; |
| options.controlFn = [&](Operation *op) -> SmallVector<unsigned> { |
| SmallVector<unsigned> droppedDimsVec; |
| for (auto [index, expr] : llvm::enumerate(resultMap.getResults())) { |
| if (!droppedDims.test(index)) { |
| continue; |
| } |
| auto dimExpr = cast<AffineDimExpr>(expr); |
| droppedDimsVec.push_back(dimExpr.getPosition()); |
| } |
| return droppedDimsVec; |
| }; |
| FailureOr<linalg::DropUnitDimsResult> dropUnitDims = |
| linalg::dropUnitDims(rewriter, swappedOp, options); |
| assert(succeeded(dropUnitDims) && |
| "failed to drop unit dims of produced operation"); |
| swappedOp = dropUnitDims->resultOp; |
| replacement = swappedOp->getResult(0); |
| } |
| rewriter.replaceOp(sliceOp, replacement); |
| return success(); |
| } |
| }; |
| |
| /// Swaps tensor.extract_slice(linalg.fill(%cst, %init)) into linalg.fill(%cst, |
| /// tensor.extract_slice(%init)) even when the linalg.fill has multiple users. |
| /// Bubbles up tensor.extract_slice when encountered with linalg.fill and the |
| /// former can be folded away. |
| struct SwapExtractSliceOfFill final |
| : public OpRewritePattern<tensor::ExtractSliceOp> { |
| using OpRewritePattern::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(tensor::ExtractSliceOp extractOp, |
| PatternRewriter &rewriter) const override { |
| auto fillOp = extractOp.getSource().getDefiningOp<linalg::FillOp>(); |
| if (!fillOp) |
| return failure(); |
| |
| auto newExtractOp = rewriter.create<tensor::ExtractSliceOp>( |
| extractOp.getLoc(), extractOp.getType(), fillOp.getOutputs()[0], |
| extractOp.getMixedOffsets(), extractOp.getMixedSizes(), |
| extractOp.getMixedStrides()); |
| rewriter.replaceOpWithNewOp<linalg::FillOp>( |
| extractOp, fillOp.getInputs(), ValueRange{newExtractOp.getResult()}); |
| return success(); |
| } |
| }; |
| |
| struct BubbleUpExtractSlicesPass |
| : impl::BubbleUpExtractSlicesPassBase<BubbleUpExtractSlicesPass> { |
| void runOnOperation() override { |
| MLIRContext *context = &getContext(); |
| { |
| RewritePatternSet patterns(context); |
| patterns.insert<BubbleUpExtract>(context); |
| patterns.insert<SwapExtractSliceOfFill>(context); |
| tensor::populateFoldTensorEmptyPatterns(patterns, false); |
| if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) { |
| return signalPassFailure(); |
| } |
| } |
| } |
| }; |
| } // namespace |
| |
| } // namespace mlir::iree_compiler::DispatchCreation |