| // Copyright 2019 Google LLC |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // https://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| |
| #include "iree/compiler/Dialect/VM/Conversion/StandardToVM/ConvertStandardToVM.h" |
| |
| #include "iree/compiler/Dialect/IREE/IR/IREETypes.h" |
| #include "iree/compiler/Dialect/VM/Conversion/TargetOptions.h" |
| #include "iree/compiler/Dialect/VM/Conversion/TypeConverter.h" |
| #include "iree/compiler/Dialect/VM/IR/VMOps.h" |
| #include "mlir/Dialect/StandardOps/IR/Ops.h" |
| #include "mlir/IR/Attributes.h" |
| #include "mlir/IR/Builders.h" |
| #include "mlir/IR/BuiltinOps.h" |
| #include "mlir/IR/Matchers.h" |
| #include "mlir/IR/OperationSupport.h" |
| #include "mlir/Transforms/DialectConversion.h" |
| |
| namespace mlir { |
| namespace iree_compiler { |
| |
| namespace { |
| |
| class ModuleOpConversion : public OpConversionPattern<ModuleOp> { |
| using OpConversionPattern::OpConversionPattern; |
| |
| LogicalResult matchAndRewrite( |
| ModuleOp srcOp, ArrayRef<Value> operands, |
| ConversionPatternRewriter &rewriter) const override { |
| // Do not attempt to convert the top level module. |
| // This mechanism can only support rewriting non top-level modules. |
| if (!srcOp->getParentOp() || !isa<ModuleOp>(srcOp->getParentOp())) { |
| return failure(); |
| } |
| |
| StringRef name = srcOp.getName() ? *srcOp.getName() : "module"; |
| auto newModuleOp = |
| rewriter.create<IREE::VM::ModuleOp>(srcOp.getLoc(), name); |
| assert(!newModuleOp.getBodyRegion().empty()); |
| Block *firstCreatedBlock = &newModuleOp.getBodyRegion().front(); |
| rewriter.inlineRegionBefore(srcOp.getBodyRegion(), firstCreatedBlock); |
| auto blockRange = llvm::make_range(Region::iterator(firstCreatedBlock), |
| newModuleOp.getBodyRegion().end()); |
| for (Block &block : llvm::make_early_inc_range(blockRange)) { |
| rewriter.eraseBlock(&block); |
| } |
| rewriter.replaceOp(srcOp, {}); |
| OpBuilder::InsertionGuard g(rewriter); |
| rewriter.setInsertionPointToEnd(&newModuleOp.getBodyRegion().front()); |
| rewriter.create<IREE::VM::ModuleTerminatorOp>(srcOp.getLoc()); |
| return success(); |
| } |
| }; |
| |
| // Allowlist of function attributes to retain when converting to vm.func. |
| constexpr const char *kRetainedAttributes[] = { |
| "iree.reflection", |
| "sym_visibility", |
| "noinline", |
| }; |
| |
| class FuncOpConversion : public OpConversionPattern<FuncOp> { |
| using OpConversionPattern::OpConversionPattern; |
| |
| LogicalResult matchAndRewrite( |
| FuncOp srcOp, ArrayRef<Value> operands, |
| ConversionPatternRewriter &rewriter) const override { |
| FunctionType srcFuncType = srcOp.getType(); |
| TypeConverter::SignatureConversion signatureConversion( |
| srcOp.getNumArguments()); |
| |
| // Convert function arguments. |
| for (unsigned i = 0, e = srcFuncType.getNumInputs(); i < e; ++i) { |
| if (failed(getTypeConverter()->convertSignatureArg( |
| i, srcFuncType.getInput(i), signatureConversion))) { |
| return rewriter.notifyMatchFailure(srcOp, "argument failed to convert"); |
| } |
| } |
| |
| // Convert function results. |
| SmallVector<Type, 1> convertedResultTypes; |
| if (failed(getTypeConverter()->convertTypes(srcFuncType.getResults(), |
| convertedResultTypes))) { |
| return rewriter.notifyMatchFailure(srcOp, "results failed to convert"); |
| } |
| |
| // Create new function with converted argument and result types. |
| // Note that attributes are dropped. Consider preserving some if needed. |
| auto newFuncType = mlir::FunctionType::get( |
| srcOp.getContext(), signatureConversion.getConvertedTypes(), |
| convertedResultTypes); |
| auto newFuncOp = rewriter.create<IREE::VM::FuncOp>( |
| srcOp.getLoc(), srcOp.getName(), newFuncType); |
| rewriter.inlineRegionBefore(srcOp.getBody(), newFuncOp.getBody(), |
| newFuncOp.end()); |
| |
| // Retain function attributes in the allowlist. |
| auto retainedAttributes = ArrayRef<const char *>( |
| kRetainedAttributes, |
| sizeof(kRetainedAttributes) / sizeof(kRetainedAttributes[0])); |
| for (auto retainAttrName : retainedAttributes) { |
| StringRef attrName(retainAttrName); |
| Attribute attr = srcOp->getAttr(attrName); |
| if (attr) { |
| newFuncOp->setAttr(attrName, attr); |
| } |
| } |
| |
| // Tell the rewriter to convert the region signature. |
| TypeConverter &typeConverter = *getTypeConverter(); |
| if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), typeConverter, |
| &signatureConversion))) { |
| return failure(); |
| } |
| |
| // Also add an export for the "raw" form of this function, which operates |
| // on low level VM types and does no verification. A later pass will |
| // materialize high level API-friendly wrappers. |
| if (auto exportAttr = srcOp->getAttr("iree.module.export")) { |
| StringRef exportName = newFuncOp.getName(); |
| if (auto exportStrAttr = exportAttr.dyn_cast<StringAttr>()) { |
| exportName = exportStrAttr.getValue(); |
| } else { |
| assert(exportAttr.isa<UnitAttr>()); |
| } |
| |
| rewriter.create<IREE::VM::ExportOp>(srcOp.getLoc(), newFuncOp, |
| exportName); |
| } |
| |
| rewriter.replaceOp(srcOp, llvm::None); |
| return success(); |
| } |
| }; |
| |
| class ReturnOpConversion : public OpConversionPattern<mlir::ReturnOp> { |
| using OpConversionPattern::OpConversionPattern; |
| |
| LogicalResult matchAndRewrite( |
| mlir::ReturnOp srcOp, ArrayRef<Value> operands, |
| ConversionPatternRewriter &rewriter) const override { |
| rewriter.replaceOpWithNewOp<IREE::VM::ReturnOp>(srcOp, operands); |
| return success(); |
| } |
| }; |
| |
| class ConstantOpConversion : public OpConversionPattern<ConstantOp> { |
| using OpConversionPattern::OpConversionPattern; |
| |
| LogicalResult matchAndRewrite( |
| ConstantOp srcOp, ArrayRef<Value> operands, |
| ConversionPatternRewriter &rewriter) const override { |
| // TODO(#2878): use getTypeConverter() when we pass it upon creation. |
| IREE::VM::TypeConverter typeConverter( |
| IREE::VM::getTargetOptionsFromFlags()); |
| auto targetType = typeConverter.convertType(srcOp.getType()); |
| if (targetType.isa<IntegerType>()) { |
| auto integerAttr = srcOp.getValue().dyn_cast<IntegerAttr>(); |
| if (!integerAttr) { |
| return srcOp.emitRemark() << "unsupported const type for dialect"; |
| } |
| switch (targetType.getIntOrFloatBitWidth()) { |
| case 1: |
| case 32: |
| if (integerAttr.getInt()) { |
| rewriter.replaceOpWithNewOp<IREE::VM::ConstI32Op>( |
| srcOp, integerAttr.getInt()); |
| } else { |
| rewriter.replaceOpWithNewOp<IREE::VM::ConstI32ZeroOp>(srcOp); |
| } |
| break; |
| case 64: |
| if (integerAttr.getInt()) { |
| rewriter.replaceOpWithNewOp<IREE::VM::ConstI64Op>( |
| srcOp, integerAttr.getInt()); |
| } else { |
| rewriter.replaceOpWithNewOp<IREE::VM::ConstI64ZeroOp>(srcOp); |
| } |
| break; |
| default: |
| return srcOp.emitRemark() |
| << "unsupported const integer bit width for dialect"; |
| } |
| } else if (targetType.isa<FloatType>()) { |
| auto floatAttr = srcOp.getValue().dyn_cast<FloatAttr>(); |
| if (!floatAttr) { |
| return srcOp.emitRemark() << "unsupported const type for dialect"; |
| } |
| switch (targetType.getIntOrFloatBitWidth()) { |
| case 32: |
| if (floatAttr.getValue().isZero()) { |
| rewriter.replaceOpWithNewOp<IREE::VM::ConstF32ZeroOp>(srcOp); |
| } else { |
| rewriter.replaceOpWithNewOp<IREE::VM::ConstF32Op>(srcOp, floatAttr); |
| } |
| break; |
| case 64: |
| if (floatAttr.getValue().isZero()) { |
| rewriter.replaceOpWithNewOp<IREE::VM::ConstF64ZeroOp>(srcOp); |
| } else { |
| rewriter.replaceOpWithNewOp<IREE::VM::ConstF64Op>(srcOp, floatAttr); |
| } |
| break; |
| default: |
| return srcOp.emitRemark() |
| << "unsupported const floating-point bit width for dialect"; |
| } |
| } |
| return success(); |
| } |
| }; |
| |
| class CmpIOpConversion : public OpConversionPattern<CmpIOp> { |
| using OpConversionPattern::OpConversionPattern; |
| |
| LogicalResult matchAndRewrite( |
| CmpIOp srcOp, ArrayRef<Value> operands, |
| ConversionPatternRewriter &rewriter) const override { |
| CmpIOp::Adaptor srcAdaptor(operands); |
| auto returnType = rewriter.getIntegerType(32); |
| switch (srcOp.getPredicate()) { |
| case CmpIPredicate::eq: |
| rewriter.replaceOpWithNewOp<IREE::VM::CmpEQI32Op>( |
| srcOp, returnType, srcAdaptor.lhs(), srcAdaptor.rhs()); |
| return success(); |
| case CmpIPredicate::ne: |
| rewriter.replaceOpWithNewOp<IREE::VM::CmpNEI32Op>( |
| srcOp, returnType, srcAdaptor.lhs(), srcAdaptor.rhs()); |
| return success(); |
| case CmpIPredicate::slt: |
| rewriter.replaceOpWithNewOp<IREE::VM::CmpLTI32SOp>( |
| srcOp, returnType, srcAdaptor.lhs(), srcAdaptor.rhs()); |
| return success(); |
| case CmpIPredicate::sle: |
| rewriter.replaceOpWithNewOp<IREE::VM::CmpLTEI32SOp>( |
| srcOp, returnType, srcAdaptor.lhs(), srcAdaptor.rhs()); |
| return success(); |
| case CmpIPredicate::sgt: |
| rewriter.replaceOpWithNewOp<IREE::VM::CmpGTI32SOp>( |
| srcOp, returnType, srcAdaptor.lhs(), srcAdaptor.rhs()); |
| return success(); |
| case CmpIPredicate::sge: |
| rewriter.replaceOpWithNewOp<IREE::VM::CmpGTEI32SOp>( |
| srcOp, returnType, srcAdaptor.lhs(), srcAdaptor.rhs()); |
| return success(); |
| case CmpIPredicate::ult: |
| rewriter.replaceOpWithNewOp<IREE::VM::CmpLTI32UOp>( |
| srcOp, returnType, srcAdaptor.lhs(), srcAdaptor.rhs()); |
| return success(); |
| case CmpIPredicate::ule: |
| rewriter.replaceOpWithNewOp<IREE::VM::CmpLTEI32UOp>( |
| srcOp, returnType, srcAdaptor.lhs(), srcAdaptor.rhs()); |
| return success(); |
| case CmpIPredicate::ugt: |
| rewriter.replaceOpWithNewOp<IREE::VM::CmpGTI32UOp>( |
| srcOp, returnType, srcAdaptor.lhs(), srcAdaptor.rhs()); |
| return success(); |
| case CmpIPredicate::uge: |
| rewriter.replaceOpWithNewOp<IREE::VM::CmpGTEI32UOp>( |
| srcOp, returnType, srcAdaptor.lhs(), srcAdaptor.rhs()); |
| return success(); |
| default: |
| return failure(); |
| } |
| } |
| }; |
| |
| template <typename SrcOpTy, typename Dst32OpTy, typename Dst64OpTy> |
| class BinaryArithmeticOpConversion : public OpConversionPattern<SrcOpTy> { |
| using OpConversionPattern<SrcOpTy>::OpConversionPattern; |
| |
| LogicalResult matchAndRewrite( |
| SrcOpTy srcOp, ArrayRef<Value> operands, |
| ConversionPatternRewriter &rewriter) const override { |
| typename SrcOpTy::Adaptor srcAdaptor(operands); |
| switch (srcAdaptor.lhs().getType().getIntOrFloatBitWidth()) { |
| case 32: |
| rewriter.replaceOpWithNewOp<Dst32OpTy>( |
| srcOp, srcAdaptor.lhs().getType(), srcAdaptor.lhs(), |
| srcAdaptor.rhs()); |
| break; |
| case 64: |
| rewriter.replaceOpWithNewOp<Dst64OpTy>( |
| srcOp, srcAdaptor.lhs().getType(), srcAdaptor.lhs(), |
| srcAdaptor.rhs()); |
| break; |
| default: |
| llvm_unreachable("invalid target type"); |
| } |
| return success(); |
| } |
| }; |
| |
| template <typename SrcOpTy, typename DstOpTy, unsigned kBits = 32> |
| class ShiftArithmeticOpConversion : public OpConversionPattern<SrcOpTy> { |
| using OpConversionPattern<SrcOpTy>::OpConversionPattern; |
| |
| LogicalResult matchAndRewrite( |
| SrcOpTy srcOp, ArrayRef<Value> operands, |
| ConversionPatternRewriter &rewriter) const override { |
| typename SrcOpTy::Adaptor srcAdaptor(operands); |
| auto type = srcOp.getType(); |
| if (!type.isSignlessInteger() || type.getIntOrFloatBitWidth() != kBits) { |
| return failure(); |
| } |
| APInt amount; |
| if (!matchPattern(srcAdaptor.rhs(), m_ConstantInt(&amount))) { |
| return failure(); |
| } |
| uint64_t amountRaw = amount.getZExtValue(); |
| if (amountRaw > kBits) return failure(); |
| IntegerAttr amountAttr = |
| IntegerAttr::get(IntegerType::get(srcOp.getContext(), 8), amountRaw); |
| rewriter.replaceOpWithNewOp<DstOpTy>(srcOp, srcOp.getType(), |
| srcAdaptor.lhs(), amountAttr); |
| return success(); |
| } |
| }; |
| |
| template <typename StdOp> |
| class CastingOpConversion : public OpConversionPattern<StdOp> { |
| using OpConversionPattern<StdOp>::OpConversionPattern; |
| |
| LogicalResult matchAndRewrite( |
| StdOp srcOp, ArrayRef<Value> operands, |
| ConversionPatternRewriter &rewriter) const override { |
| rewriter.replaceOp(srcOp, operands); |
| return success(); |
| } |
| }; |
| |
| class SelectOpConversion : public OpConversionPattern<SelectOp> { |
| using OpConversionPattern::OpConversionPattern; |
| LogicalResult matchAndRewrite( |
| SelectOp srcOp, ArrayRef<Value> operands, |
| ConversionPatternRewriter &rewriter) const override { |
| SelectOp::Adaptor srcAdaptor(operands); |
| auto valueType = srcAdaptor.true_value().getType(); |
| if (valueType.isInteger(32)) { |
| rewriter.replaceOpWithNewOp<IREE::VM::SelectI32Op>( |
| srcOp, valueType, srcAdaptor.condition(), srcAdaptor.true_value(), |
| srcAdaptor.false_value()); |
| return success(); |
| } else if (valueType.isInteger(64)) { |
| rewriter.replaceOpWithNewOp<IREE::VM::SelectI64Op>( |
| srcOp, valueType, srcAdaptor.condition(), srcAdaptor.true_value(), |
| srcAdaptor.false_value()); |
| return success(); |
| } else if (valueType.isF32()) { |
| rewriter.replaceOpWithNewOp<IREE::VM::SelectF32Op>( |
| srcOp, valueType, srcAdaptor.condition(), srcAdaptor.true_value(), |
| srcAdaptor.false_value()); |
| return success(); |
| } else if (valueType.isF64()) { |
| rewriter.replaceOpWithNewOp<IREE::VM::SelectF64Op>( |
| srcOp, valueType, srcAdaptor.condition(), srcAdaptor.true_value(), |
| srcAdaptor.false_value()); |
| return success(); |
| } else { |
| return rewriter.notifyMatchFailure(srcOp, |
| "unsupported select element type"); |
| } |
| } |
| }; |
| |
| class BranchOpConversion : public OpConversionPattern<BranchOp> { |
| using OpConversionPattern::OpConversionPattern; |
| |
| LogicalResult matchAndRewrite( |
| BranchOp srcOp, ArrayRef<Value> operands, |
| ConversionPatternRewriter &rewriter) const override { |
| rewriter.replaceOpWithNewOp<IREE::VM::BranchOp>(srcOp, srcOp.getDest(), |
| operands); |
| return success(); |
| } |
| }; |
| |
| class CondBranchOpConversion : public OpConversionPattern<CondBranchOp> { |
| using OpConversionPattern::OpConversionPattern; |
| |
| LogicalResult matchAndRewrite( |
| CondBranchOp srcOp, ArrayRef<Value> operands, |
| ConversionPatternRewriter &rewriter) const override { |
| Block *trueDest = srcOp.getTrueDest(); |
| rewriter.replaceOpWithNewOp<IREE::VM::CondBranchOp>( |
| srcOp, operands[0], trueDest, |
| operands.slice(1, trueDest->getNumArguments()), srcOp.getFalseDest(), |
| operands.slice(1 + trueDest->getNumArguments())); |
| return success(); |
| } |
| }; |
| |
| class CallOpConversion : public OpConversionPattern<CallOp> { |
| using OpConversionPattern::OpConversionPattern; |
| |
| LogicalResult matchAndRewrite( |
| CallOp srcOp, ArrayRef<Value> operands, |
| ConversionPatternRewriter &rewriter) const override { |
| CallOp::Adaptor srcAdaptor(operands); |
| // Convert function result types. The conversion framework will ensure |
| // that the callee has been equivalently converted. |
| SmallVector<Type, 4> resultTypes; |
| for (auto resultType : srcOp.getResultTypes()) { |
| resultType = getTypeConverter()->convertType(resultType); |
| if (!resultType) { |
| return failure(); |
| } |
| resultTypes.push_back(resultType); |
| } |
| rewriter.replaceOpWithNewOp<IREE::VM::CallOp>( |
| srcOp, srcOp.getCallee(), resultTypes, srcAdaptor.operands()); |
| |
| return success(); |
| } |
| }; |
| |
| } // namespace |
| |
| void populateStandardToVMPatterns(MLIRContext *context, |
| TypeConverter &typeConverter, |
| OwningRewritePatternList &patterns) { |
| patterns.insert<BranchOpConversion, CallOpConversion, CmpIOpConversion, |
| CondBranchOpConversion, ModuleOpConversion, FuncOpConversion, |
| ReturnOpConversion, CastingOpConversion<IndexCastOp>, |
| CastingOpConversion<TruncateIOp>, SelectOpConversion>( |
| typeConverter, context); |
| // TODO(#2878): pass typeConverter here. |
| patterns.insert<ConstantOpConversion>(context); |
| |
| // Integer arithmetic ops. |
| patterns.insert< |
| BinaryArithmeticOpConversion<AddIOp, IREE::VM::AddI32Op, |
| IREE::VM::AddI64Op>, |
| BinaryArithmeticOpConversion<SignedDivIOp, IREE::VM::DivI32SOp, |
| IREE::VM::DivI64SOp>, |
| BinaryArithmeticOpConversion<UnsignedDivIOp, IREE::VM::DivI32UOp, |
| IREE::VM::DivI64UOp>, |
| BinaryArithmeticOpConversion<MulIOp, IREE::VM::MulI32Op, |
| IREE::VM::MulI64Op>, |
| BinaryArithmeticOpConversion<SignedRemIOp, IREE::VM::RemI32SOp, |
| IREE::VM::RemI64SOp>, |
| BinaryArithmeticOpConversion<UnsignedRemIOp, IREE::VM::RemI32UOp, |
| IREE::VM::RemI64UOp>, |
| BinaryArithmeticOpConversion<SubIOp, IREE::VM::SubI32Op, |
| IREE::VM::SubI64Op>, |
| BinaryArithmeticOpConversion<AndOp, IREE::VM::AndI32Op, |
| IREE::VM::AndI64Op>, |
| BinaryArithmeticOpConversion<OrOp, IREE::VM::OrI32Op, IREE::VM::OrI64Op>, |
| BinaryArithmeticOpConversion<XOrOp, IREE::VM::XorI32Op, |
| IREE::VM::XorI64Op>>(typeConverter, context); |
| |
| // Floating-point arithmetic ops. |
| patterns.insert<BinaryArithmeticOpConversion<AddFOp, IREE::VM::AddF32Op, |
| IREE::VM::AddF64Op>, |
| BinaryArithmeticOpConversion<DivFOp, IREE::VM::DivF32Op, |
| IREE::VM::DivF64Op>, |
| BinaryArithmeticOpConversion<MulFOp, IREE::VM::MulF32Op, |
| IREE::VM::MulF64Op>, |
| BinaryArithmeticOpConversion<RemFOp, IREE::VM::RemF32Op, |
| IREE::VM::RemF64Op>, |
| BinaryArithmeticOpConversion<SubFOp, IREE::VM::SubF32Op, |
| IREE::VM::SubF64Op>>( |
| typeConverter, context); |
| |
| // Shift ops |
| // TODO(laurenzo): The standard dialect is missing shr ops. Add once in place. |
| patterns.insert<ShiftArithmeticOpConversion<ShiftLeftOp, IREE::VM::ShlI32Op>>( |
| typeConverter, context); |
| } |
| |
| } // namespace iree_compiler |
| } // namespace mlir |