| // Copyright 2019 Google LLC |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // https://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| |
| #include "third_party/llvm/llvm/include/llvm/ADT/DenseSet.h" |
| #include "third_party/llvm/llvm/include/llvm/ADT/SmallVector.h" |
| #include "third_party/llvm/llvm/include/llvm/Support/Allocator.h" |
| #include "third_party/llvm/llvm/include/llvm/Support/Casting.h" |
| #include "third_party/llvm/llvm/projects/google_mlir/include/mlir/Dialect/StandardOps/Ops.h" |
| #include "third_party/llvm/llvm/projects/google_mlir/include/mlir/IR/Attributes.h" |
| #include "third_party/llvm/llvm/projects/google_mlir/include/mlir/IR/BlockAndValueMapping.h" |
| #include "third_party/llvm/llvm/projects/google_mlir/include/mlir/IR/Location.h" |
| #include "third_party/llvm/llvm/projects/google_mlir/include/mlir/IR/MLIRContext.h" |
| #include "third_party/llvm/llvm/projects/google_mlir/include/mlir/IR/PatternMatch.h" |
| #include "third_party/llvm/llvm/projects/google_mlir/include/mlir/IR/StandardTypes.h" |
| #include "third_party/llvm/llvm/projects/google_mlir/include/mlir/Pass/Pass.h" |
| #include "third_party/llvm/llvm/projects/google_mlir/include/mlir/Pass/PassRegistry.h" |
| #include "third_party/llvm/llvm/projects/google_mlir/include/mlir/Transforms/DialectConversion.h" |
| #include "third_party/llvm/llvm/projects/google_mlir/include/mlir/Transforms/Utils.h" |
| #include "third_party/mlir_edge/iree/compiler/IR/Dialect.h" |
| #include "third_party/mlir_edge/iree/compiler/IR/Interpreter/HLDialect.h" |
| #include "third_party/mlir_edge/iree/compiler/IR/Interpreter/HLOps.h" |
| #include "third_party/mlir_edge/iree/compiler/IR/Interpreter/LLDialect.h" |
| #include "third_party/mlir_edge/iree/compiler/IR/Interpreter/LLOps.h" |
| #include "third_party/mlir_edge/iree/compiler/IR/Ops.h" |
| #include "third_party/mlir_edge/iree/compiler/Serialization/BytecodeTables.h" |
| #include "third_party/mlir_edge/iree/schemas/bytecode/interpreter_bytecode_v0.h" |
| |
| namespace mlir { |
| namespace iree_compiler { |
| |
| namespace { |
| |
| struct LowerBranchOpPattern |
| : public OpRewritePattern<IREEInterp::HL::BranchOp> { |
| using OpRewritePattern<IREEInterp::HL::BranchOp>::OpRewritePattern; |
| |
| PatternMatchResult matchAndRewrite(IREEInterp::HL::BranchOp op, |
| PatternRewriter &rewriter) const { |
| SmallVector<Value *, 8> operands{op.getOperation()->getOperands()}; |
| |
| rewriter.replaceOpWithNewOp<IREEInterp::LL::BranchOp>(op, op.getDest(), |
| operands); |
| return matchSuccess(); |
| } |
| }; |
| |
| struct LowerCondCondBranchOpPattern |
| : public OpRewritePattern<IREEInterp::HL::CondBranchOp> { |
| using OpRewritePattern<IREEInterp::HL::CondBranchOp>::OpRewritePattern; |
| |
| PatternMatchResult matchAndRewrite(IREEInterp::HL::CondBranchOp op, |
| PatternRewriter &rewriter) const { |
| SmallVector<Value *, 8> trueOperands{op.getTrueOperands()}; |
| SmallVector<Value *, 8> falseOperands{op.getFalseOperands()}; |
| |
| rewriter.replaceOpWithNewOp<IREEInterp::LL::CondBranchOp>( |
| op, op.getCondition(), op.getTrueDest(), trueOperands, |
| op.getFalseDest(), falseOperands); |
| return matchSuccess(); |
| } |
| }; |
| |
| // Returns true if the op defined by |opName| (like 'iree_ll_interp.reshape') |
| // uses output operands for results (like iree_ll_interp.add_i) or returns real |
| // results. |
| bool opTakesOutputOperands(llvm::StringRef opName) { |
| if (!opName.consume_front("iree_ll_interp.")) { |
| assert(false && "op not part of IREE LL Interpreter dialect"); |
| return false; |
| } |
| auto opcode = GetInterpreterOpcodeByName(opName.str()); |
| assert(opcode.hasValue() && "op has no corresponding opcode"); |
| const auto &info = GetInterpreterOpcodeInfo(opcode.getValue()); |
| for (auto &operand : info.operands) { |
| if (operand == iree::OperandEncoding::kOutputSlot || |
| operand == iree::OperandEncoding::kVariadicOutputSlots) { |
| return true; |
| } |
| } |
| return false; |
| } |
| |
| template <typename SrcOp, typename DstOp> |
| class SimpleOpLowering : public OpRewritePattern<SrcOp> { |
| using OpRewritePattern<SrcOp>::OpRewritePattern; |
| |
| PatternMatchResult matchAndRewrite(SrcOp op, |
| PatternRewriter &rewriter) const { |
| SmallVector<Value *, 8> operands{op.getOperation()->getOperands()}; |
| |
| // Most ops take results as output operands to populate during execution. |
| // Certain ops, like reshape, return references to existing memrefs and |
| // should still retain their results. |
| if (!opTakesOutputOperands(DstOp::getOperationName())) { |
| SmallVector<Type, 8> resultTypes{op.getOperation()->getResultTypes()}; |
| |
| rewriter.replaceOpWithNewOp<DstOp>(op, resultTypes, operands, |
| op.getAttrs()); |
| return this->matchSuccess(); |
| } |
| |
| for (Value *result : op.getOperation()->getResults()) { |
| auto memRefType = result->getType().cast<MemRefType>(); |
| if (!memRefType.hasStaticShape()) { |
| // TODO(benvanik): real thing here - dynamic shaping required. |
| // This should emit a shape calculation based on the operation. Most |
| // are likely simple and by running DCE after this we can clean up |
| // parts that are static or unused. |
| op.emitOpError() << "uses unsupported dynamic shapes"; |
| return this->matchFailure(); |
| } |
| ArrayRef<Value *> dim_pieces; |
| auto allocOp = rewriter.create<IREEInterp::LL::AllocHeapOp>( |
| op.getLoc(), memRefType, dim_pieces); |
| operands.push_back(allocOp); |
| result->replaceAllUsesWith(allocOp); |
| } |
| ArrayRef<Type> resultTypes; |
| rewriter.create<DstOp>(op.getLoc(), resultTypes, operands, op.getAttrs()); |
| op.erase(); |
| return this->matchSuccess(); |
| } |
| }; |
| |
| } // namespace |
| |
| class LowerInterpreterDialectPass |
| : public FunctionPass<LowerInterpreterDialectPass> { |
| public: |
| void runOnFunction() override { |
| OwningRewritePatternList patterns; |
| patterns.insert<LowerBranchOpPattern, LowerCondCondBranchOpPattern>( |
| &getContext()); |
| patterns.insert< |
| SimpleOpLowering<IREE::ConstantOp, IREEInterp::LL::ConstantOp>, |
| SimpleOpLowering<IREEInterp::HL::CopyOp, IREEInterp::LL::DynamicCopyOp>, |
| SimpleOpLowering<IREEInterp::HL::SliceOp, |
| IREEInterp::LL::DynamicSliceOp>>(&getContext()); |
| #define SAME_NAME_SIMPLE_PATTERN(op_name) \ |
| SimpleOpLowering<IREEInterp::HL::op_name, IREEInterp::LL::op_name> |
| // clang-format off |
| patterns.insert< |
| SAME_NAME_SIMPLE_PATTERN(AssignOp), |
| SAME_NAME_SIMPLE_PATTERN(AbsFOp), |
| SAME_NAME_SIMPLE_PATTERN(AbsIOp), |
| SAME_NAME_SIMPLE_PATTERN(AddFOp), |
| SAME_NAME_SIMPLE_PATTERN(AddIOp), |
| SAME_NAME_SIMPLE_PATTERN(AllocHeapOp), |
| SAME_NAME_SIMPLE_PATTERN(AndOp), |
| SAME_NAME_SIMPLE_PATTERN(Atan2FOp), |
| SAME_NAME_SIMPLE_PATTERN(BreakOp), |
| SAME_NAME_SIMPLE_PATTERN(BroadcastOp), |
| SAME_NAME_SIMPLE_PATTERN(CallOp), |
| SAME_NAME_SIMPLE_PATTERN(CallIndirectOp), |
| SAME_NAME_SIMPLE_PATTERN(CeilFOp), |
| SAME_NAME_SIMPLE_PATTERN(ClampFOp), |
| SAME_NAME_SIMPLE_PATTERN(CloneOp), |
| SAME_NAME_SIMPLE_PATTERN(CmpFOp), |
| SAME_NAME_SIMPLE_PATTERN(CmpIOp), |
| SAME_NAME_SIMPLE_PATTERN(CondAssignOp), |
| SAME_NAME_SIMPLE_PATTERN(ConvertSSOp), |
| SAME_NAME_SIMPLE_PATTERN(ConvertUUOp), |
| SAME_NAME_SIMPLE_PATTERN(ConvertSUOp), |
| SAME_NAME_SIMPLE_PATTERN(ConvertUSOp), |
| SAME_NAME_SIMPLE_PATTERN(CondBreakOp), |
| SAME_NAME_SIMPLE_PATTERN(CosFOp), |
| SAME_NAME_SIMPLE_PATTERN(DimOp), |
| SAME_NAME_SIMPLE_PATTERN(DivFOp), |
| SAME_NAME_SIMPLE_PATTERN(DivISOp), |
| SAME_NAME_SIMPLE_PATTERN(DivIUOp), |
| SAME_NAME_SIMPLE_PATTERN(ExpFOp), |
| SAME_NAME_SIMPLE_PATTERN(LogFOp), |
| SAME_NAME_SIMPLE_PATTERN(RsqrtFOp), |
| SAME_NAME_SIMPLE_PATTERN(FloorFOp), |
| SAME_NAME_SIMPLE_PATTERN(LengthOp), |
| SAME_NAME_SIMPLE_PATTERN(MatMulFOp), |
| SAME_NAME_SIMPLE_PATTERN(MatMulIOp), |
| SAME_NAME_SIMPLE_PATTERN(MaxFOp), |
| SAME_NAME_SIMPLE_PATTERN(MaxISOp), |
| SAME_NAME_SIMPLE_PATTERN(MaxIUOp), |
| SAME_NAME_SIMPLE_PATTERN(MinFOp), |
| SAME_NAME_SIMPLE_PATTERN(MinISOp), |
| SAME_NAME_SIMPLE_PATTERN(MinIUOp), |
| SAME_NAME_SIMPLE_PATTERN(MulAddFOp), |
| SAME_NAME_SIMPLE_PATTERN(MulAddIOp), |
| SAME_NAME_SIMPLE_PATTERN(MulFOp), |
| SAME_NAME_SIMPLE_PATTERN(MulIOp), |
| SAME_NAME_SIMPLE_PATTERN(NotOp), |
| SAME_NAME_SIMPLE_PATTERN(OrOp), |
| SAME_NAME_SIMPLE_PATTERN(PadOp), |
| SAME_NAME_SIMPLE_PATTERN(RankOp), |
| SAME_NAME_SIMPLE_PATTERN(ReduceSumIOp), |
| SAME_NAME_SIMPLE_PATTERN(ReduceSumFOp), |
| SAME_NAME_SIMPLE_PATTERN(ReduceMinIOp), |
| SAME_NAME_SIMPLE_PATTERN(ReduceMinFOp), |
| SAME_NAME_SIMPLE_PATTERN(ReduceMaxIOp), |
| SAME_NAME_SIMPLE_PATTERN(ReduceMaxFOp), |
| SAME_NAME_SIMPLE_PATTERN(ReshapeOp), |
| SAME_NAME_SIMPLE_PATTERN(ReturnOp), |
| SAME_NAME_SIMPLE_PATTERN(SelectOp), |
| SAME_NAME_SIMPLE_PATTERN(ShapeOp), |
| SAME_NAME_SIMPLE_PATTERN(ShiftLeftOp), |
| SAME_NAME_SIMPLE_PATTERN(ShiftRightArithmeticOp), |
| SAME_NAME_SIMPLE_PATTERN(ShiftRightLogicalOp), |
| SAME_NAME_SIMPLE_PATTERN(SinFOp), |
| SAME_NAME_SIMPLE_PATTERN(SplitOp), |
| SAME_NAME_SIMPLE_PATTERN(SubFOp), |
| SAME_NAME_SIMPLE_PATTERN(SubIOp), |
| SAME_NAME_SIMPLE_PATTERN(TanhFOp), |
| SAME_NAME_SIMPLE_PATTERN(TileOp), |
| SAME_NAME_SIMPLE_PATTERN(TraceOp), |
| SAME_NAME_SIMPLE_PATTERN(TransposeOp), |
| SAME_NAME_SIMPLE_PATTERN(ReverseOp), |
| SAME_NAME_SIMPLE_PATTERN(XorOp)>(&getContext()); |
| // clang-format on |
| #undef SAME_NAME_SIMPLE_PATTERN |
| |
| ConversionTarget target(getContext()); |
| target.addLegalDialect<IREELLInterpreterDialect>(); |
| target.addLegalOp<FuncOp, IREE::ReturnOp>(); |
| if (failed(applyFullConversion(getFunction(), target, patterns))) { |
| return signalPassFailure(); |
| } |
| } |
| }; |
| |
| std::unique_ptr<OpPassBase<FuncOp>> createLowerInterpreterDialectPass() { |
| return std::make_unique<LowerInterpreterDialectPass>(); |
| } |
| |
| static PassRegistration<LowerInterpreterDialectPass> pass( |
| "lower-iree-interpreter-hl-to-ll", "Lowers IREE HL ops to IREE LL ops"); |
| |
| } // namespace iree_compiler |
| } // namespace mlir |