blob: e5bb31f18ea55b8f93e0cd6cb0bfbb81bc6322ac [file] [log] [blame]
// Copyright 2019 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree/compiler/Dialect/VM/Conversion/StandardToVM/ConvertStandardToVM.h"
#include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
#include "iree/compiler/Dialect/VM/Conversion/ConversionTarget.h"
#include "iree/compiler/Dialect/VM/Conversion/TargetOptions.h"
#include "iree/compiler/Dialect/VM/Conversion/TypeConverter.h"
#include "iree/compiler/Dialect/VM/IR/VMOps.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/Transforms/DialectConversion.h"
namespace mlir {
namespace iree_compiler {
namespace {
class ModuleOpConversion : public OpConversionPattern<ModuleOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
ModuleOp srcOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Do not attempt to convert the top level module.
// This mechanism can only support rewriting non top-level modules.
if (VMConversionTarget::isTopLevelModule(srcOp)) {
return failure();
}
StringRef name = srcOp.getName() ? *srcOp.getName() : "module";
auto newModuleOp =
rewriter.create<IREE::VM::ModuleOp>(srcOp.getLoc(), name);
assert(!newModuleOp.getBodyRegion().empty());
Block *firstCreatedBlock = &newModuleOp.getBodyRegion().front();
rewriter.inlineRegionBefore(srcOp.getBodyRegion(), firstCreatedBlock);
auto blockRange = llvm::make_range(Region::iterator(firstCreatedBlock),
newModuleOp.getBodyRegion().end());
for (Block &block : llvm::make_early_inc_range(blockRange)) {
rewriter.eraseBlock(&block);
}
rewriter.replaceOp(srcOp, {});
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPointToEnd(&newModuleOp.getBodyRegion().front());
rewriter.create<IREE::VM::ModuleTerminatorOp>(srcOp.getLoc());
return success();
}
};
// Allowlist of function attributes to retain when converting to vm.func.
constexpr const char *kRetainedAttributes[] = {
"iree.reflection",
"sym_visibility",
"noinline",
};
class FuncOpConversion : public OpConversionPattern<FuncOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
FuncOp srcOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
FunctionType srcFuncType = srcOp.getType();
TypeConverter::SignatureConversion signatureConversion(
srcOp.getNumArguments());
// Convert function arguments.
for (unsigned i = 0, e = srcFuncType.getNumInputs(); i < e; ++i) {
if (failed(getTypeConverter()->convertSignatureArg(
i, srcFuncType.getInput(i), signatureConversion))) {
return rewriter.notifyMatchFailure(srcOp, "argument failed to convert");
}
}
// Convert function results.
SmallVector<Type, 1> convertedResultTypes;
if (failed(getTypeConverter()->convertTypes(srcFuncType.getResults(),
convertedResultTypes))) {
return rewriter.notifyMatchFailure(srcOp, "results failed to convert");
}
// Create new function with converted argument and result types.
// Note that attributes are dropped. Consider preserving some if needed.
auto newFuncType = mlir::FunctionType::get(
srcOp.getContext(), signatureConversion.getConvertedTypes(),
convertedResultTypes);
auto newFuncOp = rewriter.create<IREE::VM::FuncOp>(
srcOp.getLoc(), srcOp.getName(), newFuncType);
rewriter.inlineRegionBefore(srcOp.getBody(), newFuncOp.getBody(),
newFuncOp.end());
// Retain function attributes in the allowlist.
auto retainedAttributes = ArrayRef<const char *>(
kRetainedAttributes,
sizeof(kRetainedAttributes) / sizeof(kRetainedAttributes[0]));
for (auto retainAttrName : retainedAttributes) {
StringRef attrName(retainAttrName);
Attribute attr = srcOp->getAttr(attrName);
if (attr) {
newFuncOp->setAttr(attrName, attr);
}
}
// Tell the rewriter to convert the region signature.
TypeConverter &typeConverter = *getTypeConverter();
if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), typeConverter,
&signatureConversion))) {
return failure();
}
// Also add an export for the "raw" form of this function, which operates
// on low level VM types and does no verification. A later pass will
// materialize high level API-friendly wrappers.
if (srcOp.isPublic()) {
StringRef exportName = newFuncOp.getName();
rewriter.create<IREE::VM::ExportOp>(srcOp.getLoc(), newFuncOp,
exportName);
}
// VM functions are private by default and exported via the dedicated
// vm.export ops.
newFuncOp.setPrivate();
rewriter.replaceOp(srcOp, llvm::None);
return success();
}
};
class ReturnOpConversion : public OpConversionPattern<mlir::ReturnOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
mlir::ReturnOp srcOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<IREE::VM::ReturnOp>(srcOp,
adaptor.getOperands());
return success();
}
};
struct ConstantOpConversion : public OpConversionPattern<arith::ConstantOp> {
TypeConverter &typeConverter;
ConstantOpConversion(MLIRContext *context, TypeConverter &typeConverter)
: OpConversionPattern(context), typeConverter(typeConverter) {}
LogicalResult matchAndRewrite(
arith::ConstantOp srcOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto targetType = typeConverter.convertType(srcOp.getType());
if (!targetType) {
return srcOp.emitError() << "could not convert type: " << srcOp.getType()
<< " (check -iree-vm-target-* options)";
}
if (targetType.isa<IntegerType>()) {
auto integerAttr = srcOp.getValue().dyn_cast<IntegerAttr>();
if (!integerAttr) {
return srcOp.emitRemark() << "unsupported const type for dialect";
}
switch (targetType.getIntOrFloatBitWidth()) {
case 1:
case 32:
if (integerAttr.getInt()) {
rewriter.replaceOpWithNewOp<IREE::VM::ConstI32Op>(
srcOp,
integerAttr.getType().isInteger(1) ? 1 : integerAttr.getInt());
} else {
rewriter.replaceOpWithNewOp<IREE::VM::ConstI32ZeroOp>(srcOp);
}
break;
case 64:
if (integerAttr.getInt()) {
rewriter.replaceOpWithNewOp<IREE::VM::ConstI64Op>(
srcOp, integerAttr.getInt());
} else {
rewriter.replaceOpWithNewOp<IREE::VM::ConstI64ZeroOp>(srcOp);
}
break;
default:
return srcOp.emitRemark()
<< "unsupported const integer bit width for dialect";
}
} else if (targetType.isa<FloatType>()) {
auto floatAttr = srcOp.getValue().dyn_cast<FloatAttr>();
if (!floatAttr) {
return srcOp.emitRemark() << "unsupported const type for dialect";
}
switch (targetType.getIntOrFloatBitWidth()) {
case 32:
if (floatAttr.getValue().isZero()) {
rewriter.replaceOpWithNewOp<IREE::VM::ConstF32ZeroOp>(srcOp);
} else {
rewriter.replaceOpWithNewOp<IREE::VM::ConstF32Op>(srcOp, floatAttr);
}
break;
case 64:
if (floatAttr.getValue().isZero()) {
rewriter.replaceOpWithNewOp<IREE::VM::ConstF64ZeroOp>(srcOp);
} else {
rewriter.replaceOpWithNewOp<IREE::VM::ConstF64Op>(srcOp, floatAttr);
}
break;
default:
return srcOp.emitRemark()
<< "unsupported const floating-point bit width for dialect";
}
} else {
return rewriter.notifyMatchFailure(srcOp, "unsupported type");
}
return success();
}
};
struct CmpI32OpConversion : public OpConversionPattern<arith::CmpIOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
arith::CmpIOp srcOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (!adaptor.getLhs().getType().isInteger(32)) return failure();
auto returnType = rewriter.getIntegerType(32);
switch (srcOp.getPredicate()) {
case arith::CmpIPredicate::eq:
rewriter.replaceOpWithNewOp<IREE::VM::CmpEQI32Op>(
srcOp, returnType, adaptor.getLhs(), adaptor.getRhs());
return success();
case arith::CmpIPredicate::ne:
rewriter.replaceOpWithNewOp<IREE::VM::CmpNEI32Op>(
srcOp, returnType, adaptor.getLhs(), adaptor.getRhs());
return success();
case arith::CmpIPredicate::slt:
rewriter.replaceOpWithNewOp<IREE::VM::CmpLTI32SOp>(
srcOp, returnType, adaptor.getLhs(), adaptor.getRhs());
return success();
case arith::CmpIPredicate::sle:
rewriter.replaceOpWithNewOp<IREE::VM::CmpLTEI32SOp>(
srcOp, returnType, adaptor.getLhs(), adaptor.getRhs());
return success();
case arith::CmpIPredicate::sgt:
rewriter.replaceOpWithNewOp<IREE::VM::CmpGTI32SOp>(
srcOp, returnType, adaptor.getLhs(), adaptor.getRhs());
return success();
case arith::CmpIPredicate::sge:
rewriter.replaceOpWithNewOp<IREE::VM::CmpGTEI32SOp>(
srcOp, returnType, adaptor.getLhs(), adaptor.getRhs());
return success();
case arith::CmpIPredicate::ult:
rewriter.replaceOpWithNewOp<IREE::VM::CmpLTI32UOp>(
srcOp, returnType, adaptor.getLhs(), adaptor.getRhs());
return success();
case arith::CmpIPredicate::ule:
rewriter.replaceOpWithNewOp<IREE::VM::CmpLTEI32UOp>(
srcOp, returnType, adaptor.getLhs(), adaptor.getRhs());
return success();
case arith::CmpIPredicate::ugt:
rewriter.replaceOpWithNewOp<IREE::VM::CmpGTI32UOp>(
srcOp, returnType, adaptor.getLhs(), adaptor.getRhs());
return success();
case arith::CmpIPredicate::uge:
rewriter.replaceOpWithNewOp<IREE::VM::CmpGTEI32UOp>(
srcOp, returnType, adaptor.getLhs(), adaptor.getRhs());
return success();
default:
return failure();
}
}
};
struct CmpI64OpConversion : public OpConversionPattern<arith::CmpIOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
arith::CmpIOp srcOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (!adaptor.getLhs().getType().isInteger(64)) return failure();
auto returnType = rewriter.getIntegerType(32);
switch (srcOp.getPredicate()) {
case arith::CmpIPredicate::eq:
rewriter.replaceOpWithNewOp<IREE::VM::CmpEQI64Op>(
srcOp, returnType, adaptor.getLhs(), adaptor.getRhs());
return success();
case arith::CmpIPredicate::ne:
rewriter.replaceOpWithNewOp<IREE::VM::CmpNEI64Op>(
srcOp, returnType, adaptor.getLhs(), adaptor.getRhs());
return success();
case arith::CmpIPredicate::slt:
rewriter.replaceOpWithNewOp<IREE::VM::CmpLTI64SOp>(
srcOp, returnType, adaptor.getLhs(), adaptor.getRhs());
return success();
case arith::CmpIPredicate::sle:
rewriter.replaceOpWithNewOp<IREE::VM::CmpLTEI64SOp>(
srcOp, returnType, adaptor.getLhs(), adaptor.getRhs());
return success();
case arith::CmpIPredicate::sgt:
rewriter.replaceOpWithNewOp<IREE::VM::CmpGTI64SOp>(
srcOp, returnType, adaptor.getLhs(), adaptor.getRhs());
return success();
case arith::CmpIPredicate::sge:
rewriter.replaceOpWithNewOp<IREE::VM::CmpGTEI64SOp>(
srcOp, returnType, adaptor.getLhs(), adaptor.getRhs());
return success();
case arith::CmpIPredicate::ult:
rewriter.replaceOpWithNewOp<IREE::VM::CmpLTI64UOp>(
srcOp, returnType, adaptor.getLhs(), adaptor.getRhs());
return success();
case arith::CmpIPredicate::ule:
rewriter.replaceOpWithNewOp<IREE::VM::CmpLTEI64UOp>(
srcOp, returnType, adaptor.getLhs(), adaptor.getRhs());
return success();
case arith::CmpIPredicate::ugt:
rewriter.replaceOpWithNewOp<IREE::VM::CmpGTI64UOp>(
srcOp, returnType, adaptor.getLhs(), adaptor.getRhs());
return success();
case arith::CmpIPredicate::uge:
rewriter.replaceOpWithNewOp<IREE::VM::CmpGTEI64UOp>(
srcOp, returnType, adaptor.getLhs(), adaptor.getRhs());
return success();
default:
return failure();
}
}
};
struct CmpF32OpConversion : public OpConversionPattern<arith::CmpFOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
arith::CmpFOp srcOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (!adaptor.getLhs().getType().isF32()) return failure();
auto returnType = rewriter.getIntegerType(32);
switch (srcOp.getPredicate()) {
case arith::CmpFPredicate::AlwaysFalse: // 0
rewriter.replaceOpWithNewOp<IREE::VM::ConstI32ZeroOp>(srcOp);
break;
case arith::CmpFPredicate::AlwaysTrue: // 1
rewriter.replaceOpWithNewOp<IREE::VM::ConstI32Op>(srcOp, 1);
break;
case arith::CmpFPredicate::UNO: // isnan(lhs) || isnan(rhs)
rewriter.replaceOpWithNewOp<IREE::VM::OrI32Op>(
srcOp, returnType,
rewriter.createOrFold<IREE::VM::CmpNaNF32Op>(
srcOp.getLoc(), returnType, adaptor.getLhs()),
rewriter.createOrFold<IREE::VM::CmpNaNF32Op>(
srcOp.getLoc(), returnType, adaptor.getRhs()));
break;
case arith::CmpFPredicate::ORD: // !(isnan(lhs) || isnan(rhs))
rewriter.replaceOpWithNewOp<IREE::VM::XorI32Op>(
srcOp, returnType,
rewriter.createOrFold<IREE::VM::ConstI32Op>(srcOp.getLoc(), 1),
rewriter.createOrFold<IREE::VM::AndI32Op>(
srcOp.getLoc(), returnType,
rewriter.createOrFold<IREE::VM::CmpNaNF32Op>(
srcOp.getLoc(), returnType, adaptor.getLhs()),
rewriter.createOrFold<IREE::VM::CmpNaNF32Op>(
srcOp.getLoc(), returnType, adaptor.getRhs())));
break;
case arith::CmpFPredicate::OEQ: // ordered and equal
rewriter.replaceOpWithNewOp<IREE::VM::CmpEQF32OOp>(
srcOp, returnType, adaptor.getLhs(), adaptor.getRhs());
break;
case arith::CmpFPredicate::OGT: // ordered and greater than
rewriter.replaceOpWithNewOp<IREE::VM::CmpGTF32OOp>(
srcOp, returnType, adaptor.getLhs(), adaptor.getRhs());
break;
case arith::CmpFPredicate::OGE: // ordered and greater than or equal
rewriter.replaceOpWithNewOp<IREE::VM::CmpGTEF32OOp>(
srcOp, returnType, adaptor.getLhs(), adaptor.getRhs());
break;
case arith::CmpFPredicate::OLT: // ordered and less than
rewriter.replaceOpWithNewOp<IREE::VM::CmpLTF32OOp>(
srcOp, returnType, adaptor.getLhs(), adaptor.getRhs());
break;
case arith::CmpFPredicate::OLE: // ordered and less than or equal
rewriter.replaceOpWithNewOp<IREE::VM::CmpLTEF32OOp>(
srcOp, returnType, adaptor.getLhs(), adaptor.getRhs());
break;
case arith::CmpFPredicate::ONE: // ordered and not equal
rewriter.replaceOpWithNewOp<IREE::VM::CmpNEF32OOp>(
srcOp, returnType, adaptor.getLhs(), adaptor.getRhs());
break;
case arith::CmpFPredicate::UEQ: // unordered or equal
rewriter.replaceOpWithNewOp<IREE::VM::CmpEQF32UOp>(
srcOp, returnType, adaptor.getLhs(), adaptor.getRhs());
break;
case arith::CmpFPredicate::UGT: // unordered or greater than
rewriter.replaceOpWithNewOp<IREE::VM::CmpGTF32UOp>(
srcOp, returnType, adaptor.getLhs(), adaptor.getRhs());
break;
case arith::CmpFPredicate::UGE: // unordered or greater than or equal
rewriter.replaceOpWithNewOp<IREE::VM::CmpGTEF32UOp>(
srcOp, returnType, adaptor.getLhs(), adaptor.getRhs());
break;
case arith::CmpFPredicate::ULT: // unordered or less than
rewriter.replaceOpWithNewOp<IREE::VM::CmpLTF32UOp>(
srcOp, returnType, adaptor.getLhs(), adaptor.getRhs());
break;
case arith::CmpFPredicate::ULE: // unordered or less than or equal
rewriter.replaceOpWithNewOp<IREE::VM::CmpLTEF32UOp>(
srcOp, returnType, adaptor.getLhs(), adaptor.getRhs());
break;
case arith::CmpFPredicate::UNE: // unordered or not equal
rewriter.replaceOpWithNewOp<IREE::VM::CmpNEF32UOp>(
srcOp, returnType, adaptor.getLhs(), adaptor.getRhs());
break;
default:
return rewriter.notifyMatchFailure(srcOp,
"unhandled arith::CmpFPredicate");
}
return success();
}
};
template <typename SrcOpTy, typename Dst32OpTy, typename Dst64OpTy>
class UnaryArithmeticOpConversion : public OpConversionPattern<SrcOpTy> {
using OpConversionPattern<SrcOpTy>::OpConversionPattern;
LogicalResult matchAndRewrite(
SrcOpTy srcOp, typename SrcOpTy::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
switch (adaptor.getOperand().getType().getIntOrFloatBitWidth()) {
case 32:
rewriter.replaceOpWithNewOp<Dst32OpTy>(
srcOp, adaptor.getOperand().getType(), adaptor.getOperand());
break;
case 64:
rewriter.replaceOpWithNewOp<Dst64OpTy>(
srcOp, adaptor.getOperand().getType(), adaptor.getOperand());
break;
default:
return rewriter.notifyMatchFailure(srcOp, "unsupported type");
}
return success();
}
};
template <typename SrcOpTy, typename Dst32OpTy, typename Dst64OpTy>
class BinaryArithmeticOpConversion : public OpConversionPattern<SrcOpTy> {
using OpConversionPattern<SrcOpTy>::OpConversionPattern;
LogicalResult matchAndRewrite(
SrcOpTy srcOp, typename SrcOpTy::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
switch (adaptor.getLhs().getType().getIntOrFloatBitWidth()) {
case 32:
rewriter.replaceOpWithNewOp<Dst32OpTy>(
srcOp, adaptor.getLhs().getType(), adaptor.getLhs(),
adaptor.getRhs());
break;
case 64:
rewriter.replaceOpWithNewOp<Dst64OpTy>(
srcOp, adaptor.getLhs().getType(), adaptor.getLhs(),
adaptor.getRhs());
break;
default:
return rewriter.notifyMatchFailure(srcOp, "unsupported type");
}
return success();
}
};
template <typename SrcOpTy, typename Dst32OpTy, typename Dst64OpTy>
class ShiftArithmeticOpConversion : public OpConversionPattern<SrcOpTy> {
using OpConversionPattern<SrcOpTy>::OpConversionPattern;
LogicalResult matchAndRewrite(
SrcOpTy srcOp, typename SrcOpTy::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value amount = adaptor.getRhs();
if (amount.getType().getIntOrFloatBitWidth() > 32) {
// Shift amounts are always 32-bit in the VM.
amount = rewriter.createOrFold<arith::TruncIOp>(
srcOp.getLoc(), rewriter.getI32Type(), amount);
}
switch (adaptor.getLhs().getType().getIntOrFloatBitWidth()) {
case 32:
rewriter.replaceOpWithNewOp<Dst32OpTy>(srcOp, srcOp.getType(),
adaptor.getLhs(), amount);
break;
case 64:
rewriter.replaceOpWithNewOp<Dst64OpTy>(srcOp, srcOp.getType(),
adaptor.getLhs(), amount);
break;
default:
return rewriter.notifyMatchFailure(srcOp, "unsupported type");
}
return success();
}
};
template <typename StdOp>
class CastingOpConversion : public OpConversionPattern<StdOp> {
using OpConversionPattern<StdOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
StdOp srcOp, typename StdOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOp(srcOp, adaptor.getOperands());
return success();
}
};
class IndexCastOpConversion : public OpConversionPattern<arith::IndexCastOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
arith::IndexCastOp srcOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto srcType = adaptor.getIn().getType();
auto dstType = getTypeConverter()->convertType(srcOp.getResult().getType());
if (srcType == dstType) {
rewriter.replaceOp(srcOp, adaptor.getOperands());
} else if (srcType.getIntOrFloatBitWidth() <
dstType.getIntOrFloatBitWidth()) {
rewriter.replaceOpWithNewOp<arith::ExtUIOp>(srcOp, dstType,
adaptor.getIn());
} else {
rewriter.replaceOpWithNewOp<arith::TruncIOp>(srcOp, dstType,
adaptor.getIn());
}
return success();
}
};
class ZeroExtendIOpConversion : public OpConversionPattern<arith::ExtUIOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
arith::ExtUIOp srcOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto srcType = srcOp.getIn().getType();
auto dstType = getTypeConverter()->convertType(srcOp.getResult().getType());
if (srcType.isInteger(1) && dstType.isInteger(32)) {
// This may not be needed but ensures that the input was treated as a
// single bit.
// NOTE: this may not be required - if we know that the i1 is never able
// to have more than bit 0 manipulated then this is wasted work.
rewriter.replaceOpWithNewOp<IREE::VM::AndI32Op>(
srcOp, dstType, adaptor.getIn(),
rewriter.createOrFold<IREE::VM::ConstI32Op>(srcOp.getLoc(), 1));
} else if (srcType.isInteger(8) && dstType.isInteger(32)) {
rewriter.replaceOpWithNewOp<IREE::VM::ExtI8I32UOp>(srcOp, dstType,
adaptor.getIn());
} else if (srcType.isInteger(16) && dstType.isInteger(32)) {
rewriter.replaceOpWithNewOp<IREE::VM::ExtI16I32UOp>(srcOp, dstType,
adaptor.getIn());
} else if (srcType.isInteger(32) && dstType.isInteger(64)) {
rewriter.replaceOpWithNewOp<IREE::VM::ExtI32I64UOp>(srcOp, dstType,
adaptor.getIn());
} else {
// TODO(benvanik): we should be building a sequence of extensions for
// things like i8 -> i64.
return rewriter.notifyMatchFailure(srcOp, "unsupported zero extension");
}
return success();
}
};
class SignExtendIOpConversion : public OpConversionPattern<arith::ExtSIOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
arith::ExtSIOp srcOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto srcType = srcOp.getIn().getType();
auto dstType = getTypeConverter()->convertType(srcOp.getResult().getType());
if (srcType.isInteger(8) && dstType.isInteger(32)) {
rewriter.replaceOpWithNewOp<IREE::VM::ExtI8I32SOp>(srcOp, dstType,
adaptor.getIn());
} else if (srcType.isInteger(16) && dstType.isInteger(32)) {
rewriter.replaceOpWithNewOp<IREE::VM::ExtI16I32SOp>(srcOp, dstType,
adaptor.getIn());
} else if (srcType.isInteger(32) && dstType.isInteger(64)) {
rewriter.replaceOpWithNewOp<IREE::VM::ExtI32I64SOp>(srcOp, dstType,
adaptor.getIn());
} else {
// TODO(benvanik): we should be building a sequence of extensions for
// things like i8 -> i64.
return rewriter.notifyMatchFailure(srcOp, "unsupported sign extension");
}
return success();
}
};
class TruncateIOpConversion : public OpConversionPattern<arith::TruncIOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
arith::TruncIOp srcOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto srcType = srcOp.getIn().getType();
auto resultType = srcOp.getResult().getType();
auto dstType = getTypeConverter()->convertType(resultType);
if (resultType.isInteger(1)) {
// i1 is represented as i32, so just mask off the bit and truncate as
// normal. Note that if we started as i64 we need to first get that into
// an i32 that we can work with.
auto value = adaptor.getIn();
if (srcType.isInteger(64)) {
value = rewriter.createOrFold<IREE::VM::TruncI64I32Op>(srcOp.getLoc(),
dstType, value);
}
rewriter.replaceOpWithNewOp<IREE::VM::AndI32Op>(
srcOp, dstType, value,
rewriter.createOrFold<IREE::VM::ConstI32Op>(srcOp.getLoc(), 1));
} else if (srcType.isInteger(32) && resultType.isInteger(8)) {
rewriter.replaceOpWithNewOp<IREE::VM::TruncI32I8Op>(srcOp, dstType,
adaptor.getIn());
} else if (srcType.isInteger(32) && resultType.isInteger(16)) {
rewriter.replaceOpWithNewOp<IREE::VM::TruncI32I16Op>(srcOp, dstType,
adaptor.getIn());
} else if (srcType.isInteger(64) && resultType.isInteger(8)) {
rewriter.replaceOpWithNewOp<IREE::VM::TruncI64I8Op>(srcOp, dstType,
adaptor.getIn());
} else if (srcType.isInteger(64) && resultType.isInteger(16)) {
rewriter.replaceOpWithNewOp<IREE::VM::TruncI64I16Op>(srcOp, dstType,
adaptor.getIn());
} else if (srcType.isInteger(64) && resultType.isInteger(32)) {
rewriter.replaceOpWithNewOp<IREE::VM::TruncI64I32Op>(srcOp, dstType,
adaptor.getIn());
} else {
return rewriter.notifyMatchFailure(srcOp, "unsupported truncation");
}
return success();
}
};
class SIToFPOpConversion : public OpConversionPattern<arith::SIToFPOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
arith::SIToFPOp srcOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto srcType = adaptor.getOperands()[0].getType();
auto dstType = getTypeConverter()->convertType(srcOp.getResult().getType());
if (srcType.isSignlessInteger(32) || srcType.isSignedInteger(32)) {
if (dstType.isF32()) {
rewriter.replaceOpWithNewOp<IREE::VM::CastSI32F32Op>(
srcOp, dstType, adaptor.getOperands()[0]);
return success();
}
}
return rewriter.notifyMatchFailure(srcOp, "unsupported type");
}
};
class UIToFPOpConversion : public OpConversionPattern<arith::UIToFPOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
arith::UIToFPOp srcOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto srcType = adaptor.getOperands()[0].getType();
auto dstType = getTypeConverter()->convertType(srcOp.getResult().getType());
if (srcType.isUnsignedInteger(32)) {
if (dstType.isF32()) {
rewriter.replaceOpWithNewOp<IREE::VM::CastUI32F32Op>(
srcOp, dstType, adaptor.getOperands()[0]);
return success();
}
}
return rewriter.notifyMatchFailure(srcOp, "unsupported type");
}
};
class FPToSIOpConversion : public OpConversionPattern<arith::FPToSIOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
arith::FPToSIOp srcOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto srcType = adaptor.getOperands()[0].getType();
auto dstType = getTypeConverter()->convertType(srcOp.getResult().getType());
if (srcType.isF32()) {
if (dstType.isSignlessInteger(32) || dstType.isSignedInteger(32)) {
rewriter.replaceOpWithNewOp<IREE::VM::CastF32SI32Op>(
srcOp, dstType, adaptor.getOperands()[0]);
return success();
}
}
return rewriter.notifyMatchFailure(srcOp, "unsupported type");
}
};
class FPToUIOpConversion : public OpConversionPattern<arith::FPToUIOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
arith::FPToUIOp srcOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto srcType = adaptor.getOperands()[0].getType();
auto dstType = getTypeConverter()->convertType(srcOp.getResult().getType());
if (srcType.isF32()) {
if (srcType.isUnsignedInteger(32)) {
rewriter.replaceOpWithNewOp<IREE::VM::CastF32UI32Op>(
srcOp, dstType, adaptor.getOperands()[0]);
return success();
}
}
return rewriter.notifyMatchFailure(srcOp, "unsupported type");
}
};
class BitcastOpConversion : public OpConversionPattern<arith::BitcastOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
arith::BitcastOp srcOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto srcType = adaptor.getOperands()[0].getType();
auto dstType = getTypeConverter()->convertType(srcOp.getResult().getType());
if (srcType.isF32() && dstType.isInteger(32)) {
rewriter.replaceOpWithNewOp<IREE::VM::BitcastF32I32Op>(
srcOp, dstType, adaptor.getOperands()[0]);
} else if (srcType.isInteger(32) && dstType.isF32()) {
rewriter.replaceOpWithNewOp<IREE::VM::BitcastI32F32Op>(
srcOp, dstType, adaptor.getOperands()[0]);
} else if (srcType.isF64() && dstType.isInteger(64)) {
rewriter.replaceOpWithNewOp<IREE::VM::BitcastF64I64Op>(
srcOp, dstType, adaptor.getOperands()[0]);
} else if (srcType.isInteger(64) && dstType.isF64()) {
rewriter.replaceOpWithNewOp<IREE::VM::BitcastI64F64Op>(
srcOp, dstType, adaptor.getOperands()[0]);
} else {
return rewriter.notifyMatchFailure(srcOp, "unsupported bitcast");
}
return success();
}
};
class SelectOpConversion : public OpConversionPattern<SelectOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
SelectOp srcOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto valueType = adaptor.getTrueValue().getType();
if (valueType.isInteger(32)) {
rewriter.replaceOpWithNewOp<IREE::VM::SelectI32Op>(
srcOp, valueType, adaptor.getCondition(), adaptor.getTrueValue(),
adaptor.getFalseValue());
return success();
} else if (valueType.isInteger(64)) {
rewriter.replaceOpWithNewOp<IREE::VM::SelectI64Op>(
srcOp, valueType, adaptor.getCondition(), adaptor.getTrueValue(),
adaptor.getFalseValue());
return success();
} else if (valueType.isF32()) {
rewriter.replaceOpWithNewOp<IREE::VM::SelectF32Op>(
srcOp, valueType, adaptor.getCondition(), adaptor.getTrueValue(),
adaptor.getFalseValue());
return success();
} else if (valueType.isF64()) {
rewriter.replaceOpWithNewOp<IREE::VM::SelectF64Op>(
srcOp, valueType, adaptor.getCondition(), adaptor.getTrueValue(),
adaptor.getFalseValue());
return success();
} else if (valueType.isa<IREE::VM::RefType>()) {
rewriter.replaceOpWithNewOp<IREE::VM::SelectRefOp>(
srcOp, valueType, adaptor.getCondition(), adaptor.getTrueValue(),
adaptor.getFalseValue());
return success();
} else {
return rewriter.notifyMatchFailure(srcOp,
"unsupported select element type");
}
}
};
class AssertOpConversion : public OpConversionPattern<AssertOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
AssertOp srcOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto status = rewriter.create<IREE::VM::ConstI32Op>(
srcOp.getLoc(),
rewriter.getIntegerAttr(
rewriter.getIntegerType(32),
static_cast<int32_t>(IREE::Util::StatusCode::FailedPrecondition)));
// TODO(benvanik): invert cond_fail instead.
auto invertedCondition = rewriter.createOrFold<IREE::VM::XorI32Op>(
srcOp.getLoc(), adaptor.getArg().getType(), adaptor.getArg(),
rewriter.createOrFold<IREE::VM::ConstI32Op>(srcOp.getLoc(), 1));
rewriter.replaceOpWithNewOp<IREE::VM::CondFailOp>(srcOp, invertedCondition,
status, adaptor.getMsg());
return success();
}
};
class BranchOpConversion : public OpConversionPattern<BranchOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
BranchOp srcOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<IREE::VM::BranchOp>(srcOp, srcOp.getDest(),
adaptor.getOperands());
return success();
}
};
class CondBranchOpConversion : public OpConversionPattern<CondBranchOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
CondBranchOp srcOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Block *trueDest = srcOp.getTrueDest();
rewriter.replaceOpWithNewOp<IREE::VM::CondBranchOp>(
srcOp, adaptor.getCondition(), trueDest, adaptor.getTrueDestOperands(),
srcOp.getFalseDest(), adaptor.getFalseDestOperands());
return success();
}
};
class CallOpConversion : public OpConversionPattern<CallOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
CallOp srcOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Convert function result types. The conversion framework will ensure
// that the callee has been equivalently converted.
SmallVector<Type, 4> resultTypes;
for (auto resultType : srcOp.getResultTypes()) {
resultType = getTypeConverter()->convertType(resultType);
if (!resultType) {
return failure();
}
resultTypes.push_back(resultType);
}
rewriter.replaceOpWithNewOp<IREE::VM::CallOp>(
srcOp, srcOp.getCallee(), resultTypes, adaptor.operands());
return success();
}
};
} // namespace
void populateStandardToVMPatterns(MLIRContext *context,
TypeConverter &typeConverter,
OwningRewritePatternList &patterns) {
patterns.insert<AssertOpConversion, BranchOpConversion, CallOpConversion,
CmpI32OpConversion, CmpI64OpConversion, CmpF32OpConversion,
CondBranchOpConversion, ModuleOpConversion, FuncOpConversion,
ReturnOpConversion, SelectOpConversion>(typeConverter,
context);
// TODO(#2878): figure out how to pass the type converter in a supported way.
// Right now if we pass the type converter as the first argument - triggering
// the ConversionPattern stuff - it'll do weird things.
patterns.insert<ConstantOpConversion>(context, typeConverter);
patterns.insert<CastingOpConversion<UnrealizedConversionCastOp>,
IndexCastOpConversion, ZeroExtendIOpConversion,
SignExtendIOpConversion, TruncateIOpConversion>(typeConverter,
context);
// Integer arithmetic ops.
patterns
.insert<BinaryArithmeticOpConversion<arith::AddIOp, IREE::VM::AddI32Op,
IREE::VM::AddI64Op>,
BinaryArithmeticOpConversion<arith::DivSIOp, IREE::VM::DivI32SOp,
IREE::VM::DivI64SOp>,
BinaryArithmeticOpConversion<arith::DivUIOp, IREE::VM::DivI32UOp,
IREE::VM::DivI64UOp>,
BinaryArithmeticOpConversion<arith::MulIOp, IREE::VM::MulI32Op,
IREE::VM::MulI64Op>,
BinaryArithmeticOpConversion<arith::RemSIOp, IREE::VM::RemI32SOp,
IREE::VM::RemI64SOp>,
BinaryArithmeticOpConversion<arith::RemUIOp, IREE::VM::RemI32UOp,
IREE::VM::RemI64UOp>,
BinaryArithmeticOpConversion<arith::SubIOp, IREE::VM::SubI32Op,
IREE::VM::SubI64Op>,
BinaryArithmeticOpConversion<arith::AndIOp, IREE::VM::AndI32Op,
IREE::VM::AndI64Op>,
BinaryArithmeticOpConversion<arith::OrIOp, IREE::VM::OrI32Op,
IREE::VM::OrI64Op>,
BinaryArithmeticOpConversion<arith::XOrIOp, IREE::VM::XorI32Op,
IREE::VM::XorI64Op>>(typeConverter,
context);
// Floating-point arithmetic ops.
patterns
.insert<UnaryArithmeticOpConversion<math::AbsOp, IREE::VM::AbsF32Op,
IREE::VM::AbsF64Op>,
BinaryArithmeticOpConversion<arith::AddFOp, IREE::VM::AddF32Op,
IREE::VM::AddF64Op>,
UnaryArithmeticOpConversion<math::CeilOp, IREE::VM::CeilF32Op,
IREE::VM::CeilF64Op>,
UnaryArithmeticOpConversion<math::FloorOp, IREE::VM::FloorF32Op,
IREE::VM::FloorF64Op>,
BinaryArithmeticOpConversion<arith::DivFOp, IREE::VM::DivF32Op,
IREE::VM::DivF64Op>,
BinaryArithmeticOpConversion<arith::MulFOp, IREE::VM::MulF32Op,
IREE::VM::MulF64Op>,
UnaryArithmeticOpConversion<arith::NegFOp, IREE::VM::NegF32Op,
IREE::VM::NegF64Op>,
BinaryArithmeticOpConversion<arith::RemFOp, IREE::VM::RemF32Op,
IREE::VM::RemF64Op>,
BinaryArithmeticOpConversion<arith::SubFOp, IREE::VM::SubF32Op,
IREE::VM::SubF64Op>>(typeConverter,
context);
// Floating-point conversion ops.
patterns.insert<SIToFPOpConversion, UIToFPOpConversion, FPToSIOpConversion,
FPToUIOpConversion, BitcastOpConversion>(typeConverter,
context);
// Shift ops.
patterns
.insert<ShiftArithmeticOpConversion<arith::ShLIOp, IREE::VM::ShlI32Op,
IREE::VM::ShlI64Op>,
ShiftArithmeticOpConversion<arith::ShRSIOp, IREE::VM::ShrI32SOp,
IREE::VM::ShrI64SOp>,
ShiftArithmeticOpConversion<arith::ShRUIOp, IREE::VM::ShrI32UOp,
IREE::VM::ShrI64UOp>>(typeConverter,
context);
}
} // namespace iree_compiler
} // namespace mlir