Adding `vm.select.ref` lowering to emitc.
diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp b/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp
index 5eef577..9df179f 100644
--- a/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp
+++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp
@@ -1063,7 +1063,6 @@
if (typeName[0] == '!') {
typeName = typeName.substr(1);
}
- typeName = std::string("\"") + typeName + std::string("\"");
Value stringView =
emitc_builders::ireeMakeCstringView(builder, loc, typeName);
@@ -2947,6 +2946,107 @@
}
};
+class SelectRefOpConversion
+ : public EmitCConversionPattern<IREE::VM::SelectRefOp> {
+ using Adaptor = typename IREE::VM::SelectRefOp::Adaptor;
+ using EmitCConversionPattern<IREE::VM::SelectRefOp>::EmitCConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(IREE::VM::SelectRefOp selectOp, Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto ctx = selectOp.getContext();
+ auto loc = selectOp.getLoc();
+
+ auto moduleOp =
+ selectOp.getOperation()->template getParentOfType<IREE::VM::ModuleOp>();
+ auto funcOp = selectOp.getOperation()
+ ->template getParentOfType<mlir::emitc::FuncOp>();
+ auto &funcAnalysis = getModuleAnalysis().lookupFunction(funcOp);
+
+ const BlockArgument moduleArg = funcOp.getArgument(CCONV_ARGUMENT_MODULE);
+ auto resultTypePtr =
+ createVmTypeDefPtr(rewriter, loc, this->getModuleAnalysis(), moduleOp,
+ moduleArg, selectOp.getType());
+ if (!resultTypePtr.has_value()) {
+ return selectOp->emitError() << "generating iree_vm_type_def_t* failed";
+ }
+ auto resultTypeAsRef =
+ rewriter
+ .create<emitc::CallOpaqueOp>(
+ /*location=*/loc,
+ /*type=*/emitc::OpaqueType::get(ctx, "iree_vm_ref_type_t"),
+ /*callee=*/StringAttr::get(ctx, "iree_vm_type_def_as_ref"),
+ /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{},
+ /*operands=*/ArrayRef<Value>{resultTypePtr.value()})
+ .getResult(0);
+
+ bool moveTrue =
+ funcAnalysis.isMove(selectOp.getTrueValue(), selectOp.getOperation());
+ bool moveFalse =
+ funcAnalysis.isMove(selectOp.getFalseValue(), selectOp.getOperation());
+
+ Value refTrue =
+ this->getModuleAnalysis().lookupRef(selectOp.getTrueValue());
+ Value refFalse =
+ this->getModuleAnalysis().lookupRef(selectOp.getFalseValue());
+ Value refResult = this->getModuleAnalysis().lookupRef(selectOp.getResult());
+
+ Type boolType = rewriter.getI1Type();
+ auto condition = rewriter.create<IREE::VM::CmpNZI32Op>(
+ loc, rewriter.getI32Type(), selectOp.getCondition());
+ auto conditionI1 = rewriter.create<emitc::CastOp>(
+ /*location=*/loc,
+ /*type=*/boolType,
+ /*operand=*/condition.getResult());
+
+ auto *continueBlock =
+ rewriter.splitBlock(selectOp->getBlock(), Block::iterator(selectOp));
+
+ Block *trueBlock = nullptr;
+ {
+ OpBuilder::InsertionGuard guard(rewriter);
+ trueBlock = rewriter.createBlock(continueBlock);
+ returnIfError(
+ /*rewriter=*/rewriter,
+ /*location=*/loc,
+ /*callee=*/StringAttr::get(ctx, "iree_vm_ref_retain_or_move_checked"),
+ /*args=*/
+ ArrayAttr::get(
+ ctx, {rewriter.getBoolAttr(moveTrue), rewriter.getIndexAttr(0),
+ rewriter.getIndexAttr(1), rewriter.getIndexAttr(2)}),
+ /*operands=*/
+ ArrayRef<Value>{refTrue, resultTypeAsRef, refResult},
+ this->getModuleAnalysis());
+ rewriter.create<IREE::VM::BranchOp>(loc, continueBlock);
+ }
+
+ Block *falseBlock = nullptr;
+ {
+ OpBuilder::InsertionGuard guard(rewriter);
+ falseBlock = rewriter.createBlock(continueBlock);
+ returnIfError(
+ /*rewriter=*/rewriter,
+ /*location=*/loc,
+ /*callee=*/StringAttr::get(ctx, "iree_vm_ref_retain_or_move_checked"),
+ /*args=*/
+ ArrayAttr::get(
+ ctx, {rewriter.getBoolAttr(moveFalse), rewriter.getIndexAttr(0),
+ rewriter.getIndexAttr(1), rewriter.getIndexAttr(2)}),
+ /*operands=*/
+ ArrayRef<Value>{refFalse, resultTypeAsRef, refResult},
+ this->getModuleAnalysis());
+ rewriter.create<IREE::VM::BranchOp>(loc, continueBlock);
+ }
+
+ rewriter.setInsertionPointAfterValue(conditionI1);
+ rewriter.create<mlir::cf::CondBranchOp>(loc, conditionI1.getResult(),
+ trueBlock, falseBlock);
+ rewriter.replaceOp(selectOp, refResult);
+
+ return success();
+ }
+};
+
template <typename OpTy>
class ConstOpConversion : public EmitCConversionPattern<OpTy> {
using Adaptor = typename OpTy::Adaptor;
@@ -3429,12 +3529,8 @@
releaseRefs(rewriter, loc, funcOp, getModuleAnalysis());
- std::string messageStr = std::string("\"") +
- op.getMessage().value_or("").str() +
- std::string("\"");
-
- Value message =
- emitc_builders::ireeMakeCstringView(rewriter, loc, messageStr);
+ Value message = emitc_builders::ireeMakeCstringView(
+ rewriter, loc, op.getMessage().value_or("").str());
auto messageSizeOp = emitc_builders::structMember(
rewriter, loc,
@@ -4430,6 +4526,7 @@
CallOpConversion<IREE::VM::CallOp>,
CallOpConversion<IREE::VM::CallVariadicOp>,
CompareRefNotZeroOpConversion,
+ SelectRefOpConversion,
CondBranchOpConversion,
BranchTableOpConversion,
ConstOpConversion<IREE::VM::ConstF32Op>,
diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/EmitCBuilders.cpp b/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/EmitCBuilders.cpp
index e817f9e..3076f2d 100644
--- a/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/EmitCBuilders.cpp
+++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/EmitCBuilders.cpp
@@ -299,6 +299,11 @@
Value ireeMakeCstringView(OpBuilder builder, Location location,
std::string str) {
+ std::string escapedStr;
+ llvm::raw_string_ostream os(escapedStr);
+ os.write_escaped(str);
+ auto quotedStr = std::string("\"") + escapedStr + std::string("\"");
+
auto ctx = builder.getContext();
return builder
.create<emitc::CallOpaqueOp>(
@@ -306,7 +311,7 @@
/*type=*/emitc::OpaqueType::get(ctx, "iree_string_view_t"),
/*callee=*/StringAttr::get(ctx, "iree_make_cstring_view"),
/*args=*/
- ArrayAttr::get(ctx, {emitc::OpaqueAttr::get(ctx, str)}),
+ ArrayAttr::get(ctx, {emitc::OpaqueAttr::get(ctx, quotedStr)}),
/*templateArgs=*/ArrayAttr{},
/*operands=*/ArrayRef<Value>{})
.getResult(0);
diff --git a/runtime/src/iree/vm/test/assignment_ops.mlir b/runtime/src/iree/vm/test/assignment_ops.mlir
index 1388c1e..891165d 100644
--- a/runtime/src/iree/vm/test/assignment_ops.mlir
+++ b/runtime/src/iree/vm/test/assignment_ops.mlir
@@ -17,7 +17,7 @@
vm.return
}
- vm.export @test_select_ref attributes {emitc.exclude}
+ vm.export @test_select_ref
vm.func private @test_select_ref() {
%c0 = vm.const.i32 0
%list0 = vm.list.alloc %c0 : (i32) -> !vm.list<i8>