| // Copyright 2021 The IREE Authors |
| // |
| // Licensed under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| |
| //===- BufferAllocViewCleanUpPass.cpp -------------------------------------===// |
| // |
| // This pass performs canonicalizations/cleanups related to HAL interface/buffer |
| // allocations and views. We need a dedicated pass because patterns here involve |
| // multiple dialects. |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "iree/compiler/Codegen/PassDetail.h" |
| #include "iree/compiler/Codegen/Passes.h" |
| #include "iree/compiler/Dialect/Flow/IR/FlowOps.h" |
| #include "iree/compiler/Dialect/HAL/IR/HALOps.h" |
| #include "mlir/Dialect/Linalg/IR/Linalg.h" |
| #include "mlir/IR/BuiltinAttributes.h" |
| #include "mlir/IR/BuiltinTypes.h" |
| #include "mlir/IR/Matchers.h" |
| #include "mlir/IR/PatternMatch.h" |
| #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| |
| namespace mlir { |
| namespace iree_compiler { |
| |
| void populateReshapeToInterfaceTensorPatterns(RewritePatternSet &patterns); |
| |
| namespace { |
| |
| /// Folds tensor.expand/collapse_shape into the source |
| /// hal.interface.binding.subspan. |
| /// |
| /// For example, this matches the following pattern: |
| /// |
| /// %subspan = hal.interface.binding.subspan ... : |
| /// !flow.dispatch.tensor<readonly:tensor<3x3x1x96xf32>> |
| /// %tensor = flow.dispatch.tensor.load %subspan : |
| /// !flow.dispatch.tensor<readonly:tensor<3x3x1x96xf32>> -> |
| /// tensor<3x3x1x96xf32> |
| /// %0 = linalg.tensor_reshape %tensor [ |
| /// affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> |
| /// ] : tensor<3x3x1x96xf32> into tensor<864xf32> |
| /// |
| /// And turns it into: |
| /// |
| /// %subspan = hal.interface.binding.subspan ... : |
| /// !flow.dispatch.tensor<readonly:tensor<864xf32>> |
| /// %0 = flow.dispatch.tensor.load %subspan : |
| /// !flow.dispatch.tensor<readonly:tensor<864xf32>> -> tensor<864xf32> |
| template <typename TensorReshapeOp> |
| struct FoldReshapeIntoInterfaceTensorLoad : OpRewritePattern<TensorReshapeOp> { |
| using OpRewritePattern<TensorReshapeOp>::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp, |
| PatternRewriter &rewriter) const override { |
| // TODO(antigainst): enable dynamic shape support once they are needed. |
| auto reshapeSrcType = |
| reshapeOp.getSrc().getType().template cast<ShapedType>(); |
| auto reshapeDstType = reshapeOp.getType().template cast<ShapedType>(); |
| if (!reshapeSrcType.hasStaticShape() || !reshapeDstType.hasStaticShape()) { |
| return failure(); |
| } |
| |
| auto loadOp = |
| reshapeOp.getSrc() |
| .template getDefiningOp<IREE::Flow::DispatchTensorLoadOp>(); |
| if (!loadOp) return failure(); |
| |
| // Make sure we are loading the full incoming subspan. Otherwise we cannot |
| // simply adjust the subspan's resultant type later. |
| if (!loadOp.offsets().empty() || !loadOp.sizes().empty() || |
| !loadOp.strides().empty()) |
| return failure(); |
| |
| auto subspanOp = |
| loadOp.getSource() |
| .template getDefiningOp<IREE::HAL::InterfaceBindingSubspanOp>(); |
| if (!subspanOp) return failure(); |
| assert(subspanOp.getDynamicDims().empty()); |
| |
| auto tensorAccess = subspanOp.getType() |
| .template cast<IREE::Flow::DispatchTensorType>() |
| .getAccess(); |
| auto newSubspanType = IREE::Flow::DispatchTensorType::get( |
| tensorAccess, reshapeOp.getResultType()); |
| |
| Value newSubspanOp = rewriter.create<IREE::HAL::InterfaceBindingSubspanOp>( |
| subspanOp.getLoc(), newSubspanType, subspanOp.getSet(), |
| subspanOp.getBinding(), subspanOp.getDescriptorType(), |
| subspanOp.getByteOffset(), subspanOp.getDynamicDims(), |
| subspanOp.getAlignmentAttr(), subspanOp.getDescriptorFlagsAttr()); |
| |
| rewriter.replaceOpWithNewOp<IREE::Flow::DispatchTensorLoadOp>( |
| reshapeOp, reshapeOp.getResultType(), newSubspanOp, |
| loadOp.getSourceDims()); |
| |
| return success(); |
| } |
| }; |
| |
| // Removes operations with Allocate MemoryEffects but no uses. |
| struct RemoveDeadMemAllocs : RewritePattern { |
| RemoveDeadMemAllocs(MLIRContext *context, PatternBenefit benefit = 1) |
| : RewritePattern(MatchAnyOpTypeTag(), benefit, context) {} |
| |
| LogicalResult matchAndRewrite(Operation *op, |
| PatternRewriter &rewriter) const override { |
| auto memEffect = dyn_cast<MemoryEffectOpInterface>(op); |
| if (!memEffect || !memEffect.hasEffect<MemoryEffects::Allocate>()) { |
| return failure(); |
| } |
| if (!op->use_empty()) return failure(); |
| rewriter.eraseOp(op); |
| return success(); |
| } |
| }; |
| |
| /// Runs canonicalization patterns on interface load/store ops. |
| struct CleanupBufferAllocViewPass |
| : public CleanupBufferAllocViewBase<CleanupBufferAllocViewPass> { |
| void runOnOperation() override { |
| RewritePatternSet patterns(&getContext()); |
| populateReshapeToInterfaceTensorPatterns(patterns); |
| patterns.insert<RemoveDeadMemAllocs>(&getContext()); |
| if (failed(applyPatternsAndFoldGreedily(getOperation(), |
| std::move(patterns)))) { |
| return signalPassFailure(); |
| } |
| } |
| }; |
| |
| } // namespace |
| |
| void populateReshapeToInterfaceTensorPatterns(RewritePatternSet &patterns) { |
| patterns.insert<FoldReshapeIntoInterfaceTensorLoad<tensor::CollapseShapeOp>, |
| FoldReshapeIntoInterfaceTensorLoad<tensor::ExpandShapeOp>>( |
| patterns.getContext()); |
| } |
| |
| std::unique_ptr<OperationPass<func::FuncOp>> |
| createCleanupBufferAllocViewPass() { |
| return std::make_unique<CleanupBufferAllocViewPass>(); |
| } |
| |
| } // namespace iree_compiler |
| } // namespace mlir |