Fix vm to emitc conversion for ref types (#7565)
This was broken by upstream changes to the conversion framework and hence some vm tests were disabled in #7488 for the C backend.
diff --git a/iree/compiler/Dialect/VM/Conversion/VMToEmitC/CMakeLists.txt b/iree/compiler/Dialect/VM/Conversion/VMToEmitC/CMakeLists.txt
index 32dde54..056b90d 100644
--- a/iree/compiler/Dialect/VM/Conversion/VMToEmitC/CMakeLists.txt
+++ b/iree/compiler/Dialect/VM/Conversion/VMToEmitC/CMakeLists.txt
@@ -14,6 +14,7 @@
"ConvertVMToEmitC.h"
"DropExcludedExports.h"
"EmitCTypeConverter.h"
+ "VMAnalysis.h"
SRCS
"ConvertVMToEmitC.cpp"
"DropExcludedExports.cpp"
diff --git a/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp b/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp
index cd3bb2b..50403b0 100644
--- a/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp
+++ b/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp
@@ -9,6 +9,7 @@
#include "iree/compiler/Dialect/Util/Conversion/ConversionPatterns.h"
#include "iree/compiler/Dialect/Util/IR/UtilDialect.h"
#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
+#include "iree/compiler/Dialect/VM/Conversion/VMToEmitC/VMAnalysis.h"
#include "iree/compiler/Dialect/VM/IR/VMOps.h"
#include "iree/compiler/Dialect/VM/Utils/CallingConvention.h"
#include "llvm/ADT/TypeSwitch.h"
@@ -1337,8 +1338,10 @@
using OpConversionPattern<SrcOpTy>::OpConversionPattern;
public:
- GenericOpConversion(MLIRContext *context, StringRef funcName)
- : OpConversionPattern<SrcOpTy>(context), funcName(funcName) {}
+ GenericOpConversion(TypeConverter &typeConverter, MLIRContext *context,
+ StringRef funcName)
+ : OpConversionPattern<SrcOpTy>(typeConverter, context),
+ funcName(funcName) {}
private:
LogicalResult matchAndRewrite(
@@ -1580,6 +1583,12 @@
bool move = ptr->second.isLastValueUse(operand, op.getOperation());
+ Optional<Value> operandRef = findRef(funcOp, vmAnalysisCache, operand);
+
+ if (!operandRef.hasValue()) {
+ return op.emitError() << "local ref not found";
+ }
+
rewriter.create<emitc::CallOp>(
/*location=*/loc,
/*type=*/TypeRange{},
@@ -1589,7 +1598,8 @@
ctx, {rewriter.getBoolAttr(move), rewriter.getIndexAttr(0),
rewriter.getIndexAttr(1)}),
/*templateArgs=*/ArrayAttr{},
- /*operands=*/ArrayRef<Value>{operand, refPtrOp.getResult()});
+ /*operands=*/
+ ArrayRef<Value>{operandRef.getValue(), refPtrOp.getResult()});
updatedOperands.push_back(refPtrOp.getResult());
} else {
@@ -1682,9 +1692,9 @@
public:
using OpConversionPattern<CmpOpTy>::OpConversionPattern;
- CompareRefOpConversion(MLIRContext *context, StringRef funcName,
- VMAnalysisCache &vmAnalysisCache)
- : OpConversionPattern<CmpOpTy>(context),
+ CompareRefOpConversion(TypeConverter &typeConverter, MLIRContext *context,
+ StringRef funcName, VMAnalysisCache &vmAnalysisCache)
+ : OpConversionPattern<CmpOpTy>(typeConverter, context),
funcName(funcName),
vmAnalysisCache(vmAnalysisCache) {}
@@ -1762,9 +1772,10 @@
using OpConversionPattern<IREE::VM::CmpNZRefOp>::OpConversionPattern;
public:
- CompareRefNotZeroOpConversion(MLIRContext *context,
+ CompareRefNotZeroOpConversion(TypeConverter &typeConverter,
+ MLIRContext *context,
VMAnalysisCache &vmAnalysisCache)
- : OpConversionPattern<IREE::VM::CmpNZRefOp>(context),
+ : OpConversionPattern<IREE::VM::CmpNZRefOp>(typeConverter, context),
vmAnalysisCache(vmAnalysisCache) {}
private:
@@ -1815,12 +1826,14 @@
};
template <typename ConstOpTy>
-class ConstOpConversion : public OpRewritePattern<ConstOpTy> {
+class ConstOpConversion : public OpConversionPattern<ConstOpTy> {
public:
- using OpRewritePattern<ConstOpTy>::OpRewritePattern;
+ using Adaptor = typename ConstOpTy::Adaptor;
+ using OpConversionPattern<ConstOpTy>::OpConversionPattern;
- LogicalResult matchAndRewrite(ConstOpTy constOp,
- PatternRewriter &rewriter) const final {
+ LogicalResult matchAndRewrite(
+ ConstOpTy constOp, Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const final {
rewriter.replaceOpWithNewOp<emitc::ConstantOp>(constOp, constOp.getType(),
constOp.value());
return success();
@@ -1828,12 +1841,14 @@
};
template <typename ConstZeroOpTy>
-class ConstZeroOpConversion : public OpRewritePattern<ConstZeroOpTy> {
+class ConstZeroOpConversion : public OpConversionPattern<ConstZeroOpTy> {
public:
- using OpRewritePattern<ConstZeroOpTy>::OpRewritePattern;
+ using Adaptor = typename ConstZeroOpTy::Adaptor;
+ using OpConversionPattern<ConstZeroOpTy>::OpConversionPattern;
- LogicalResult matchAndRewrite(ConstZeroOpTy constZeroOp,
- PatternRewriter &rewriter) const final {
+ LogicalResult matchAndRewrite(
+ ConstZeroOpTy constZeroOp, Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const final {
auto type = constZeroOp.getType();
Attribute value;
@@ -1855,9 +1870,9 @@
public:
using OpConversionPattern<IREE::VM::ConstRefZeroOp>::OpConversionPattern;
- ConstRefZeroOpConversion(MLIRContext *context,
+ ConstRefZeroOpConversion(TypeConverter &typeConverter, MLIRContext *context,
VMAnalysisCache &vmAnalysisCache)
- : OpConversionPattern<IREE::VM::ConstRefZeroOp>(context),
+ : OpConversionPattern<IREE::VM::ConstRefZeroOp>(typeConverter, context),
vmAnalysisCache(vmAnalysisCache) {}
LogicalResult matchAndRewrite(
@@ -1897,9 +1912,9 @@
public:
using OpConversionPattern<IREE::VM::ConstRefRodataOp>::OpConversionPattern;
- ConstRefRodataOpConversion(MLIRContext *context,
+ ConstRefRodataOpConversion(TypeConverter &typeConverter, MLIRContext *context,
VMAnalysisCache &vmAnalysisCache)
- : OpConversionPattern<IREE::VM::ConstRefRodataOp>(context),
+ : OpConversionPattern<IREE::VM::ConstRefRodataOp>(typeConverter, context),
vmAnalysisCache(vmAnalysisCache) {}
LogicalResult matchAndRewrite(
@@ -1978,8 +1993,9 @@
using OpConversionPattern<IREE::VM::BranchOp>::OpConversionPattern;
public:
- BranchOpConversion(MLIRContext *context, VMAnalysisCache &vmAnalysisCache)
- : OpConversionPattern<IREE::VM::BranchOp>(context),
+ BranchOpConversion(TypeConverter &typeConverter, MLIRContext *context,
+ VMAnalysisCache &vmAnalysisCache)
+ : OpConversionPattern<IREE::VM::BranchOp>(typeConverter, context),
vmAnalysisCache(vmAnalysisCache) {}
private:
@@ -2058,8 +2074,9 @@
using OpConversionPattern<IREE::VM::CondBranchOp>::OpConversionPattern;
public:
- CondBranchOpConversion(MLIRContext *context, VMAnalysisCache &vmAnalysisCache)
- : OpConversionPattern<IREE::VM::CondBranchOp>(context),
+ CondBranchOpConversion(TypeConverter &typeConverter, MLIRContext *context,
+ VMAnalysisCache &vmAnalysisCache)
+ : OpConversionPattern<IREE::VM::CondBranchOp>(typeConverter, context),
vmAnalysisCache(vmAnalysisCache) {}
private:
@@ -2120,16 +2137,26 @@
OpBuilder::InsertionGuard guard(rewriter);
trueDestDispatch = rewriter.createBlock(trueDest);
- for (Value operand : op.getTrueOperands()) {
+ for (auto pair :
+ llvm::zip(op.getTrueOperands(), trueDest->getArguments())) {
+ Value operand = std::get<0>(pair);
+ BlockArgument blockArg = std::get<1>(pair);
+
if (isNotRefOperand(operand)) {
continue;
}
assert(operand.getType().isa<IREE::VM::RefType>());
+ assert(blockArg.getType().isa<IREE::VM::RefType>());
- Optional<Value> destRef = findRef(funcOp, vmAnalysisCache, operand);
+ Optional<Value> operandRef = findRef(funcOp, vmAnalysisCache, operand);
+ Optional<Value> blockArgRef =
+ findRef(funcOp, vmAnalysisCache, blockArg);
- if (!destRef.hasValue()) {
+ if (!operandRef.hasValue()) {
+ return op.emitError() << "local ref not found";
+ }
+ if (!blockArgRef.hasValue()) {
return op.emitError() << "local ref not found";
}
@@ -2140,7 +2167,8 @@
StringAttr::get(ctx, "iree_vm_ref_retain"),
/*args=*/ArrayAttr{},
/*templateArgs=*/ArrayAttr{},
- /*operands=*/ArrayRef<Value>{operand, destRef.getValue()});
+ /*operands=*/
+ ArrayRef<Value>{operandRef.getValue(), blockArgRef.getValue()});
}
rewriter.create<mlir::BranchOp>(loc, op.trueDest(), op.getTrueOperands());
}
@@ -2150,16 +2178,26 @@
OpBuilder::InsertionGuard guard(rewriter);
falseDestDispatch = rewriter.createBlock(falseDest);
- for (Value operand : op.getFalseOperands()) {
+ for (auto pair :
+ llvm::zip(op.getFalseOperands(), falseDest->getArguments())) {
+ Value operand = std::get<0>(pair);
+ BlockArgument blockArg = std::get<1>(pair);
+
if (isNotRefOperand(operand)) {
continue;
}
assert(operand.getType().isa<IREE::VM::RefType>());
+ assert(blockArg.getType().isa<IREE::VM::RefType>());
- Optional<Value> destRef = findRef(funcOp, vmAnalysisCache, operand);
+ Optional<Value> operandRef = findRef(funcOp, vmAnalysisCache, operand);
+ Optional<Value> blockArgRef =
+ findRef(funcOp, vmAnalysisCache, blockArg);
- if (!destRef.hasValue()) {
+ if (!operandRef.hasValue()) {
+ return op.emitError() << "local ref not found";
+ }
+ if (!blockArgRef.hasValue()) {
return op.emitError() << "local ref not found";
}
@@ -2170,7 +2208,8 @@
StringAttr::get(ctx, "iree_vm_ref_retain"),
/*args=*/ArrayAttr{},
/*templateArgs=*/ArrayAttr{},
- /*operands=*/ArrayRef<Value>{operand, destRef.getValue()});
+ /*operands=*/
+ ArrayRef<Value>{operandRef.getValue(), blockArgRef.getValue()});
}
rewriter.create<mlir::BranchOp>(loc, op.falseDest(),
op.getFalseOperands());
@@ -2189,8 +2228,9 @@
using OpConversionPattern<IREE::VM::ReturnOp>::OpConversionPattern;
public:
- ReturnOpConversion(MLIRContext *context, VMAnalysisCache &vmAnalysisCache)
- : OpConversionPattern<IREE::VM::ReturnOp>(context),
+ ReturnOpConversion(TypeConverter &typeConverter, MLIRContext *context,
+ VMAnalysisCache &vmAnalysisCache)
+ : OpConversionPattern<IREE::VM::ReturnOp>(typeConverter, context),
vmAnalysisCache(vmAnalysisCache) {}
private:
@@ -2252,8 +2292,9 @@
using OpConversionPattern<IREE::VM::FailOp>::OpConversionPattern;
public:
- FailOpConversion(MLIRContext *context, VMAnalysisCache &vmAnalysisCache)
- : OpConversionPattern<IREE::VM::FailOp>(context),
+ FailOpConversion(TypeConverter &typeConverter, MLIRContext *context,
+ VMAnalysisCache &vmAnalysisCache)
+ : OpConversionPattern<IREE::VM::FailOp>(typeConverter, context),
vmAnalysisCache(vmAnalysisCache) {}
private:
@@ -2372,8 +2413,10 @@
using OpConversionPattern<LoadOpTy>::OpConversionPattern;
public:
- GlobalLoadOpConversion(MLIRContext *context, StringRef funcName)
- : OpConversionPattern<LoadOpTy>(context), funcName(funcName) {}
+ GlobalLoadOpConversion(TypeConverter &typeConverter, MLIRContext *context,
+ StringRef funcName)
+ : OpConversionPattern<LoadOpTy>(typeConverter, context),
+ funcName(funcName) {}
private:
LogicalResult matchAndRewrite(
@@ -2427,9 +2470,10 @@
public:
using OpConversionPattern<LoadStoreOpTy>::OpConversionPattern;
- GlobalLoadStoreRefOpConversion(MLIRContext *context,
+ GlobalLoadStoreRefOpConversion(TypeConverter &typeConverter,
+ MLIRContext *context,
VMAnalysisCache &vmAnalysisCache)
- : OpConversionPattern<LoadStoreOpTy>(context),
+ : OpConversionPattern<LoadStoreOpTy>(typeConverter, context),
vmAnalysisCache(vmAnalysisCache) {}
private:
@@ -2552,8 +2596,10 @@
using OpConversionPattern<StoreOpTy>::OpConversionPattern;
public:
- GlobalStoreOpConversion(MLIRContext *context, StringRef funcName)
- : OpConversionPattern<StoreOpTy>(context), funcName(funcName) {}
+ GlobalStoreOpConversion(TypeConverter &typeConverter, MLIRContext *context,
+ StringRef funcName)
+ : OpConversionPattern<StoreOpTy>(typeConverter, context),
+ funcName(funcName) {}
private:
LogicalResult matchAndRewrite(
@@ -2610,10 +2656,10 @@
using OpConversionPattern<SrcOpTy>::OpConversionPattern;
public:
- ListOpConversion(MLIRContext *context, StringRef funcName,
- size_t listArgumentIndex, bool failable,
+ ListOpConversion(TypeConverter &typeConverter, MLIRContext *context,
+ StringRef funcName, size_t listArgumentIndex, bool failable,
VMAnalysisCache &vmAnalysisCache)
- : OpConversionPattern<SrcOpTy>(context),
+ : OpConversionPattern<SrcOpTy>(typeConverter, context),
funcName(funcName),
listArgumentIndex(listArgumentIndex),
failable(failable),
@@ -2805,8 +2851,9 @@
using OpConversionPattern<GetOpTy>::OpConversionPattern;
public:
- ListGetOpConversion(MLIRContext *context, VMAnalysisCache &vmAnalysisCache)
- : OpConversionPattern<GetOpTy>(context),
+ ListGetOpConversion(TypeConverter &typeConverter, MLIRContext *context,
+ VMAnalysisCache &vmAnalysisCache)
+ : OpConversionPattern<GetOpTy>(typeConverter, context),
vmAnalysisCache(vmAnalysisCache) {}
private:
@@ -2898,8 +2945,9 @@
public:
using OpConversionPattern<IREE::VM::ListGetRefOp>::OpConversionPattern;
- ListGetRefOpConversion(MLIRContext *context, VMAnalysisCache &vmAnalysisCache)
- : OpConversionPattern<IREE::VM::ListGetRefOp>(context),
+ ListGetRefOpConversion(TypeConverter &typeConverter, MLIRContext *context,
+ VMAnalysisCache &vmAnalysisCache)
+ : OpConversionPattern<IREE::VM::ListGetRefOp>(typeConverter, context),
vmAnalysisCache(vmAnalysisCache) {}
private:
@@ -3078,8 +3126,9 @@
using OpConversionPattern<SetOpTy>::OpConversionPattern;
public:
- ListSetOpConversion(MLIRContext *context, VMAnalysisCache &vmAnalysisCache)
- : OpConversionPattern<SetOpTy>(context),
+ ListSetOpConversion(TypeConverter &typeConverter, MLIRContext *context,
+ VMAnalysisCache &vmAnalysisCache)
+ : OpConversionPattern<SetOpTy>(typeConverter, context),
vmAnalysisCache(vmAnalysisCache) {}
private:
@@ -3155,8 +3204,9 @@
using OpConversionPattern<IREE::VM::ListSetRefOp>::OpConversionPattern;
public:
- ListSetRefOpConversion(MLIRContext *context, VMAnalysisCache &vmAnalysisCache)
- : OpConversionPattern<IREE::VM::ListSetRefOp>(context),
+ ListSetRefOpConversion(TypeConverter &typeConverter, MLIRContext *context,
+ VMAnalysisCache &vmAnalysisCache)
+ : OpConversionPattern<IREE::VM::ListSetRefOp>(typeConverter, context),
vmAnalysisCache(vmAnalysisCache) {}
private:
@@ -3223,292 +3273,306 @@
patterns);
// CFG
- patterns.insert<BranchOpConversion>(context, vmAnalysisCache);
+ patterns.insert<BranchOpConversion>(typeConverter, context, vmAnalysisCache);
patterns.insert<CallOpConversion>(typeConverter, context, vmAnalysisCache);
- patterns.insert<CondBranchOpConversion>(context, vmAnalysisCache);
- patterns.insert<FailOpConversion>(context, vmAnalysisCache);
+ patterns.insert<CondBranchOpConversion>(typeConverter, context,
+ vmAnalysisCache);
+ patterns.insert<FailOpConversion>(typeConverter, context, vmAnalysisCache);
patterns.insert<FuncOpConversion>(typeConverter, context, vmAnalysisCache);
- patterns.insert<ReturnOpConversion>(context, vmAnalysisCache);
+ patterns.insert<ReturnOpConversion>(typeConverter, context, vmAnalysisCache);
// Globals
patterns.insert<
GlobalLoadOpConversion<IREE::VM::GlobalLoadI32Op, IREE::VM::GlobalI32Op>>(
- context, "vm_global_load_i32");
+ typeConverter, context, "vm_global_load_i32");
patterns.insert<GlobalStoreOpConversion<IREE::VM::GlobalStoreI32Op,
IREE::VM::GlobalI32Op>>(
- context, "vm_global_store_i32");
+ typeConverter, context, "vm_global_store_i32");
patterns.insert<GlobalLoadStoreRefOpConversion<IREE::VM::GlobalLoadRefOp>>(
- context, vmAnalysisCache);
+ typeConverter, context, vmAnalysisCache);
patterns.insert<GlobalLoadStoreRefOpConversion<IREE::VM::GlobalStoreRefOp>>(
- context, vmAnalysisCache);
+ typeConverter, context, vmAnalysisCache);
// Constants
- patterns.insert<ConstOpConversion<IREE::VM::ConstI32Op>>(context);
- patterns.insert<ConstZeroOpConversion<IREE::VM::ConstI32ZeroOp>>(context);
- patterns.insert<ConstRefZeroOpConversion>(context, vmAnalysisCache);
- patterns.insert<ConstRefRodataOpConversion>(context, vmAnalysisCache);
+ patterns.insert<ConstOpConversion<IREE::VM::ConstI32Op>>(typeConverter,
+ context);
+ patterns.insert<ConstZeroOpConversion<IREE::VM::ConstI32ZeroOp>>(
+ typeConverter, context);
+ patterns.insert<ConstRefZeroOpConversion>(typeConverter, context,
+ vmAnalysisCache);
+ patterns.insert<ConstRefRodataOpConversion>(typeConverter, context,
+ vmAnalysisCache);
// List ops
patterns.insert<ListAllocOpConversion>(typeConverter, context,
vmAnalysisCache);
patterns.insert<ListOpConversion<IREE::VM::ListReserveOp>>(
- context, "iree_vm_list_reserve", 0, true, vmAnalysisCache);
+ typeConverter, context, "iree_vm_list_reserve", 0, true, vmAnalysisCache);
patterns.insert<ListOpConversion<IREE::VM::ListResizeOp>>(
- context, "iree_vm_list_resize", 0, true, vmAnalysisCache);
+ typeConverter, context, "iree_vm_list_resize", 0, true, vmAnalysisCache);
patterns.insert<ListOpConversion<IREE::VM::ListSizeOp>>(
- context, "iree_vm_list_size", 0, false, vmAnalysisCache);
- patterns.insert<ListGetOpConversion<IREE::VM::ListGetI32Op>>(context,
- vmAnalysisCache);
- patterns.insert<ListGetRefOpConversion>(context, vmAnalysisCache);
- patterns.insert<ListSetOpConversion<IREE::VM::ListSetI32Op>>(context,
- vmAnalysisCache);
- patterns.insert<ListSetRefOpConversion>(context, vmAnalysisCache);
+ typeConverter, context, "iree_vm_list_size", 0, false, vmAnalysisCache);
+ patterns.insert<ListGetOpConversion<IREE::VM::ListGetI32Op>>(
+ typeConverter, context, vmAnalysisCache);
+ patterns.insert<ListGetRefOpConversion>(typeConverter, context,
+ vmAnalysisCache);
+ patterns.insert<ListSetOpConversion<IREE::VM::ListSetI32Op>>(
+ typeConverter, context, vmAnalysisCache);
+ patterns.insert<ListSetRefOpConversion>(typeConverter, context,
+ vmAnalysisCache);
// Conditional assignment ops
- patterns.insert<GenericOpConversion<IREE::VM::SelectI32Op>>(context,
- "vm_select_i32");
+ patterns.insert<GenericOpConversion<IREE::VM::SelectI32Op>>(
+ typeConverter, context, "vm_select_i32");
// Native integer arithmetic ops
- patterns.insert<GenericOpConversion<IREE::VM::AddI32Op>>(context,
- "vm_add_i32");
- patterns.insert<GenericOpConversion<IREE::VM::SubI32Op>>(context,
- "vm_sub_i32");
- patterns.insert<GenericOpConversion<IREE::VM::MulI32Op>>(context,
- "vm_mul_i32");
- patterns.insert<GenericOpConversion<IREE::VM::DivI32SOp>>(context,
- "vm_div_i32s");
- patterns.insert<GenericOpConversion<IREE::VM::DivI32UOp>>(context,
- "vm_div_i32u");
- patterns.insert<GenericOpConversion<IREE::VM::RemI32SOp>>(context,
- "vm_rem_i32s");
- patterns.insert<GenericOpConversion<IREE::VM::RemI32UOp>>(context,
- "vm_rem_i32u");
- patterns.insert<GenericOpConversion<IREE::VM::FMAI32Op>>(context,
- "vm_fma_i32");
- patterns.insert<GenericOpConversion<IREE::VM::NotI32Op>>(context,
- "vm_not_i32");
- patterns.insert<GenericOpConversion<IREE::VM::AndI32Op>>(context,
- "vm_and_i32");
- patterns.insert<GenericOpConversion<IREE::VM::OrI32Op>>(context, "vm_or_i32");
- patterns.insert<GenericOpConversion<IREE::VM::XorI32Op>>(context,
- "vm_xor_i32");
+ patterns.insert<GenericOpConversion<IREE::VM::AddI32Op>>(
+ typeConverter, context, "vm_add_i32");
+ patterns.insert<GenericOpConversion<IREE::VM::SubI32Op>>(
+ typeConverter, context, "vm_sub_i32");
+ patterns.insert<GenericOpConversion<IREE::VM::MulI32Op>>(
+ typeConverter, context, "vm_mul_i32");
+ patterns.insert<GenericOpConversion<IREE::VM::DivI32SOp>>(
+ typeConverter, context, "vm_div_i32s");
+ patterns.insert<GenericOpConversion<IREE::VM::DivI32UOp>>(
+ typeConverter, context, "vm_div_i32u");
+ patterns.insert<GenericOpConversion<IREE::VM::RemI32SOp>>(
+ typeConverter, context, "vm_rem_i32s");
+ patterns.insert<GenericOpConversion<IREE::VM::RemI32UOp>>(
+ typeConverter, context, "vm_rem_i32u");
+ patterns.insert<GenericOpConversion<IREE::VM::FMAI32Op>>(
+ typeConverter, context, "vm_fma_i32");
+ patterns.insert<GenericOpConversion<IREE::VM::NotI32Op>>(
+ typeConverter, context, "vm_not_i32");
+ patterns.insert<GenericOpConversion<IREE::VM::AndI32Op>>(
+ typeConverter, context, "vm_and_i32");
+ patterns.insert<GenericOpConversion<IREE::VM::OrI32Op>>(typeConverter,
+ context, "vm_or_i32");
+ patterns.insert<GenericOpConversion<IREE::VM::XorI32Op>>(
+ typeConverter, context, "vm_xor_i32");
// Casting and type conversion/emulation ops
patterns.insert<GenericOpConversion<IREE::VM::TruncI32I8Op>>(
- context, "vm_trunc_i32i8");
+ typeConverter, context, "vm_trunc_i32i8");
patterns.insert<GenericOpConversion<IREE::VM::TruncI32I16Op>>(
- context, "vm_trunc_i32i16");
- patterns.insert<GenericOpConversion<IREE::VM::ExtI8I32SOp>>(context,
- "vm_ext_i8i32s");
- patterns.insert<GenericOpConversion<IREE::VM::ExtI8I32UOp>>(context,
- "vm_ext_i8i32u");
+ typeConverter, context, "vm_trunc_i32i16");
+ patterns.insert<GenericOpConversion<IREE::VM::ExtI8I32SOp>>(
+ typeConverter, context, "vm_ext_i8i32s");
+ patterns.insert<GenericOpConversion<IREE::VM::ExtI8I32UOp>>(
+ typeConverter, context, "vm_ext_i8i32u");
patterns.insert<GenericOpConversion<IREE::VM::ExtI16I32SOp>>(
- context, "vm_ext_i16i32s");
+ typeConverter, context, "vm_ext_i16i32s");
patterns.insert<GenericOpConversion<IREE::VM::ExtI16I32UOp>>(
- context, "vm_ext_i16i32u");
+ typeConverter, context, "vm_ext_i16i32u");
// Native bitwise shift and rotate ops
- patterns.insert<GenericOpConversion<IREE::VM::ShlI32Op>>(context,
- "vm_shl_i32");
- patterns.insert<GenericOpConversion<IREE::VM::ShrI32SOp>>(context,
- "vm_shr_i32s");
- patterns.insert<GenericOpConversion<IREE::VM::ShrI32UOp>>(context,
- "vm_shr_i32u");
+ patterns.insert<GenericOpConversion<IREE::VM::ShlI32Op>>(
+ typeConverter, context, "vm_shl_i32");
+ patterns.insert<GenericOpConversion<IREE::VM::ShrI32SOp>>(
+ typeConverter, context, "vm_shr_i32s");
+ patterns.insert<GenericOpConversion<IREE::VM::ShrI32UOp>>(
+ typeConverter, context, "vm_shr_i32u");
// Comparison ops
- patterns.insert<GenericOpConversion<IREE::VM::CmpEQI32Op>>(context,
- "vm_cmp_eq_i32");
- patterns.insert<GenericOpConversion<IREE::VM::CmpNEI32Op>>(context,
- "vm_cmp_ne_i32");
- patterns.insert<GenericOpConversion<IREE::VM::CmpLTI32SOp>>(context,
- "vm_cmp_lt_i32s");
- patterns.insert<GenericOpConversion<IREE::VM::CmpLTI32UOp>>(context,
- "vm_cmp_lt_i32u");
- patterns.insert<GenericOpConversion<IREE::VM::CmpNZI32Op>>(context,
- "vm_cmp_nz_i32");
+ patterns.insert<GenericOpConversion<IREE::VM::CmpEQI32Op>>(
+ typeConverter, context, "vm_cmp_eq_i32");
+ patterns.insert<GenericOpConversion<IREE::VM::CmpNEI32Op>>(
+ typeConverter, context, "vm_cmp_ne_i32");
+ patterns.insert<GenericOpConversion<IREE::VM::CmpLTI32SOp>>(
+ typeConverter, context, "vm_cmp_lt_i32s");
+ patterns.insert<GenericOpConversion<IREE::VM::CmpLTI32UOp>>(
+ typeConverter, context, "vm_cmp_lt_i32u");
+ patterns.insert<GenericOpConversion<IREE::VM::CmpNZI32Op>>(
+ typeConverter, context, "vm_cmp_nz_i32");
patterns.insert<CompareRefOpConversion<IREE::VM::CmpEQRefOp>>(
- context, "vm_cmp_eq_ref", vmAnalysisCache);
+ typeConverter, context, "vm_cmp_eq_ref", vmAnalysisCache);
patterns.insert<CompareRefOpConversion<IREE::VM::CmpNERefOp>>(
- context, "vm_cmp_ne_ref", vmAnalysisCache);
- patterns.insert<CompareRefNotZeroOpConversion>(context, vmAnalysisCache);
+ typeConverter, context, "vm_cmp_ne_ref", vmAnalysisCache);
+ patterns.insert<CompareRefNotZeroOpConversion>(typeConverter, context,
+ vmAnalysisCache);
// ExtF32: Globals
patterns.insert<
GlobalLoadOpConversion<IREE::VM::GlobalLoadF32Op, IREE::VM::GlobalF32Op>>(
- context, "vm_global_load_f32");
+ typeConverter, context, "vm_global_load_f32");
patterns.insert<GlobalStoreOpConversion<IREE::VM::GlobalStoreF32Op,
IREE::VM::GlobalF32Op>>(
- context, "vm_global_store_f32");
+ typeConverter, context, "vm_global_store_f32");
// ExtF32: Native floating-point constants
- patterns.insert<ConstOpConversion<IREE::VM::ConstF32Op>>(context);
- patterns.insert<ConstZeroOpConversion<IREE::VM::ConstF32ZeroOp>>(context);
+ patterns.insert<ConstOpConversion<IREE::VM::ConstF32Op>>(typeConverter,
+ context);
+ patterns.insert<ConstZeroOpConversion<IREE::VM::ConstF32ZeroOp>>(
+ typeConverter, context);
// ExtF32: Conditional assignment
- patterns.insert<GenericOpConversion<IREE::VM::SelectF32Op>>(context,
- "vm_select_f32");
+ patterns.insert<GenericOpConversion<IREE::VM::SelectF32Op>>(
+ typeConverter, context, "vm_select_f32");
// ExtF32: Native floating-point arithmetic
- patterns.insert<GenericOpConversion<IREE::VM::AddF32Op>>(context,
- "vm_add_f32");
- patterns.insert<GenericOpConversion<IREE::VM::SubF32Op>>(context,
- "vm_sub_f32");
- patterns.insert<GenericOpConversion<IREE::VM::MulF32Op>>(context,
- "vm_mul_f32");
- patterns.insert<GenericOpConversion<IREE::VM::DivF32Op>>(context,
- "vm_div_f32");
- patterns.insert<GenericOpConversion<IREE::VM::RemF32Op>>(context,
- "vm_rem_f32");
- patterns.insert<GenericOpConversion<IREE::VM::FMAF32Op>>(context,
- "vm_fma_f32");
- patterns.insert<GenericOpConversion<IREE::VM::AbsF32Op>>(context,
- "vm_abs_f32");
- patterns.insert<GenericOpConversion<IREE::VM::NegF32Op>>(context,
- "vm_neg_f32");
- patterns.insert<GenericOpConversion<IREE::VM::CeilF32Op>>(context,
- "vm_ceil_f32");
- patterns.insert<GenericOpConversion<IREE::VM::FloorF32Op>>(context,
- "vm_floor_f32");
+ patterns.insert<GenericOpConversion<IREE::VM::AddF32Op>>(
+ typeConverter, context, "vm_add_f32");
+ patterns.insert<GenericOpConversion<IREE::VM::SubF32Op>>(
+ typeConverter, context, "vm_sub_f32");
+ patterns.insert<GenericOpConversion<IREE::VM::MulF32Op>>(
+ typeConverter, context, "vm_mul_f32");
+ patterns.insert<GenericOpConversion<IREE::VM::DivF32Op>>(
+ typeConverter, context, "vm_div_f32");
+ patterns.insert<GenericOpConversion<IREE::VM::RemF32Op>>(
+ typeConverter, context, "vm_rem_f32");
+ patterns.insert<GenericOpConversion<IREE::VM::FMAF32Op>>(
+ typeConverter, context, "vm_fma_f32");
+ patterns.insert<GenericOpConversion<IREE::VM::AbsF32Op>>(
+ typeConverter, context, "vm_abs_f32");
+ patterns.insert<GenericOpConversion<IREE::VM::NegF32Op>>(
+ typeConverter, context, "vm_neg_f32");
+ patterns.insert<GenericOpConversion<IREE::VM::CeilF32Op>>(
+ typeConverter, context, "vm_ceil_f32");
+ patterns.insert<GenericOpConversion<IREE::VM::FloorF32Op>>(
+ typeConverter, context, "vm_floor_f32");
- patterns.insert<GenericOpConversion<IREE::VM::AtanF32Op>>(context,
- "vm_atan_f32");
- patterns.insert<GenericOpConversion<IREE::VM::Atan2F32Op>>(context,
- "vm_atan2_f32");
- patterns.insert<GenericOpConversion<IREE::VM::CosF32Op>>(context,
- "vm_cos_f32");
- patterns.insert<GenericOpConversion<IREE::VM::SinF32Op>>(context,
- "vm_sin_f32");
- patterns.insert<GenericOpConversion<IREE::VM::ExpF32Op>>(context,
- "vm_exp_f32");
- patterns.insert<GenericOpConversion<IREE::VM::Exp2F32Op>>(context,
- "vm_exp2_f32");
- patterns.insert<GenericOpConversion<IREE::VM::ExpM1F32Op>>(context,
- "vm_expm1_f32");
- patterns.insert<GenericOpConversion<IREE::VM::LogF32Op>>(context,
- "vm_log_f32");
- patterns.insert<GenericOpConversion<IREE::VM::Log10F32Op>>(context,
- "vm_log10_f32");
- patterns.insert<GenericOpConversion<IREE::VM::Log1pF32Op>>(context,
- "vm_log1p_f32");
- patterns.insert<GenericOpConversion<IREE::VM::Log2F32Op>>(context,
- "vm_log2_f32");
- patterns.insert<GenericOpConversion<IREE::VM::PowF32Op>>(context,
- "vm_pow_f32");
- patterns.insert<GenericOpConversion<IREE::VM::RsqrtF32Op>>(context,
- "vm_rsqrt_f32");
- patterns.insert<GenericOpConversion<IREE::VM::SqrtF32Op>>(context,
- "vm_sqrt_f32");
- patterns.insert<GenericOpConversion<IREE::VM::TanhF32Op>>(context,
- "vm_tanh_f32");
+ patterns.insert<GenericOpConversion<IREE::VM::AtanF32Op>>(
+ typeConverter, context, "vm_atan_f32");
+ patterns.insert<GenericOpConversion<IREE::VM::Atan2F32Op>>(
+ typeConverter, context, "vm_atan2_f32");
+ patterns.insert<GenericOpConversion<IREE::VM::CosF32Op>>(
+ typeConverter, context, "vm_cos_f32");
+ patterns.insert<GenericOpConversion<IREE::VM::SinF32Op>>(
+ typeConverter, context, "vm_sin_f32");
+ patterns.insert<GenericOpConversion<IREE::VM::ExpF32Op>>(
+ typeConverter, context, "vm_exp_f32");
+ patterns.insert<GenericOpConversion<IREE::VM::Exp2F32Op>>(
+ typeConverter, context, "vm_exp2_f32");
+ patterns.insert<GenericOpConversion<IREE::VM::ExpM1F32Op>>(
+ typeConverter, context, "vm_expm1_f32");
+ patterns.insert<GenericOpConversion<IREE::VM::LogF32Op>>(
+ typeConverter, context, "vm_log_f32");
+ patterns.insert<GenericOpConversion<IREE::VM::Log10F32Op>>(
+ typeConverter, context, "vm_log10_f32");
+ patterns.insert<GenericOpConversion<IREE::VM::Log1pF32Op>>(
+ typeConverter, context, "vm_log1p_f32");
+ patterns.insert<GenericOpConversion<IREE::VM::Log2F32Op>>(
+ typeConverter, context, "vm_log2_f32");
+ patterns.insert<GenericOpConversion<IREE::VM::PowF32Op>>(
+ typeConverter, context, "vm_pow_f32");
+ patterns.insert<GenericOpConversion<IREE::VM::RsqrtF32Op>>(
+ typeConverter, context, "vm_rsqrt_f32");
+ patterns.insert<GenericOpConversion<IREE::VM::SqrtF32Op>>(
+ typeConverter, context, "vm_sqrt_f32");
+ patterns.insert<GenericOpConversion<IREE::VM::TanhF32Op>>(
+ typeConverter, context, "vm_tanh_f32");
// ExtF32: Casting and type conversion/emulation
patterns.insert<GenericOpConversion<IREE::VM::CastSI32F32Op>>(
- context, "vm_cast_si32f32");
+ typeConverter, context, "vm_cast_si32f32");
patterns.insert<GenericOpConversion<IREE::VM::CastUI32F32Op>>(
- context, "vm_cast_ui32f32");
+ typeConverter, context, "vm_cast_ui32f32");
patterns.insert<GenericOpConversion<IREE::VM::CastF32SI32Op>>(
- context, "vm_cast_f32si32");
+ typeConverter, context, "vm_cast_f32si32");
patterns.insert<GenericOpConversion<IREE::VM::CastF32UI32Op>>(
- context, "vm_cast_f32ui32");
+ typeConverter, context, "vm_cast_f32ui32");
patterns.insert<GenericOpConversion<IREE::VM::BitcastI32F32Op>>(
- context, "vm_bitcast_i32f32");
+ typeConverter, context, "vm_bitcast_i32f32");
patterns.insert<GenericOpConversion<IREE::VM::BitcastF32I32Op>>(
- context, "vm_bitcast_f32i32");
+ typeConverter, context, "vm_bitcast_f32i32");
// ExtF32: Comparison ops
- patterns.insert<GenericOpConversion<IREE::VM::CmpEQF32OOp>>(context,
- "vm_cmp_eq_f32o");
- patterns.insert<GenericOpConversion<IREE::VM::CmpEQF32UOp>>(context,
- "vm_cmp_eq_f32u");
- patterns.insert<GenericOpConversion<IREE::VM::CmpNEF32OOp>>(context,
- "vm_cmp_ne_f32o");
- patterns.insert<GenericOpConversion<IREE::VM::CmpNEF32UOp>>(context,
- "vm_cmp_ne_f32u");
- patterns.insert<GenericOpConversion<IREE::VM::CmpLTF32OOp>>(context,
- "vm_cmp_lt_f32o");
- patterns.insert<GenericOpConversion<IREE::VM::CmpLTF32UOp>>(context,
- "vm_cmp_lt_f32u");
+ patterns.insert<GenericOpConversion<IREE::VM::CmpEQF32OOp>>(
+ typeConverter, context, "vm_cmp_eq_f32o");
+ patterns.insert<GenericOpConversion<IREE::VM::CmpEQF32UOp>>(
+ typeConverter, context, "vm_cmp_eq_f32u");
+ patterns.insert<GenericOpConversion<IREE::VM::CmpNEF32OOp>>(
+ typeConverter, context, "vm_cmp_ne_f32o");
+ patterns.insert<GenericOpConversion<IREE::VM::CmpNEF32UOp>>(
+ typeConverter, context, "vm_cmp_ne_f32u");
+ patterns.insert<GenericOpConversion<IREE::VM::CmpLTF32OOp>>(
+ typeConverter, context, "vm_cmp_lt_f32o");
+ patterns.insert<GenericOpConversion<IREE::VM::CmpLTF32UOp>>(
+ typeConverter, context, "vm_cmp_lt_f32u");
patterns.insert<GenericOpConversion<IREE::VM::CmpLTEF32OOp>>(
- context, "vm_cmp_lte_f32o");
+ typeConverter, context, "vm_cmp_lte_f32o");
patterns.insert<GenericOpConversion<IREE::VM::CmpLTEF32UOp>>(
- context, "vm_cmp_lte_f32u");
- patterns.insert<GenericOpConversion<IREE::VM::CmpNaNF32Op>>(context,
- "vm_cmp_nan_f32");
+ typeConverter, context, "vm_cmp_lte_f32u");
+ patterns.insert<GenericOpConversion<IREE::VM::CmpNaNF32Op>>(
+ typeConverter, context, "vm_cmp_nan_f32");
// ExtI64: Globals
patterns.insert<
GlobalLoadOpConversion<IREE::VM::GlobalLoadI64Op, IREE::VM::GlobalI64Op>>(
- context, "vm_global_load_i64");
+ typeConverter, context, "vm_global_load_i64");
patterns.insert<GlobalStoreOpConversion<IREE::VM::GlobalStoreI64Op,
IREE::VM::GlobalI64Op>>(
- context, "vm_global_store_i64");
+ typeConverter, context, "vm_global_store_i64");
// ExtI64: Constants
- patterns.insert<ConstOpConversion<IREE::VM::ConstI64Op>>(context);
- patterns.insert<ConstZeroOpConversion<IREE::VM::ConstI64ZeroOp>>(context);
+ patterns.insert<ConstOpConversion<IREE::VM::ConstI64Op>>(typeConverter,
+ context);
+ patterns.insert<ConstZeroOpConversion<IREE::VM::ConstI64ZeroOp>>(
+ typeConverter, context);
// ExtI64: List ops
- patterns.insert<ListGetOpConversion<IREE::VM::ListGetI64Op>>(context,
- vmAnalysisCache);
- patterns.insert<ListSetOpConversion<IREE::VM::ListSetI64Op>>(context,
- vmAnalysisCache);
+ patterns.insert<ListGetOpConversion<IREE::VM::ListGetI64Op>>(
+ typeConverter, context, vmAnalysisCache);
+ patterns.insert<ListSetOpConversion<IREE::VM::ListSetI64Op>>(
+ typeConverter, context, vmAnalysisCache);
// ExtI64: Conditional assignment ops
- patterns.insert<GenericOpConversion<IREE::VM::SelectI64Op>>(context,
- "vm_select_i64");
+ patterns.insert<GenericOpConversion<IREE::VM::SelectI64Op>>(
+ typeConverter, context, "vm_select_i64");
// ExtI64: Native integer arithmetic ops
- patterns.insert<GenericOpConversion<IREE::VM::AddI64Op>>(context,
- "vm_add_i64");
- patterns.insert<GenericOpConversion<IREE::VM::SubI64Op>>(context,
- "vm_sub_i64");
- patterns.insert<GenericOpConversion<IREE::VM::MulI64Op>>(context,
- "vm_mul_i64");
- patterns.insert<GenericOpConversion<IREE::VM::DivI64SOp>>(context,
- "vm_div_i64s");
- patterns.insert<GenericOpConversion<IREE::VM::DivI64UOp>>(context,
- "vm_div_i64u");
- patterns.insert<GenericOpConversion<IREE::VM::RemI64SOp>>(context,
- "vm_rem_i64s");
- patterns.insert<GenericOpConversion<IREE::VM::RemI64UOp>>(context,
- "vm_rem_i64u");
- patterns.insert<GenericOpConversion<IREE::VM::FMAI64Op>>(context,
- "vm_fma_i64");
- patterns.insert<GenericOpConversion<IREE::VM::NotI64Op>>(context,
- "vm_not_i64");
- patterns.insert<GenericOpConversion<IREE::VM::AndI64Op>>(context,
- "vm_and_i64");
- patterns.insert<GenericOpConversion<IREE::VM::OrI64Op>>(context, "vm_or_i64");
- patterns.insert<GenericOpConversion<IREE::VM::XorI64Op>>(context,
- "vm_xor_i64");
+ patterns.insert<GenericOpConversion<IREE::VM::AddI64Op>>(
+ typeConverter, context, "vm_add_i64");
+ patterns.insert<GenericOpConversion<IREE::VM::SubI64Op>>(
+ typeConverter, context, "vm_sub_i64");
+ patterns.insert<GenericOpConversion<IREE::VM::MulI64Op>>(
+ typeConverter, context, "vm_mul_i64");
+ patterns.insert<GenericOpConversion<IREE::VM::DivI64SOp>>(
+ typeConverter, context, "vm_div_i64s");
+ patterns.insert<GenericOpConversion<IREE::VM::DivI64UOp>>(
+ typeConverter, context, "vm_div_i64u");
+ patterns.insert<GenericOpConversion<IREE::VM::RemI64SOp>>(
+ typeConverter, context, "vm_rem_i64s");
+ patterns.insert<GenericOpConversion<IREE::VM::RemI64UOp>>(
+ typeConverter, context, "vm_rem_i64u");
+ patterns.insert<GenericOpConversion<IREE::VM::FMAI64Op>>(
+ typeConverter, context, "vm_fma_i64");
+ patterns.insert<GenericOpConversion<IREE::VM::NotI64Op>>(
+ typeConverter, context, "vm_not_i64");
+ patterns.insert<GenericOpConversion<IREE::VM::AndI64Op>>(
+ typeConverter, context, "vm_and_i64");
+ patterns.insert<GenericOpConversion<IREE::VM::OrI64Op>>(typeConverter,
+ context, "vm_or_i64");
+ patterns.insert<GenericOpConversion<IREE::VM::XorI64Op>>(
+ typeConverter, context, "vm_xor_i64");
// ExtI64: Casting and type conversion/emulation ops
patterns.insert<GenericOpConversion<IREE::VM::TruncI64I32Op>>(
- context, "vm_trunc_i64i32");
+ typeConverter, context, "vm_trunc_i64i32");
patterns.insert<GenericOpConversion<IREE::VM::ExtI32I64SOp>>(
- context, "vm_ext_i32i64s");
+ typeConverter, context, "vm_ext_i32i64s");
patterns.insert<GenericOpConversion<IREE::VM::ExtI32I64UOp>>(
- context, "vm_ext_i32i64u");
+ typeConverter, context, "vm_ext_i32i64u");
// ExtI64: Native bitwise shift and rotate ops
- patterns.insert<GenericOpConversion<IREE::VM::ShlI64Op>>(context,
- "vm_shl_i64");
- patterns.insert<GenericOpConversion<IREE::VM::ShrI64SOp>>(context,
- "vm_shr_i64s");
- patterns.insert<GenericOpConversion<IREE::VM::ShrI64UOp>>(context,
- "vm_shr_i64u");
+ patterns.insert<GenericOpConversion<IREE::VM::ShlI64Op>>(
+ typeConverter, context, "vm_shl_i64");
+ patterns.insert<GenericOpConversion<IREE::VM::ShrI64SOp>>(
+ typeConverter, context, "vm_shr_i64s");
+ patterns.insert<GenericOpConversion<IREE::VM::ShrI64UOp>>(
+ typeConverter, context, "vm_shr_i64u");
// ExtI64: Comparison ops
- patterns.insert<GenericOpConversion<IREE::VM::CmpEQI64Op>>(context,
- "vm_cmp_eq_i64");
- patterns.insert<GenericOpConversion<IREE::VM::CmpNEI64Op>>(context,
- "vm_cmp_ne_i64");
- patterns.insert<GenericOpConversion<IREE::VM::CmpLTI64SOp>>(context,
- "vm_cmp_lt_i64s");
- patterns.insert<GenericOpConversion<IREE::VM::CmpLTI64UOp>>(context,
- "vm_cmp_lt_i64u");
- patterns.insert<GenericOpConversion<IREE::VM::CmpNZI64Op>>(context,
- "vm_cmp_nz_i64");
+ patterns.insert<GenericOpConversion<IREE::VM::CmpEQI64Op>>(
+ typeConverter, context, "vm_cmp_eq_i64");
+ patterns.insert<GenericOpConversion<IREE::VM::CmpNEI64Op>>(
+ typeConverter, context, "vm_cmp_ne_i64");
+ patterns.insert<GenericOpConversion<IREE::VM::CmpLTI64SOp>>(
+ typeConverter, context, "vm_cmp_lt_i64s");
+ patterns.insert<GenericOpConversion<IREE::VM::CmpLTI64UOp>>(
+ typeConverter, context, "vm_cmp_lt_i64u");
+ patterns.insert<GenericOpConversion<IREE::VM::CmpNZI64Op>>(
+ typeConverter, context, "vm_cmp_nz_i64");
}
namespace IREE {
@@ -3628,11 +3692,24 @@
return signalPassFailure();
}
- // Global ops are dead now
- module.walk([](Operation *op) {
+ SetVector<Operation *> &materializations =
+ typeConverter.sourceMaterializations;
+
+ module.walk([&materializations](Operation *op) {
+ // Global ops are dead now
if (isa<IREE::VM::GlobalI32Op, IREE::VM::GlobalI64Op,
IREE::VM::GlobalF32Op, IREE::VM::GlobalRefOp>(op)) {
op->erase();
+ return;
+ }
+ // Remove dead basic block arguments
+ if (materializations.contains(op)) {
+ assert(isa<emitc::ConstantOp>(op));
+ assert(op->use_empty());
+
+ materializations.remove(op);
+ op->erase();
+ return;
}
});
}
diff --git a/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.h b/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.h
index b4caae5..8d07c31 100644
--- a/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.h
+++ b/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.h
@@ -7,63 +7,14 @@
#ifndef IREE_COMPILER_DIALECT_VM_CONVERSION_VMTOEMITC_CONVERTVMTOEMITC_H_
#define IREE_COMPILER_DIALECT_VM_CONVERSION_VMTOEMITC_CONVERTVMTOEMITC_H_
-#include "iree/compiler/Dialect/VM/Analysis/RegisterAllocation.h"
-#include "iree/compiler/Dialect/VM/Analysis/ValueLiveness.h"
#include "iree/compiler/Dialect/VM/Conversion/VMToEmitC/EmitCTypeConverter.h"
+#include "iree/compiler/Dialect/VM/Conversion/VMToEmitC/VMAnalysis.h"
#include "iree/compiler/Dialect/VM/IR/VMOps.h"
-#include "mlir/Dialect/EmitC/IR/EmitC.h"
#include "mlir/Pass/Pass.h"
namespace mlir {
namespace iree_compiler {
-struct VMAnalysis {
- public:
- VMAnalysis(RegisterAllocation &®isterAllocation,
- ValueLiveness &&valueLiveness)
- : registerAllocation(std::move(registerAllocation)),
- valueLiveness(std::move(valueLiveness)) {}
-
- VMAnalysis(VMAnalysis &&) = default;
- VMAnalysis &operator=(VMAnalysis &&) = default;
- VMAnalysis(const VMAnalysis &) = delete;
- VMAnalysis &operator=(const VMAnalysis &) = delete;
-
- int getNumRefRegisters() {
- return registerAllocation.getMaxRefRegisterOrdinal() + 1;
- }
-
- uint16_t getRefRegisterOrdinal(Value ref) {
- assert(ref.getType().isa<IREE::VM::RefType>());
- return registerAllocation.mapToRegister(ref).ordinal();
- }
-
- bool isLastValueUse(Value ref, Operation *op) {
- assert(ref.getType().isa<IREE::VM::RefType>());
- return valueLiveness.isLastValueUse(ref, op);
- }
-
- void cacheLocalRef(int64_t ordinal, emitc::ApplyOp &applyOp) {
- assert(!refs.count(ordinal));
- refs[ordinal] = applyOp.getOperation();
- }
-
- emitc::ApplyOp lookupLocalRef(int64_t ordinal) {
- assert(refs.count(ordinal));
- Operation *op = refs[ordinal];
- return cast<emitc::ApplyOp>(op);
- }
-
- DenseMap<int64_t, Operation *> &localRefs() { return refs; }
-
- private:
- RegisterAllocation registerAllocation;
- ValueLiveness valueLiveness;
- DenseMap<int64_t, Operation *> refs;
-};
-
-using VMAnalysisCache = DenseMap<Operation *, VMAnalysis>;
-
void populateVMToEmitCPatterns(MLIRContext *context,
IREE::VM::EmitCTypeConverter &typeConverter,
OwningRewritePatternList &patterns,
diff --git a/iree/compiler/Dialect/VM/Conversion/VMToEmitC/EmitCTypeConverter.h b/iree/compiler/Dialect/VM/Conversion/VMToEmitC/EmitCTypeConverter.h
index 63f2b9f..90f201b 100644
--- a/iree/compiler/Dialect/VM/Conversion/VMToEmitC/EmitCTypeConverter.h
+++ b/iree/compiler/Dialect/VM/Conversion/VMToEmitC/EmitCTypeConverter.h
@@ -27,7 +27,34 @@
addConversion([](IREE::VM::RefType type) {
return emitc::OpaqueType::get(type.getContext(), "iree_vm_ref_t*");
});
+
+ // We need a source materialization for refs because after running
+ // `applyFullConversion` there would be references to the original
+ // IREE::VM::Ref values in unused basic block arguments. As these are unused
+ // anyway we create dummy ops which get deleted after the conversion has
+ // finished.
+ addSourceMaterialization([this](OpBuilder &builder, IREE::VM::RefType type,
+ ValueRange inputs, Location loc) -> Value {
+ assert(inputs.size() == 1);
+ Value input = inputs[0];
+ assert(input.getType().isa<emitc::OpaqueType>());
+
+ Type objectType = IREE::VM::OpaqueType::get(builder.getContext());
+ Type refType = IREE::VM::RefType::get(objectType);
+
+ auto ctx = builder.getContext();
+ auto op = builder.create<emitc::ConstantOp>(
+ /*location=*/loc,
+ /*resultType=*/refType,
+ /*value=*/emitc::OpaqueAttr::get(ctx, ""));
+
+ sourceMaterializations.insert(op.getOperation());
+
+ return op.getResult();
+ });
}
+
+ SetVector<Operation *> sourceMaterializations;
};
} // namespace VM
diff --git a/iree/compiler/Dialect/VM/Conversion/VMToEmitC/VMAnalysis.h b/iree/compiler/Dialect/VM/Conversion/VMToEmitC/VMAnalysis.h
new file mode 100644
index 0000000..d127882
--- /dev/null
+++ b/iree/compiler/Dialect/VM/Conversion/VMToEmitC/VMAnalysis.h
@@ -0,0 +1,68 @@
+// Copyright 2020 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_COMPILER_DIALECT_VM_CONVERSION_VMTOEMITC_VMANALYSIS_H_
+#define IREE_COMPILER_DIALECT_VM_CONVERSION_VMTOEMITC_VMANALYSIS_H_
+
+#include "iree/compiler/Dialect/VM/Analysis/RegisterAllocation.h"
+#include "iree/compiler/Dialect/VM/Analysis/ValueLiveness.h"
+#include "iree/compiler/Dialect/VM/IR/VMTypes.h"
+#include "mlir/Dialect/EmitC/IR/EmitC.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+struct VMAnalysis {
+ public:
+ VMAnalysis(RegisterAllocation &®isterAllocation,
+ ValueLiveness &&valueLiveness)
+ : registerAllocation(std::move(registerAllocation)),
+ valueLiveness(std::move(valueLiveness)) {}
+
+ VMAnalysis(VMAnalysis &&) = default;
+ VMAnalysis &operator=(VMAnalysis &&) = default;
+ VMAnalysis(const VMAnalysis &) = delete;
+ VMAnalysis &operator=(const VMAnalysis &) = delete;
+
+ int getNumRefRegisters() {
+ return registerAllocation.getMaxRefRegisterOrdinal() + 1;
+ }
+
+ uint16_t getRefRegisterOrdinal(Value ref) {
+ assert(ref.getType().isa<IREE::VM::RefType>());
+ return registerAllocation.mapToRegister(ref).ordinal();
+ }
+
+ bool isLastValueUse(Value ref, Operation *op) {
+ assert(ref.getType().isa<IREE::VM::RefType>());
+ return valueLiveness.isLastValueUse(ref, op);
+ }
+
+ void cacheLocalRef(int64_t ordinal, emitc::ApplyOp &applyOp) {
+ assert(!refs.count(ordinal));
+ refs[ordinal] = applyOp.getOperation();
+ }
+
+ emitc::ApplyOp lookupLocalRef(int64_t ordinal) {
+ assert(refs.count(ordinal));
+ Operation *op = refs[ordinal];
+ return cast<emitc::ApplyOp>(op);
+ }
+
+ DenseMap<int64_t, Operation *> &localRefs() { return refs; }
+
+ private:
+ RegisterAllocation registerAllocation;
+ ValueLiveness valueLiveness;
+ DenseMap<int64_t, Operation *> refs;
+};
+
+using VMAnalysisCache = DenseMap<Operation *, VMAnalysis>;
+
+} // namespace iree_compiler
+} // namespace mlir
+
+#endif // IREE_COMPILER_DIALECT_VM_CONVERSION_VMTOEMITC_VMANALYSIS_H_
diff --git a/iree/vm/test/call_ops.mlir b/iree/vm/test/call_ops.mlir
index c0ae804..8d0d445 100644
--- a/iree/vm/test/call_ops.mlir
+++ b/iree/vm/test/call_ops.mlir
@@ -15,18 +15,16 @@
vm.return
}
- // TODO(#7487): Enable the test for emitc.
- vm.export @test_call_r_v attributes {emitc.exclude}
- vm.func private @test_call_r_v() {
+ vm.export @test_call_r_v
+ vm.func @test_call_r_v() {
%ref = vm.const.ref.zero : !vm.ref<?>
vm.call @_r_v(%ref) : (!vm.ref<?>) -> ()
vm.return
}
// Check that reused ref argument slots are handled properly
- // TODO(#7487): Enable the test for emitc.
- vm.export @test_call_r_v_reuse_reg attributes {emitc.exclude}
- vm.func private @test_call_r_v_reuse_reg() {
+ vm.export @test_call_r_v_reuse_reg
+ vm.func @test_call_r_v_reuse_reg() {
%ref = vm.const.ref.zero : !vm.buffer
%unused = vm.const.ref.zero : !vm.buffer
vm.call @_r_v_reuse_reg(%ref, %unused) : (!vm.buffer, !vm.buffer) -> ()
@@ -39,7 +37,7 @@
// of the tests during the lattter. This means we would need to add a pattern
// that inserts calls to `iree_vm_ref_retain` for operand/result pairs of the
// do_not_optimize op.
- // TODO(#7487): Enable the test for emitc.
+ // TODO(simon-camp): Enable the test for emitc.
vm.export @test_call_r_v_preserve_ref attributes {emitc.exclude}
vm.func private @test_call_r_v_preserve_ref() {
%ref = vm.const.ref.zero : !vm.buffer
diff --git a/iree/vm/test/control_flow_ops.mlir b/iree/vm/test/control_flow_ops.mlir
index 2b6d146..22663fe 100644
--- a/iree/vm/test/control_flow_ops.mlir
+++ b/iree/vm/test/control_flow_ops.mlir
@@ -71,9 +71,8 @@
vm.fail %code, "unreachable!"
}
- // TODO(#7487): Enable the test for emitc.
- vm.export @test_cond_br_ref_arg attributes {emitc.exclude}
- vm.func private @test_cond_br_ref_arg() {
+ vm.export @test_cond_br_ref_arg
+ vm.func @test_cond_br_ref_arg() {
%c1 = vm.const.i32 1 : i32
%c1dno = util.do_not_optimize(%c1) : i32
%ref = vm.const.ref.zero : !vm.ref<?>