Support global ref ops and fix passing of refs on function boundaries in the C target (#6835)
diff --git a/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp b/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp index 906ad50..d5c54b4 100644 --- a/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp +++ b/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp
@@ -57,6 +57,82 @@ return None; } +/// Create a call to memset to clear a struct +LogicalResult clearStruct(OpBuilder builder, Value structValue, + bool isPointer) { + auto ctx = structValue.getContext(); + auto loc = structValue.getLoc(); + + Type type = structValue.getType(); + + if (!type.isa<emitc::OpaqueType>()) { + return failure(); + } + + Optional<std::string> cType = getCType(type); + + if (!cType.hasValue()) { + return failure(); + } + + Value structPointerValue; + Value sizeValue; + + if (isPointer) { + std::string pointerType = cType.getValue(); + if (pointerType.back() != '*') { + return failure(); + } + + std::string nonPointerType = pointerType.substr(0, pointerType.size() - 1); + + auto size = builder.create<emitc::CallOp>( + /*location=*/loc, + /*type=*/builder.getI32Type(), + /*callee=*/StringAttr::get(ctx, "sizeof"), + /*args=*/ + ArrayAttr::get(ctx, {emitc::OpaqueAttr::get(ctx, nonPointerType)}), + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ArrayRef<Value>{}); + + structPointerValue = structValue; + sizeValue = size.getResult(0); + } else { + std::string cPtrType = cType.getValue() + "*"; + + auto structPointer = builder.create<emitc::ApplyOp>( + /*location=*/loc, + /*result=*/emitc::OpaqueType::get(ctx, cPtrType), + /*applicableOperator=*/StringAttr::get(ctx, "&"), + /*operand=*/structValue); + + auto size = builder.create<emitc::CallOp>( + /*location=*/loc, + /*type=*/builder.getI32Type(), + /*callee=*/StringAttr::get(ctx, "sizeof"), + /*args=*/ArrayAttr{}, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ArrayRef<Value>{structValue}); + + structPointerValue = structPointer.getResult(); + sizeValue = size.getResult(0); + } + + builder.create<emitc::CallOp>( + /*location=*/loc, + /*type=*/TypeRange{}, + /*callee=*/StringAttr::get(ctx, "memset"), + /*args=*/ + ArrayAttr::get(ctx, + {builder.getIndexAttr(0), builder.getUI32IntegerAttr(0), + builder.getIndexAttr(1)}), + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ + ArrayRef<Value>{structPointerValue, sizeValue}); + + return success(); +} + LogicalResult convertFuncOp(IREE::VM::FuncOp funcOp, VMAnalysisCache &vmAnalysisCache) { auto ctx = funcOp.getContext(); @@ -152,31 +228,9 @@ refOp.getOperation()->setAttr("ref_ordinal", builder.getIndexAttr(i + numRefArgs)); - auto refPtrOp = builder.create<emitc::ApplyOp>( - /*location=*/loc, - /*type=*/emitc::OpaqueType::get(ctx, "iree_vm_ref_t*"), - /*applicableOperator=*/StringAttr::get(ctx, "&"), - /*operand=*/refOp.getResult()); - - auto refSizeOp = builder.create<emitc::CallOp>( - /*location=*/loc, - /*type=*/builder.getI32Type(), - /*callee=*/StringAttr::get(ctx, "sizeof"), - /*args=*/ArrayAttr{}, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ArrayRef<Value>{refOp.getResult()}); - - builder.create<emitc::CallOp>( - /*location=*/loc, - /*type=*/TypeRange{}, - /*callee=*/StringAttr::get(ctx, "memset"), - /*args=*/ - ArrayAttr::get(ctx, - {builder.getIndexAttr(0), builder.getUI32IntegerAttr(0), - builder.getIndexAttr(1)}), - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ - ArrayRef<Value>{refPtrOp.getResult(), refSizeOp.getResult(0)}); + if (failed(clearStruct(builder, refOp.getResult(), /*isPointer=*/false))) { + return failure(); + } } vmAnalysisCache.insert( @@ -197,63 +251,141 @@ return std::string("call_") + callingConvention.getValue() + "_import"; } -template <typename SrcOpTy> Optional<emitc::ApplyOp> createVmTypeDefPtr(ConversionPatternRewriter &rewriter, - SrcOpTy srcOp, Type elementType) { - auto ctx = srcOp.getContext(); - auto loc = srcOp.getLoc(); + Operation *srcOp, + Type elementType) { + auto ctx = srcOp->getContext(); + auto loc = srcOp->getLoc(); - // TODO(simon-camp): Cleanup this up - StringRef elementTypeConstructor; - std::string elementTypeConstructorArg; - if (auto intType = elementType.template dyn_cast<IntegerType>()) { - unsigned int bitWidth = intType.getIntOrFloatBitWidth(); - elementTypeConstructor = "iree_vm_type_def_make_value_type"; - elementTypeConstructorArg = - std::string("IREE_VM_VALUE_TYPE_I") + std::to_string(bitWidth); - } else if (auto refType = - elementType.template dyn_cast<IREE::VM::RefType>()) { - auto objType = refType.getObjectType(); + // Map from type to enum values of type iree_vm_value_type_t and + // iree_vm_ref_type_t + mlir::DenseMap<Type, std::pair<std::string, std::string>> valueTypeMap = { + {IntegerType::get(ctx, 8), + {"IREE_VM_VALUE_TYPE_I8", "IREE_VM_REF_TYPE_NULL"}}, + {IntegerType::get(ctx, 16), + {"IREE_VM_VALUE_TYPE_I16", "IREE_VM_REF_TYPE_NULL"}}, + {IntegerType::get(ctx, 32), + {"IREE_VM_VALUE_TYPE_I32", "IREE_VM_REF_TYPE_NULL"}}, + {IntegerType::get(ctx, 64), + {"IREE_VM_VALUE_TYPE_I64", "IREE_VM_REF_TYPE_NULL"}}, + {Float32Type::get(ctx), + {"IREE_VM_VALUE_TYPE_F32", "IREE_VM_REF_TYPE_NULL"}}, + {Float64Type::get(ctx), + {"IREE_VM_VALUE_TYPE_F64", "IREE_VM_REF_TYPE_NULL"}}, + {IREE::VM::OpaqueType::get(ctx), + {"IREE_VM_VALUE_TYPE_NONE", "IREE_VM_REF_TYPE_NULL"}}, + }; - elementTypeConstructor = "iree_vm_type_def_make_ref_type"; - if (objType.template isa<IREE::VM::BufferType>()) { - elementTypeConstructorArg = "iree_vm_buffer_type_id()"; - } else if (objType.template isa<IREE::VM::ListType>()) { - elementTypeConstructorArg = "iree_vm_list_type_id()"; - } else { - srcOp.emitError() << "Unhandled ref object type " << objType; - return None; - } - } else if (auto opaqueType = - elementType.template dyn_cast<IREE::VM::OpaqueType>()) { - elementTypeConstructor = "iree_vm_type_def_make_variant_type"; - elementTypeConstructorArg = ""; - } else { - srcOp.emitError() << "Unhandled element type " << elementType; + auto elementTypeOp = rewriter.create<emitc::ConstantOp>( + /*location=*/loc, + /*resultType=*/emitc::OpaqueType::get(ctx, "iree_vm_type_def_t"), + /*value=*/emitc::OpaqueAttr::get(ctx, "")); + + if (failed(clearStruct(rewriter, elementTypeOp.getResult(), + /*isPointer=*/false))) { return None; } - auto elementTypeOp = rewriter.template create<emitc::CallOp>( - /*location=*/loc, - /*type=*/emitc::OpaqueType::get(ctx, "iree_vm_type_def_t"), - /*callee=*/StringAttr::get(ctx, elementTypeConstructor), - /*args=*/ - ArrayAttr::get(ctx, - {emitc::OpaqueAttr::get(ctx, elementTypeConstructorArg)}), - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ArrayRef<Value>{}); + auto ptr = valueTypeMap.find((elementType)); + if (ptr != valueTypeMap.end()) { + rewriter.create<emitc::CallOp>( + /*location=*/loc, + /*type=*/TypeRange{}, + /*callee=*/StringAttr::get(ctx, "EMITC_STRUCT_MEMBER_ASSIGN"), + /*args=*/ + ArrayAttr::get(ctx, {rewriter.getIndexAttr(0), + emitc::OpaqueAttr::get(ctx, "value_type"), + emitc::OpaqueAttr::get(ctx, ptr->second.first)}), + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ArrayRef<Value>{elementTypeOp.getResult()}); - auto elementTypePtrOp = rewriter.template create<emitc::ApplyOp>( + rewriter.create<emitc::CallOp>( + /*location=*/loc, + /*type=*/TypeRange{}, + /*callee=*/StringAttr::get(ctx, "EMITC_STRUCT_MEMBER_ASSIGN"), + /*args=*/ + ArrayAttr::get(ctx, {rewriter.getIndexAttr(0), + emitc::OpaqueAttr::get(ctx, "ref_type"), + emitc::OpaqueAttr::get(ctx, ptr->second.second)}), + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ArrayRef<Value>{elementTypeOp.getResult()}); + } else { + if (!elementType.isa<IREE::VM::RefType>()) { + return None; + } + Type objType = elementType.cast<IREE::VM::RefType>().getObjectType(); + + std::string typeName; + + if (objType.isa<IREE::VM::ListType>()) { + typeName = "!vm.list"; + } else { + llvm::raw_string_ostream sstream(typeName); + objType.print(sstream); + sstream.flush(); + } + + // Remove leading '!' and wrap in quotes + typeName = std::string("\"") + typeName.substr(1) + std::string("\""); + + auto typeNameCStringView = rewriter.create<emitc::CallOp>( + /*location=*/loc, + /*type=*/emitc::OpaqueType::get(ctx, "iree_string_view_t"), + /*callee=*/StringAttr::get(ctx, "iree_make_cstring_view"), + /*args=*/ArrayAttr::get(ctx, {emitc::OpaqueAttr::get(ctx, typeName)}), + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ArrayRef<Value>{}); + + auto typeDescriptor = rewriter.create<emitc::CallOp>( + /*location=*/loc, + /*type=*/ + emitc::OpaqueType::get(ctx, "const iree_vm_ref_type_descriptor_t*"), + /*callee=*/StringAttr::get(ctx, "iree_vm_ref_lookup_registered_type"), + /*args=*/ArrayAttr{}, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ArrayRef<Value>{typeNameCStringView.getResult(0)}); + + // TODDO(simon-camp) typeDescriptor might be NULL + auto typeDescriptorType = rewriter.create<emitc::CallOp>( + /*location=*/loc, + /*type=*/emitc::OpaqueType::get(ctx, "iree_vm_ref_type_t"), + /*callee=*/StringAttr::get(ctx, "EMITC_STRUCT_PTR_MEMBER"), + /*args=*/ + ArrayAttr::get(ctx, {rewriter.getIndexAttr(0), + emitc::OpaqueAttr::get(ctx, "type")}), + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ArrayRef<Value>{typeDescriptor.getResult(0)}); + + rewriter.create<emitc::CallOp>( + /*location=*/loc, + /*type=*/TypeRange{}, + /*callee=*/StringAttr::get(ctx, "EMITC_STRUCT_MEMBER_ASSIGN"), + /*args=*/ + ArrayAttr::get(ctx, {rewriter.getIndexAttr(0), + emitc::OpaqueAttr::get(ctx, "ref_type"), + rewriter.getIndexAttr(1)}), + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ + ArrayRef<Value>{elementTypeOp.getResult(), + typeDescriptorType.getResult(0)}); + } + + auto elementTypePtrOp = rewriter.create<emitc::ApplyOp>( /*location=*/loc, /*result=*/emitc::OpaqueType::get(ctx, "iree_vm_type_def_t*"), /*applicableOperator=*/StringAttr::get(ctx, "&"), - /*operand=*/elementTypeOp.getResult(0)); + /*operand=*/elementTypeOp.getResult()); return elementTypePtrOp; } -Optional<Value> findRef(mlir::FuncOp &parentFuncOp, +/// Find a local ref of type `iree_vm_ref_t*` matching the ordinal of the given +/// `IREE::VM::Ref` value. +Optional<Value> findRef(OpBuilder builder, Location location, + mlir::FuncOp &parentFuncOp, VMAnalysisCache &vmAnalysisCache, Value refResult) { + auto ctx = builder.getContext(); + assert(refResult.getType().isa<IREE::VM::RefType>()); auto ptr = vmAnalysisCache.find(parentFuncOp.getOperation()); @@ -264,6 +396,20 @@ int32_t ordinal = ptr->second.getRefRegisterOrdinal(refResult); + // Search block arguments + int refArgCounter = 0; + for (BlockArgument arg : parentFuncOp.getArguments()) { + assert(!arg.getType().isa<IREE::VM::RefType>()); + + if (arg.getType() == emitc::OpaqueType::get(ctx, "iree_vm_ref_t*")) { + if (ordinal == refArgCounter++) { + ptr->second.remapValue(refResult, arg); + return arg; + } + } + } + + // Search local refs for (auto constantOp : parentFuncOp.getOps<emitc::ConstantOp>()) { Operation *op = constantOp.getOperation(); if (!op->hasAttr("ref_ordinal")) continue; @@ -271,7 +417,16 @@ .cast<IntegerAttr>() .getValue() .getZExtValue() == ordinal) { - return constantOp.getResult(); + // Get address of constant + auto ptrOp = builder.create<emitc::ApplyOp>( + /*location=*/location, + /*type=*/emitc::OpaqueType::get(ctx, "iree_vm_ref_t*"), + /*applicableOperator=*/StringAttr::get(ctx, "&"), + /*operand=*/constantOp.getResult()); + + ptr->second.remapValue(refResult, ptrOp.getResult()); + + return ptrOp.getResult(); } } @@ -575,7 +730,7 @@ /*templateArgs=*/ArrayAttr{}, /*operands=*/ArrayRef<Value>{}); - auto statePtr = builder.template create<emitc::ApplyOp>( + auto statePtr = builder.create<emitc::ApplyOp>( /*location=*/loc, /*result=*/emitc::OpaqueType::get(ctx, moduleStateTypeName + "*"), /*applicableOperator=*/StringAttr::get(ctx, "&"), @@ -697,6 +852,53 @@ buffer.getResult(0)}); } + // Zero out refs + auto ordinal_counts = moduleOp.ordinal_counts(); + + if (!ordinal_counts.hasValue()) { + return moduleOp.emitError() + << "ordinal_counts attribute not found. The OrdinalAllocationPass " + "must be run before."; + } + + const int numGlobalRefs = ordinal_counts.getValue().global_refs(); + + auto refs = builder.create<emitc::CallOp>( + /*location=*/loc, + /*type=*/emitc::OpaqueType::get(ctx, "iree_vm_ref_t*"), + /*callee=*/StringAttr::get(ctx, "EMITC_STRUCT_PTR_MEMBER"), + /*args=*/ + ArrayAttr::get(ctx, {builder.getIndexAttr(0), + emitc::OpaqueAttr::get(ctx, "refs")}), + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ArrayRef<Value>{stateOp.getResult()}); + + auto refSizeOp = builder.create<emitc::CallOp>( + /*location=*/loc, + /*type=*/builder.getI32Type(), + /*callee=*/StringAttr::get(ctx, "sizeof"), + /*args=*/ + ArrayAttr::get(ctx, {emitc::OpaqueAttr::get(ctx, "iree_vm_ref_t")}), + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ArrayRef<Value>{}); + + for (int i = 0; i < numGlobalRefs; i++) { + auto refPtrOp = builder.create<emitc::CallOp>( + /*location=*/loc, + /*type=*/emitc::OpaqueType::get(ctx, "iree_vm_ref_t*"), + /*callee=*/StringAttr::get(ctx, "EMITC_ARRAY_ELEMENT_ADDRESS"), + /*args=*/ + ArrayAttr::get( + ctx, {builder.getIndexAttr(0), builder.getUI32IntegerAttr(i)}), + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ArrayRef<Value>{refs.getResult(0)}); + + if (failed(clearStruct(builder, refPtrOp.getResult(0), + /*isPointer=*/true))) { + return failure(); + } + } + auto baseStateOp = builder.create<emitc::CallOp>( /*location=*/loc, /*type=*/emitc::OpaqueType::get(ctx, "iree_vm_module_state_t*"), @@ -896,7 +1098,7 @@ /*templateArgs=*/ArrayAttr{}, /*operands=*/ArrayRef<Value>{}); - auto modulePtr = builder.template create<emitc::ApplyOp>( + auto modulePtr = builder.create<emitc::ApplyOp>( /*location=*/loc, /*result=*/emitc::OpaqueType::get(ctx, moduleTypeName + "*"), /*applicableOperator=*/StringAttr::get(ctx, "&"), @@ -945,7 +1147,7 @@ /*resultType=*/emitc::OpaqueType::get(ctx, "iree_vm_module_t"), /*value=*/emitc::OpaqueAttr::get(ctx, "")); - auto vmModulePtr = builder.template create<emitc::ApplyOp>( + auto vmModulePtr = builder.create<emitc::ApplyOp>( /*location=*/loc, /*result=*/emitc::OpaqueType::get(ctx, "iree_vm_module_t*"), /*applicableOperator=*/StringAttr::get(ctx, "&"), @@ -1048,15 +1250,13 @@ })); } -template <typename AccessOpTy, typename ResultOpTy> -ResultOpTy lookupSymbolRef(AccessOpTy accessOp, StringRef attrName) { +template <typename ResultOpTy> +ResultOpTy lookupSymbolRef(Operation *accessOp, StringRef attrName) { FlatSymbolRefAttr globalAttr = - accessOp.getOperation()->template getAttrOfType<FlatSymbolRefAttr>( - attrName); + accessOp->getAttrOfType<FlatSymbolRefAttr>(attrName); ResultOpTy globalOp = - accessOp.getOperation() - ->template getParentOfType<IREE::VM::ModuleOp>() - .template lookupSymbol<ResultOpTy>(globalAttr.getValue()); + accessOp->getParentOfType<IREE::VM::ModuleOp>().lookupSymbol<ResultOpTy>( + globalAttr.getValue()); return globalOp; } @@ -1123,12 +1323,21 @@ TypeConverter::SignatureConversion signatureConverter( funcOp.getType().getNumInputs()); TypeConverter typeConverter; - for (const auto &arg : llvm::enumerate(funcOp.getArgumentTypes())) { - Type convertedType = getTypeConverter()->convertType(arg.value()); + for (const auto &arg : llvm::enumerate(funcOp.getArguments())) { + Type convertedType = + getTypeConverter()->convertType(arg.value().getType()); signatureConverter.addInputs(arg.index(), convertedType); } - rewriter.applySignatureConversion(&funcOp.getBody(), signatureConverter); + Block &entryBlock = funcOp.getBlocks().front(); + + Block *newEntryBlock = rewriter.applySignatureConversion( + &funcOp.getBody(), signatureConverter); + + auto ptr = vmAnalysisCache.find(funcOp.getOperation()); + if (ptr == vmAnalysisCache.end()) { + return funcOp.emitError() << "parent func op not found in cache."; + } // Creates a new function with the updated signature. rewriter.updateRootInPlace(funcOp, [&] { @@ -1156,9 +1365,9 @@ IREE::VM::CallOp op, ArrayRef<Value> operands, ConversionPatternRewriter &rewriter) const override { mlir::FuncOp funcOp = - lookupSymbolRef<IREE::VM::CallOp, mlir::FuncOp>(op, "callee"); + lookupSymbolRef<mlir::FuncOp>(op.getOperation(), "callee"); IREE::VM::ImportOp importOp = - lookupSymbolRef<IREE::VM::CallOp, IREE::VM::ImportOp>(op, "callee"); + lookupSymbolRef<IREE::VM::ImportOp>(op.getOperation(), "callee"); if (!funcOp && !importOp) return op.emitError() << "lookup of callee failed"; @@ -1230,7 +1439,7 @@ int importOrdinal = importOp.ordinal().getValue().getZExtValue(); - auto funcOp = op.getOperation()->template getParentOfType<mlir::FuncOp>(); + auto funcOp = op.getOperation()->getParentOfType<mlir::FuncOp>(); BlockArgument stackArg = funcOp.getArgument(0); BlockArgument stateArg = funcOp.getArgument(2); @@ -1285,22 +1494,62 @@ auto ctx = op.getContext(); auto loc = op.getLoc(); - auto funcOp = op.getOperation()->template getParentOfType<mlir::FuncOp>(); + auto funcOp = op.getOperation()->getParentOfType<mlir::FuncOp>(); - for (const Value &operand : operands) { - updatedOperands.push_back(operand); + auto ptr = vmAnalysisCache.find(funcOp.getOperation()); + if (ptr == vmAnalysisCache.end()) { + return op.emitError() << "parent func op not found in cache."; } - // Create a variable for every non-ref result and a pointer to it as output + for (Value operand : operands) { + assert(!operand.getType().isa<IREE::VM::RefType>()); + + if (operand.getType() == emitc::OpaqueType::get(ctx, "iree_vm_ref_t*")) { + auto refOp = rewriter.create<emitc::ConstantOp>( + /*location=*/loc, + /*resultType=*/emitc::OpaqueType::get(ctx, "iree_vm_ref_t"), + /*value=*/emitc::OpaqueAttr::get(ctx, "")); + + auto refPtrOp = rewriter.create<emitc::ApplyOp>( + /*location=*/loc, + /*type=*/emitc::OpaqueType::get(ctx, "iree_vm_ref_t*"), + /*applicableOperator=*/StringAttr::get(ctx, "&"), + /*operand=*/refOp.getResult()); + + if (failed(clearStruct(rewriter, refPtrOp.getResult(), + /*isPointer=*/true))) { + return failure(); + } + + bool move = ptr->second.isLastValueUse(operand, op.getOperation()); + + auto assignOp = rewriter.create<emitc::CallOp>( + /*location=*/loc, + /*type=*/TypeRange{}, + /*callee=*/StringAttr::get(ctx, "iree_vm_ref_retain_or_move"), + /*args=*/ + ArrayAttr::get( + ctx, {rewriter.getBoolAttr(move), rewriter.getIndexAttr(0), + rewriter.getIndexAttr(1)}), + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ArrayRef<Value>{operand, refPtrOp.getResult()}); + + updatedOperands.push_back(refPtrOp.getResult()); + } else { + updatedOperands.push_back(operand); + } + } + + // Create a variable for every result and a pointer to it as output // parameter to the call. for (OpResult result : op.getResults()) { emitc::ConstantOp resultOp; if (result.getType().isa<IREE::VM::RefType>()) { - auto ref = findRef(funcOp, vmAnalysisCache, result); + auto ref = findRef(rewriter, loc, funcOp, vmAnalysisCache, result); if (!ref.hasValue()) { - return failure(); + return op.emitError() << "local ref not found"; } // Keep track of the replaced value in the analysis to keep the value @@ -1310,22 +1559,10 @@ return op.emitError() << "parent func op not found in cache."; } - Optional<std::string> cType = getCType(ref.getValue().getType()); - if (!cType.hasValue()) { - return op.emitError() << "unable to emit C type"; - } + ptr->second.remapValue(result, ref.getValue()); - std::string cPtrType = cType.getValue() + std::string("*"); - auto resultPtrOp = rewriter.create<emitc::ApplyOp>( - /*location=*/loc, - /*type=*/emitc::OpaqueType::get(ctx, cPtrType), - /*applicableOperator=*/StringAttr::get(ctx, "&"), - /*operand=*/ref.getValue()); - - ptr->second.remapValue(result, resultPtrOp.getResult()); - - resultOperands.push_back(resultPtrOp.getResult()); - updatedOperands.push_back(resultPtrOp.getResult()); + resultOperands.push_back(ref.getValue()); + updatedOperands.push_back(ref.getValue()); } else { resultOp = rewriter.create<emitc::ConstantOp>( /*location=*/loc, @@ -1519,42 +1756,41 @@ }; class ConstRefZeroOpConversion - : public OpRewritePattern<IREE::VM::ConstRefZeroOp> { + : public OpConversionPattern<IREE::VM::ConstRefZeroOp> { public: - using OpRewritePattern<IREE::VM::ConstRefZeroOp>::OpRewritePattern; + using OpConversionPattern<IREE::VM::ConstRefZeroOp>::OpConversionPattern; ConstRefZeroOpConversion(MLIRContext *context, VMAnalysisCache &vmAnalysisCache) - : OpRewritePattern<IREE::VM::ConstRefZeroOp>(context), + : OpConversionPattern<IREE::VM::ConstRefZeroOp>(context), vmAnalysisCache(vmAnalysisCache) {} - LogicalResult matchAndRewrite(IREE::VM::ConstRefZeroOp constRefZeroOp, - PatternRewriter &rewriter) const final { + LogicalResult matchAndRewrite( + IREE::VM::ConstRefZeroOp constRefZeroOp, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const final { auto ctx = constRefZeroOp.getContext(); auto loc = constRefZeroOp.getLoc(); auto funcOp = constRefZeroOp.getOperation()->getParentOfType<mlir::FuncOp>(); - auto ref = findRef(funcOp, vmAnalysisCache, constRefZeroOp.getResult()); + auto ref = findRef(rewriter, loc, funcOp, vmAnalysisCache, + constRefZeroOp.getResult()); if (!ref.hasValue()) { - return failure(); + return constRefZeroOp.emitError() << "local ref not found"; } - auto refPtrOp = rewriter.replaceOpWithNewOp<emitc::ApplyOp>( - /*op=*/constRefZeroOp, - /*type=*/emitc::OpaqueType::get(ctx, "iree_vm_ref_t*"), - /*applicableOperator=*/StringAttr::get(ctx, "&"), - /*operand=*/ref.getValue()); - rewriter.create<emitc::CallOp>( /*location=*/loc, /*type=*/TypeRange{}, /*callee=*/StringAttr::get(ctx, "iree_vm_ref_release"), /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/ArrayRef<Value>{refPtrOp.result()}); + /*operands=*/ArrayRef<Value>{ref.getValue()}); + + rewriter.replaceOp(constRefZeroOp, ref.getValue()); + return success(); } @@ -1577,15 +1813,14 @@ auto ctx = constRefRodataOp.getContext(); auto loc = constRefRodataOp.getLoc(); - auto rodataOp = - lookupSymbolRef<IREE::VM::ConstRefRodataOp, IREE::VM::RodataOp>( - constRefRodataOp, "rodata"); + auto rodataOp = lookupSymbolRef<IREE::VM::RodataOp>( + constRefRodataOp.getOperation(), "rodata"); if (!rodataOp) { return constRefRodataOp.emitError() << "Unable to find RodataOp"; } - auto funcOp = constRefRodataOp.getOperation() - ->template getParentOfType<mlir::FuncOp>(); + auto funcOp = + constRefRodataOp.getOperation()->getParentOfType<mlir::FuncOp>(); BlockArgument stateArg = funcOp.getArgument(2); auto rodataBuffersPtr = rewriter.create<emitc::CallOp>( @@ -1618,15 +1853,12 @@ /*templateArgs=*/ArrayAttr{}, /*operands=*/ArrayRef<Value>{}); - auto ref = findRef(funcOp, vmAnalysisCache, constRefRodataOp.getResult()); + auto ref = findRef(rewriter, loc, funcOp, vmAnalysisCache, + constRefRodataOp.getResult()); - if (!ref.hasValue()) return failure(); - - auto refPtrOp = rewriter.replaceOpWithNewOp<emitc::ApplyOp>( - /*op=*/constRefRodataOp, - /*type=*/emitc::OpaqueType::get(ctx, "iree_vm_ref_t*"), - /*applicableOperator=*/StringAttr::get(ctx, "&"), - /*operand=*/ref.getValue()); + if (!ref.hasValue()) { + return constRefRodataOp.emitError() << "local ref not found"; + } returnIfError( /*rewriter=*/rewriter, @@ -1636,7 +1868,9 @@ /*templateArgs=*/ArrayAttr{}, /*operands=*/ ArrayRef<Value>{byteBufferPtrOp.getResult(0), typeIdOp.getResult(0), - refPtrOp.getResult()}); + ref.getValue()}); + + rewriter.replaceOp(constRefRodataOp, ref.getValue()); return success(); } @@ -1654,6 +1888,15 @@ auto ctx = op.getContext(); auto loc = op.getLoc(); + if (llvm::any_of(operands, [&ctx](Value operand) { + Type type = operand.getType(); + assert(!type.isa<IREE::VM::RefType>()); + return type == emitc::OpaqueType::get(ctx, "iree_vm_ref_t*"); + })) { + return op.emitError() + << "basic block arguments with ref type not supported"; + } + rewriter.replaceOpWithNewOp<mlir::BranchOp>(op, op.dest(), op.destOperands()); @@ -1672,6 +1915,15 @@ auto ctx = op.getContext(); auto loc = op.getLoc(); + if (llvm::any_of(operands, [&ctx](Value operand) { + Type type = operand.getType(); + assert(!type.isa<IREE::VM::RefType>()); + return type == emitc::OpaqueType::get(ctx, "iree_vm_ref_t*"); + })) { + return op.emitError() + << "basic block arguments with ref type not supported"; + } + Type boolType = rewriter.getI1Type(); auto condition = rewriter.create<IREE::VM::CmpNZI32Op>( @@ -1706,8 +1958,6 @@ auto funcOp = op.getOperation()->getParentOfType<mlir::FuncOp>(); - releaseLocalRefs(rewriter, loc, funcOp); - // The result variables are the last N arguments of the function. unsigned int firstOutputArgumentIndex = funcOp.getNumArguments() - operands.size(); @@ -1716,23 +1966,28 @@ unsigned int argumentIndex = firstOutputArgumentIndex + operand.index(); BlockArgument resultArgument = funcOp.getArgument(argumentIndex); - auto isRef = [&ctx](Type type) { - return type == emitc::OpaqueType::get(ctx, "iree_vm_ref_t*"); - }; - - StringRef assignMacro = isRef(operand.value().getType()) - ? "EMITC_ASSIGN_VALUE" - : "EMITC_DEREF_ASSIGN_VALUE"; - - rewriter.create<emitc::CallOp>( - /*location=*/loc, - /*type=*/TypeRange{}, - /*callee=*/StringAttr::get(ctx, assignMacro), - /*args=*/ArrayAttr{}, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ArrayRef<Value>{resultArgument, operand.value()}); + if (operand.value().getType() == + emitc::OpaqueType::get(ctx, "iree_vm_ref_t*")) { + rewriter.create<emitc::CallOp>( + /*location=*/loc, + /*type=*/TypeRange{}, + /*callee=*/StringAttr::get(ctx, "iree_vm_ref_move"), + /*args=*/ArrayAttr{}, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ArrayRef<Value>{operand.value(), resultArgument}); + } else { + rewriter.create<emitc::CallOp>( + /*location=*/loc, + /*type=*/TypeRange{}, + /*callee=*/StringAttr::get(ctx, "EMITC_DEREF_ASSIGN_VALUE"), + /*args=*/ArrayAttr{}, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ArrayRef<Value>{resultArgument, operand.value()}); + } } + releaseLocalRefs(rewriter, loc, funcOp); + auto status = rewriter.create<emitc::CallOp>( /*location=*/loc, /*type=*/emitc::OpaqueType::get(ctx, "iree_status_t"), @@ -1874,7 +2129,7 @@ auto loc = loadOp.getLoc(); GlobalOpTy globalOp = - lookupSymbolRef<LoadOpTy, GlobalOpTy>(loadOp, "global"); + lookupSymbolRef<GlobalOpTy>(loadOp.getOperation(), "global"); if (!globalOp) { return loadOp.emitError() << "Unable to find GlobalOp"; } @@ -1911,6 +2166,131 @@ StringRef funcName; }; +template <typename LoadStoreOpTy> +class GlobalLoadStoreRefOpConversion + : public OpConversionPattern<LoadStoreOpTy> { + public: + using OpConversionPattern<LoadStoreOpTy>::OpConversionPattern; + + GlobalLoadStoreRefOpConversion(MLIRContext *context, + VMAnalysisCache &vmAnalysisCache) + : OpConversionPattern<LoadStoreOpTy>(context), + vmAnalysisCache(vmAnalysisCache) {} + + private: + // TODO(simon-camp): Deduplicate code + LogicalResult matchAndRewrite( + LoadStoreOpTy op, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const override { + if (isa<IREE::VM::GlobalLoadRefOp>(op)) { + return rewriteOp(op.getOperation(), operands, rewriter, true); + } else if (isa<IREE::VM::GlobalStoreRefOp>(op)) { + return rewriteOp(op.getOperation(), operands, rewriter, false); + } + + return op.emitError() << "op must be one of `vm.global.load.ref` or " + "`vm.global.store.ref`"; + } + + LogicalResult rewriteOp(Operation *op, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter, + bool isLoad) const { + auto ctx = op->getContext(); + auto loc = op->getLoc(); + + IREE::VM::GlobalRefOp globalOp = + lookupSymbolRef<IREE::VM::GlobalRefOp>(op, "global"); + if (!globalOp) { + return op->emitError() << "Unable to find GlobalOp"; + } + + auto globalOrdinal = globalOp.ordinal().getValue().getZExtValue(); + + auto funcOp = op->getParentOfType<mlir::FuncOp>(); + + auto ptr = vmAnalysisCache.find(funcOp.getOperation()); + if (ptr == vmAnalysisCache.end()) { + return op->emitError() << "parent func op not found in cache."; + } + + Value localValue = isLoad ? op->getResult(0) : op->getOperand(0); + + bool move = ptr->second.isLastValueUse(localValue, op); + + auto localRef = findRef(rewriter, loc, funcOp, vmAnalysisCache, localValue); + + if (!localRef.hasValue()) { + return op->emitError() << "local ref not found"; + } + + BlockArgument stateArg = funcOp.getArgument(2); + auto refs = rewriter.create<emitc::CallOp>( + /*location=*/loc, + /*type=*/emitc::OpaqueType::get(ctx, "iree_vm_ref_t*"), + /*callee=*/StringAttr::get(ctx, "EMITC_STRUCT_PTR_MEMBER"), + /*args=*/ + ArrayAttr::get(ctx, {rewriter.getIndexAttr(0), + emitc::OpaqueAttr::get(ctx, "refs")}), + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ArrayRef<Value>{stateArg}); + + auto stateRef = rewriter.create<emitc::CallOp>( + /*location=*/loc, + /*type=*/emitc::OpaqueType::get(ctx, "iree_vm_ref_t*"), + /*callee=*/StringAttr::get(ctx, "EMITC_ARRAY_ELEMENT_ADDRESS"), + /*args=*/ + ArrayAttr::get(ctx, {rewriter.getIndexAttr(0), + rewriter.getUI32IntegerAttr(globalOrdinal)}), + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ArrayRef<Value>{refs.getResult(0)}); + + Type elementType = localValue.getType(); + + auto elementTypePtrOp = createVmTypeDefPtr(rewriter, op, elementType); + + if (!elementTypePtrOp.hasValue()) { + return op->emitError() << "generating iree_vm_type_def_t* failed"; + } + + auto typedefRefType = rewriter.create<emitc::CallOp>( + /*location=*/loc, + /*type=*/ + emitc::OpaqueType::get(ctx, "iree_vm_ref_type_t"), + /*callee=*/StringAttr::get(ctx, "EMITC_STRUCT_PTR_MEMBER"), + /*args=*/ + ArrayAttr::get(ctx, {rewriter.getIndexAttr(0), + emitc::OpaqueAttr::get(ctx, "ref_type")}), + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ + ArrayRef<Value>{elementTypePtrOp.getValue().getResult()}); + + Value srcRef = isLoad ? stateRef.getResult(0) : localRef.getValue(); + Value destRef = isLoad ? localRef.getValue() : stateRef.getResult(0); + + returnIfError( + /*rewriter=*/rewriter, + /*location=*/loc, + /*callee=*/StringAttr::get(ctx, "iree_vm_ref_retain_or_move_checked"), + /*args=*/ + ArrayAttr::get(ctx, + {rewriter.getBoolAttr(move), rewriter.getIndexAttr(0), + rewriter.getIndexAttr(1), rewriter.getIndexAttr(2)}), + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ + ArrayRef<Value>{srcRef, typedefRefType.getResult(0), destRef}); + + if (isLoad) { + rewriter.replaceOp(op, localRef.getValue()); + } else { + rewriter.eraseOp(op); + } + + return success(); + } + + VMAnalysisCache &vmAnalysisCache; +}; + template <typename StoreOpTy, typename GlobalOpTy> class GlobalStoreOpConversion : public OpConversionPattern<StoreOpTy> { using OpConversionPattern<StoreOpTy>::OpConversionPattern; @@ -1927,7 +2307,7 @@ auto loc = storeOp.getLoc(); GlobalOpTy globalOp = - lookupSymbolRef<StoreOpTy, GlobalOpTy>(storeOp, "global"); + lookupSymbolRef<GlobalOpTy>(storeOp.getOperation(), "global"); if (!globalOp) { return storeOp.emitError() << "Unable to find GlobalOp"; } @@ -2082,10 +2462,10 @@ .getElementType(); Optional<emitc::ApplyOp> elementTypePtrOp = - createVmTypeDefPtr(rewriter, allocOp, elementType); + createVmTypeDefPtr(rewriter, allocOp.getOperation(), elementType); if (!elementTypePtrOp.hasValue()) { - return failure(); + return allocOp.emitError() << "generating iree_vm_type_def_t* failed"; } auto listOp = rewriter.create<emitc::ConstantOp>( @@ -2099,8 +2479,7 @@ /*applicableOperator=*/StringAttr::get(ctx, "&"), /*operand=*/listOp.getResult()); - auto funcOp = - allocOp.getOperation()->template getParentOfType<mlir::FuncOp>(); + auto funcOp = allocOp.getOperation()->getParentOfType<mlir::FuncOp>(); BlockArgument stateArg = funcOp.getArgument(2); auto allocatorOp = rewriter.create<emitc::CallOp>( @@ -2123,15 +2502,12 @@ ArrayRef<Value>{elementTypePtrOp.getValue().getResult(), operands[0], allocatorOp.getResult(0), listPtrOp.getResult()}); - auto ref = findRef(funcOp, vmAnalysisCache, allocOp.getResult()); + auto ref = + findRef(rewriter, loc, funcOp, vmAnalysisCache, allocOp.getResult()); - if (!ref.hasValue()) return failure(); - - auto refPtrOp = rewriter.replaceOpWithNewOp<emitc::ApplyOp>( - /*op=*/allocOp, - /*type=*/emitc::OpaqueType::get(ctx, "iree_vm_ref_t*"), - /*applicableOperator=*/StringAttr::get(ctx, "&"), - /*operand=*/ref.getValue()); + if (!ref.hasValue()) { + return allocOp.emitError() << "local ref not found"; + } auto refTypeOp = rewriter.create<emitc::CallOp>( /*location=*/loc, @@ -2149,7 +2525,9 @@ /*templateArgs=*/ArrayAttr{}, /*operands=*/ ArrayRef<Value>{listOp.getResult(), refTypeOp.getResult(0), - refPtrOp.getResult()}); + ref.getValue()}); + + rewriter.replaceOp(allocOp, ref.getValue()); return success(); } @@ -2274,15 +2652,12 @@ auto funcOp = getOp.getOperation()->getParentOfType<mlir::FuncOp>(); - auto ref = findRef(funcOp, vmAnalysisCache, getOp.getResult()); + auto ref = + findRef(rewriter, loc, funcOp, vmAnalysisCache, getOp.getResult()); - if (!ref.hasValue()) return failure(); - - auto refPtrOp = rewriter.replaceOpWithNewOp<emitc::ApplyOp>( - /*op=*/getOp, - /*type=*/emitc::OpaqueType::get(ctx, "iree_vm_ref_t*"), - /*applicableOperator=*/StringAttr::get(ctx, "&"), - /*operand=*/ref.getValue()); + if (!ref.hasValue()) { + return getOp.emitError() << "local ref not found"; + } returnIfError( /*rewriter=*/rewriter, @@ -2292,14 +2667,15 @@ /*templateArgs=*/ArrayAttr{}, /*operands=*/ ArrayRef<Value>{listDerefOp.getResult(0), getOp.index(), - refPtrOp.getResult()}); + ref.getValue()}); Type elementType = getOp.getResult().getType(); - auto elementTypePtrOp = createVmTypeDefPtr(rewriter, getOp, elementType); + auto elementTypePtrOp = + createVmTypeDefPtr(rewriter, getOp.getOperation(), elementType); if (!elementTypePtrOp.hasValue()) { - return failure(); + return getOp.emitError() << "generating iree_vm_type_def_t* failed"; } // Build the following expression: @@ -2318,7 +2694,7 @@ emitc::OpaqueAttr::get(ctx, "type")}), /*templateArgs=*/ArrayAttr{}, /*operands=*/ - ArrayRef<Value>{refPtrOp.getResult()}); + ArrayRef<Value>{ref.getValue()}); auto refTypeNull = rewriter.create<emitc::ConstantOp>( /*location=*/loc, @@ -2405,7 +2781,7 @@ /*callee=*/StringAttr::get(ctx, "iree_vm_ref_release"), /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/ArrayRef<Value>{refPtrOp.getResult()}); + /*operands=*/ArrayRef<Value>{ref.getValue()}); rewriter.create<mlir::BranchOp>(loc, continuationBlock); } @@ -2414,6 +2790,8 @@ auto branchOp = rewriter.create<CondBranchOp>( loc, invalidType.getResult(0), failureBlock, continuationBlock); + rewriter.replaceOp(getOp, ref.getValue()); + return success(); } @@ -2549,10 +2927,12 @@ } // namespace void populateVMToEmitCPatterns(MLIRContext *context, + ConversionTarget &conversionTarget, IREE::VM::EmitCTypeConverter &typeConverter, OwningRewritePatternList &patterns, VMAnalysisCache &vmAnalysisCache) { - populateUtilConversionPatterns(context, typeConverter, patterns); + populateUtilConversionPatterns(context, conversionTarget, typeConverter, + patterns); // CFG patterns.insert<BranchOpConversion>(context); @@ -2570,6 +2950,11 @@ IREE::VM::GlobalI32Op>>( context, "vm_global_store_i32"); + patterns.insert<GlobalLoadStoreRefOpConversion<IREE::VM::GlobalLoadRefOp>>( + context, vmAnalysisCache); + patterns.insert<GlobalLoadStoreRefOpConversion<IREE::VM::GlobalStoreRefOp>>( + context, vmAnalysisCache); + // Constants patterns.insert<ConstOpConversion<IREE::VM::ConstI32Op>>(context); patterns.insert<ConstZeroOpConversion<IREE::VM::ConstI32ZeroOp>>(context); @@ -2895,18 +3280,23 @@ vmAnalysisCache.insert(std::make_pair( op, VMAnalysis{RegisterAllocation(op), ValueLiveness(op)})); - if (failed(convertFuncOp(funcOp, vmAnalysisCache))) + if (failed(convertFuncOp(funcOp, vmAnalysisCache))) { return signalPassFailure(); + } funcsToRemove.push_back(funcOp); } - for (auto &funcOp : funcsToRemove) funcOp.erase(); + for (auto &funcOp : funcsToRemove) { + funcOp.erase(); + } // Generate func ops that implement the C API. - if (failed(createAPIFunctions(module))) return signalPassFailure(); + if (failed(createAPIFunctions(module))) { + return signalPassFailure(); + } OwningRewritePatternList patterns(&getContext()); - populateVMToEmitCPatterns(&getContext(), typeConverter, patterns, + populateVMToEmitCPatterns(&getContext(), target, typeConverter, patterns, vmAnalysisCache); target.addLegalDialect<emitc::EmitCDialect, mlir::BuiltinDialect, @@ -2916,26 +3306,35 @@ return typeConverter.isSignatureLegal(op.getType()); }); - target.addDynamicallyLegalOp<IREE::Util::DoNotOptimizeOp>( - [&](IREE::Util::DoNotOptimizeOp op) { - return typeConverter.isLegal(op.getResultTypes()); - }); - // Structural ops target.addLegalOp<IREE::VM::ModuleOp>(); target.addLegalOp<IREE::VM::ModuleTerminatorOp>(); + // These ops are needed to build arrays for the module descriptor. There is + // no way to generate this directly with the EmitC dialect at the moment. target.addLegalOp<IREE::VM::ExportOp>(); target.addLegalOp<IREE::VM::ImportOp>(); // Global ops + // The global ops are dead after the conversion and will get removed. target.addLegalOp<IREE::VM::GlobalI32Op>(); target.addLegalOp<IREE::VM::GlobalI64Op>(); target.addLegalOp<IREE::VM::GlobalF32Op>(); + target.addLegalOp<IREE::VM::GlobalRefOp>(); + + // This op is needed in the printer to emit an array holding the data. target.addLegalOp<IREE::VM::RodataOp>(); if (failed(applyFullConversion(module, target, std::move(patterns)))) { return signalPassFailure(); } + + // Global ops are dead now + module.walk([](Operation *op) { + if (isa<IREE::VM::GlobalI32Op, IREE::VM::GlobalI64Op, + IREE::VM::GlobalF32Op, IREE::VM::GlobalRefOp>(op)) { + op->erase(); + } + }); } };
diff --git a/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.h b/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.h index 477ce93..188eb02 100644 --- a/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.h +++ b/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.h
@@ -33,14 +33,19 @@ } int getRefRegisterOrdinal(Value ref) { - return registerAllocation.mapToRegister(originalValue(ref)).ordinal(); + auto originalRef = originalValue(ref); + assert(originalRef.getType().isa<IREE::VM::RefType>()); + return registerAllocation.mapToRegister(originalRef).ordinal(); } bool isLastValueUse(Value ref, Operation *op) { - return valueLiveness.isLastValueUse(originalValue(ref), op); + auto originalRef = originalValue(ref); + assert(originalRef.getType().isa<IREE::VM::RefType>()); + return valueLiveness.isLastValueUse(originalRef, op); } void remapValue(Value original, Value replacement) { + assert(original.getType().isa<IREE::VM::RefType>()); mapping[replacement] = original; return; }
diff --git a/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/arithmetic_ops.mlir b/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/arithmetic_ops.mlir index 4ff0777..43d11c7 100644 --- a/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/arithmetic_ops.mlir +++ b/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/arithmetic_ops.mlir
@@ -1,4 +1,4 @@ -// RUN: iree-opt -split-input-file -pass-pipeline='vm.module(iree-convert-vm-to-emitc)' %s | IreeFileCheck %s +// RUN: iree-opt -split-input-file -pass-pipeline='vm.module(iree-vm-ordinal-allocation),vm.module(iree-convert-vm-to-emitc)' %s | IreeFileCheck %s // CHECK-LABEL: @my_module_add_i32 vm.module @my_module {
diff --git a/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/arithmetic_ops_f32.mlir b/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/arithmetic_ops_f32.mlir index bc9e4f6..675d239 100644 --- a/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/arithmetic_ops_f32.mlir +++ b/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/arithmetic_ops_f32.mlir
@@ -1,4 +1,4 @@ -// RUN: iree-opt -split-input-file -pass-pipeline='vm.module(iree-convert-vm-to-emitc)' %s | IreeFileCheck %s +// RUN: iree-opt -split-input-file -pass-pipeline='vm.module(iree-vm-ordinal-allocation),vm.module(iree-convert-vm-to-emitc)' %s | IreeFileCheck %s // CHECK-LABEL: @my_module_add_f32 vm.module @my_module {
diff --git a/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/arithmetic_ops_i64.mlir b/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/arithmetic_ops_i64.mlir index 4b8b0aa..f02e4c2 100644 --- a/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/arithmetic_ops_i64.mlir +++ b/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/arithmetic_ops_i64.mlir
@@ -1,4 +1,4 @@ -// RUN: iree-opt -split-input-file -pass-pipeline='vm.module(iree-convert-vm-to-emitc)' %s | IreeFileCheck %s +// RUN: iree-opt -split-input-file -pass-pipeline='vm.module(iree-vm-ordinal-allocation),vm.module(iree-convert-vm-to-emitc)' %s | IreeFileCheck %s // CHECK-LABEL: @my_module_add_i64 vm.module @my_module {
diff --git a/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/assignment_ops.mlir b/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/assignment_ops.mlir index 049b3ec..f363f85 100644 --- a/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/assignment_ops.mlir +++ b/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/assignment_ops.mlir
@@ -1,4 +1,4 @@ -// RUN: iree-opt -split-input-file -pass-pipeline='vm.module(iree-convert-vm-to-emitc)' %s | IreeFileCheck %s +// RUN: iree-opt -split-input-file -pass-pipeline='vm.module(iree-vm-ordinal-allocation),vm.module(iree-convert-vm-to-emitc)' %s | IreeFileCheck %s // CHECK-LABEL: @my_module_select_i32 vm.module @my_module {
diff --git a/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/assignment_ops_f32.mlir b/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/assignment_ops_f32.mlir index 3fc3189..dc9bb80 100644 --- a/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/assignment_ops_f32.mlir +++ b/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/assignment_ops_f32.mlir
@@ -1,4 +1,4 @@ -// RUN: iree-opt -split-input-file -pass-pipeline='vm.module(iree-convert-vm-to-emitc)' %s | IreeFileCheck %s +// RUN: iree-opt -split-input-file -pass-pipeline='vm.module(iree-vm-ordinal-allocation),vm.module(iree-convert-vm-to-emitc)' %s | IreeFileCheck %s // CHECK-LABEL: @my_module_select_f32 vm.module @my_module {
diff --git a/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/assignment_ops_i64.mlir b/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/assignment_ops_i64.mlir index 7ef872f..6ae4688 100644 --- a/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/assignment_ops_i64.mlir +++ b/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/assignment_ops_i64.mlir
@@ -1,4 +1,4 @@ -// RUN: iree-opt -split-input-file -pass-pipeline='vm.module(iree-convert-vm-to-emitc)' %s | IreeFileCheck %s +// RUN: iree-opt -split-input-file -pass-pipeline='vm.module(iree-vm-ordinal-allocation),vm.module(iree-convert-vm-to-emitc)' %s | IreeFileCheck %s // CHECK-LABEL: @my_module_select_i64 vm.module @my_module {
diff --git a/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/comparison_ops.mlir b/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/comparison_ops.mlir index ae98829..d811647 100644 --- a/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/comparison_ops.mlir +++ b/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/comparison_ops.mlir
@@ -1,4 +1,4 @@ -// RUN: iree-opt -split-input-file -pass-pipeline='vm.module(iree-convert-vm-to-emitc)' %s | IreeFileCheck %s +// RUN: iree-opt -split-input-file -pass-pipeline='vm.module(iree-vm-ordinal-allocation),vm.module(iree-convert-vm-to-emitc)' %s | IreeFileCheck %s vm.module @module { // CHECK-LABEL: @module_cmp_eq_i32
diff --git a/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/comparison_ops_f32.mlir b/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/comparison_ops_f32.mlir index dba5173..9d33687 100644 --- a/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/comparison_ops_f32.mlir +++ b/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/comparison_ops_f32.mlir
@@ -1,4 +1,4 @@ -// RUN: iree-opt -split-input-file -pass-pipeline='vm.module(iree-convert-vm-to-emitc)' %s | IreeFileCheck %s +// RUN: iree-opt -split-input-file -pass-pipeline='vm.module(iree-vm-ordinal-allocation),vm.module(iree-convert-vm-to-emitc)' %s | IreeFileCheck %s vm.module @module { // CHECK-LABEL: @module_cmp_eq_f32o
diff --git a/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/comparison_ops_i64.mlir b/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/comparison_ops_i64.mlir index f60b8ae..dfe302b 100644 --- a/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/comparison_ops_i64.mlir +++ b/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/comparison_ops_i64.mlir
@@ -1,4 +1,4 @@ -// RUN: iree-opt -split-input-file -pass-pipeline='vm.module(iree-convert-vm-to-emitc)' %s | IreeFileCheck %s +// RUN: iree-opt -split-input-file -pass-pipeline='vm.module(iree-vm-ordinal-allocation),vm.module(iree-convert-vm-to-emitc)' %s | IreeFileCheck %s vm.module @module { // CHECK-LABEL: @module_cmp_eq_i64
diff --git a/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/const_ops.mlir b/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/const_ops.mlir index a62f614..1dd24c6 100644 --- a/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/const_ops.mlir +++ b/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/const_ops.mlir
@@ -1,4 +1,4 @@ -// RUN: iree-opt -split-input-file -pass-pipeline='vm.module(iree-convert-vm-to-emitc)' %s | IreeFileCheck %s +// RUN: iree-opt -split-input-file -pass-pipeline='vm.module(iree-vm-ordinal-allocation),vm.module(iree-convert-vm-to-emitc)' %s | IreeFileCheck %s vm.module @my_module {
diff --git a/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/const_ops_f32.mlir b/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/const_ops_f32.mlir index eb5e30c..6618a63 100644 --- a/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/const_ops_f32.mlir +++ b/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/const_ops_f32.mlir
@@ -1,4 +1,4 @@ -// RUN: iree-opt -split-input-file -pass-pipeline='vm.module(iree-convert-vm-to-emitc)' %s | IreeFileCheck %s +// RUN: iree-opt -split-input-file -pass-pipeline='vm.module(iree-vm-ordinal-allocation),vm.module(iree-convert-vm-to-emitc)' %s | IreeFileCheck %s vm.module @my_module { // CHECK-LABEL: @my_module_const_f32_zero
diff --git a/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/const_ops_i64.mlir b/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/const_ops_i64.mlir index ca70fcc..3a00e67 100644 --- a/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/const_ops_i64.mlir +++ b/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/const_ops_i64.mlir
@@ -1,4 +1,4 @@ -// RUN: iree-opt -split-input-file -pass-pipeline='vm.module(iree-convert-vm-to-emitc)' %s | IreeFileCheck %s +// RUN: iree-opt -split-input-file -pass-pipeline='vm.module(iree-vm-ordinal-allocation),vm.module(iree-convert-vm-to-emitc)' %s | IreeFileCheck %s vm.module @my_module {
diff --git a/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/conversion_ops.mlir b/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/conversion_ops.mlir index 7d6cbab..913e5ec 100644 --- a/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/conversion_ops.mlir +++ b/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/conversion_ops.mlir
@@ -1,4 +1,4 @@ -// RUN: iree-opt -split-input-file -pass-pipeline='vm.module(iree-convert-vm-to-emitc)' %s | IreeFileCheck %s +// RUN: iree-opt -split-input-file -pass-pipeline='vm.module(iree-vm-ordinal-allocation),vm.module(iree-convert-vm-to-emitc)' %s | IreeFileCheck %s // CHECK-LABEL: @my_module_trunc vm.module @my_module {
diff --git a/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/conversion_ops_f32.mlir b/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/conversion_ops_f32.mlir index dae8de4..a24e783 100644 --- a/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/conversion_ops_f32.mlir +++ b/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/conversion_ops_f32.mlir
@@ -1,4 +1,4 @@ -// RUN: iree-opt -split-input-file -pass-pipeline='vm.module(iree-convert-vm-to-emitc)' %s | IreeFileCheck %s +// RUN: iree-opt -split-input-file -pass-pipeline='vm.module(iree-vm-ordinal-allocation),vm.module(iree-convert-vm-to-emitc)' %s | IreeFileCheck %s // CHECK-LABEL: @my_module_cast vm.module @my_module {
diff --git a/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/conversion_ops_i64.mlir b/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/conversion_ops_i64.mlir index 6371d31..7488e42 100644 --- a/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/conversion_ops_i64.mlir +++ b/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/conversion_ops_i64.mlir
@@ -1,4 +1,4 @@ -// RUN: iree-opt -split-input-file -pass-pipeline='vm.module(iree-convert-vm-to-emitc)' %s | IreeFileCheck %s +// RUN: iree-opt -split-input-file -pass-pipeline='vm.module(iree-vm-ordinal-allocation),vm.module(iree-convert-vm-to-emitc)' %s | IreeFileCheck %s // CHECK-LABEL: @my_module_trunc_i64 vm.module @my_module {
diff --git a/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/shift_ops.mlir b/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/shift_ops.mlir index ebde983..32f3ab8 100644 --- a/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/shift_ops.mlir +++ b/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/shift_ops.mlir
@@ -1,4 +1,4 @@ -// RUN: iree-opt -split-input-file -pass-pipeline='vm.module(iree-convert-vm-to-emitc)' %s | IreeFileCheck %s +// RUN: iree-opt -split-input-file -pass-pipeline='vm.module(iree-vm-ordinal-allocation),vm.module(iree-convert-vm-to-emitc)' %s | IreeFileCheck %s // CHECK-LABEL: @my_module_shl_i32 vm.module @my_module {
diff --git a/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/shift_ops_i64.mlir b/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/shift_ops_i64.mlir index 4a72b13..e85bf77 100644 --- a/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/shift_ops_i64.mlir +++ b/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/shift_ops_i64.mlir
@@ -1,4 +1,4 @@ -// RUN: iree-opt -split-input-file -pass-pipeline='vm.module(iree-convert-vm-to-emitc)' %s | IreeFileCheck %s +// RUN: iree-opt -split-input-file -pass-pipeline='vm.module(iree-vm-ordinal-allocation),vm.module(iree-convert-vm-to-emitc)' %s | IreeFileCheck %s // CHECK-LABEL: @my_module_shl_i64 vm.module @my_module {
diff --git a/iree/compiler/Dialect/VM/Target/C/CModuleTarget.cpp b/iree/compiler/Dialect/VM/Target/C/CModuleTarget.cpp index 7dee7bb..fbf11c9 100644 --- a/iree/compiler/Dialect/VM/Target/C/CModuleTarget.cpp +++ b/iree/compiler/Dialect/VM/Target/C/CModuleTarget.cpp
@@ -42,16 +42,6 @@ << std::string(77, '=') << "\n"; } -static void printSeparatingComment(llvm::raw_ostream &output) { - output << "//" << std::string(77, '=') - << "\n" - "// The code below setups functions and lookup tables to " - "implement the vm\n" - "// interface\n" - "//" - << std::string(77, '=') << "\n"; -} - static LogicalResult printRodataBuffers(IREE::VM::ModuleOp &moduleOp, mlir::emitc::CppEmitter &emitter) { llvm::raw_ostream &output = emitter.ostream();
diff --git a/iree/vm/ops_emitc.h b/iree/vm/ops_emitc.h index 9a2848c..f9fe306 100644 --- a/iree/vm/ops_emitc.h +++ b/iree/vm/ops_emitc.h
@@ -10,9 +10,6 @@ // This file contains utility macros used for things that EmitC can't handle // directly. -// Assign a value to a variable -#define EMITC_ASSIGN_VALUE(var, value) (var) = (value) - // Assign a value through a pointer variable #define EMITC_DEREF_ASSIGN_VALUE(ptr, value) *(ptr) = (value)
diff --git a/iree/vm/test/call_ops.mlir b/iree/vm/test/call_ops.mlir index 8b5add7..9ad8d1b 100644 --- a/iree/vm/test/call_ops.mlir +++ b/iree/vm/test/call_ops.mlir
@@ -1,5 +1,7 @@ vm.module @call_ops { + vm.rodata private @buffer dense<[1, 2, 3]> : tensor<3xi8> + vm.export @fail_call_v_v vm.func @fail_call_v_v() { vm.call @_v_v_fail() : () -> () @@ -20,6 +22,28 @@ vm.return } + // Check that reused ref argument slots are handled properly + vm.export @test_call_r_v_reuse_reg + vm.func @test_call_r_v_reuse_reg() { + %ref = vm.const.ref.zero : !vm.buffer + %unused = vm.const.ref.zero : !vm.buffer + vm.call @_r_v_reuse_reg(%ref, %unused) : (!vm.buffer, !vm.buffer) -> () + vm.return + } + + // Check passing refs as arguments doesn't alter values on the call site + vm.export @test_call_r_v_preserve_ref + vm.func @test_call_r_v_preserve_ref() { + %ref = vm.const.ref.zero : !vm.buffer + %unused = vm.const.ref.rodata @buffer : !vm.buffer + %unusued_dno_1 = util.do_not_optimize(%unused) : !vm.buffer + vm.check.nz %unused : !vm.buffer + vm.call @_r_v_preserve_reg(%ref, %unused) : (!vm.buffer, !vm.buffer) -> () + %unusued_dno_2 = util.do_not_optimize(%unused) : !vm.buffer + vm.check.nz %unusued_dno_2 : !vm.buffer + vm.return + } + vm.export @test_call_v_i vm.func @test_call_v_i() { %c1 = vm.const.i32 1 : i32 @@ -66,12 +90,27 @@ vm.return } + vm.func @_r_v_reuse_reg(%arg : !vm.ref<?>, %unused : !vm.ref<?>) attributes {noinline} { + %ref = vm.const.ref.zero : !vm.ref<?> + %ref_dno = util.do_not_optimize(%ref) : !vm.ref<?> + vm.check.eq %arg, %ref_dno, "Expected %arg to be NULL" : !vm.ref<?> + vm.return + } + + vm.func @_r_v_preserve_reg(%arg1 : !vm.ref<?>, %arg2 : !vm.ref<?>) attributes {noinline} { + %ref = vm.const.ref.zero : !vm.ref<?> + %ref_dno = util.do_not_optimize(%ref) : !vm.ref<?> + vm.check.eq %arg1, %ref_dno, "Expected %arg1 to be NULL" : !vm.ref<?> + vm.check.nz %arg2, "Expected %arg2 to be not NULL" : !vm.ref<?> + vm.return + } + vm.func @_v_i() -> i32 attributes {noinline} { %c1 = vm.const.i32 1 : i32 vm.return %c1 : i32 } - vm.func private @_v_r() -> !vm.ref<?> attributes {noinline} { + vm.func @_v_r() -> !vm.ref<?> attributes {noinline} { %ref = vm.const.ref.zero : !vm.ref<?> vm.return %ref : !vm.ref<?> }
diff --git a/iree/vm/test/global_ops.mlir b/iree/vm/test/global_ops.mlir index b2deb45..9c77c61 100644 --- a/iree/vm/test/global_ops.mlir +++ b/iree/vm/test/global_ops.mlir
@@ -6,8 +6,12 @@ vm.global.i32 private @c42 = 42 : i32 vm.global.i32 private mutable @c107_mut = 107 : i32 + vm.global.ref mutable @g0 : !vm.buffer // TODO(simon-camp): Add test for initializer + vm.rodata private @buffer dense<[1, 2, 3]> : tensor<3xi8> + + // TODO(simon-camp) This test gets constant folded vm.export @test_global_load_i32 vm.func @test_global_load_i32() { %actual = vm.global.load.i32 @c42 : i32 @@ -16,6 +20,15 @@ vm.return } + vm.export @test_global_load_ref + vm.func @test_global_load_ref() { + %actual = vm.global.load.ref @g0 : !vm.buffer + %expected = vm.const.ref.zero : !vm.buffer + %expecteddno = util.do_not_optimize(%expected) : !vm.buffer + vm.check.eq %actual, %expecteddno : !vm.buffer + vm.return + } + vm.export @test_global_store_i32 vm.func @test_global_store_i32() { %c17 = vm.const.i32 17 : i32 @@ -25,4 +38,13 @@ vm.return } + vm.export @test_global_store_ref + vm.func @test_global_store_ref() { + %ref_buffer = vm.const.ref.rodata @buffer : !vm.buffer + vm.global.store.ref %ref_buffer, @g0 : !vm.buffer + %actual = vm.global.load.ref @g0 : !vm.buffer + vm.check.eq %actual, %ref_buffer, "@g0 != buffer" : !vm.buffer + vm.return + } + }