| // Copyright 2019 The IREE Authors |
| // |
| // Licensed under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| |
| #include "iree/compiler/Dialect/VM/Conversion/StandardToVM/ConvertStandardToVM.h" |
| |
| #include "iree/base/api.h" |
| #include "iree/compiler/Dialect/Util/IR/UtilTypes.h" |
| #include "iree/compiler/Dialect/VM/Conversion/TargetOptions.h" |
| #include "iree/compiler/Dialect/VM/Conversion/TypeConverter.h" |
| #include "iree/compiler/Dialect/VM/IR/VMOps.h" |
| #include "mlir/Dialect/StandardOps/IR/Ops.h" |
| #include "mlir/IR/Attributes.h" |
| #include "mlir/IR/Builders.h" |
| #include "mlir/IR/BuiltinOps.h" |
| #include "mlir/IR/Matchers.h" |
| #include "mlir/IR/OperationSupport.h" |
| #include "mlir/Transforms/DialectConversion.h" |
| |
| namespace mlir { |
| namespace iree_compiler { |
| |
| namespace { |
| |
| class ModuleOpConversion : public OpConversionPattern<ModuleOp> { |
| using OpConversionPattern::OpConversionPattern; |
| |
| LogicalResult matchAndRewrite( |
| ModuleOp srcOp, ArrayRef<Value> operands, |
| ConversionPatternRewriter &rewriter) const override { |
| // Do not attempt to convert the top level module. |
| // This mechanism can only support rewriting non top-level modules. |
| if (!srcOp->getParentOp() || !isa<ModuleOp>(srcOp->getParentOp())) { |
| return failure(); |
| } |
| |
| StringRef name = srcOp.getName() ? *srcOp.getName() : "module"; |
| auto newModuleOp = |
| rewriter.create<IREE::VM::ModuleOp>(srcOp.getLoc(), name); |
| assert(!newModuleOp.getBodyRegion().empty()); |
| Block *firstCreatedBlock = &newModuleOp.getBodyRegion().front(); |
| rewriter.inlineRegionBefore(srcOp.getBodyRegion(), firstCreatedBlock); |
| auto blockRange = llvm::make_range(Region::iterator(firstCreatedBlock), |
| newModuleOp.getBodyRegion().end()); |
| for (Block &block : llvm::make_early_inc_range(blockRange)) { |
| rewriter.eraseBlock(&block); |
| } |
| rewriter.replaceOp(srcOp, {}); |
| OpBuilder::InsertionGuard g(rewriter); |
| rewriter.setInsertionPointToEnd(&newModuleOp.getBodyRegion().front()); |
| rewriter.create<IREE::VM::ModuleTerminatorOp>(srcOp.getLoc()); |
| return success(); |
| } |
| }; |
| |
| // Allowlist of function attributes to retain when converting to vm.func. |
| constexpr const char *kRetainedAttributes[] = { |
| "iree.reflection", |
| "sym_visibility", |
| "noinline", |
| }; |
| |
| class FuncOpConversion : public OpConversionPattern<FuncOp> { |
| using OpConversionPattern::OpConversionPattern; |
| |
| LogicalResult matchAndRewrite( |
| FuncOp srcOp, ArrayRef<Value> operands, |
| ConversionPatternRewriter &rewriter) const override { |
| FunctionType srcFuncType = srcOp.getType(); |
| TypeConverter::SignatureConversion signatureConversion( |
| srcOp.getNumArguments()); |
| |
| // Convert function arguments. |
| for (unsigned i = 0, e = srcFuncType.getNumInputs(); i < e; ++i) { |
| if (failed(getTypeConverter()->convertSignatureArg( |
| i, srcFuncType.getInput(i), signatureConversion))) { |
| return rewriter.notifyMatchFailure(srcOp, "argument failed to convert"); |
| } |
| } |
| |
| // Convert function results. |
| SmallVector<Type, 1> convertedResultTypes; |
| if (failed(getTypeConverter()->convertTypes(srcFuncType.getResults(), |
| convertedResultTypes))) { |
| return rewriter.notifyMatchFailure(srcOp, "results failed to convert"); |
| } |
| |
| // Create new function with converted argument and result types. |
| // Note that attributes are dropped. Consider preserving some if needed. |
| auto newFuncType = mlir::FunctionType::get( |
| srcOp.getContext(), signatureConversion.getConvertedTypes(), |
| convertedResultTypes); |
| auto newFuncOp = rewriter.create<IREE::VM::FuncOp>( |
| srcOp.getLoc(), srcOp.getName(), newFuncType); |
| rewriter.inlineRegionBefore(srcOp.getBody(), newFuncOp.getBody(), |
| newFuncOp.end()); |
| |
| // Retain function attributes in the allowlist. |
| auto retainedAttributes = ArrayRef<const char *>( |
| kRetainedAttributes, |
| sizeof(kRetainedAttributes) / sizeof(kRetainedAttributes[0])); |
| for (auto retainAttrName : retainedAttributes) { |
| StringRef attrName(retainAttrName); |
| Attribute attr = srcOp->getAttr(attrName); |
| if (attr) { |
| newFuncOp->setAttr(attrName, attr); |
| } |
| } |
| |
| // Tell the rewriter to convert the region signature. |
| TypeConverter &typeConverter = *getTypeConverter(); |
| if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), typeConverter, |
| &signatureConversion))) { |
| return failure(); |
| } |
| |
| // Also add an export for the "raw" form of this function, which operates |
| // on low level VM types and does no verification. A later pass will |
| // materialize high level API-friendly wrappers. |
| if (srcOp.isPublic()) { |
| StringRef exportName = newFuncOp.getName(); |
| rewriter.create<IREE::VM::ExportOp>(srcOp.getLoc(), newFuncOp, |
| exportName); |
| } |
| // VM functions are private by default and exported via the dedicated |
| // vm.export ops. |
| newFuncOp.setPrivate(); |
| |
| rewriter.replaceOp(srcOp, llvm::None); |
| return success(); |
| } |
| }; |
| |
| class ReturnOpConversion : public OpConversionPattern<mlir::ReturnOp> { |
| using OpConversionPattern::OpConversionPattern; |
| |
| LogicalResult matchAndRewrite( |
| mlir::ReturnOp srcOp, ArrayRef<Value> operands, |
| ConversionPatternRewriter &rewriter) const override { |
| rewriter.replaceOpWithNewOp<IREE::VM::ReturnOp>(srcOp, operands); |
| return success(); |
| } |
| }; |
| |
| struct ConstantOpConversion : public OpConversionPattern<ConstantOp> { |
| ConstantOpConversion(MLIRContext *context, TypeConverter &typeConverter) |
| : OpConversionPattern(context), typeConverter(typeConverter) {} |
| |
| TypeConverter &typeConverter; |
| |
| LogicalResult matchAndRewrite( |
| ConstantOp srcOp, ArrayRef<Value> operands, |
| ConversionPatternRewriter &rewriter) const override { |
| auto targetType = typeConverter.convertType(srcOp.getType()); |
| if (!targetType) { |
| return srcOp.emitError() << "could not convert type: " << srcOp.getType() |
| << " (check -iree-vm-target-* options)"; |
| } |
| if (targetType.isa<IntegerType>()) { |
| auto integerAttr = srcOp.getValue().dyn_cast<IntegerAttr>(); |
| if (!integerAttr) { |
| return srcOp.emitRemark() << "unsupported const type for dialect"; |
| } |
| switch (targetType.getIntOrFloatBitWidth()) { |
| case 1: |
| case 32: |
| if (integerAttr.getInt()) { |
| rewriter.replaceOpWithNewOp<IREE::VM::ConstI32Op>( |
| srcOp, |
| integerAttr.getType().isInteger(1) ? 1 : integerAttr.getInt()); |
| } else { |
| rewriter.replaceOpWithNewOp<IREE::VM::ConstI32ZeroOp>(srcOp); |
| } |
| break; |
| case 64: |
| if (integerAttr.getInt()) { |
| rewriter.replaceOpWithNewOp<IREE::VM::ConstI64Op>( |
| srcOp, integerAttr.getInt()); |
| } else { |
| rewriter.replaceOpWithNewOp<IREE::VM::ConstI64ZeroOp>(srcOp); |
| } |
| break; |
| default: |
| return srcOp.emitRemark() |
| << "unsupported const integer bit width for dialect"; |
| } |
| } else if (targetType.isa<FloatType>()) { |
| auto floatAttr = srcOp.getValue().dyn_cast<FloatAttr>(); |
| if (!floatAttr) { |
| return srcOp.emitRemark() << "unsupported const type for dialect"; |
| } |
| switch (targetType.getIntOrFloatBitWidth()) { |
| case 32: |
| if (floatAttr.getValue().isZero()) { |
| rewriter.replaceOpWithNewOp<IREE::VM::ConstF32ZeroOp>(srcOp); |
| } else { |
| rewriter.replaceOpWithNewOp<IREE::VM::ConstF32Op>(srcOp, floatAttr); |
| } |
| break; |
| case 64: |
| if (floatAttr.getValue().isZero()) { |
| rewriter.replaceOpWithNewOp<IREE::VM::ConstF64ZeroOp>(srcOp); |
| } else { |
| rewriter.replaceOpWithNewOp<IREE::VM::ConstF64Op>(srcOp, floatAttr); |
| } |
| break; |
| default: |
| return srcOp.emitRemark() |
| << "unsupported const floating-point bit width for dialect"; |
| } |
| } else { |
| return rewriter.notifyMatchFailure(srcOp, "unsupported type"); |
| } |
| return success(); |
| } |
| }; |
| |
| class CmpIOpConversion : public OpConversionPattern<CmpIOp> { |
| using OpConversionPattern::OpConversionPattern; |
| |
| LogicalResult matchAndRewrite( |
| CmpIOp srcOp, ArrayRef<Value> operands, |
| ConversionPatternRewriter &rewriter) const override { |
| CmpIOp::Adaptor srcAdaptor(operands); |
| auto returnType = rewriter.getIntegerType(32); |
| switch (srcOp.getPredicate()) { |
| case CmpIPredicate::eq: |
| rewriter.replaceOpWithNewOp<IREE::VM::CmpEQI32Op>( |
| srcOp, returnType, srcAdaptor.lhs(), srcAdaptor.rhs()); |
| return success(); |
| case CmpIPredicate::ne: |
| rewriter.replaceOpWithNewOp<IREE::VM::CmpNEI32Op>( |
| srcOp, returnType, srcAdaptor.lhs(), srcAdaptor.rhs()); |
| return success(); |
| case CmpIPredicate::slt: |
| rewriter.replaceOpWithNewOp<IREE::VM::CmpLTI32SOp>( |
| srcOp, returnType, srcAdaptor.lhs(), srcAdaptor.rhs()); |
| return success(); |
| case CmpIPredicate::sle: |
| rewriter.replaceOpWithNewOp<IREE::VM::CmpLTEI32SOp>( |
| srcOp, returnType, srcAdaptor.lhs(), srcAdaptor.rhs()); |
| return success(); |
| case CmpIPredicate::sgt: |
| rewriter.replaceOpWithNewOp<IREE::VM::CmpGTI32SOp>( |
| srcOp, returnType, srcAdaptor.lhs(), srcAdaptor.rhs()); |
| return success(); |
| case CmpIPredicate::sge: |
| rewriter.replaceOpWithNewOp<IREE::VM::CmpGTEI32SOp>( |
| srcOp, returnType, srcAdaptor.lhs(), srcAdaptor.rhs()); |
| return success(); |
| case CmpIPredicate::ult: |
| rewriter.replaceOpWithNewOp<IREE::VM::CmpLTI32UOp>( |
| srcOp, returnType, srcAdaptor.lhs(), srcAdaptor.rhs()); |
| return success(); |
| case CmpIPredicate::ule: |
| rewriter.replaceOpWithNewOp<IREE::VM::CmpLTEI32UOp>( |
| srcOp, returnType, srcAdaptor.lhs(), srcAdaptor.rhs()); |
| return success(); |
| case CmpIPredicate::ugt: |
| rewriter.replaceOpWithNewOp<IREE::VM::CmpGTI32UOp>( |
| srcOp, returnType, srcAdaptor.lhs(), srcAdaptor.rhs()); |
| return success(); |
| case CmpIPredicate::uge: |
| rewriter.replaceOpWithNewOp<IREE::VM::CmpGTEI32UOp>( |
| srcOp, returnType, srcAdaptor.lhs(), srcAdaptor.rhs()); |
| return success(); |
| default: |
| return failure(); |
| } |
| } |
| }; |
| |
| class CmpFOpConversion : public OpConversionPattern<CmpFOp> { |
| using OpConversionPattern::OpConversionPattern; |
| |
| LogicalResult matchAndRewrite( |
| CmpFOp srcOp, ArrayRef<Value> operands, |
| ConversionPatternRewriter &rewriter) const override { |
| CmpFOp::Adaptor srcAdaptor(operands); |
| auto returnType = rewriter.getIntegerType(32); |
| switch (srcOp.getPredicate()) { |
| case CmpFPredicate::AlwaysFalse: // 0 |
| rewriter.replaceOpWithNewOp<IREE::VM::ConstI32ZeroOp>(srcOp); |
| break; |
| case CmpFPredicate::AlwaysTrue: // 1 |
| rewriter.replaceOpWithNewOp<IREE::VM::ConstI32Op>(srcOp, 1); |
| break; |
| case CmpFPredicate::UNO: // isnan(lhs) || isnan(rhs) |
| rewriter.replaceOpWithNewOp<IREE::VM::OrI32Op>( |
| srcOp, returnType, |
| rewriter.createOrFold<IREE::VM::CmpNaNF32Op>( |
| srcOp.getLoc(), returnType, srcAdaptor.lhs()), |
| rewriter.createOrFold<IREE::VM::CmpNaNF32Op>( |
| srcOp.getLoc(), returnType, srcAdaptor.rhs())); |
| break; |
| case CmpFPredicate::ORD: // !(isnan(lhs) || isnan(rhs)) |
| rewriter.replaceOpWithNewOp<IREE::VM::XorI32Op>( |
| srcOp, returnType, |
| rewriter.createOrFold<IREE::VM::ConstI32Op>(srcOp.getLoc(), 1), |
| rewriter.createOrFold<IREE::VM::AndI32Op>( |
| srcOp.getLoc(), returnType, |
| rewriter.createOrFold<IREE::VM::CmpNaNF32Op>( |
| srcOp.getLoc(), returnType, srcAdaptor.lhs()), |
| rewriter.createOrFold<IREE::VM::CmpNaNF32Op>( |
| srcOp.getLoc(), returnType, srcAdaptor.rhs()))); |
| break; |
| case CmpFPredicate::OEQ: // ordered and equal |
| rewriter.replaceOpWithNewOp<IREE::VM::CmpEQF32OOp>( |
| srcOp, returnType, srcAdaptor.lhs(), srcAdaptor.rhs()); |
| break; |
| case CmpFPredicate::OGT: // ordered and greater than |
| rewriter.replaceOpWithNewOp<IREE::VM::CmpGTF32OOp>( |
| srcOp, returnType, srcAdaptor.lhs(), srcAdaptor.rhs()); |
| break; |
| case CmpFPredicate::OGE: // ordered and greater than or equal |
| rewriter.replaceOpWithNewOp<IREE::VM::CmpGTEF32OOp>( |
| srcOp, returnType, srcAdaptor.lhs(), srcAdaptor.rhs()); |
| break; |
| case CmpFPredicate::OLT: // ordered and less than |
| rewriter.replaceOpWithNewOp<IREE::VM::CmpLTF32OOp>( |
| srcOp, returnType, srcAdaptor.lhs(), srcAdaptor.rhs()); |
| break; |
| case CmpFPredicate::OLE: // ordered and less than or equal |
| rewriter.replaceOpWithNewOp<IREE::VM::CmpLTEF32OOp>( |
| srcOp, returnType, srcAdaptor.lhs(), srcAdaptor.rhs()); |
| break; |
| case CmpFPredicate::ONE: // ordered and not equal |
| rewriter.replaceOpWithNewOp<IREE::VM::CmpNEF32OOp>( |
| srcOp, returnType, srcAdaptor.lhs(), srcAdaptor.rhs()); |
| break; |
| case CmpFPredicate::UEQ: // unordered or equal |
| rewriter.replaceOpWithNewOp<IREE::VM::CmpEQF32UOp>( |
| srcOp, returnType, srcAdaptor.lhs(), srcAdaptor.rhs()); |
| break; |
| case CmpFPredicate::UGT: // unordered or greater than |
| rewriter.replaceOpWithNewOp<IREE::VM::CmpGTF32UOp>( |
| srcOp, returnType, srcAdaptor.lhs(), srcAdaptor.rhs()); |
| break; |
| case CmpFPredicate::UGE: // unordered or greater than or equal |
| rewriter.replaceOpWithNewOp<IREE::VM::CmpGTEF32UOp>( |
| srcOp, returnType, srcAdaptor.lhs(), srcAdaptor.rhs()); |
| break; |
| case CmpFPredicate::ULT: // unordered or less than |
| rewriter.replaceOpWithNewOp<IREE::VM::CmpLTF32UOp>( |
| srcOp, returnType, srcAdaptor.lhs(), srcAdaptor.rhs()); |
| break; |
| case CmpFPredicate::ULE: // unordered or less than or equal |
| rewriter.replaceOpWithNewOp<IREE::VM::CmpLTEF32UOp>( |
| srcOp, returnType, srcAdaptor.lhs(), srcAdaptor.rhs()); |
| break; |
| case CmpFPredicate::UNE: // unordered or not equal |
| rewriter.replaceOpWithNewOp<IREE::VM::CmpNEF32UOp>( |
| srcOp, returnType, srcAdaptor.lhs(), srcAdaptor.rhs()); |
| break; |
| default: |
| return rewriter.notifyMatchFailure(srcOp, "unhandled CmpFPredicate"); |
| } |
| return success(); |
| } |
| }; |
| |
| template <typename SrcOpTy, typename Dst32OpTy, typename Dst64OpTy> |
| class UnaryArithmeticOpConversion : public OpConversionPattern<SrcOpTy> { |
| using OpConversionPattern<SrcOpTy>::OpConversionPattern; |
| |
| LogicalResult matchAndRewrite( |
| SrcOpTy srcOp, ArrayRef<Value> operands, |
| ConversionPatternRewriter &rewriter) const override { |
| typename SrcOpTy::Adaptor srcAdaptor(operands); |
| switch (srcAdaptor.operand().getType().getIntOrFloatBitWidth()) { |
| case 32: |
| rewriter.replaceOpWithNewOp<Dst32OpTy>( |
| srcOp, srcAdaptor.operand().getType(), srcAdaptor.operand()); |
| break; |
| case 64: |
| rewriter.replaceOpWithNewOp<Dst64OpTy>( |
| srcOp, srcAdaptor.operand().getType(), srcAdaptor.operand()); |
| break; |
| default: |
| return rewriter.notifyMatchFailure(srcOp, "unsupported type"); |
| } |
| return success(); |
| } |
| }; |
| |
| template <typename SrcOpTy, typename Dst32OpTy, typename Dst64OpTy> |
| class BinaryArithmeticOpConversion : public OpConversionPattern<SrcOpTy> { |
| using OpConversionPattern<SrcOpTy>::OpConversionPattern; |
| |
| LogicalResult matchAndRewrite( |
| SrcOpTy srcOp, ArrayRef<Value> operands, |
| ConversionPatternRewriter &rewriter) const override { |
| typename SrcOpTy::Adaptor srcAdaptor(operands); |
| switch (srcAdaptor.lhs().getType().getIntOrFloatBitWidth()) { |
| case 32: |
| rewriter.replaceOpWithNewOp<Dst32OpTy>( |
| srcOp, srcAdaptor.lhs().getType(), srcAdaptor.lhs(), |
| srcAdaptor.rhs()); |
| break; |
| case 64: |
| rewriter.replaceOpWithNewOp<Dst64OpTy>( |
| srcOp, srcAdaptor.lhs().getType(), srcAdaptor.lhs(), |
| srcAdaptor.rhs()); |
| break; |
| default: |
| return rewriter.notifyMatchFailure(srcOp, "unsupported type"); |
| } |
| return success(); |
| } |
| }; |
| |
| template <typename SrcOpTy, typename Dst32OpTy, typename Dst64OpTy> |
| class ShiftArithmeticOpConversion : public OpConversionPattern<SrcOpTy> { |
| using OpConversionPattern<SrcOpTy>::OpConversionPattern; |
| |
| LogicalResult matchAndRewrite( |
| SrcOpTy srcOp, ArrayRef<Value> operands, |
| ConversionPatternRewriter &rewriter) const override { |
| typename SrcOpTy::Adaptor srcAdaptor(operands); |
| Value amount = srcAdaptor.rhs(); |
| if (amount.getType().getIntOrFloatBitWidth() > 32) { |
| // Shift amounts are always 32-bit in the VM. |
| amount = rewriter.createOrFold<TruncateIOp>( |
| srcOp.getLoc(), rewriter.getI32Type(), amount); |
| } |
| switch (srcAdaptor.lhs().getType().getIntOrFloatBitWidth()) { |
| case 32: |
| rewriter.replaceOpWithNewOp<Dst32OpTy>(srcOp, srcOp.getType(), |
| srcAdaptor.lhs(), amount); |
| break; |
| case 64: |
| rewriter.replaceOpWithNewOp<Dst64OpTy>(srcOp, srcOp.getType(), |
| srcAdaptor.lhs(), amount); |
| break; |
| default: |
| return rewriter.notifyMatchFailure(srcOp, "unsupported type"); |
| } |
| return success(); |
| } |
| }; |
| |
| template <typename StdOp> |
| class CastingOpConversion : public OpConversionPattern<StdOp> { |
| using OpConversionPattern<StdOp>::OpConversionPattern; |
| |
| LogicalResult matchAndRewrite( |
| StdOp srcOp, ArrayRef<Value> operands, |
| ConversionPatternRewriter &rewriter) const override { |
| rewriter.replaceOp(srcOp, operands); |
| return success(); |
| } |
| }; |
| |
| class IndexCastOpConversion : public OpConversionPattern<IndexCastOp> { |
| using OpConversionPattern::OpConversionPattern; |
| |
| LogicalResult matchAndRewrite( |
| IndexCastOp srcOp, ArrayRef<Value> rawOperands, |
| ConversionPatternRewriter &rewriter) const override { |
| IndexCastOpAdaptor operands(rawOperands); |
| auto srcType = operands.in().getType(); |
| auto dstType = getTypeConverter()->convertType(srcOp.getResult().getType()); |
| if (srcType == dstType) { |
| rewriter.replaceOp(srcOp, rawOperands); |
| } else if (srcType.getIntOrFloatBitWidth() < |
| dstType.getIntOrFloatBitWidth()) { |
| rewriter.replaceOpWithNewOp<ZeroExtendIOp>(srcOp, dstType, operands.in()); |
| } else { |
| rewriter.replaceOpWithNewOp<TruncateIOp>(srcOp, dstType, operands.in()); |
| } |
| return success(); |
| } |
| }; |
| |
| class ZeroExtendIOpConversion : public OpConversionPattern<ZeroExtendIOp> { |
| using OpConversionPattern::OpConversionPattern; |
| |
| LogicalResult matchAndRewrite( |
| ZeroExtendIOp srcOp, ArrayRef<Value> rawOperands, |
| ConversionPatternRewriter &rewriter) const override { |
| ZeroExtendIOpAdaptor operands(rawOperands); |
| auto srcType = srcOp.value().getType(); |
| auto dstType = getTypeConverter()->convertType(srcOp.getResult().getType()); |
| if (srcType.isInteger(1) && dstType.isInteger(32)) { |
| // This may not be needed but ensures that the input was treated as a |
| // single bit. |
| // NOTE: this may not be required - if we know that the i1 is never able |
| // to have more than bit 0 manipulated then this is wasted work. |
| rewriter.replaceOpWithNewOp<IREE::VM::AndI32Op>( |
| srcOp, dstType, operands.value(), |
| rewriter.createOrFold<IREE::VM::ConstI32Op>(srcOp.getLoc(), 1)); |
| } else if (srcType.isInteger(8) && dstType.isInteger(32)) { |
| rewriter.replaceOpWithNewOp<IREE::VM::ExtI8I32UOp>(srcOp, dstType, |
| operands.value()); |
| } else if (srcType.isInteger(16) && dstType.isInteger(32)) { |
| rewriter.replaceOpWithNewOp<IREE::VM::ExtI16I32UOp>(srcOp, dstType, |
| operands.value()); |
| } else if (srcType.isInteger(32) && dstType.isInteger(64)) { |
| rewriter.replaceOpWithNewOp<IREE::VM::ExtI32I64UOp>(srcOp, dstType, |
| operands.value()); |
| } else { |
| // TODO(benvanik): we should be building a sequence of extensions for |
| // things like i8 -> i64. |
| return rewriter.notifyMatchFailure(srcOp, "unsupported zero extension"); |
| } |
| return success(); |
| } |
| }; |
| |
| class SignExtendIOpConversion : public OpConversionPattern<SignExtendIOp> { |
| using OpConversionPattern::OpConversionPattern; |
| |
| LogicalResult matchAndRewrite( |
| SignExtendIOp srcOp, ArrayRef<Value> rawOperands, |
| ConversionPatternRewriter &rewriter) const override { |
| SignExtendIOpAdaptor operands(rawOperands); |
| auto srcType = srcOp.value().getType(); |
| auto dstType = getTypeConverter()->convertType(srcOp.getResult().getType()); |
| if (srcType.isInteger(8) && dstType.isInteger(32)) { |
| rewriter.replaceOpWithNewOp<IREE::VM::ExtI8I32SOp>(srcOp, dstType, |
| operands.value()); |
| } else if (srcType.isInteger(16) && dstType.isInteger(32)) { |
| rewriter.replaceOpWithNewOp<IREE::VM::ExtI16I32SOp>(srcOp, dstType, |
| operands.value()); |
| } else if (srcType.isInteger(32) && dstType.isInteger(64)) { |
| rewriter.replaceOpWithNewOp<IREE::VM::ExtI32I64SOp>(srcOp, dstType, |
| operands.value()); |
| } else { |
| // TODO(benvanik): we should be building a sequence of extensions for |
| // things like i8 -> i64. |
| return rewriter.notifyMatchFailure(srcOp, "unsupported sign extension"); |
| } |
| return success(); |
| } |
| }; |
| |
| class TruncateIOpConversion : public OpConversionPattern<TruncateIOp> { |
| using OpConversionPattern::OpConversionPattern; |
| |
| LogicalResult matchAndRewrite( |
| TruncateIOp srcOp, ArrayRef<Value> rawOperands, |
| ConversionPatternRewriter &rewriter) const override { |
| TruncateIOpAdaptor operands(rawOperands); |
| auto srcType = srcOp.value().getType(); |
| auto resultType = srcOp.getResult().getType(); |
| auto dstType = getTypeConverter()->convertType(resultType); |
| if (resultType.isInteger(1)) { |
| // i1 is represented as i32, so just mask off the bit and truncate as |
| // normal. Note that if we started as i64 we need to first get that into |
| // an i32 that we can work with. |
| auto value = operands.value(); |
| if (srcType.isInteger(64)) { |
| value = rewriter.createOrFold<IREE::VM::TruncI64I32Op>(srcOp.getLoc(), |
| dstType, value); |
| } |
| rewriter.replaceOpWithNewOp<IREE::VM::AndI32Op>( |
| srcOp, dstType, value, |
| rewriter.createOrFold<IREE::VM::ConstI32Op>(srcOp.getLoc(), 1)); |
| } else if (srcType.isInteger(32) && resultType.isInteger(8)) { |
| rewriter.replaceOpWithNewOp<IREE::VM::TruncI32I8Op>(srcOp, dstType, |
| operands.value()); |
| } else if (srcType.isInteger(32) && resultType.isInteger(16)) { |
| rewriter.replaceOpWithNewOp<IREE::VM::TruncI32I16Op>(srcOp, dstType, |
| operands.value()); |
| } else if (srcType.isInteger(64) && resultType.isInteger(8)) { |
| rewriter.replaceOpWithNewOp<IREE::VM::TruncI64I8Op>(srcOp, dstType, |
| operands.value()); |
| } else if (srcType.isInteger(64) && resultType.isInteger(16)) { |
| rewriter.replaceOpWithNewOp<IREE::VM::TruncI64I16Op>(srcOp, dstType, |
| operands.value()); |
| } else if (srcType.isInteger(64) && resultType.isInteger(32)) { |
| rewriter.replaceOpWithNewOp<IREE::VM::TruncI64I32Op>(srcOp, dstType, |
| operands.value()); |
| } else { |
| return rewriter.notifyMatchFailure(srcOp, "unsupported truncation"); |
| } |
| return success(); |
| } |
| }; |
| |
| class SIToFPOpConversion : public OpConversionPattern<SIToFPOp> { |
| using OpConversionPattern::OpConversionPattern; |
| |
| LogicalResult matchAndRewrite( |
| SIToFPOp srcOp, ArrayRef<Value> operands, |
| ConversionPatternRewriter &rewriter) const override { |
| SIToFPOpAdaptor srcAdaptor(operands); |
| auto srcType = operands[0].getType(); |
| auto dstType = getTypeConverter()->convertType(srcOp.getResult().getType()); |
| if (srcType.isSignlessInteger(32) || srcType.isSignedInteger(32)) { |
| if (dstType.isF32()) { |
| rewriter.replaceOpWithNewOp<IREE::VM::CastSI32F32Op>(srcOp, dstType, |
| operands[0]); |
| return success(); |
| } |
| } |
| return rewriter.notifyMatchFailure(srcOp, "unsupported type"); |
| } |
| }; |
| |
| class UIToFPOpConversion : public OpConversionPattern<UIToFPOp> { |
| using OpConversionPattern::OpConversionPattern; |
| |
| LogicalResult matchAndRewrite( |
| UIToFPOp srcOp, ArrayRef<Value> operands, |
| ConversionPatternRewriter &rewriter) const override { |
| UIToFPOpAdaptor srcAdaptor(operands); |
| auto srcType = operands[0].getType(); |
| auto dstType = getTypeConverter()->convertType(srcOp.getResult().getType()); |
| if (srcType.isUnsignedInteger(32)) { |
| if (dstType.isF32()) { |
| rewriter.replaceOpWithNewOp<IREE::VM::CastUI32F32Op>(srcOp, dstType, |
| operands[0]); |
| return success(); |
| } |
| } |
| return rewriter.notifyMatchFailure(srcOp, "unsupported type"); |
| } |
| }; |
| |
| class FPToSIOpConversion : public OpConversionPattern<FPToSIOp> { |
| using OpConversionPattern::OpConversionPattern; |
| |
| LogicalResult matchAndRewrite( |
| FPToSIOp srcOp, ArrayRef<Value> operands, |
| ConversionPatternRewriter &rewriter) const override { |
| FPToSIOpAdaptor srcAdaptor(operands); |
| auto srcType = operands[0].getType(); |
| auto dstType = getTypeConverter()->convertType(srcOp.getResult().getType()); |
| if (srcType.isF32()) { |
| if (dstType.isSignlessInteger(32) || dstType.isSignedInteger(32)) { |
| rewriter.replaceOpWithNewOp<IREE::VM::CastF32SI32Op>(srcOp, dstType, |
| operands[0]); |
| return success(); |
| } |
| } |
| return rewriter.notifyMatchFailure(srcOp, "unsupported type"); |
| } |
| }; |
| |
| class FPToUIOpConversion : public OpConversionPattern<FPToUIOp> { |
| using OpConversionPattern::OpConversionPattern; |
| |
| LogicalResult matchAndRewrite( |
| FPToUIOp srcOp, ArrayRef<Value> operands, |
| ConversionPatternRewriter &rewriter) const override { |
| FPToUIOpAdaptor srcAdaptor(operands); |
| auto srcType = operands[0].getType(); |
| auto dstType = getTypeConverter()->convertType(srcOp.getResult().getType()); |
| if (srcType.isF32()) { |
| if (srcType.isUnsignedInteger(32)) { |
| rewriter.replaceOpWithNewOp<IREE::VM::CastF32UI32Op>(srcOp, dstType, |
| operands[0]); |
| return success(); |
| } |
| } |
| return rewriter.notifyMatchFailure(srcOp, "unsupported type"); |
| } |
| }; |
| |
| class BitcastOpConversion : public OpConversionPattern<BitcastOp> { |
| using OpConversionPattern::OpConversionPattern; |
| |
| LogicalResult matchAndRewrite( |
| BitcastOp srcOp, ArrayRef<Value> operands, |
| ConversionPatternRewriter &rewriter) const override { |
| auto srcType = operands[0].getType(); |
| auto dstType = getTypeConverter()->convertType(srcOp.getResult().getType()); |
| if (srcType.isF32() && dstType.isInteger(32)) { |
| rewriter.replaceOpWithNewOp<IREE::VM::BitcastF32I32Op>(srcOp, dstType, |
| operands[0]); |
| } else if (srcType.isInteger(32) && dstType.isF32()) { |
| rewriter.replaceOpWithNewOp<IREE::VM::BitcastI32F32Op>(srcOp, dstType, |
| operands[0]); |
| } else if (srcType.isF64() && dstType.isInteger(64)) { |
| rewriter.replaceOpWithNewOp<IREE::VM::BitcastF64I64Op>(srcOp, dstType, |
| operands[0]); |
| } else if (srcType.isInteger(64) && dstType.isF64()) { |
| rewriter.replaceOpWithNewOp<IREE::VM::BitcastI64F64Op>(srcOp, dstType, |
| operands[0]); |
| } else { |
| return rewriter.notifyMatchFailure(srcOp, "unsupported bitcast"); |
| } |
| return success(); |
| } |
| }; |
| |
| class SelectOpConversion : public OpConversionPattern<SelectOp> { |
| using OpConversionPattern::OpConversionPattern; |
| LogicalResult matchAndRewrite( |
| SelectOp srcOp, ArrayRef<Value> operands, |
| ConversionPatternRewriter &rewriter) const override { |
| SelectOp::Adaptor srcAdaptor(operands); |
| auto valueType = srcAdaptor.true_value().getType(); |
| if (valueType.isInteger(32)) { |
| rewriter.replaceOpWithNewOp<IREE::VM::SelectI32Op>( |
| srcOp, valueType, srcAdaptor.condition(), srcAdaptor.true_value(), |
| srcAdaptor.false_value()); |
| return success(); |
| } else if (valueType.isInteger(64)) { |
| rewriter.replaceOpWithNewOp<IREE::VM::SelectI64Op>( |
| srcOp, valueType, srcAdaptor.condition(), srcAdaptor.true_value(), |
| srcAdaptor.false_value()); |
| return success(); |
| } else if (valueType.isF32()) { |
| rewriter.replaceOpWithNewOp<IREE::VM::SelectF32Op>( |
| srcOp, valueType, srcAdaptor.condition(), srcAdaptor.true_value(), |
| srcAdaptor.false_value()); |
| return success(); |
| } else if (valueType.isF64()) { |
| rewriter.replaceOpWithNewOp<IREE::VM::SelectF64Op>( |
| srcOp, valueType, srcAdaptor.condition(), srcAdaptor.true_value(), |
| srcAdaptor.false_value()); |
| return success(); |
| } else if (valueType.isa<IREE::VM::RefType>()) { |
| rewriter.replaceOpWithNewOp<IREE::VM::SelectRefOp>( |
| srcOp, valueType, srcAdaptor.condition(), srcAdaptor.true_value(), |
| srcAdaptor.false_value()); |
| return success(); |
| } else { |
| return rewriter.notifyMatchFailure(srcOp, |
| "unsupported select element type"); |
| } |
| } |
| }; |
| |
| class AssertOpConversion : public OpConversionPattern<AssertOp> { |
| using OpConversionPattern::OpConversionPattern; |
| |
| LogicalResult matchAndRewrite( |
| AssertOp srcOp, ArrayRef<Value> operands, |
| ConversionPatternRewriter &rewriter) const override { |
| AssertOpAdaptor adaptor(operands); |
| Location loc = srcOp.getLoc(); |
| |
| // Start by splitting the block containing the assert into two. The part |
| // before will contain the condition, and the part after will contain |
| // the continuation point. |
| Block *condBlock = rewriter.getInsertionBlock(); |
| Block::iterator opPosition = rewriter.getInsertionPoint(); |
| Block *continuationBlock = rewriter.splitBlock(condBlock, opPosition); |
| |
| // Create a new block for the target of the failure. |
| Block *failureBlock; |
| { |
| OpBuilder::InsertionGuard guard(rewriter); |
| Region *parentRegion = condBlock->getParent(); |
| failureBlock = rewriter.createBlock(parentRegion, parentRegion->end()); |
| auto status = rewriter.create<IREE::VM::ConstI32Op>( |
| loc, rewriter.getIntegerAttr(rewriter.getIntegerType(32), |
| IREE_STATUS_FAILED_PRECONDITION)); |
| rewriter.create<IREE::VM::FailOp>(loc, status, srcOp.msgAttr()); |
| } |
| |
| rewriter.setInsertionPointToEnd(condBlock); |
| rewriter.replaceOpWithNewOp<CondBranchOp>(srcOp, adaptor.arg(), |
| continuationBlock, failureBlock); |
| return success(); |
| } |
| }; |
| |
| class BranchOpConversion : public OpConversionPattern<BranchOp> { |
| using OpConversionPattern::OpConversionPattern; |
| |
| LogicalResult matchAndRewrite( |
| BranchOp srcOp, ArrayRef<Value> operands, |
| ConversionPatternRewriter &rewriter) const override { |
| rewriter.replaceOpWithNewOp<IREE::VM::BranchOp>(srcOp, srcOp.getDest(), |
| operands); |
| return success(); |
| } |
| }; |
| |
| class CondBranchOpConversion : public OpConversionPattern<CondBranchOp> { |
| using OpConversionPattern::OpConversionPattern; |
| |
| LogicalResult matchAndRewrite( |
| CondBranchOp srcOp, ArrayRef<Value> operands, |
| ConversionPatternRewriter &rewriter) const override { |
| Block *trueDest = srcOp.getTrueDest(); |
| rewriter.replaceOpWithNewOp<IREE::VM::CondBranchOp>( |
| srcOp, operands[0], trueDest, |
| operands.slice(1, trueDest->getNumArguments()), srcOp.getFalseDest(), |
| operands.slice(1 + trueDest->getNumArguments())); |
| return success(); |
| } |
| }; |
| |
| class CallOpConversion : public OpConversionPattern<CallOp> { |
| using OpConversionPattern::OpConversionPattern; |
| |
| LogicalResult matchAndRewrite( |
| CallOp srcOp, ArrayRef<Value> operands, |
| ConversionPatternRewriter &rewriter) const override { |
| CallOp::Adaptor srcAdaptor(operands); |
| // Convert function result types. The conversion framework will ensure |
| // that the callee has been equivalently converted. |
| SmallVector<Type, 4> resultTypes; |
| for (auto resultType : srcOp.getResultTypes()) { |
| resultType = getTypeConverter()->convertType(resultType); |
| if (!resultType) { |
| return failure(); |
| } |
| resultTypes.push_back(resultType); |
| } |
| rewriter.replaceOpWithNewOp<IREE::VM::CallOp>( |
| srcOp, srcOp.getCallee(), resultTypes, srcAdaptor.operands()); |
| |
| return success(); |
| } |
| }; |
| |
| } // namespace |
| |
| void populateStandardToVMPatterns(MLIRContext *context, |
| TypeConverter &typeConverter, |
| OwningRewritePatternList &patterns) { |
| patterns.insert<AssertOpConversion, BranchOpConversion, CallOpConversion, |
| CmpIOpConversion, CmpFOpConversion, CondBranchOpConversion, |
| ModuleOpConversion, FuncOpConversion, ReturnOpConversion, |
| SelectOpConversion>(typeConverter, context); |
| |
| // TODO(#2878): figure out how to pass the type converter in a supported way. |
| // Right now if we pass the type converter as the first argument - triggering |
| // the ConversionPattern stuff - it'll do weird things. |
| patterns.insert<ConstantOpConversion>(context, typeConverter); |
| |
| patterns.insert<CastingOpConversion<UnrealizedConversionCastOp>, |
| IndexCastOpConversion, ZeroExtendIOpConversion, |
| SignExtendIOpConversion, TruncateIOpConversion>(typeConverter, |
| context); |
| |
| // Integer arithmetic ops. |
| patterns.insert< |
| BinaryArithmeticOpConversion<AddIOp, IREE::VM::AddI32Op, |
| IREE::VM::AddI64Op>, |
| BinaryArithmeticOpConversion<SignedDivIOp, IREE::VM::DivI32SOp, |
| IREE::VM::DivI64SOp>, |
| BinaryArithmeticOpConversion<UnsignedDivIOp, IREE::VM::DivI32UOp, |
| IREE::VM::DivI64UOp>, |
| BinaryArithmeticOpConversion<MulIOp, IREE::VM::MulI32Op, |
| IREE::VM::MulI64Op>, |
| BinaryArithmeticOpConversion<SignedRemIOp, IREE::VM::RemI32SOp, |
| IREE::VM::RemI64SOp>, |
| BinaryArithmeticOpConversion<UnsignedRemIOp, IREE::VM::RemI32UOp, |
| IREE::VM::RemI64UOp>, |
| BinaryArithmeticOpConversion<SubIOp, IREE::VM::SubI32Op, |
| IREE::VM::SubI64Op>, |
| BinaryArithmeticOpConversion<AndOp, IREE::VM::AndI32Op, |
| IREE::VM::AndI64Op>, |
| BinaryArithmeticOpConversion<OrOp, IREE::VM::OrI32Op, IREE::VM::OrI64Op>, |
| BinaryArithmeticOpConversion<XOrOp, IREE::VM::XorI32Op, |
| IREE::VM::XorI64Op>>(typeConverter, context); |
| |
| // Floating-point arithmetic ops. |
| patterns.insert<UnaryArithmeticOpConversion<AbsFOp, IREE::VM::AbsF32Op, |
| IREE::VM::AbsF64Op>, |
| BinaryArithmeticOpConversion<AddFOp, IREE::VM::AddF32Op, |
| IREE::VM::AddF64Op>, |
| UnaryArithmeticOpConversion<CeilFOp, IREE::VM::CeilF32Op, |
| IREE::VM::CeilF64Op>, |
| UnaryArithmeticOpConversion<FloorFOp, IREE::VM::FloorF32Op, |
| IREE::VM::FloorF64Op>, |
| BinaryArithmeticOpConversion<DivFOp, IREE::VM::DivF32Op, |
| IREE::VM::DivF64Op>, |
| BinaryArithmeticOpConversion<MulFOp, IREE::VM::MulF32Op, |
| IREE::VM::MulF64Op>, |
| UnaryArithmeticOpConversion<NegFOp, IREE::VM::NegF32Op, |
| IREE::VM::NegF64Op>, |
| BinaryArithmeticOpConversion<RemFOp, IREE::VM::RemF32Op, |
| IREE::VM::RemF64Op>, |
| BinaryArithmeticOpConversion<SubFOp, IREE::VM::SubF32Op, |
| IREE::VM::SubF64Op>>( |
| typeConverter, context); |
| |
| // Floating-point conversion ops. |
| patterns.insert<SIToFPOpConversion, UIToFPOpConversion, FPToSIOpConversion, |
| FPToUIOpConversion, BitcastOpConversion>(typeConverter, |
| context); |
| |
| // Shift ops. |
| patterns.insert< |
| ShiftArithmeticOpConversion<ShiftLeftOp, IREE::VM::ShlI32Op, |
| IREE::VM::ShlI64Op>, |
| ShiftArithmeticOpConversion<SignedShiftRightOp, IREE::VM::ShrI32SOp, |
| IREE::VM::ShrI64SOp>, |
| ShiftArithmeticOpConversion<UnsignedShiftRightOp, IREE::VM::ShrI32UOp, |
| IREE::VM::ShrI64UOp>>(typeConverter, context); |
| } |
| |
| } // namespace iree_compiler |
| } // namespace mlir |