blob: 72358a09aabda00dc76c2196b06d60da58e25bb5 [file] [log] [blame]
// Copyright 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "iree/compiler/Dialect/VM/Conversion/StandardToVM/ConvertStandardToVM.h"
#include "iree/compiler/Dialect/IREE/IR/IREETypes.h"
#include "iree/compiler/Dialect/VM/Conversion/TypeConverter.h"
#include "iree/compiler/Dialect/VM/IR/VMOps.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/Module.h"
#include "mlir/Transforms/DialectConversion.h"
namespace mlir {
namespace iree_compiler {
namespace {
class ModuleOpConversion : public OpConversionPattern<ModuleOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
ModuleOp srcOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
// Do not attempt to convert the top level module.
// This mechanism can only support rewriting non top-level modules.
if (!srcOp.getParentOp() || !isa<ModuleOp>(srcOp.getParentOp())) {
return failure();
}
StringRef name = srcOp.getName() ? *srcOp.getName() : "module";
auto newModuleOp =
rewriter.create<IREE::VM::ModuleOp>(srcOp.getLoc(), name);
assert(!newModuleOp.getBodyRegion().empty());
Block *firstCreatedBlock = &newModuleOp.getBodyRegion().front();
rewriter.inlineRegionBefore(srcOp.getBodyRegion(), firstCreatedBlock);
auto blockRange = llvm::make_range(Region::iterator(firstCreatedBlock),
newModuleOp.getBodyRegion().end());
for (Block &block : llvm::make_early_inc_range(blockRange)) {
rewriter.eraseBlock(&block);
}
rewriter.replaceOp(srcOp, {});
return success();
}
};
class ModuleTerminatorOpConversion
: public OpConversionPattern<ModuleTerminatorOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
ModuleTerminatorOp srcOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
// Do not attempt to convert the top level module's terminator.
// This mechanism can only support rewriting non top-level modules.
if (!isa<IREE::VM::ModuleOp>(srcOp.getParentOp())) {
return failure();
}
rewriter.replaceOpWithNewOp<IREE::VM::ModuleTerminatorOp>(srcOp);
return success();
}
};
// Allowlist of function attributes to retain when converting to vm.func.
constexpr const char *kRetainedAttributes[] = {
"iree.reflection",
"sym_visibility",
};
class FuncOpConversion : public OpConversionPattern<FuncOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
FuncOp srcOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
FunctionType srcFuncType = srcOp.getType();
TypeConverter::SignatureConversion signatureConversion(
srcOp.getNumArguments());
// Convert function arguments.
for (unsigned i = 0, e = srcFuncType.getNumInputs(); i < e; ++i) {
if (failed(typeConverter.convertSignatureArg(i, srcFuncType.getInput(i),
signatureConversion))) {
return failure();
}
}
// Convert function results.
SmallVector<Type, 1> convertedResultTypes;
if (failed(typeConverter.convertTypes(srcFuncType.getResults(),
convertedResultTypes))) {
return failure();
}
// Create new function with converted argument and result types.
// Note that attributes are dropped. Consider preserving some if needed.
auto newFuncType =
mlir::FunctionType::get(signatureConversion.getConvertedTypes(),
convertedResultTypes, srcOp.getContext());
auto newFuncOp = rewriter.create<IREE::VM::FuncOp>(
srcOp.getLoc(), srcOp.getName(), newFuncType);
rewriter.inlineRegionBefore(srcOp.getBody(), newFuncOp.getBody(),
newFuncOp.end());
// Retain function attributes in the allowlist.
auto retainedAttributes = ArrayRef<const char *>(
kRetainedAttributes,
sizeof(kRetainedAttributes) / sizeof(kRetainedAttributes[0]));
for (auto retainAttrName : retainedAttributes) {
StringRef attrName(retainAttrName);
Attribute attr = srcOp.getAttr(attrName);
if (attr) {
newFuncOp.setAttr(attrName, attr);
}
}
// Tell the rewriter to convert the region signature.
if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), typeConverter,
&signatureConversion))) {
return failure();
}
// Also add an export for the "raw" form of this function, which operates
// on low level VM types and does no verification. A later pass will
// materialize high level API-friendly wrappers.
if (auto exportAttr = srcOp.getAttr("iree.module.export")) {
StringRef exportName = newFuncOp.getName();
if (auto exportStrAttr = exportAttr.dyn_cast<StringAttr>()) {
exportName = exportStrAttr.getValue();
} else {
assert(exportAttr.isa<UnitAttr>());
}
rewriter.create<IREE::VM::ExportOp>(srcOp.getLoc(), newFuncOp,
exportName);
}
rewriter.replaceOp(srcOp, llvm::None);
return success();
}
mutable VMTypeConverter typeConverter;
};
class ReturnOpConversion : public OpConversionPattern<mlir::ReturnOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
mlir::ReturnOp srcOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<IREE::VM::ReturnOp>(srcOp, operands);
return success();
}
};
class ConstantOpConversion : public OpConversionPattern<ConstantOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
ConstantOp srcOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto integerAttr = srcOp.getValue().dyn_cast<IntegerAttr>();
// Only 32bit integer supported for now.
if (!integerAttr) {
srcOp.emitRemark() << "unsupported const type for dialect";
return failure();
}
if (integerAttr.getType().isIntOrFloat()) {
int numBits = integerAttr.getType().getIntOrFloatBitWidth();
if (numBits != 1 && numBits != 32) {
srcOp.emitRemark() << "unsupported bit width for dialect constant";
return failure();
}
}
auto intValue = integerAttr.getInt();
if (intValue == 0) {
rewriter.replaceOpWithNewOp<IREE::VM::ConstI32ZeroOp>(srcOp);
} else {
rewriter.replaceOpWithNewOp<IREE::VM::ConstI32Op>(srcOp, intValue);
}
return success();
}
};
class CmpIOpConversion : public OpConversionPattern<CmpIOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
CmpIOp srcOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
CmpIOp::Adaptor srcAdapter(operands);
auto returnType = rewriter.getIntegerType(32);
switch (srcOp.getPredicate()) {
case CmpIPredicate::eq:
rewriter.replaceOpWithNewOp<IREE::VM::CmpEQI32Op>(
srcOp, returnType, srcAdapter.lhs(), srcAdapter.rhs());
return success();
case CmpIPredicate::ne:
rewriter.replaceOpWithNewOp<IREE::VM::CmpNEI32Op>(
srcOp, returnType, srcAdapter.lhs(), srcAdapter.rhs());
return success();
case CmpIPredicate::slt:
rewriter.replaceOpWithNewOp<IREE::VM::CmpLTI32SOp>(
srcOp, returnType, srcAdapter.lhs(), srcAdapter.rhs());
return success();
case CmpIPredicate::sle:
rewriter.replaceOpWithNewOp<IREE::VM::CmpLTEI32SOp>(
srcOp, returnType, srcAdapter.lhs(), srcAdapter.rhs());
return success();
case CmpIPredicate::sgt:
rewriter.replaceOpWithNewOp<IREE::VM::CmpGTI32SOp>(
srcOp, returnType, srcAdapter.lhs(), srcAdapter.rhs());
return success();
case CmpIPredicate::sge:
rewriter.replaceOpWithNewOp<IREE::VM::CmpGTEI32SOp>(
srcOp, returnType, srcAdapter.lhs(), srcAdapter.rhs());
return success();
case CmpIPredicate::ult:
rewriter.replaceOpWithNewOp<IREE::VM::CmpLTI32UOp>(
srcOp, returnType, srcAdapter.lhs(), srcAdapter.rhs());
return success();
case CmpIPredicate::ule:
rewriter.replaceOpWithNewOp<IREE::VM::CmpLTEI32UOp>(
srcOp, returnType, srcAdapter.lhs(), srcAdapter.rhs());
return success();
case CmpIPredicate::ugt:
rewriter.replaceOpWithNewOp<IREE::VM::CmpGTI32UOp>(
srcOp, returnType, srcAdapter.lhs(), srcAdapter.rhs());
return success();
case CmpIPredicate::uge:
rewriter.replaceOpWithNewOp<IREE::VM::CmpGTEI32UOp>(
srcOp, returnType, srcAdapter.lhs(), srcAdapter.rhs());
return success();
default:
return failure();
}
}
};
template <typename SrcOpTy, typename DstOpTy>
class BinaryArithmeticOpConversion : public OpConversionPattern<SrcOpTy> {
using OpConversionPattern<SrcOpTy>::OpConversionPattern;
LogicalResult matchAndRewrite(
SrcOpTy srcOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
typename SrcOpTy::Adaptor srcAdapter(operands);
rewriter.replaceOpWithNewOp<DstOpTy>(srcOp, srcAdapter.lhs().getType(),
srcAdapter.lhs(), srcAdapter.rhs());
return success();
}
};
template <typename SrcOpTy, typename DstOpTy, unsigned kBits = 32>
class ShiftArithmeticOpConversion : public OpConversionPattern<SrcOpTy> {
using OpConversionPattern<SrcOpTy>::OpConversionPattern;
LogicalResult matchAndRewrite(
SrcOpTy srcOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
typename SrcOpTy::Adaptor srcAdaptor(operands);
auto type = srcOp.getType();
if (!type.isSignlessInteger() || type.getIntOrFloatBitWidth() != kBits) {
return failure();
}
APInt amount;
if (!matchPattern(srcAdaptor.rhs(), m_ConstantInt(&amount))) {
return failure();
}
uint64_t amountRaw = amount.getZExtValue();
if (amountRaw > kBits) return failure();
IntegerAttr amountAttr =
IntegerAttr::get(IntegerType::get(8, srcOp.getContext()), amountRaw);
rewriter.replaceOpWithNewOp<DstOpTy>(srcOp, srcOp.getType(),
srcAdaptor.lhs(), amountAttr);
return success();
}
};
template <typename StdOp>
class CastingOpConversion : public OpConversionPattern<StdOp> {
using OpConversionPattern<StdOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
StdOp srcOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOp(srcOp, operands);
return success();
}
};
class SelectI32OpConversion : public OpConversionPattern<SelectOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
SelectOp srcOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
SelectOp::Adaptor srcAdaptor(operands);
IntegerType requiredType = IntegerType::get(32, srcOp.getContext());
// Note: This check can correctly just be a verification that
// actualType == requiredType, but since the VM type conversion also
// maps Indextype to this type, widening the check here reduces red-herrings
// when other conversions fail to properly match/rewrite index related ops.
// (Otherwise, the dialect converter may report the error as a failure to
// legalize the select op depending on order of resolution).
auto actualType = srcAdaptor.true_value().getType();
if (actualType != requiredType && actualType.isa<IndexType>())
return failure();
rewriter.replaceOpWithNewOp<IREE::VM::SelectI32Op>(
srcOp, requiredType, srcAdaptor.condition(), srcAdaptor.true_value(),
srcAdaptor.false_value());
return success();
}
};
class BranchOpConversion : public OpConversionPattern<BranchOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
BranchOp srcOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<IREE::VM::BranchOp>(srcOp, srcOp.getDest(),
operands);
return success();
}
};
class CondBranchOpConversion : public OpConversionPattern<CondBranchOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
CondBranchOp srcOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
Block *trueDest = srcOp.getTrueDest();
rewriter.replaceOpWithNewOp<IREE::VM::CondBranchOp>(
srcOp, operands[0], trueDest,
operands.slice(1, trueDest->getNumArguments()), srcOp.getFalseDest(),
operands.slice(1 + trueDest->getNumArguments()));
return success();
}
};
class CallOpConversion : public OpConversionPattern<CallOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
CallOp srcOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
CallOp::Adaptor srcAdaptor(operands);
// Convert function result types. The conversion framework will ensure
// that the callee has been equivalently converted.
VMTypeConverter typeConverter;
SmallVector<Type, 4> resultTypes;
for (auto resultType : srcOp.getResultTypes()) {
resultType = typeConverter.convertType(resultType);
if (!resultType) {
return failure();
}
resultTypes.push_back(resultType);
}
rewriter.replaceOpWithNewOp<IREE::VM::CallOp>(
srcOp, srcOp.getCallee(), resultTypes, srcAdaptor.operands());
return success();
}
};
} // namespace
void populateStandardToVMPatterns(MLIRContext *context,
OwningRewritePatternList &patterns) {
patterns
.insert<BranchOpConversion, CallOpConversion, CmpIOpConversion,
CondBranchOpConversion, ConstantOpConversion, ModuleOpConversion,
ModuleTerminatorOpConversion, FuncOpConversion,
ReturnOpConversion, CastingOpConversion<IndexCastOp>,
CastingOpConversion<TruncateIOp>, SelectI32OpConversion>(context);
// Binary arithmetic ops
patterns
.insert<BinaryArithmeticOpConversion<AddIOp, IREE::VM::AddI32Op>,
BinaryArithmeticOpConversion<SignedDivIOp, IREE::VM::DivI32SOp>,
BinaryArithmeticOpConversion<UnsignedDivIOp, IREE::VM::DivI32UOp>,
BinaryArithmeticOpConversion<MulIOp, IREE::VM::MulI32Op>,
BinaryArithmeticOpConversion<SignedRemIOp, IREE::VM::RemI32SOp>,
BinaryArithmeticOpConversion<UnsignedRemIOp, IREE::VM::RemI32UOp>,
BinaryArithmeticOpConversion<SubIOp, IREE::VM::SubI32Op>,
BinaryArithmeticOpConversion<AndOp, IREE::VM::AndI32Op>,
BinaryArithmeticOpConversion<OrOp, IREE::VM::OrI32Op>,
BinaryArithmeticOpConversion<XOrOp, IREE::VM::XorI32Op>>(context);
// Shift ops
// TODO(laurenzo): The standard dialect is missing shr ops. Add once in place.
patterns.insert<ShiftArithmeticOpConversion<ShiftLeftOp, IREE::VM::ShlI32Op>>(
context);
}
} // namespace iree_compiler
} // namespace mlir