| // Copyright 2020 The IREE Authors |
| // |
| // Licensed under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| |
| #include "iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.h" |
| |
| #include <optional> |
| |
| #include "iree/compiler/Dialect/Util/Conversion/ConversionPatterns.h" |
| #include "iree/compiler/Dialect/Util/IR/UtilDialect.h" |
| #include "iree/compiler/Dialect/Util/IR/UtilOps.h" |
| #include "iree/compiler/Dialect/VM/Conversion/VMToEmitC/EmitCBuilders.h" |
| #include "iree/compiler/Dialect/VM/Conversion/VMToEmitC/VMAnalysis.h" |
| #include "iree/compiler/Dialect/VM/IR/VMOps.h" |
| #include "llvm/ADT/TypeSwitch.h" |
| #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" |
| #include "mlir/Dialect/EmitC/IR/EmitC.h" |
| #include "mlir/IR/Builders.h" |
| #include "mlir/IR/BuiltinDialect.h" |
| #include "mlir/IR/BuiltinOps.h" |
| #include "mlir/IR/IRMapping.h" |
| #include "mlir/IR/MLIRContext.h" |
| #include "mlir/IR/Matchers.h" |
| #include "mlir/Pass/Pass.h" |
| #include "mlir/Transforms/DialectConversion.h" |
| |
| namespace mlir::iree_compiler { |
| |
| namespace { |
| |
| enum { |
| SHIM_ARGUMENT_STACK = 0, |
| SHIM_ARGUMENT_FLAGS, |
| SHIM_ARGUMENT_ARGS_STORAGE, |
| SHIM_ARGUMENT_RETS_STORAGE, |
| SHIM_ARGUMENT_MODULE, |
| SHIM_ARGUMENT_MODULE_STATE, |
| }; |
| |
| enum { |
| CCONV_ARGUMENT_STACK = 0, |
| CCONV_ARGUMENT_MODULE, |
| CCONV_ARGUMENT_MODULE_STATE, |
| }; |
| |
| LogicalResult convertFuncOp(IREE::VM::FuncOp funcOp, |
| const IREE::VM::EmitCTypeConverter &typeConverter) { |
| auto ctx = funcOp.getContext(); |
| auto loc = funcOp.getLoc(); |
| |
| OpBuilder builder(funcOp); |
| |
| auto moduleOp = funcOp.getOperation()->getParentOfType<IREE::VM::ModuleOp>(); |
| |
| FunctionType funcType = funcOp.getFunctionType(); |
| std::string name = moduleOp.getName().str() + "_" + funcOp.getName().str(); |
| std::string moduleTypeName = |
| std::string("struct ") + moduleOp.getName().str() + "_t"; |
| std::string moduleStateTypeName = |
| std::string("struct ") + moduleOp.getName().str() + "_state_t"; |
| |
| Type stackType = |
| emitc::PointerType::get(emitc::OpaqueType::get(ctx, "iree_vm_stack_t")); |
| Type moduleType = |
| emitc::PointerType::get(emitc::OpaqueType::get(ctx, moduleTypeName)); |
| Type moduleStateType = |
| emitc::PointerType::get(emitc::OpaqueType::get(ctx, moduleStateTypeName)); |
| |
| SmallVector<Type, 3> inputTypes = {stackType, moduleType, moduleStateType}; |
| SmallVector<Type, 1> outputTypes; |
| |
| for (auto &inputType : funcType.getInputs()) { |
| inputTypes.push_back(inputType); |
| } |
| |
| for (auto &resultType : funcType.getResults()) { |
| // We pass refs as iree_vm_ref_t* regardless of whether it is an in or out |
| // parameter |
| Type type = typeConverter.convertTypeAsPointer(resultType); |
| |
| inputTypes.push_back(type); |
| outputTypes.push_back(type); |
| } |
| |
| auto newFuncType = mlir::FunctionType::get( |
| ctx, {inputTypes}, {emitc::OpaqueType::get(ctx, "iree_status_t")}); |
| |
| auto newFuncOp = mlir::emitc::FuncOp::create(builder, loc, name, newFuncType); |
| newFuncOp.setSpecifiersAttr( |
| builder.getArrayAttr({builder.getStringAttr("static")})); |
| newFuncOp.setPrivate(); |
| |
| // This call shold be equivalent to rewriter.inlineRegionBefore() |
| newFuncOp.getFunctionBody().getBlocks().splice( |
| newFuncOp.end(), funcOp.getFunctionBody().getBlocks()); |
| |
| Block &entryBlock = newFuncOp.getBlocks().front(); |
| |
| if (!entryBlock.hasNoPredecessors()) { |
| return funcOp.emitError() |
| << "branches to the entry block are not supported for now."; |
| } |
| |
| entryBlock.insertArgument(static_cast<unsigned>(0), stackType, loc); |
| entryBlock.insertArgument(static_cast<unsigned>(1), moduleType, loc); |
| entryBlock.insertArgument(static_cast<unsigned>(2), moduleStateType, loc); |
| |
| SmallVector<Location> locs(outputTypes.size(), loc); |
| entryBlock.addArguments(outputTypes, locs); |
| |
| typeConverter.analysis.move(newFuncOp, funcOp); |
| |
| auto &funcAnalysis = typeConverter.analysis.lookupFunction(newFuncOp); |
| |
| // Add variable ops for local refs |
| const int numRefArgs = funcAnalysis.getNumRefArguments(); |
| const int numLocalRefs = funcAnalysis.getNumLocalRefs(); |
| |
| builder.setInsertionPointToStart(&entryBlock); |
| |
| for (int i = 0; i < numLocalRefs; i++) { |
| auto [ref, refPtr] = emitc_builders::allocZeroInitializedVar( |
| builder, loc, emitc::OpaqueType::get(ctx, "iree_vm_ref_t")); |
| |
| // Cache local refs so that we can release them before a return operation. |
| // Here we rely on the fact that the register allocation maps arguments in |
| // the first slots. |
| funcAnalysis.cacheLocalRef(i + numRefArgs, refPtr); |
| } |
| |
| if (failed( |
| funcOp.replaceAllSymbolUses(builder.getStringAttr(name), moduleOp))) |
| return funcOp.emitError() << "unable to update symbol name in module"; |
| |
| return success(); |
| } |
| |
| std::optional<std::string> buildFunctionName(IREE::VM::ModuleOp &moduleOp, |
| IREE::VM::ImportOp &importOp) { |
| auto callingConvention = makeImportCallingConventionString(importOp); |
| if (!callingConvention.has_value()) { |
| return std::nullopt; |
| } |
| return moduleOp.getName().str() + "_call_" + callingConvention.value() + |
| "_import_shim"; |
| } |
| |
| std::optional<std::string> |
| buildVariadicFunctionName(IREE::VM::ModuleOp &moduleOp, |
| IREE::VM::ImportOp &importOp, |
| DenseIntElementsAttr segmentSizes) { |
| auto callingConvention = makeImportCallingConventionString(importOp); |
| if (!callingConvention.has_value()) { |
| return std::nullopt; |
| } |
| std::string result(moduleOp.getName()); |
| result += "_call_"; |
| result += callingConvention.value(); |
| for (int i = 0; i < importOp.getNumArguments(); i++) { |
| if (importOp.isFuncArgumentVariadic(i)) { |
| APInt size = *(segmentSizes.begin() + i); |
| result += "_"; |
| result += std::to_string(size.getSExtValue()); |
| } |
| } |
| result += "_import_shim"; |
| return result; |
| } |
| |
| std::optional<Value> |
| createVmTypeDefPtr(ConversionPatternRewriter &rewriter, Location loc, |
| const IREE::VM::ModuleAnalysis &moduleAnalysis, |
| IREE::VM::ModuleOp moduleOp, BlockArgument moduleArg, |
| Type elementType) { |
| auto ctx = rewriter.getContext(); |
| |
| // Map from type to enum values of type iree_vm_value_type_t and |
| // iree_vm_ref_type_t |
| mlir::DenseMap<Type, std::pair<std::string, std::string>> valueTypeMap = { |
| {IntegerType::get(ctx, 8), |
| {"IREE_VM_VALUE_TYPE_I8", "IREE_VM_REF_TYPE_NULL"}}, |
| {IntegerType::get(ctx, 16), |
| {"IREE_VM_VALUE_TYPE_I16", "IREE_VM_REF_TYPE_NULL"}}, |
| {IntegerType::get(ctx, 32), |
| {"IREE_VM_VALUE_TYPE_I32", "IREE_VM_REF_TYPE_NULL"}}, |
| {IntegerType::get(ctx, 64), |
| {"IREE_VM_VALUE_TYPE_I64", "IREE_VM_REF_TYPE_NULL"}}, |
| {Float32Type::get(ctx), |
| {"IREE_VM_VALUE_TYPE_F32", "IREE_VM_REF_TYPE_NULL"}}, |
| {Float64Type::get(ctx), |
| {"IREE_VM_VALUE_TYPE_F64", "IREE_VM_REF_TYPE_NULL"}}, |
| }; |
| Value elementTypeValue; |
| |
| auto ptr = valueTypeMap.find((elementType)); |
| if (ptr != valueTypeMap.end()) { |
| elementTypeValue = |
| emitc::CallOpaqueOp::create( |
| rewriter, |
| /*location=*/loc, |
| /*type=*/emitc::OpaqueType::get(ctx, "iree_vm_type_def_t"), |
| /*callee=*/"iree_vm_make_value_type_def", |
| /*operands=*/ArrayRef<Value>{}, |
| /*args=*/ |
| ArrayAttr::get(ctx, |
| {emitc::OpaqueAttr::get(ctx, ptr->second.first)})) |
| .getResult(0); |
| } else if (auto elemRefType = dyn_cast<IREE::VM::RefType>(elementType)) { |
| Type objType = elemRefType.getObjectType(); |
| |
| Type typeRefType = emitc::OpaqueType::get(ctx, "iree_vm_ref_type_t"); |
| Type typeRefArrayType = emitc::PointerType::get(typeRefType); |
| std::optional<size_t> typeIndex = moduleAnalysis.lookupType(objType); |
| if (!typeIndex.has_value()) { |
| moduleOp.emitError("type index lookup failed"); |
| return std::nullopt; |
| } |
| |
| TypedValue<emitc::LValueType> moduleArgLValue = |
| emitc_builders::asLValue(rewriter, loc, moduleArg); |
| |
| TypedValue<emitc::PointerType> typeArray = |
| cast<TypedValue<emitc::PointerType>>(emitc_builders::structPtrMember( |
| rewriter, loc, typeRefArrayType, "types", moduleArgLValue)); |
| Value refType = emitc_builders::arrayElement(rewriter, loc, |
| typeIndex.value(), typeArray); |
| |
| elementTypeValue = |
| emitc::CallOpaqueOp::create( |
| rewriter, |
| /*location=*/loc, |
| /*type=*/emitc::OpaqueType::get(ctx, "iree_vm_type_def_t"), |
| /*callee=*/"iree_vm_make_ref_type_def", |
| /*operands=*/ArrayRef<Value>{refType}) |
| .getResult(0); |
| } else if (isa<IREE::VM::OpaqueType>(elementType)) { |
| elementTypeValue = |
| emitc::CallOpaqueOp::create( |
| rewriter, |
| /*location=*/loc, |
| /*type=*/emitc::OpaqueType::get(ctx, "iree_vm_type_def_t"), |
| /*callee=*/"iree_vm_make_undefined_type_def", |
| /*operands=*/ArrayRef<Value>{}) |
| .getResult(0); |
| } |
| |
| return elementTypeValue; |
| } |
| |
| /// Move multiple refs from one set of variables to another set. As these two |
| /// sets may alias we move the source variables into temporaries first. |
| /// The generated code works as follows: |
| /// `isMove` == true: |
| /// move(src_i, tmp_i); for all i |
| /// move(tmp_i, dest_i); for all i |
| /// `isMove` == false: |
| /// retain(src_i, tmp_i); for all i |
| /// assign(tmp_i, dest_i); for all i |
| LogicalResult retainOrMoveRefs(OpBuilder &builder, Location location, |
| IRMapping mapping, bool isMove) { |
| auto ctx = builder.getContext(); |
| |
| IRMapping tmpMapping; |
| for (auto &[srcRef, destRef] : mapping.getValueMap()) { |
| assert(srcRef.getType() == emitc::PointerType::get(emitc::OpaqueType::get( |
| ctx, "iree_vm_ref_t"))); |
| |
| auto [tmpRef, tmpPtr] = emitc_builders::allocZeroInitializedVar( |
| builder, location, emitc::OpaqueType::get(ctx, "iree_vm_ref_t")); |
| |
| StringRef callee = isMove ? "iree_vm_ref_move" : "iree_vm_ref_retain"; |
| emitc::CallOpaqueOp::create(builder, |
| /*location=*/location, |
| /*type=*/TypeRange{}, |
| /*callee=*/callee, |
| /*operands=*/ArrayRef<Value>{srcRef, tmpPtr}); |
| |
| tmpMapping.map(srcRef, tmpPtr); |
| } |
| |
| for (const auto &[srcRef, destRef] : mapping.getValueMap()) { |
| Value tmpRef = tmpMapping.lookup(srcRef); |
| |
| StringRef callee = isMove ? "iree_vm_ref_move" : "iree_vm_ref_assign"; |
| |
| emitc::CallOpaqueOp::create(builder, |
| /*location=*/location, |
| /*type=*/TypeRange{}, |
| /*callee=*/callee, |
| /*operands=*/ArrayRef<Value>{tmpRef, destRef}); |
| } |
| |
| return success(); |
| } |
| |
| /// Releases refs which are local to the function as well as ref arguments. |
| void releaseRefs(OpBuilder &builder, Location location, |
| mlir::emitc::FuncOp funcOp, |
| IREE::VM::ModuleAnalysis &moduleAnalysis) { |
| auto ctx = builder.getContext(); |
| |
| auto &funcAnalysis = moduleAnalysis.lookupFunction(funcOp); |
| |
| if (funcAnalysis.hasLocalRefs()) { |
| |
| for (auto &[key, localRef] : funcAnalysis.localRefs()) { |
| emitc_builders::ireeVmRefRelease(builder, location, localRef); |
| } |
| } |
| |
| // We only release the original arguments not the results which were appended |
| // as further operands. |
| size_t refArgumentsReleased = 0; |
| for (auto arg : funcOp.getArguments()) { |
| if (arg.getType() == |
| emitc::PointerType::get(emitc::OpaqueType::get(ctx, "iree_vm_ref_t"))) { |
| if (funcAnalysis.getNumRefArguments() <= refArgumentsReleased++) { |
| break; |
| } |
| emitc_builders::ireeVmRefRelease(builder, location, arg); |
| } |
| } |
| } |
| |
| /// Generate an emitc.call_opaque op with one result and split the current block |
| /// into a continuation and failure block based on the truthiness of the result |
| /// value, i.e. a truthy value branches to the continuation block when |
| /// `negateCondition` is false. |
| emitc::CallOpaqueOp failableCall( |
| OpBuilder &builder, Location location, Type type, StringRef callee, |
| ArrayAttr args, ArrayRef<Value> operands, |
| const std::function<void(emitc::CallOpaqueOp &)> &failureBlockBuilder, |
| bool negateCondition = false) { |
| auto callOp = emitc::CallOpaqueOp::create(builder, |
| /*location=*/location, |
| /*type=*/type, |
| /*callee=*/callee, |
| /*operands=*/operands, |
| /*args=*/args); |
| |
| Type boolType = builder.getIntegerType(1); |
| |
| auto conditionI1 = emitc::CastOp::create(builder, |
| /*location=*/location, |
| /*type=*/boolType, |
| /*operand=*/callOp.getResult(0)); |
| |
| // Start by splitting the block into two. The part before will contain the |
| // condition, and the part after will contain the continuation point. |
| Block *condBlock = builder.getInsertionBlock(); |
| Block::iterator opPosition = builder.getInsertionPoint(); |
| Block *continuationBlock = condBlock->splitBlock(opPosition); |
| |
| // Create a new block for the target of the failure. |
| Block *failureBlock; |
| { |
| OpBuilder::InsertionGuard guard(builder); |
| Region *parentRegion = condBlock->getParent(); |
| failureBlock = builder.createBlock(parentRegion, parentRegion->end()); |
| |
| failureBlockBuilder(callOp); |
| } |
| |
| builder.setInsertionPointToEnd(condBlock); |
| mlir::cf::CondBranchOp::create( |
| builder, location, conditionI1.getResult(), |
| negateCondition ? failureBlock : continuationBlock, |
| negateCondition ? continuationBlock : failureBlock); |
| |
| builder.setInsertionPointToStart(continuationBlock); |
| |
| return callOp; |
| } |
| |
| emitc::CallOpaqueOp returnIfError(OpBuilder &builder, Location location, |
| StringRef callee, ArrayAttr args, |
| ArrayRef<Value> operands, |
| IREE::VM::ModuleAnalysis &moduleAnalysis) { |
| auto blockBuilder = [&builder, &location, |
| &moduleAnalysis](emitc::CallOpaqueOp &callOp) { |
| Block *block = builder.getBlock(); |
| mlir::emitc::FuncOp funcOp = |
| cast<mlir::emitc::FuncOp>(block->getParentOp()); |
| |
| releaseRefs(builder, location, funcOp, moduleAnalysis); |
| |
| mlir::emitc::ReturnOp::create(builder, location, callOp.getResult(0)); |
| }; |
| |
| auto ctx = builder.getContext(); |
| Type type = emitc::OpaqueType::get(ctx, "iree_status_t"); |
| return failableCall(builder, location, type, callee, args, operands, |
| blockBuilder, /*negateCondition=*/true); |
| } |
| |
| emitc::CallOpaqueOp |
| failContainerNull(OpBuilder &builder, Location location, Type type, |
| StringRef callee, ArrayAttr args, ArrayRef<Value> operands, |
| IREE::VM::ModuleAnalysis &moduleAnalysis) { |
| auto blockBuilder = [&builder, &location, |
| &moduleAnalysis](emitc::CallOpaqueOp &callOp) { |
| auto ctx = builder.getContext(); |
| |
| Block *block = builder.getBlock(); |
| mlir::emitc::FuncOp funcOp = |
| cast<mlir::emitc::FuncOp>(block->getParentOp()); |
| |
| releaseRefs(builder, location, funcOp, moduleAnalysis); |
| |
| auto statusOp = emitc::CallOpaqueOp::create( |
| builder, |
| /*location=*/location, |
| /*type=*/emitc::OpaqueType::get(ctx, "iree_status_t"), |
| /*callee=*/"iree_make_status", |
| /*operands=*/ArrayRef<Value>{}, |
| /*args=*/ |
| ArrayAttr::get(ctx, {emitc::OpaqueAttr::get( |
| ctx, "IREE_STATUS_INVALID_ARGUMENT")})); |
| |
| mlir::emitc::ReturnOp::create(builder, location, statusOp.getResult(0)); |
| }; |
| |
| return failableCall(builder, location, type, callee, args, operands, |
| blockBuilder); |
| } |
| |
| /// Generate a emitc.call op with one result and split the current block into a |
| /// continuation and failure block based on the truthiness of the result |
| /// value, i.e. a truthy value branches to the continuation block when |
| /// `negateCondition` is false. |
| mlir::emitc::CallOp failableCall( |
| OpBuilder &builder, Location location, mlir::emitc::FuncOp &callee, |
| ArrayRef<Value> operands, |
| const std::function<void(mlir::emitc::CallOp &)> &failureBlockBuilder, |
| bool negateCondition = false) { |
| auto callOp = mlir::emitc::CallOp::create(builder, |
| /*location=*/location, |
| /*callee=*/callee, |
| /*operands=*/operands); |
| |
| Type boolType = builder.getIntegerType(1); |
| |
| auto conditionI1 = emitc::CastOp::create(builder, |
| /*location=*/location, |
| /*type=*/boolType, |
| /*operand=*/callOp.getResult(0)); |
| |
| // Start by splitting the block into two. The part before will contain the |
| // condition, and the part after will contain the continuation point. |
| Block *condBlock = builder.getInsertionBlock(); |
| Block::iterator opPosition = builder.getInsertionPoint(); |
| Block *continuationBlock = condBlock->splitBlock(opPosition); |
| |
| // Create a new block for the target of the failure. |
| Block *failureBlock; |
| { |
| OpBuilder::InsertionGuard guard(builder); |
| Region *parentRegion = condBlock->getParent(); |
| failureBlock = builder.createBlock(parentRegion, parentRegion->end()); |
| |
| failureBlockBuilder(callOp); |
| } |
| |
| builder.setInsertionPointToEnd(condBlock); |
| mlir::cf::CondBranchOp::create( |
| builder, location, conditionI1.getResult(), |
| negateCondition ? failureBlock : continuationBlock, |
| negateCondition ? continuationBlock : failureBlock); |
| |
| builder.setInsertionPointToStart(continuationBlock); |
| |
| return callOp; |
| } |
| |
| mlir::emitc::CallOp returnIfError(OpBuilder &builder, Location location, |
| mlir::emitc::FuncOp &callee, |
| ArrayRef<Value> operands, |
| IREE::VM::ModuleAnalysis &moduleAnalysis) { |
| auto blockBuilder = [&builder, &location, |
| &moduleAnalysis](mlir::emitc::CallOp &callOp) { |
| Block *block = builder.getBlock(); |
| mlir::emitc::FuncOp funcOp = |
| cast<mlir::emitc::FuncOp>(block->getParentOp()); |
| |
| releaseRefs(builder, location, funcOp, moduleAnalysis); |
| |
| mlir::emitc::ReturnOp::create(builder, location, callOp.getResult(0)); |
| }; |
| |
| return failableCall(builder, location, callee, operands, blockBuilder, |
| /*negateCondition=*/true); |
| } |
| |
| LogicalResult createAPIFunctions(IREE::VM::ModuleOp moduleOp, |
| IREE::VM::ModuleAnalysis &moduleAnalysis) { |
| auto ctx = moduleOp.getContext(); |
| auto loc = moduleOp.getLoc(); |
| |
| OpBuilder builder(moduleOp); |
| builder.setInsertionPoint(moduleOp.getBlock().getTerminator()); |
| |
| std::string moduleName{moduleOp.getName()}; |
| |
| // void destroy(void*) |
| { |
| OpBuilder::InsertionGuard guard(builder); |
| |
| const int moduleArgIndex = 0; |
| |
| auto funcType = mlir::FunctionType::get( |
| ctx, {emitc::PointerType::get(emitc::OpaqueType::get(ctx, "void"))}, |
| {}); |
| |
| auto funcOp = mlir::emitc::FuncOp::create( |
| builder, loc, moduleName + "_destroy", funcType); |
| funcOp.setSpecifiersAttr( |
| builder.getArrayAttr({builder.getStringAttr("static")})); |
| funcOp.setPrivate(); |
| |
| moduleAnalysis.addDummy(funcOp, /*emitAtEnd=*/false); |
| |
| Block *entryBlock = funcOp.addEntryBlock(); |
| const BlockArgument moduleArg = funcOp.getArgument(moduleArgIndex); |
| |
| builder.setInsertionPointToStart(entryBlock); |
| |
| std::string moduleTypeName = std::string("struct ") + moduleName + "_t"; |
| |
| auto castedModuleOp = emitc::CastOp::create( |
| builder, |
| /*location=*/loc, |
| /*type=*/ |
| emitc::PointerType::get(emitc::OpaqueType::get(ctx, moduleTypeName)), |
| /*operand=*/moduleArg); |
| |
| auto castedModuleOpLValue = |
| emitc_builders::asLValue(builder, loc, castedModuleOp.getResult()); |
| |
| auto allocatorOp = emitc_builders::structPtrMember( |
| builder, loc, |
| /*type=*/emitc::OpaqueType::get(ctx, "iree_allocator_t"), |
| /*memberName=*/"allocator", |
| /*operand=*/castedModuleOpLValue); |
| |
| emitc::CallOpaqueOp::create( |
| builder, |
| /*location=*/loc, |
| /*type=*/TypeRange{}, |
| /*callee=*/"iree_allocator_free", |
| /*operands=*/ |
| ArrayRef<Value>{allocatorOp, castedModuleOp.getResult()}); |
| |
| mlir::emitc::ReturnOp::create(builder, loc, nullptr); |
| } |
| |
| // iree_status_t alloc_state(void*, iree_allocator_t, |
| // iree_vm_module_state_t**) |
| { |
| OpBuilder::InsertionGuard guard(builder); |
| |
| const int allocatorArgIndex = 1; |
| const int moduleStateArgIndex = 2; |
| |
| auto funcType = mlir::FunctionType::get( |
| ctx, |
| {emitc::PointerType::get(emitc::OpaqueType::get(ctx, "void")), |
| emitc::OpaqueType::get(ctx, "iree_allocator_t"), |
| emitc::PointerType::get(emitc::PointerType::get( |
| emitc::OpaqueType::get(ctx, "iree_vm_module_state_t")))}, |
| {emitc::OpaqueType::get(ctx, "iree_status_t")}); |
| |
| auto funcOp = mlir::emitc::FuncOp::create( |
| builder, loc, moduleName + "_alloc_state", funcType); |
| funcOp.setSpecifiersAttr( |
| builder.getArrayAttr({builder.getStringAttr("static")})); |
| funcOp.setPrivate(); |
| |
| moduleAnalysis.addDummy(funcOp, /*emitAtEnd=*/false); |
| |
| Block *entryBlock = funcOp.addEntryBlock(); |
| |
| const BlockArgument allocatorArg = funcOp.getArgument(allocatorArgIndex); |
| const BlockArgument moduleStateArg = |
| funcOp.getArgument(moduleStateArgIndex); |
| |
| builder.setInsertionPointToStart(entryBlock); |
| |
| std::string moduleStateTypeName = |
| std::string("struct ") + moduleName + "_state_t"; |
| |
| auto state = emitc_builders::allocateVariable( |
| builder, loc, |
| emitc::PointerType::get( |
| emitc::OpaqueType::get(ctx, moduleStateTypeName)), |
| {"NULL"}); |
| |
| Value stateSize = emitc_builders::sizeOf( |
| builder, loc, emitc::OpaqueAttr::get(ctx, moduleStateTypeName)); |
| |
| Value statePtr = emitc_builders::addressOf(builder, loc, state); |
| |
| auto voidPtr = |
| emitc::CastOp::create(builder, |
| /*location=*/loc, |
| /*type=*/ |
| emitc::PointerType::get(emitc::PointerType::get( |
| emitc::OpaqueType::get(ctx, "void"))), |
| /*operand=*/statePtr); |
| |
| returnIfError(builder, loc, "iree_allocator_malloc", {}, |
| {allocatorArg, stateSize, voidPtr.getResult()}, |
| moduleAnalysis); |
| |
| auto stateRValue = emitc_builders::asRValue(builder, loc, state); |
| emitc_builders::memset(builder, loc, stateRValue, 0, stateSize); |
| |
| emitc_builders::structPtrMemberAssign(builder, loc, |
| /*memberName=*/"allocator", |
| /*operand=*/state, |
| /*value=*/allocatorArg); |
| |
| // Initialize buffers |
| for (auto rodataOp : moduleOp.getOps<IREE::VM::RodataOp>()) { |
| auto ordinal = rodataOp.getOrdinal()->getZExtValue(); |
| |
| std::string bufferName = moduleName + "_" + rodataOp.getName().str(); |
| |
| Type type = |
| emitc::PointerType::get(emitc::OpaqueType::get(ctx, "const uint8_t")); |
| |
| auto rodataPointer = |
| emitc_builders::allocateVariable(builder, loc, type, {bufferName}); |
| auto rodataPointerRValue = |
| emitc_builders::asRValue(builder, loc, rodataPointer); |
| |
| auto bufferVoid = emitc::CastOp::create( |
| builder, |
| /*location=*/loc, |
| /*type=*/emitc::PointerType::get(emitc::OpaqueType::get(ctx, "void")), |
| /*operand=*/rodataPointerRValue); |
| |
| Value bufferSize = emitc_builders::sizeOf( |
| builder, loc, emitc::OpaqueAttr::get(ctx, bufferName)); |
| |
| auto byteSpan = emitc::CallOpaqueOp::create( |
| builder, |
| /*location=*/loc, |
| /*type=*/emitc::OpaqueType::get(ctx, "iree_byte_span_t"), |
| /*callee=*/"iree_make_byte_span", |
| /*operands=*/ArrayRef<Value>{bufferVoid.getResult(), bufferSize}); |
| |
| auto allocator = emitc::CallOpaqueOp::create( |
| builder, |
| /*location=*/loc, |
| /*type=*/emitc::OpaqueType::get(ctx, "iree_allocator_t"), |
| /*callee=*/"iree_allocator_null", |
| /*operands=*/ArrayRef<Value>{}); |
| |
| auto buffers = |
| cast<TypedValue<emitc::PointerType>>(emitc_builders::structPtrMember( |
| builder, loc, /*type=*/ |
| emitc::PointerType::get( |
| emitc::OpaqueType::get(ctx, "iree_vm_buffer_t")), |
| /*memberName=*/"rodata_buffers", /*operand=*/state)); |
| |
| auto buffer = emitc_builders::arrayElementAddress( |
| builder, loc, /*index=*/ordinal, /*operand=*/buffers); |
| |
| emitc::CallOpaqueOp::create( |
| builder, |
| /*location=*/loc, |
| /*type=*/TypeRange{}, |
| /*callee=*/"iree_vm_buffer_initialize", |
| /*operands=*/ |
| ArrayRef<Value>{byteSpan.getResult(0), allocator.getResult(0), |
| buffer}, |
| /*args=*/ |
| ArrayAttr::get(ctx, {emitc::OpaqueAttr::get( |
| ctx, "IREE_VM_BUFFER_ACCESS_ORIGIN_MODULE"), |
| builder.getIndexAttr(0), builder.getIndexAttr(1), |
| builder.getIndexAttr(2)})); |
| } |
| |
| auto baseStateOp = |
| emitc::CastOp::create(builder, |
| /*location=*/loc, |
| /*type=*/ |
| emitc::PointerType::get(emitc::OpaqueType::get( |
| ctx, "iree_vm_module_state_t")), |
| /*operand=*/stateRValue); |
| |
| emitc::CallOpaqueOp::create( |
| builder, |
| /*location=*/loc, |
| /*type=*/TypeRange{}, |
| /*callee=*/"EMITC_DEREF_ASSIGN_VALUE", |
| /*operands=*/ArrayRef<Value>{moduleStateArg, baseStateOp.getResult()}); |
| |
| auto status = emitc_builders::ireeOkStatus(builder, loc); |
| |
| mlir::emitc::ReturnOp::create(builder, loc, status); |
| } |
| |
| // void free_state(void*, iree_vm_module_state_t*) |
| { |
| OpBuilder::InsertionGuard guard(builder); |
| |
| const int moduleStateArgIndex = 1; |
| |
| auto funcType = mlir::FunctionType::get( |
| ctx, |
| {emitc::PointerType::get(emitc::OpaqueType::get(ctx, "void")), |
| emitc::PointerType::get( |
| emitc::OpaqueType::get(ctx, "iree_vm_module_state_t"))}, |
| {}); |
| |
| auto funcOp = mlir::emitc::FuncOp::create( |
| builder, loc, moduleName + "_free_state", funcType); |
| funcOp.setSpecifiersAttr( |
| builder.getArrayAttr({builder.getStringAttr("static")})); |
| funcOp.setPrivate(); |
| |
| moduleAnalysis.addDummy(funcOp, /*emitAtEnd=*/false); |
| |
| Block *entryBlock = funcOp.addEntryBlock(); |
| |
| const BlockArgument moduleStateArg = |
| funcOp.getArgument(moduleStateArgIndex); |
| |
| builder.setInsertionPointToStart(entryBlock); |
| |
| std::string moduleStateTypeName = |
| std::string("struct ") + moduleName + "_state_t"; |
| |
| auto stateOp = |
| emitc::CastOp::create(builder, |
| /*location=*/loc, |
| /*type=*/ |
| emitc::PointerType::get(emitc::OpaqueType::get( |
| ctx, moduleStateTypeName)), |
| /*operand=*/moduleStateArg); |
| |
| auto stateOpLValue = |
| emitc_builders::asLValue(builder, loc, stateOp.getResult()); |
| |
| // Release refs from state struct. |
| auto ordinalCounts = moduleOp.getOrdinalCountsAttr(); |
| if (!ordinalCounts) { |
| return moduleOp.emitError() |
| << "ordinal_counts attribute not found. The OrdinalAllocationPass " |
| "must be run before."; |
| } |
| const int numGlobalRefs = ordinalCounts.getGlobalRefs(); |
| |
| if (numGlobalRefs > 0) { |
| auto refs = |
| cast<TypedValue<emitc::PointerType>>(emitc_builders::structPtrMember( |
| builder, loc, |
| /*type=*/ |
| emitc::PointerType::get( |
| emitc::OpaqueType::get(ctx, "iree_vm_ref_t")), |
| /*memberName=*/"refs", /*operand=*/stateOpLValue)); |
| |
| for (int i = 0; i < numGlobalRefs; i++) { |
| auto refPtrOp = emitc_builders::arrayElementAddress( |
| builder, loc, /*index=*/i, /*operand=*/refs); |
| |
| emitc_builders::ireeVmRefRelease(builder, loc, refPtrOp); |
| } |
| } |
| |
| auto allocatorOp = emitc_builders::structPtrMember( |
| builder, loc, |
| /*type=*/emitc::OpaqueType::get(ctx, "iree_allocator_t"), |
| /*memberName=*/"allocator", |
| /*operand=*/stateOpLValue); |
| |
| emitc::CallOpaqueOp::create( |
| builder, |
| /*location=*/loc, |
| /*type=*/TypeRange{}, |
| /*callee=*/"iree_allocator_free", |
| /*operands=*/ArrayRef<Value>{allocatorOp, stateOp.getResult()}); |
| |
| mlir::emitc::ReturnOp::create(builder, loc, nullptr); |
| } |
| |
| // iree_status_t fork_state( |
| // void* self, |
| // iree_vm_module_state_t* parent_state, |
| // iree_allocator_t allocator, |
| // iree_vm_module_state_t** out_child_state |
| // ) |
| { |
| OpBuilder::InsertionGuard guard(builder); |
| |
| // const int parentStateArgIndex = 1; |
| // const int allocatorArgIndex = 2; |
| // const int childStateArgIndex = 3; |
| |
| auto funcType = mlir::FunctionType::get( |
| ctx, |
| {emitc::PointerType::get(emitc::OpaqueType::get(ctx, "void")), |
| emitc::PointerType::get( |
| emitc::OpaqueType::get(ctx, "iree_vm_module_state_t")), |
| emitc::OpaqueType::get(ctx, "iree_allocator_t"), |
| emitc::PointerType::get(emitc::PointerType::get( |
| emitc::OpaqueType::get(ctx, "iree_vm_module_state_t")))}, |
| {emitc::OpaqueType::get(ctx, "iree_status_t")}); |
| |
| auto funcOp = mlir::emitc::FuncOp::create( |
| builder, loc, moduleName + "_fork_state", funcType); |
| funcOp.setSpecifiersAttr( |
| builder.getArrayAttr({builder.getStringAttr("static")})); |
| funcOp.setPrivate(); |
| |
| moduleAnalysis.addDummy(funcOp, /*emitAtEnd=*/false); |
| |
| Block *entryBlock = funcOp.addEntryBlock(); |
| |
| builder.setInsertionPointToStart(entryBlock); |
| |
| // TODO: someone will need to do what the bytecode module does in order to |
| // support forking. For now we don't support forking emitc contexts. |
| auto statusOp = emitc::CallOpaqueOp::create( |
| builder, |
| /*location=*/loc, |
| /*type=*/emitc::OpaqueType::get(ctx, "iree_status_t"), |
| /*callee=*/"iree_make_status", |
| /*operands=*/ArrayRef<Value>{}, |
| /*args=*/ |
| ArrayAttr::get( |
| ctx, {emitc::OpaqueAttr::get(ctx, "IREE_STATUS_UNIMPLEMENTED")})); |
| |
| mlir::emitc::ReturnOp::create(builder, loc, statusOp.getResult(0)); |
| } |
| |
| // iree_status_t resolve_import( |
| // void*, |
| // iree_vm_module_state_t*, |
| // iree_host_size_t, |
| // const iree_vm_function_t*, |
| // const iree_vm_function_signature_t* |
| // ) |
| { |
| OpBuilder::InsertionGuard guard(builder); |
| |
| const int moduleStateArgIndex = 1; |
| const int ordinalArgIndex = 2; |
| const int functionArgIndex = 3; |
| |
| auto funcType = mlir::FunctionType::get( |
| ctx, |
| { |
| emitc::PointerType::get(emitc::OpaqueType::get(ctx, "void")), |
| emitc::PointerType::get( |
| emitc::OpaqueType::get(ctx, "iree_vm_module_state_t")), |
| emitc::OpaqueType::get(ctx, "iree_host_size_t"), |
| emitc::PointerType::get( |
| emitc::OpaqueType::get(ctx, "const iree_vm_function_t")), |
| emitc::PointerType::get(emitc::OpaqueType::get( |
| ctx, "const iree_vm_function_signature_t")), |
| }, |
| {emitc::OpaqueType::get(ctx, "iree_status_t")}); |
| |
| auto funcOp = mlir::emitc::FuncOp::create( |
| builder, loc, moduleName + "_resolve_import", funcType); |
| funcOp.setSpecifiersAttr( |
| builder.getArrayAttr({builder.getStringAttr("static")})); |
| funcOp.setPrivate(); |
| |
| moduleAnalysis.addDummy(funcOp, /*emitAtEnd=*/false); |
| |
| Block *entryBlock = funcOp.addEntryBlock(); |
| |
| const BlockArgument moduleStateArg = |
| funcOp.getArgument(moduleStateArgIndex); |
| const BlockArgument ordinalArg = funcOp.getArgument(ordinalArgIndex); |
| const BlockArgument functionArg = funcOp.getArgument(functionArgIndex); |
| |
| builder.setInsertionPointToStart(entryBlock); |
| |
| std::string moduleStateTypeName = |
| std::string("struct ") + moduleName + "_state_t"; |
| |
| auto stateOp = |
| emitc::CastOp::create(builder, |
| /*location=*/loc, |
| /*type=*/ |
| emitc::PointerType::get(emitc::OpaqueType::get( |
| ctx, moduleStateTypeName)), |
| /*operand=*/moduleStateArg); |
| |
| auto stateOpLValue = |
| emitc_builders::asLValue(builder, loc, stateOp.getResult()); |
| |
| auto imports = |
| cast<TypedValue<emitc::PointerType>>(emitc_builders::structPtrMember( |
| builder, loc, |
| /*type=*/ |
| emitc::PointerType::get( |
| emitc::OpaqueType::get(ctx, "iree_vm_function_t")), |
| /*memberName=*/"imports", |
| /*operand=*/stateOpLValue)); |
| |
| auto import = emitc_builders::arrayElementAddress( |
| builder, loc, /*index=*/ordinalArg, /*operand=*/imports); |
| |
| emitc::CallOpaqueOp::create( |
| builder, |
| /*location=*/loc, |
| /*type=*/TypeRange{}, |
| /*callee=*/"EMITC_DEREF_ASSIGN_PTR", |
| /*operands=*/ArrayRef<Value>{import, functionArg}); |
| |
| auto status = emitc_builders::ireeOkStatus(builder, loc); |
| |
| mlir::emitc::ReturnOp::create(builder, loc, status); |
| } |
| |
| // iree_status_t create( |
| // iree_vm_instance_t*, |
| // iree_allocator_t, |
| // iree_vm_module_t** |
| // ); |
| { |
| OpBuilder::InsertionGuard guard(builder); |
| |
| const int instanceArgIndex = 0; |
| const int allocatorArgIndex = 1; |
| const int moduleArgIndex = 2; |
| |
| auto funcType = mlir::FunctionType::get( |
| ctx, |
| { |
| emitc::PointerType::get( |
| emitc::OpaqueType::get(ctx, "iree_vm_instance_t")), |
| emitc::OpaqueType::get(ctx, "iree_allocator_t"), |
| emitc::PointerType::get(emitc::PointerType::get( |
| emitc::OpaqueType::get(ctx, "iree_vm_module_t"))), |
| }, |
| { |
| emitc::OpaqueType::get(ctx, "iree_status_t"), |
| }); |
| |
| auto funcOp = mlir::emitc::FuncOp::create(builder, loc, |
| moduleName + "_create", funcType); |
| funcOp.setPublic(); |
| |
| // This function needs an iree_vm_native_module_descriptor_t that is emitted |
| // by the CModuleTarget at the moment. So we add a marker to this function |
| // and delay the printing of it. |
| moduleAnalysis.addDummy(funcOp, /*emitAtEnd=*/true); |
| |
| Block *entryBlock = funcOp.addEntryBlock(); |
| |
| const BlockArgument instanceArg = funcOp.getArgument(instanceArgIndex); |
| const BlockArgument allocatorArg = funcOp.getArgument(allocatorArgIndex); |
| const BlockArgument moduleArg = funcOp.getArgument(moduleArgIndex); |
| |
| builder.setInsertionPointToStart(entryBlock); |
| |
| std::string moduleTypeName = std::string("struct ") + moduleName + "_t"; |
| |
| auto module = emitc_builders::allocateVariable( |
| builder, loc, |
| emitc::PointerType::get(emitc::OpaqueType::get(ctx, moduleTypeName)), |
| {"NULL"}); |
| |
| Value moduleSize = emitc_builders::sizeOf( |
| builder, loc, emitc::OpaqueAttr::get(ctx, moduleTypeName)); |
| |
| Value modulePtr = emitc_builders::addressOf(builder, loc, module); |
| |
| auto voidPtr = |
| emitc::CastOp::create(builder, |
| /*location=*/loc, |
| /*type=*/ |
| emitc::PointerType::get(emitc::PointerType::get( |
| emitc::OpaqueType::get(ctx, "void"))), |
| /*operand=*/modulePtr); |
| |
| returnIfError(builder, loc, "iree_allocator_malloc", {}, |
| {allocatorArg, moduleSize, voidPtr.getResult()}, |
| moduleAnalysis); |
| |
| auto moduleRValue = emitc_builders::asRValue(builder, loc, module); |
| emitc_builders::memset(builder, loc, moduleRValue, 0, moduleSize); |
| |
| emitc_builders::structPtrMemberAssign(builder, loc, |
| /*memberName=*/"allocator", |
| /*operand=*/module, |
| /*value=*/allocatorArg); |
| |
| auto &typeTable = moduleAnalysis.typeTable; |
| if (!typeTable.empty()) { |
| Type typeRefType = emitc::OpaqueType::get(ctx, "iree_vm_ref_type_t"); |
| Type typeRefArrayType = emitc::PointerType::get(typeRefType); |
| auto moduleTypes = |
| cast<TypedValue<emitc::PointerType>>(emitc_builders::structPtrMember( |
| builder, loc, typeRefArrayType, "types", module)); |
| |
| std::string listType = "!vm.list"; |
| for (auto [index, typeDef] : llvm::enumerate(typeTable)) { |
| std::string typeName = typeDef.full_name; |
| std::string listPrefix = typeName.substr(0, listType.size()); |
| if (listType == listPrefix) { |
| typeName = listPrefix; |
| } |
| |
| // Remove leading '!' and wrap in quotes |
| if (typeName[0] == '!') { |
| typeName = typeName.substr(1); |
| } |
| |
| Value stringView = |
| emitc_builders::ireeMakeCstringView(builder, loc, typeName); |
| Value refType = emitc_builders::ireeVmInstanceLookupType( |
| builder, loc, instanceArg, stringView); |
| emitc_builders::arrayElementAssign(builder, loc, moduleTypes, index, |
| refType); |
| } |
| } |
| |
| auto vmModule = emitc_builders::allocateVariable( |
| builder, loc, emitc::OpaqueType::get(ctx, "iree_vm_module_t")); |
| |
| Value vmModulePtr = emitc_builders::addressOf(builder, loc, vmModule); |
| |
| auto vmInitializeStatus = emitc::CallOpaqueOp::create( |
| builder, |
| /*location=*/loc, |
| /*type=*/emitc::OpaqueType::get(ctx, "iree_status_t"), |
| /*callee=*/"iree_vm_module_initialize", |
| /*operands=*/ |
| ArrayRef<Value>{vmModulePtr, |
| emitc_builders::asRValue(builder, loc, module)}); |
| |
| Type boolType = builder.getIntegerType(1); |
| |
| auto vmInitializeIsOk = emitc::CallOpaqueOp::create( |
| builder, |
| /*location=*/loc, |
| /*type=*/boolType, |
| /*callee=*/"iree_status_is_ok", |
| /*operands=*/ArrayRef<Value>{vmInitializeStatus.getResult(0)}); |
| |
| // Start by splitting the block into two. The part before will contain the |
| // condition, and the part after will contain the continuation point. |
| Block *condBlock = builder.getInsertionBlock(); |
| Block::iterator opPosition = builder.getInsertionPoint(); |
| Block *continuationBlock = condBlock->splitBlock(opPosition); |
| |
| // Create a new block for the target of the failure. |
| Block *failureBlock; |
| { |
| OpBuilder::InsertionGuard guard(builder); |
| Region *parentRegion = condBlock->getParent(); |
| failureBlock = builder.createBlock(parentRegion, parentRegion->end()); |
| |
| emitc::CallOpaqueOp::create( |
| builder, |
| /*location=*/loc, |
| /*type=*/TypeRange{}, |
| /*callee=*/"iree_allocator_free", |
| /*operands=*/ |
| ArrayRef<Value>{allocatorArg, |
| emitc_builders::asRValue(builder, loc, module)}); |
| |
| mlir::emitc::ReturnOp::create(builder, loc, |
| vmInitializeStatus.getResult(0)); |
| } |
| |
| builder.setInsertionPointToEnd(condBlock); |
| |
| mlir::cf::CondBranchOp::create(builder, loc, vmInitializeIsOk.getResult(0), |
| continuationBlock, failureBlock); |
| |
| builder.setInsertionPointToStart(continuationBlock); |
| |
| // Set function pointers |
| for (std::string funcName : {"destroy", "alloc_state", "free_state", |
| "fork_state", "resolve_import"}) { |
| // The type doesn't matter, the result gets inlined into it's uses anyway. |
| Type type = emitc::PointerType::get(emitc::OpaqueType::get(ctx, "void")); |
| |
| Value funcPtr = emitc::LiteralOp::create(builder, loc, type, |
| moduleName + "_" + funcName); |
| emitc_builders::structMemberAssign(builder, loc, |
| /*memberName=*/funcName, |
| /*operand=*/vmModule, |
| /*value=*/funcPtr); |
| } |
| |
| std::string descriptorPtr = "&" + moduleName + "_descriptor_"; |
| |
| auto status = emitc::CallOpaqueOp::create( |
| builder, |
| /*location=*/loc, |
| /*type=*/emitc::OpaqueType::get(ctx, "iree_status_t"), |
| /*callee=*/"iree_vm_native_module_create", |
| /*operands=*/ |
| ArrayRef<Value>{vmModulePtr, instanceArg, allocatorArg, moduleArg}, |
| /*args=*/ |
| ArrayAttr::get(ctx, {builder.getIndexAttr(0), |
| emitc::OpaqueAttr::get(ctx, descriptorPtr), |
| builder.getIndexAttr(1), builder.getIndexAttr(2), |
| builder.getIndexAttr(3)})); |
| |
| mlir::emitc::ReturnOp::create(builder, loc, status.getResult(0)); |
| } |
| |
| return success(); |
| } |
| |
| /// Generate boilerplate code like includes for the IREE C API, include guards, |
| /// structures to hold the module state, functions and global variables to |
| /// create a module instance etc. |
| LogicalResult |
| createModuleStructure(IREE::VM::ModuleOp moduleOp, |
| IREE::VM::EmitCTypeConverter &typeConverter) { |
| if (failed(createAPIFunctions(moduleOp, typeConverter.analysis))) { |
| return failure(); |
| } |
| |
| auto loc = moduleOp.getLoc(); |
| |
| OpBuilder builder(moduleOp); |
| |
| SmallVector<Operation *> opsToRemove; |
| { |
| OpBuilder::InsertionGuard guard(builder); |
| builder.setInsertionPointToStart(&moduleOp.getBlock()); |
| |
| std::string includeGuard = moduleOp.getName().upper() + "_H_"; |
| |
| emitc_builders::preprocessorDirective(builder, loc, emitc_builders::IFNDEF, |
| includeGuard); |
| |
| emitc_builders::preprocessorDirective(builder, loc, emitc_builders::DEFINE, |
| includeGuard); |
| |
| emitc::IncludeOp::create(builder, loc, "iree/vm/api.h"); |
| |
| emitc_builders::preprocessorDirective(builder, loc, emitc_builders::IFDEF, |
| "__cplusplus"); |
| emitc::VerbatimOp::create(builder, loc, "extern \"C\" {"); |
| emitc_builders::preprocessorDirective(builder, loc, emitc_builders::ENDIF, |
| "// __cplusplus"); |
| |
| // Emit declarations for public functions. |
| for (auto funcOp : moduleOp.getOps<mlir::emitc::FuncOp>()) { |
| if (funcOp.isPublic()) { |
| emitc::DeclareFuncOp::create(builder, loc, funcOp.getName()); |
| } |
| } |
| |
| emitc_builders::preprocessorDirective(builder, loc, emitc_builders::IFDEF, |
| "__cplusplus"); |
| emitc::VerbatimOp::create(builder, loc, "} // extern \"C\""); |
| emitc_builders::preprocessorDirective(builder, loc, emitc_builders::ENDIF, |
| "// __cplusplus"); |
| emitc_builders::preprocessorDirective(builder, loc, emitc_builders::ENDIF, |
| std::string("// ") + includeGuard); |
| emitc_builders::preprocessorDirective(builder, loc, emitc_builders::IF, |
| "defined(EMITC_IMPLEMENTATION)"); |
| emitc::IncludeOp::create(builder, loc, "iree/vm/ops.h"); |
| emitc::IncludeOp::create(builder, loc, "iree/vm/ops_emitc.h"); |
| emitc::IncludeOp::create(builder, loc, "iree/vm/shims_emitc.h"); |
| |
| // Rodata ops. |
| for (auto rodataOp : moduleOp.getOps<IREE::VM::RodataOp>()) { |
| auto value = |
| dyn_cast<IREE::Util::SerializableAttrInterface>(rodataOp.getValue()); |
| assert(value && "expected a serializable rodata value"); |
| SmallVector<char> byteBuffer; |
| if (failed(value.serializeToVector( |
| rodataOp.getLoc(), llvm::endianness::little, byteBuffer))) { |
| return rodataOp.emitError() << "error during serialization"; |
| } |
| |
| constexpr size_t kDefaultRodataAlignment = 16; |
| size_t alignment = |
| rodataOp.getAlignment() |
| ? static_cast<size_t>(rodataOp.getAlignment().value()) |
| : 0; |
| if (alignment == 0) |
| alignment = kDefaultRodataAlignment; |
| |
| std::string bufferName = |
| moduleOp.getName().str() + "_" + rodataOp.getName().str(); |
| |
| std::string stmt = "iree_alignas(" + std::to_string(alignment) + |
| ") static const uint8_t " + bufferName + "[] = {"; |
| size_t index = 0; |
| for (char value : byteBuffer) { |
| if (index++ > 0) |
| stmt += ", "; |
| stmt += std::to_string( |
| static_cast<unsigned int>(static_cast<unsigned char>(value))); |
| } |
| stmt += "};"; |
| emitc::VerbatimOp::create(builder, loc, stmt); |
| opsToRemove.push_back(rodataOp.getOperation()); |
| } |
| |
| // structs |
| // Returns |count| or 1 if |count| == 0. |
| // Some compilers (MSVC) don't support zero-length struct fields on the |
| // interior of structs (just VLA at the tail). |
| auto countOrEmpty = [](uint32_t count) { return count ? count : 1; }; |
| |
| const int64_t numTypes = typeConverter.analysis.typeTable.size(); |
| |
| std::string moduleStructName = moduleOp.getName().str() + "_t"; |
| SmallVector<emitc_builders::StructField> moduleStructFields{ |
| {"iree_allocator_t", "allocator"}, |
| {"iree_vm_ref_type_t", "types", countOrEmpty(numTypes)}}; |
| |
| emitc_builders::structDefinition(builder, loc, moduleStructName, |
| moduleStructFields); |
| |
| auto ordinalCounts = moduleOp.getOrdinalCountsAttr(); |
| |
| std::string moduleStructStateName = moduleOp.getName().str() + "_state_t"; |
| SmallVector<emitc_builders::StructField> moduleStructStateFields{ |
| {"iree_allocator_t", "allocator"}, |
| {"uint8_t", "rwdata", countOrEmpty(ordinalCounts.getGlobalBytes())}, |
| {"iree_vm_ref_t", "refs", countOrEmpty(ordinalCounts.getGlobalRefs())}, |
| {"iree_vm_buffer_t", "rodata_buffers", |
| countOrEmpty(ordinalCounts.getRodatas())}, |
| {"iree_vm_function_t", "imports", |
| countOrEmpty(ordinalCounts.getImportFuncs())}, |
| }; |
| |
| emitc_builders::structDefinition(builder, loc, moduleStructStateName, |
| moduleStructStateFields); |
| |
| // Create a typedef for the begin_call member of the `iree_vm_module_t` |
| // type. See `vm/module.h` for the definition. The EmitC dialect doesn't |
| // handle function pointer types currently. Neither member lookups nor |
| // calls can be fused into one expression at the moment. |
| StringRef beginCallTypedef = |
| "typedef iree_status_t(*begin_call_t)(void*, iree_vm_stack_t*, " |
| "iree_vm_function_call_t);"; |
| emitc::VerbatimOp::create(builder, loc, beginCallTypedef); |
| |
| // Emit declarations for private functions. |
| for (auto funcOp : moduleOp.getOps<mlir::emitc::FuncOp>()) { |
| if (funcOp.isPrivate()) { |
| emitc::DeclareFuncOp::create(builder, loc, funcOp.getName()); |
| } |
| } |
| |
| // TODO(simon-camp): Move these to a structured helper |
| // global descriptors |
| // - define structs for each entity etc. |
| auto printStringView = [](StringRef s) -> std::string { |
| // We can't use iree_make_string_view because function calls are not |
| // allowed for constant expressions in C. |
| // TODO(#7605): Switch to IREE_SVL. We can't use IREE_SVL today because it |
| // uses designated initializers, which cause issues when compiled as C++. |
| return ("{\"" + s + "\", " + std::to_string(s.size()) + "}").str(); |
| }; |
| |
| // dependencies |
| std::string dependenciesName = moduleOp.getName().str() + "_dependencies_"; |
| std::string deps; |
| deps += "static const iree_vm_module_dependency_t " + dependenciesName + |
| "[] = {"; |
| auto dependencies = moduleOp.getDependencies(); |
| if (dependencies.empty()) { |
| // Empty list placeholder. |
| deps += "{{0}},"; |
| } else { |
| for (auto &dependency : dependencies) { |
| deps += "{" + printStringView(dependency.name) + ", " + |
| std::to_string(dependency.minimumVersion) + ", " + |
| (dependency.isOptional |
| ? "IREE_VM_MODULE_DEPENDENCY_FLAG_OPTIONAL" |
| : "IREE_VM_MODULE_DEPENDENCY_FLAG_REQUIRED") + |
| "},"; |
| } |
| } |
| deps += "};"; |
| emitc::VerbatimOp::create(builder, loc, deps); |
| |
| // Imports. |
| SmallVector<IREE::VM::ImportOp> importOps( |
| moduleOp.getOps<IREE::VM::ImportOp>()); |
| std::string importName = moduleOp.getName().str() + "_imports_"; |
| std::string imports; |
| imports += "static const iree_vm_native_import_descriptor_t " + importName + |
| "[] = {"; |
| if (importOps.empty()) { |
| // Empty list placeholder. |
| imports += "{0},"; |
| } else { |
| // Sort import ops by ordinal. |
| llvm::sort(importOps, [](auto &lhs, auto &rhs) { |
| return lhs.getOrdinal()->getZExtValue() < |
| rhs.getOrdinal()->getZExtValue(); |
| }); |
| for (auto importOp : importOps) { |
| imports += |
| std::string("{") + |
| (importOp.getIsOptional() ? "IREE_VM_NATIVE_IMPORT_OPTIONAL" |
| : "IREE_VM_NATIVE_IMPORT_REQUIRED") + |
| ", " + printStringView(importOp.getName()) + "},"; |
| } |
| } |
| imports += "};"; |
| emitc::VerbatimOp::create(builder, loc, imports); |
| |
| for (auto op : moduleOp.getOps<IREE::VM::ImportOp>()) { |
| opsToRemove.push_back(op); |
| } |
| |
| // Exports. |
| SmallVector<emitc::FuncOp> exportedFunctions; |
| for (auto func : moduleOp.getOps<emitc::FuncOp>()) { |
| if (typeConverter.analysis.lookupFunction(func).isExported()) { |
| exportedFunctions.push_back(func); |
| } |
| } |
| auto extractExportName = [&typeConverter](emitc::FuncOp funcOp) { |
| return typeConverter.analysis.lookupFunction(funcOp).getExportName(); |
| }; |
| std::string exportName = moduleOp.getName().str() + "_exports_"; |
| std::string exports; |
| exports += "static const iree_vm_native_export_descriptor_t " + exportName + |
| "[] = {"; |
| if (exportedFunctions.empty()) { |
| // Empty list placeholder. |
| exports += "{{0}},"; |
| } else { |
| // Sort export ops. |
| llvm::sort(exportedFunctions, [&extractExportName](auto &lhs, auto &rhs) { |
| return extractExportName(lhs).compare(extractExportName(rhs)) < 0; |
| }); |
| for (auto funcOp : exportedFunctions) { |
| StringRef exportName = extractExportName(funcOp); |
| StringRef callingConvention = |
| typeConverter.analysis.lookupFunction(funcOp) |
| .getCallingConvention(); |
| |
| // TODO(simon-camp): support function-level reflection attributes |
| exports += "{" + printStringView(exportName) + ", " + |
| printStringView(callingConvention) + ", 0, NULL},"; |
| } |
| } |
| exports += "};"; |
| emitc::VerbatimOp::create(builder, loc, exports); |
| |
| // Functions. |
| std::string functionName = moduleOp.getName().str() + "_funcs_"; |
| std::string functions; |
| functions += |
| "static const iree_vm_native_function_ptr_t " + functionName + "[] = {"; |
| if (exportedFunctions.empty()) { |
| // Empty list placeholder. |
| functions += "{0},"; |
| } else { |
| // We only add exported functions to the table, as calls to internal |
| // functions are directly mapped to C function calls of the generated |
| // implementation. |
| for (auto funcOp : exportedFunctions) { |
| auto funcName = funcOp.getName(); |
| functions += std::string("{") + |
| "(iree_vm_native_function_shim_t)iree_emitc_shim, " + |
| "(iree_vm_native_function_target_t)" + funcName.str() + |
| "},"; |
| } |
| } |
| functions += "};"; |
| emitc::VerbatimOp::create(builder, loc, functions); |
| |
| // Module descriptor. |
| // TODO(simon-camp): support module-level reflection attributes |
| std::string descriptorName = moduleOp.getName().str() + "_descriptor_"; |
| std::string descriptor; |
| descriptor += |
| "static const iree_vm_native_module_descriptor_t " + descriptorName + |
| " = {" |
| // name: |
| + printStringView(moduleOp.getName()) + |
| "," |
| // version: |
| + std::to_string(moduleOp.getVersion().value_or(0u)) + |
| "," |
| // attrs: |
| + "0," + |
| "NULL," |
| // dependencies: |
| + std::to_string(dependencies.size()) + "," + dependenciesName + |
| "," |
| // imports: |
| + std::to_string(importOps.size()) + "," + importName + |
| "," |
| // exports: |
| + std::to_string(exportedFunctions.size()) + "," + exportName + |
| "," |
| // functions: |
| + std::to_string(exportedFunctions.size()) + "," + functionName + "," + |
| "};"; |
| |
| emitc::VerbatimOp::create(builder, loc, descriptor); |
| |
| // Move functions marked as `emitAtEnd` to the end of the module. |
| auto funcs = |
| SmallVector<emitc::FuncOp>(moduleOp.getOps<mlir::emitc::FuncOp>()); |
| for (auto func : funcs) { |
| if (typeConverter.analysis.lookupFunction(func).shouldEmitAtEnd()) { |
| func->moveBefore(moduleOp.getBlock().getTerminator()); |
| } |
| } |
| |
| builder.setInsertionPoint(moduleOp.getBlock().getTerminator()); |
| emitc_builders::preprocessorDirective(builder, loc, emitc_builders::ENDIF, |
| " // EMITC_IMPLEMENTATION"); |
| } |
| |
| for (auto op : opsToRemove) { |
| op->erase(); |
| } |
| |
| // TODO(simon-camp): The Cpp Emitter expects a builtin.module as the |
| // outer container of the supported operations. Instead of nesting a |
| // builtin.module inside the vm.module as in the current implementation, |
| // the conversion pass should be changed to replace the vm.module |
| // with a builtin.module. |
| builder.setInsertionPointToStart(&moduleOp.getBlock()); |
| auto innerModule = mlir::ModuleOp::create(builder, loc); |
| |
| IRRewriter rewriter(moduleOp.getContext()); |
| |
| for (Operation &op : llvm::make_early_inc_range(moduleOp.getBlock())) { |
| if (isa<IREE::VM::ModuleTerminatorOp>(op) || |
| &op == innerModule.getOperation()) { |
| continue; |
| } |
| rewriter.moveOpBefore(&op, innerModule.getBody(), |
| innerModule.getBody()->end()); |
| } |
| |
| return success(); |
| } |
| |
| SmallVector<Attribute> indexSequence(int64_t n, MLIRContext *ctx) { |
| return llvm::map_to_vector(llvm::seq<int64_t>(0, n), |
| [&ctx](int64_t i) -> Attribute { |
| return IntegerAttr::get(IndexType::get(ctx), i); |
| }); |
| } |
| |
| template <typename ResultOpTy> |
| ResultOpTy lookupSymbolRef(Operation *accessOp, StringRef attrName) { |
| FlatSymbolRefAttr globalAttr = |
| accessOp->getAttrOfType<FlatSymbolRefAttr>(attrName); |
| ResultOpTy globalOp = |
| accessOp->getParentOfType<IREE::VM::ModuleOp>().lookupSymbol<ResultOpTy>( |
| globalAttr.getValue()); |
| return globalOp; |
| } |
| |
| template <typename OpTy> |
| class EmitCConversionPattern : public OpConversionPattern<OpTy> { |
| public: |
| using Adaptor = typename OpTy::Adaptor; |
| using OpConversionPattern<OpTy>::OpConversionPattern; |
| |
| EmitCConversionPattern(const TypeConverter &typeConverter, |
| MLIRContext *context, PatternBenefit benefit = 1) |
| : OpConversionPattern<OpTy>(typeConverter, context, benefit) {} |
| |
| protected: |
| IREE::VM::ModuleAnalysis &getModuleAnalysis() const { |
| return this->template getTypeConverter<IREE::VM::EmitCTypeConverter>() |
| ->analysis; |
| } |
| }; |
| |
| // Convert vm operations to emitc opaque_calls. The resultiong opaque_call has |
| // the ops operands as arguments followed by an argument for every attribute. |
| template <typename OpTy> |
| class GenericOpConversion : public EmitCConversionPattern<OpTy> { |
| using Adaptor = typename OpTy::Adaptor; |
| using EmitCConversionPattern<OpTy>::EmitCConversionPattern; |
| |
| public: |
| GenericOpConversion(const TypeConverter &typeConverter, MLIRContext *context, |
| StringRef funcName) |
| : EmitCConversionPattern<OpTy>(typeConverter, context), |
| funcName(funcName) {} |
| |
| private: |
| LogicalResult |
| matchAndRewrite(OpTy op, Adaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| auto ctx = op.getContext(); |
| |
| // Default to an empty args attribute, which results in the operands being |
| // printed as the arguments to the function call. |
| SmallVector<Attribute> args_; |
| |
| // If the operation has attributes, we need to explicitely build the args |
| // attribute of the emitc opaque_call op. This consists of index attributes |
| // for the operands, followed by the source op attributes themselves. |
| if (op->getAttrs().size() > 0) { |
| args_ = indexSequence(adaptor.getOperands().size(), op.getContext()); |
| |
| for (NamedAttribute attr : op->getAttrs()) { |
| args_.push_back(attr.getValue()); |
| } |
| } |
| |
| ArrayAttr args = |
| args_.size() > 0 ? ArrayAttr::get(ctx, args_) : ArrayAttr{}; |
| |
| rewriter.replaceOpWithNewOp<emitc::CallOpaqueOp>( |
| /*op=*/op, |
| /*type=*/op.getOperation()->getResultTypes(), |
| /*callee=*/funcName, |
| /*operands=*/adaptor.getOperands(), |
| /*args=*/args); |
| |
| return success(); |
| } |
| |
| StringRef funcName; |
| }; |
| |
| template <typename OpTy> |
| class DeleteOpConversion : public EmitCConversionPattern<OpTy> { |
| using Adaptor = typename OpTy::Adaptor; |
| using EmitCConversionPattern<OpTy>::EmitCConversionPattern; |
| |
| LogicalResult |
| matchAndRewrite(OpTy op, Adaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| rewriter.eraseOp(op); |
| return success(); |
| } |
| }; |
| |
| class FuncOpConversion : public EmitCConversionPattern<mlir::emitc::FuncOp> { |
| using Adaptor = mlir::emitc::FuncOp::Adaptor; |
| using EmitCConversionPattern<mlir::emitc::FuncOp>::EmitCConversionPattern; |
| |
| LogicalResult |
| matchAndRewrite(mlir::emitc::FuncOp funcOp, Adaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| // Entry block arguments, i.e. function arguments get converted 1:1. |
| // VM::RefType arguments get replaced by iree_vm_ref_t*. |
| { |
| Block &block = funcOp.getBlocks().front(); |
| TypeConverter::SignatureConversion signatureConversion( |
| block.getNumArguments()); |
| |
| for (const auto &[index, arg] : llvm::enumerate(block.getArguments())) { |
| Type convertedType = getTypeConverter()->convertType(arg.getType()); |
| signatureConversion.addInputs(index, convertedType); |
| } |
| |
| rewriter.applySignatureConversion(&block, signatureConversion); |
| |
| rewriter.modifyOpInPlace(funcOp, [&] { |
| funcOp.setType( |
| rewriter.getFunctionType(signatureConversion.getConvertedTypes(), |
| funcOp.getFunctionType().getResults())); |
| }); |
| } |
| |
| // Non-entry block arguments are handled differently between numeric types |
| // and VM::RefType. |
| { |
| for (Block &block : llvm::make_early_inc_range( |
| llvm::drop_begin(funcOp.getBlocks(), 1))) { |
| TypeConverter::SignatureConversion signatureConversion( |
| block.getNumArguments()); |
| |
| for (const auto &[index, arg] : llvm::enumerate(block.getArguments())) { |
| if (isa<IREE::VM::RefType>(arg.getType())) { |
| // VM::RefType arguments are dropped and their uses are replaced. |
| // The replacement values are determined by the register allocation |
| // pass. |
| Value ref = getModuleAnalysis().lookupRef(arg); |
| signatureConversion.remapInput(index, ref); |
| } else { |
| // Numerically typed arguments are kept as block arguments. These |
| // are automatically handled later in the emitter. |
| signatureConversion.addInputs(index, arg.getType()); |
| } |
| } |
| |
| Block *newBlock = |
| rewriter.applySignatureConversion(&block, signatureConversion); |
| |
| // The signatureConversion stores a mapping from the original block |
| // argument index to the replacement value. This information is needed |
| // in the conversion of branch ops to correctly map from branch operands |
| // to the replacement values. |
| getModuleAnalysis().lookupFunction(funcOp).cacheBlockConversion( |
| newBlock, signatureConversion); |
| } |
| } |
| |
| return success(); |
| } |
| }; |
| |
| class ExportOpConversion : public EmitCConversionPattern<IREE::VM::ExportOp> { |
| using Adaptor = IREE::VM::ExportOp::Adaptor; |
| using EmitCConversionPattern<IREE::VM::ExportOp>::EmitCConversionPattern; |
| |
| struct GeneratedStruct { |
| std::optional<TypedValue<emitc::LValueType>> value = std::nullopt; |
| std::optional<std::string> name = std::nullopt; |
| SmallVector<Value> callArguments; |
| }; |
| |
| LogicalResult |
| matchAndRewrite(IREE::VM::ExportOp exportOp, Adaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| auto ctx = exportOp.getContext(); |
| auto loc = exportOp.getLoc(); |
| |
| mlir::emitc::FuncOp funcOp = lookupSymbolRef<mlir::emitc::FuncOp>( |
| exportOp.getOperation(), "function_ref"); |
| |
| std::string newFuncName = funcOp.getName().str() + "_export_shim"; |
| |
| Type stackType = |
| emitc::PointerType::get(emitc::OpaqueType::get(ctx, "iree_vm_stack_t")); |
| Type flagsType = emitc::OpaqueType::get(ctx, "uint32_t"); |
| Type spanType = emitc::OpaqueType::get(ctx, "iree_byte_span_t"); |
| Type moduleType = |
| emitc::PointerType::get(emitc::OpaqueType::get(ctx, "void")); |
| Type moduleStateType = |
| emitc::PointerType::get(emitc::OpaqueType::get(ctx, "void")); |
| |
| SmallVector<Type> inputTypes = { |
| stackType, // SHIM_ARGUMENT_STACK |
| flagsType, // SHIM_ARGUMENT_FLAGS |
| spanType, // SHIM_ARGUMENT_ARGS_STORAGE |
| spanType, // SHIM_ARGUMENT_RETS_STORAGE |
| moduleType, // SHIM_ARGUMENT_MODULE |
| moduleStateType, // SHIM_ARGUMENT_MODULE_STATE |
| }; |
| |
| auto newFuncType = mlir::FunctionType::get( |
| ctx, {inputTypes}, {emitc::OpaqueType::get(ctx, "iree_status_t")}); |
| |
| auto newFuncOp = |
| mlir::emitc::FuncOp::create(rewriter, loc, newFuncName, newFuncType); |
| newFuncOp.setSpecifiersAttr( |
| rewriter.getArrayAttr({rewriter.getStringAttr("static")})); |
| newFuncOp.setPrivate(); |
| |
| getModuleAnalysis().addFromExport(newFuncOp, exportOp); |
| |
| // Populate newly generated function. |
| { |
| OpBuilder::InsertionGuard guard(rewriter); |
| Block *block = rewriter.createBlock(&newFuncOp.getFunctionBody(), |
| newFuncOp.getFunctionBody().end()); |
| |
| // Insert arguments into block. |
| block->addArgument(stackType, loc); // SHIM_ARGUMENT_STACK |
| block->addArgument(flagsType, loc); // SHIM_ARGUMENT_FLAGS |
| block->addArgument(spanType, loc); // SHIM_ARGUMENT_ARGS_STORAGE |
| block->addArgument(spanType, loc); // SHIM_ARGUMENT_RETS_STORAGE |
| block->addArgument(moduleType, loc); // SHIM_ARGUMENT_MODULE |
| block->addArgument(moduleStateType, loc); // SHIM_ARGUMENT_MODULE_STATE |
| |
| rewriter.setInsertionPointToStart(block); |
| |
| // Create typedefs for argument and result structs. |
| auto typedefs = |
| typedefArgumentAndResultStructs(rewriter, exportOp, newFuncOp); |
| |
| if (failed(typedefs)) { |
| return exportOp.emitError() << "struct typedef failed."; |
| } |
| |
| GeneratedStruct argumentStruct; |
| GeneratedStruct resultStruct; |
| |
| std::tie(argumentStruct, resultStruct) = typedefs.value(); |
| |
| // Cast module and module state structs. |
| auto moduleStructs = |
| castModuleAndStateStructs(rewriter, exportOp, newFuncOp); |
| |
| if (failed(moduleStructs)) { |
| return exportOp.emitError() << "module struct casting failed."; |
| } |
| |
| Value moduleStruct; |
| Value moduleStateStruct; |
| |
| std::tie(moduleStruct, moduleStateStruct) = moduleStructs.value(); |
| |
| // Cast argument and result structs. |
| castArgumentAndResultStructs(rewriter, exportOp, newFuncOp, |
| argumentStruct, resultStruct); |
| |
| // Unpack arguments from struct. |
| auto arguments = unpackArguments(rewriter, exportOp, argumentStruct); |
| |
| if (failed(arguments)) { |
| return exportOp.emitError() << "failed to unpack arguments."; |
| } |
| |
| // Unpack result pointers from struct. |
| auto results = unpackResults(rewriter, exportOp, resultStruct); |
| |
| if (failed(results)) { |
| return exportOp.emitError() << "failed to unpack results."; |
| } |
| |
| // Call internal function and return on error. |
| SmallVector<Value> operands{block->getArgument(SHIM_ARGUMENT_STACK), |
| moduleStruct, moduleStateStruct}; |
| |
| for (auto &argument : argumentStruct.callArguments) { |
| operands.push_back(argument); |
| } |
| for (auto &result : resultStruct.callArguments) { |
| operands.push_back(result); |
| } |
| |
| returnIfError(rewriter, loc, funcOp, operands, getModuleAnalysis()); |
| |
| auto status = emitc_builders::ireeOkStatus(rewriter, loc); |
| |
| mlir::emitc::ReturnOp::create(rewriter, loc, status); |
| } |
| |
| rewriter.eraseOp(exportOp); |
| |
| return success(); |
| } |
| |
| FailureOr<std::pair<Value, Value>> |
| castModuleAndStateStructs(ConversionPatternRewriter &rewriter, |
| IREE::VM::ExportOp &exportOp, |
| mlir::emitc::FuncOp &newFuncOp) const { |
| auto ctx = exportOp.getContext(); |
| auto loc = exportOp.getLoc(); |
| |
| auto module = newFuncOp.getArgument(SHIM_ARGUMENT_MODULE); |
| auto moduleState = newFuncOp.getArgument(SHIM_ARGUMENT_MODULE_STATE); |
| |
| auto moduleOp = |
| newFuncOp.getOperation()->getParentOfType<IREE::VM::ModuleOp>(); |
| |
| std::string moduleTypeName = |
| std::string("struct ") + moduleOp.getName().str() + "_t"; |
| std::string moduleStateTypeName = |
| std::string("struct ") + moduleOp.getName().str() + "_state_t"; |
| |
| auto moduleCasted = emitc::CastOp::create( |
| rewriter, |
| /*location=*/loc, |
| /*type=*/ |
| emitc::PointerType::get(emitc::OpaqueType::get(ctx, moduleTypeName)), |
| /*operand=*/module); |
| |
| auto moduleStateCasted = |
| emitc::CastOp::create(rewriter, |
| /*location=*/loc, |
| /*type=*/ |
| emitc::PointerType::get(emitc::OpaqueType::get( |
| ctx, moduleStateTypeName)), |
| /*operand=*/moduleState); |
| |
| return {{moduleCasted.getResult(), moduleStateCasted.getResult()}}; |
| } |
| |
| FailureOr<std::pair<GeneratedStruct, GeneratedStruct>> |
| typedefArgumentAndResultStructs(ConversionPatternRewriter &rewriter, |
| IREE::VM::ExportOp &exportOp, |
| mlir::emitc::FuncOp &newFuncOp) const { |
| auto loc = exportOp.getLoc(); |
| |
| mlir::emitc::FuncOp funcOp = lookupSymbolRef<mlir::emitc::FuncOp>( |
| exportOp.getOperation(), "function_ref"); |
| |
| auto &funcAnalysis = getModuleAnalysis().lookupFunction(funcOp); |
| |
| auto generateStructFields = [this](ArrayRef<Type> types, StringRef prefix) |
| -> FailureOr<SmallVector<emitc_builders::StructField>> { |
| SmallVector<emitc_builders::StructField> result; |
| |
| for (auto pair : llvm::enumerate(types)) { |
| emitc::OpaqueType cType = |
| getTypeConverter<IREE::VM::EmitCTypeConverter>() |
| ->convertTypeAsCType(pair.value()); |
| |
| if (!cType) { |
| return failure(); |
| } |
| |
| auto fieldName = prefix.str() + std::to_string(pair.index()); |
| result.push_back({cType.getValue().str(), fieldName}); |
| } |
| |
| return result; |
| }; |
| |
| // TODO(simon-camp): Clean up; We generate calls to a macro that defines |
| // a struct. As we declare all variables at the start of the function, |
| // the macro call cannot be inlined into the function. |
| |
| // To prevent scoping issues we prefix the struct name with module and |
| // function name. |
| auto typedefStruct = [&rewriter, &newFuncOp, |
| &loc](std::string structName, |
| ArrayRef<emitc_builders::StructField> fields) { |
| OpBuilder::InsertionGuard guard(rewriter); |
| rewriter.setInsertionPoint(newFuncOp.getOperation()); |
| |
| emitc_builders::structDefinition(/*builder=*/rewriter, /*location=*/loc, |
| /*structName=*/structName, |
| /*fields=*/fields); |
| }; |
| |
| FunctionType funcType = funcAnalysis.getFunctionType(); |
| |
| GeneratedStruct argumentStruct; |
| GeneratedStruct resultStruct; |
| |
| const bool needArgumentStruct = funcType.getNumInputs() > 0; |
| |
| if (needArgumentStruct) { |
| auto structBody = generateStructFields(funcType.getInputs(), "arg"); |
| if (failed(structBody)) { |
| return exportOp.emitError() |
| << "failed to emit C type for struct definition"; |
| } |
| |
| std::string structName = funcOp.getName().str() + "_args_t"; |
| argumentStruct.name = structName; |
| typedefStruct(structName, structBody.value()); |
| } |
| |
| const bool needResultStruct = funcType.getNumResults() > 0; |
| |
| if (needResultStruct) { |
| auto structBody = generateStructFields(funcType.getResults(), "res"); |
| |
| if (failed(structBody)) { |
| return failure(); |
| } |
| |
| std::string structName = funcOp.getName().str() + "_result_t"; |
| resultStruct.name = structName; |
| typedefStruct(structName, structBody.value()); |
| } |
| |
| return {{argumentStruct, resultStruct}}; |
| } |
| |
| void castArgumentAndResultStructs(ConversionPatternRewriter &rewriter, |
| IREE::VM::ExportOp &exportOp, |
| mlir::emitc::FuncOp &newFuncOp, |
| GeneratedStruct &argumentStruct, |
| GeneratedStruct &resultStruct) const { |
| auto ctx = exportOp.getContext(); |
| auto loc = exportOp.getLoc(); |
| |
| const bool haveArgumentStruct = argumentStruct.name.has_value(); |
| |
| if (haveArgumentStruct) { |
| auto argumentsLValue = emitc_builders::asLValue( |
| rewriter, loc, newFuncOp.getArgument(SHIM_ARGUMENT_ARGS_STORAGE)); |
| |
| // args_t* args = (args_t*)call->arguments.data; |
| |
| // arguments.data |
| auto argumentsData = emitc_builders::structMember( |
| rewriter, loc, |
| /*type=*/emitc::PointerType::get(rewriter.getIntegerType(8, false)), |
| /*memberName=*/"data", |
| /*operand=*/argumentsLValue); |
| |
| // cast |
| std::string argumentsType = |
| std::string("struct ") + argumentStruct.name.value(); |
| auto arguments = emitc::CastOp::create( |
| rewriter, |
| /*location=*/loc, |
| /*type=*/ |
| emitc::PointerType::get(emitc::OpaqueType::get(ctx, argumentsType)), |
| /*operand=*/argumentsData); |
| |
| argumentsLValue = |
| emitc_builders::asLValue(rewriter, loc, arguments.getResult()); |
| |
| argumentStruct.value = argumentsLValue; |
| } |
| |
| const bool haveResultStruct = resultStruct.name.has_value(); |
| if (haveResultStruct) { |
| auto resultsLValue = emitc_builders::asLValue( |
| rewriter, loc, newFuncOp.getArgument(SHIM_ARGUMENT_RETS_STORAGE)); |
| // results_t* results = (results_t*)call->results.data; |
| |
| // results.data |
| auto resultsData = cast< |
| TypedValue<emitc::PointerType>>(emitc_builders::structMember( |
| rewriter, loc, |
| /*type=*/emitc::PointerType::get(rewriter.getIntegerType(8, false)), |
| /*memberName=*/"data", /*operand=*/resultsLValue)); |
| |
| // cast |
| std::string resultType = |
| std::string("struct ") + resultStruct.name.value(); |
| auto results = emitc::CastOp::create( |
| rewriter, |
| /*location=*/loc, |
| /*type=*/ |
| emitc::PointerType::get(emitc::OpaqueType::get(ctx, resultType)), |
| /*operand=*/resultsData); |
| |
| resultsLValue = |
| emitc_builders::asLValue(rewriter, loc, results.getResult()); |
| |
| resultStruct.value = resultsLValue; |
| } |
| } |
| |
| LogicalResult unpackArguments(ConversionPatternRewriter &rewriter, |
| IREE::VM::ExportOp &exportOp, |
| GeneratedStruct &argumentStruct) const { |
| auto ctx = exportOp.getContext(); |
| auto loc = exportOp.getLoc(); |
| |
| // The struct is empty, nothing to do. |
| if (!argumentStruct.value.has_value()) { |
| return success(); |
| } |
| |
| mlir::emitc::FuncOp funcOp = lookupSymbolRef<mlir::emitc::FuncOp>( |
| exportOp.getOperation(), "function_ref"); |
| |
| auto &funcAnalysis = getModuleAnalysis().lookupFunction(funcOp); |
| |
| FunctionType funcType = funcAnalysis.getFunctionType(); |
| |
| for (const auto &input : llvm::enumerate(funcType.getInputs())) { |
| assert(argumentStruct.value.has_value()); |
| auto value = argumentStruct.value.value(); |
| |
| if (isa<IREE::VM::RefType>(input.value())) { |
| auto ptrType = emitc::PointerType::get( |
| emitc::OpaqueType::get(ctx, "iree_vm_ref_t")); |
| std::string memberName = "arg" + std::to_string(input.index()); |
| auto memberPtr = emitc_builders::structPtrMemberAddress( |
| rewriter, loc, ptrType, memberName, value); |
| emitc::CallOpaqueOp::create(rewriter, |
| /*location=*/memberPtr.getLoc(), |
| /*type=*/TypeRange{}, |
| /*callee=*/"iree_vm_ref_retain_inplace", |
| /*operands=*/ArrayRef<Value>{memberPtr}); |
| argumentStruct.callArguments.push_back(memberPtr); |
| } else { |
| Type memberType = input.value(); |
| std::string memberName = "arg" + std::to_string(input.index()); |
| auto member = emitc_builders::structPtrMember(rewriter, loc, |
| /*type=*/memberType, |
| /*memberName=*/memberName, |
| /*operand=*/value); |
| argumentStruct.callArguments.push_back(member); |
| } |
| } |
| |
| return success(); |
| } |
| |
| LogicalResult unpackResults(ConversionPatternRewriter &rewriter, |
| IREE::VM::ExportOp &exportOp, |
| GeneratedStruct &resultStruct) const { |
| auto loc = exportOp.getLoc(); |
| |
| // The struct is empty, nothing to do. |
| if (!resultStruct.value.has_value()) { |
| return success(); |
| } |
| |
| const auto typeConverter = getTypeConverter<IREE::VM::EmitCTypeConverter>(); |
| |
| mlir::emitc::FuncOp funcOp = lookupSymbolRef<mlir::emitc::FuncOp>( |
| exportOp.getOperation(), "function_ref"); |
| |
| auto &funcAnalysis = getModuleAnalysis().lookupFunction(funcOp); |
| |
| FunctionType funcType = funcAnalysis.getFunctionType(); |
| |
| for (const auto &result : llvm::enumerate(funcType.getResults())) { |
| assert(resultStruct.value.has_value()); |
| auto value = resultStruct.value.value(); |
| |
| auto ptrType = typeConverter->convertTypeAsPointer(result.value()); |
| |
| std::string memberName = "res" + std::to_string(result.index()); |
| Value memberPtr = emitc_builders::structPtrMemberAddress( |
| rewriter, loc, ptrType, memberName, value); |
| resultStruct.callArguments.push_back(memberPtr); |
| } |
| |
| return success(); |
| } |
| }; |
| |
| class ImportOpConverter { |
| public: |
| ImportOpConverter(IREE::VM::EmitCTypeConverter &typeConverter, |
| SmallVector<std::string> &importShims) |
| : typeConverter(typeConverter), importShims(importShims) {} |
| |
| LogicalResult operator()(IREE::VM::ImportOp importOp) const { |
| OpBuilder builder(importOp); |
| |
| auto key = makeImportCallingConventionString(importOp); |
| if (!key.has_value()) { |
| return importOp.emitError() |
| << "Failed to build key for import shim cache."; |
| } |
| |
| // The needed shim already exists. |
| if (llvm::find(importShims, key) != std::end(importShims)) { |
| return success(); |
| } |
| |
| if (importOp.isVariadic()) { |
| if (failed(createVariadicImportShims(importOp, builder))) { |
| return failure(); |
| } |
| } else { |
| if (failed(createImportShim(importOp, nullptr, builder))) { |
| return failure(); |
| } |
| } |
| |
| importShims.push_back(key.value()); |
| return success(); |
| } |
| |
| private: |
| struct MaybeZeroValue { |
| Value value; |
| bool isZero; |
| }; |
| |
| LogicalResult createVariadicImportShims(IREE::VM::ImportOp &importOp, |
| OpBuilder &builder) const { |
| SetVector<const void *> arities; |
| |
| for (auto caller : getCallers(importOp)) { |
| DenseIntElementsAttr segmentSizes = caller.getSegmentSizes(); |
| const void *p = segmentSizes.getAsOpaquePointer(); |
| if (arities.insert(p)) { |
| if (failed(createImportShim(importOp, segmentSizes, builder))) { |
| return failure(); |
| } |
| } |
| } |
| return success(); |
| } |
| |
| void failIfImportUnresolved(OpBuilder &builder, Location location, |
| TypedValue<emitc::LValueType> import) const { |
| auto *ctx = builder.getContext(); |
| Type boolType = builder.getIntegerType(1); |
| |
| // (iree_vm_function_t*)->module |
| auto importModule = emitc_builders::structPtrMember( |
| builder, location, |
| /*type=*/ |
| emitc::PointerType::get( |
| emitc::OpaqueType::get(ctx, "iree_vm_module_t")), |
| /*memberName=*/"module", |
| /*operand=*/import); |
| |
| auto conditionI1 = |
| emitc::LogicalNotOp::create(builder, |
| /*location=*/location, /*type=*/boolType, |
| /*operands=*/importModule) |
| .getResult(); |
| |
| // Start by splitting the block into two. The part before will contain the |
| // condition, and the part after will contain the continuation point. |
| Block *condBlock = builder.getInsertionBlock(); |
| Block::iterator opPosition = builder.getInsertionPoint(); |
| Block *continuationBlock = condBlock->splitBlock(opPosition); |
| |
| // Create a new block for the target of the failure. |
| Block *failureBlock; |
| { |
| OpBuilder::InsertionGuard guard(builder); |
| Region *parentRegion = condBlock->getParent(); |
| failureBlock = builder.createBlock(parentRegion, parentRegion->end()); |
| |
| mlir::emitc::FuncOp funcOp = |
| cast<mlir::emitc::FuncOp>(failureBlock->getParentOp()); |
| releaseRefs(builder, location, funcOp, typeConverter.analysis); |
| |
| auto statusOp = emitc::CallOpaqueOp::create( |
| builder, |
| /*location=*/location, |
| /*type=*/emitc::OpaqueType::get(ctx, "iree_status_t"), |
| /*callee=*/"iree_make_status", |
| /*operands=*/ArrayRef<Value>{}, |
| /*args=*/ |
| ArrayAttr::get( |
| ctx, {emitc::OpaqueAttr::get(ctx, "IREE_STATUS_NOT_FOUND")})); |
| mlir::emitc::ReturnOp::create(builder, location, statusOp.getResult(0)); |
| } |
| |
| builder.setInsertionPointToEnd(condBlock); |
| cf::CondBranchOp::create(builder, location, conditionI1, failureBlock, |
| continuationBlock); |
| |
| builder.setInsertionPointToStart(continuationBlock); |
| } |
| |
| LogicalResult createImportShim(IREE::VM::ImportOp &importOp, |
| DenseIntElementsAttr segmentSizes, |
| OpBuilder &builder) const { |
| auto loc = importOp.getLoc(); |
| |
| auto moduleOp = |
| importOp.getOperation()->getParentOfType<IREE::VM::ModuleOp>(); |
| |
| auto newFuncName = |
| importOp.isVariadic() |
| ? buildVariadicFunctionName(moduleOp, importOp, segmentSizes) |
| : buildFunctionName(moduleOp, importOp); |
| |
| if (!newFuncName.has_value()) { |
| return importOp.emitError() << "failed to build import shim name."; |
| } |
| |
| auto newFuncType = buildFuncType(importOp, segmentSizes, builder, loc); |
| |
| if (failed(newFuncType)) { |
| return importOp.emitError() |
| << "Failed to build function type for wrapper"; |
| } |
| |
| auto newFuncOp = mlir::emitc::FuncOp::create( |
| builder, loc, newFuncName.value(), newFuncType.value()); |
| newFuncOp.setSpecifiersAttr( |
| builder.getArrayAttr({builder.getStringAttr("static")})); |
| newFuncOp.setPrivate(); |
| |
| typeConverter.analysis.addFromImport(newFuncOp, importOp); |
| |
| // Populate newly generated function. |
| { |
| OpBuilder::InsertionGuard guard(builder); |
| Block *block = builder.createBlock(&newFuncOp.getFunctionBody(), |
| newFuncOp.getFunctionBody().end()); |
| |
| for (Type type : newFuncOp.getFunctionType().getInputs()) { |
| block->addArgument(type, loc); |
| } |
| |
| builder.setInsertionPointToStart(block); |
| |
| MaybeZeroValue argumentSize = buildSizeExpression( |
| flattenInputTypes(importOp, segmentSizes, builder), builder, loc); |
| MaybeZeroValue resultSize = |
| buildSizeExpression(importOp.getResultTypes(), builder, loc); |
| |
| const int importArgIndex = 1; |
| const BlockArgument importArg = newFuncOp.getArgument(importArgIndex); |
| auto importArgLValue = emitc_builders::asLValue(builder, loc, importArg); |
| failIfImportUnresolved(builder, loc, importArgLValue); |
| |
| auto call = buildIreeVmFunctionCallStruct(importArg, argumentSize, |
| resultSize, builder, loc); |
| |
| if (failed(call)) { |
| return importOp.emitError() << "failed to create call struct"; |
| } |
| |
| if (failed(packArgumentBuffer( |
| flattenInputTypes(importOp, segmentSizes, builder), newFuncOp, |
| call.value(), builder, loc))) { |
| return importOp.emitError() << "failed to pack argument struct"; |
| } |
| |
| const BlockArgument stackArg = |
| newFuncOp.getArgument(CCONV_ARGUMENT_STACK); |
| if (failed(createCall(call.value(), importArgLValue, stackArg, builder, |
| loc))) { |
| return importOp.emitError() << "failed to create call"; |
| } |
| |
| if (failed(unpackResultBuffer(importOp.getResultTypes(), newFuncOp, |
| call.value(), builder, loc))) { |
| return importOp.emitError() << "failed to unpack result struct"; |
| } |
| |
| auto status = emitc_builders::ireeOkStatus(builder, loc); |
| |
| mlir::emitc::ReturnOp::create(builder, loc, status); |
| } |
| |
| return success(); |
| } |
| |
| FailureOr<FunctionType> buildFuncType(IREE::VM::ImportOp importOp, |
| DenseIntElementsAttr segmentSizes, |
| OpBuilder &builder, |
| Location loc) const { |
| auto ctx = builder.getContext(); |
| |
| Type stackType = |
| emitc::PointerType::get(emitc::OpaqueType::get(ctx, "iree_vm_stack_t")); |
| Type funcType = emitc::PointerType::get( |
| emitc::OpaqueType::get(ctx, "iree_vm_function_t")); |
| |
| SmallVector<Type> types{stackType, funcType}; |
| |
| for (Type type : flattenInputTypes(importOp, segmentSizes, builder)) { |
| auto convertedType = typeConverter.convertType(type); |
| types.push_back(convertedType); |
| } |
| |
| for (auto &resultType : importOp.getResultTypes()) { |
| Type ptrType = typeConverter.convertTypeAsPointer(resultType); |
| |
| types.push_back(ptrType); |
| } |
| |
| FunctionType result = mlir::FunctionType::get( |
| ctx, {types}, {emitc::OpaqueType::get(ctx, "iree_status_t")}); |
| |
| return {result}; |
| } |
| |
| MaybeZeroValue buildSizeExpression(ArrayRef<Type> types, OpBuilder &builder, |
| Location loc) const { |
| auto ctx = builder.getContext(); |
| |
| Type hostSizeType = emitc::OpaqueType::get(ctx, "iree_host_size_t"); |
| |
| Value result = |
| emitc::ConstantOp::create(builder, |
| /*location=*/loc, |
| /*resultType=*/hostSizeType, |
| /*value=*/emitc::OpaqueAttr::get(ctx, "0")) |
| .getResult(); |
| bool isZero = true; |
| for (Type type : types) { |
| Type valueType = typeConverter.convertTypeAsNonPointer(type); |
| Value size = |
| emitc_builders::sizeOf(builder, loc, TypeAttr::get(valueType)); |
| |
| result = emitc::AddOp::create(builder, |
| /*location=*/loc, |
| /*type=*/hostSizeType, |
| /*operands=*/ArrayRef<Value>{result, size}) |
| .getResult(); |
| isZero = false; |
| } |
| |
| return MaybeZeroValue{result, isZero}; |
| } |
| |
| FailureOr<TypedValue<emitc::LValueType>> |
| buildIreeVmFunctionCallStruct(Value import, MaybeZeroValue argumentSize, |
| MaybeZeroValue resultSize, OpBuilder &builder, |
| Location loc) const { |
| auto ctx = builder.getContext(); |
| |
| // iree_vm_function_call_t call; |
| auto call = emitc_builders::allocateVariable( |
| builder, loc, emitc::OpaqueType::get(ctx, "iree_vm_function_call_t")); |
| |
| // importValue = *import; |
| auto importValue = emitc_builders::contentsOf(builder, loc, import); |
| |
| // call.function = importValue; |
| emitc_builders::structMemberAssign(builder, loc, |
| /*memberName=*/"function", |
| /*operand=*/call, |
| /*value=*/importValue); |
| |
| allocateByteSpan(call, argumentSize, "arguments", builder, loc); |
| allocateByteSpan(call, resultSize, "results", builder, loc); |
| |
| return {call}; |
| } |
| |
| Value allocateByteSpan(TypedValue<emitc::LValueType> call, |
| MaybeZeroValue size, StringRef memberName, |
| OpBuilder &builder, Location loc) const { |
| |
| auto ctx = builder.getContext(); |
| |
| // byteSpan = call.<memberName>; |
| TypedValue<mlir::Type> byteSpan = |
| emitc::MemberOp::create(builder, loc, |
| emitc::LValueType::get(emitc::OpaqueType::get( |
| ctx, "iree_byte_span_t")), |
| memberName, call) |
| .getResult(); |
| |
| // alloca_(0) returns NULL in some configurations on Windows. Make sure to |
| // allocate at least one byte to get a valid pointer. |
| Value allocaSize; |
| if (size.isZero) { |
| Type hostSizeType = emitc::OpaqueType::get(ctx, "iree_host_size_t"); |
| |
| allocaSize = |
| emitc::ConstantOp::create(builder, |
| /*location=*/loc, |
| /*resultType=*/hostSizeType, |
| /*value=*/emitc::OpaqueAttr::get(ctx, "1")) |
| .getResult(); |
| } else { |
| allocaSize = size.value; |
| } |
| |
| // void *byteSpan_data_void = iree_alloca(size); |
| auto byteSpanDataVoid = |
| emitc::CallOpaqueOp::create( |
| builder, |
| /*location=*/loc, |
| /*type=*/ |
| emitc::PointerType::get(emitc::OpaqueType::get(ctx, "void")), |
| /*callee=*/"iree_alloca", |
| /*operands=*/ArrayRef<Value>{allocaSize}) |
| .getResult(0); |
| |
| // uint8_t *byteSpan_data = (uint8_t*)byteSpan_data_void; |
| Type bytePtr = emitc::PointerType::get(builder.getIntegerType(8, false)); |
| auto byteSpanData = emitc::CastOp::create(builder, |
| /*location=*/loc, |
| /*type=*/bytePtr, |
| /*operand=*/byteSpanDataVoid) |
| .getResult(); |
| |
| // byteSpan.data_length = SIZE; |
| emitc_builders::structMemberAssign(builder, loc, |
| /*memberName=*/"data_length", |
| /*operand=*/byteSpan, |
| /*value=*/size.value); |
| |
| // byteSpan.data = byteSpan_data |
| emitc_builders::structMemberAssign(builder, loc, |
| /*memberName=*/"data", |
| /*operand=*/byteSpan, |
| /*value=*/byteSpanData); |
| |
| // memset(byteSpanData, 0, SIZE); |
| emitc_builders::memset(builder, loc, byteSpanData, 0, allocaSize); |
| |
| return byteSpan; |
| } |
| |
| LogicalResult packArgumentBuffer(ArrayRef<Type> inputTypes, |
| mlir::emitc::FuncOp &funcOp, |
| TypedValue<emitc::LValueType> call, |
| OpBuilder &builder, Location loc) const { |
| if (inputTypes.empty()) { |
| return success(); |
| } |
| |
| auto ctx = builder.getContext(); |
| |
| size_t inputOffset = 2; |
| |
| auto arguments = |
| emitc::MemberOp::create(builder, loc, |
| /*type=*/ |
| emitc::LValueType::get(emitc::OpaqueType::get( |
| ctx, "iree_byte_span_t")), |
| /*memberName=*/"arguments", |
| /*operand=*/call) |
| .getResult(); |
| |
| Type bytePtrType = |
| emitc::PointerType::get(builder.getIntegerType(8, false)); |
| auto uint8Ptr = emitc_builders::structMember(builder, loc, |
| /*type=*/bytePtrType, |
| /*memberName=*/"data", |
| /*operand=*/arguments); |
| for (size_t i = 0; i < inputTypes.size(); i++) { |
| BlockArgument arg = funcOp.getArgument(i + inputOffset); |
| |
| Type argType = arg.getType(); |
| assert(!isa<IREE::VM::RefType>(argType)); |
| |
| if (argType == emitc::PointerType::get( |
| emitc::OpaqueType::get(ctx, "iree_vm_ref_t"))) { |
| Type refPtrType = emitc::PointerType::get( |
| emitc::OpaqueType::get(ctx, "iree_vm_ref_t")); |
| Value refPtr = emitc::CastOp::create(builder, |
| /*location=*/loc, |
| /*type=*/refPtrType, |
| /*operand=*/uint8Ptr) |
| .getResult(); |
| |
| emitc::CallOpaqueOp::create(builder, |
| /*location=*/loc, |
| /*type=*/TypeRange{}, |
| /*callee=*/"iree_vm_ref_assign", |
| /*operands=*/ArrayRef<Value>{arg, refPtr}); |
| } else { |
| auto argLValue = emitc_builders::asLValue(builder, loc, arg); |
| assert(!isa<emitc::PointerType>(argType)); |
| Value size = |
| emitc_builders::sizeOf(builder, loc, TypeAttr::get(argType)); |
| |
| // memcpy(uint8Ptr, &arg, size); |
| Value argPtr = emitc_builders::addressOf(builder, loc, argLValue); |
| emitc_builders::memcpy(builder, loc, uint8Ptr, argPtr, size); |
| } |
| |
| // Skip the addition in the last iteration. |
| if (i < inputTypes.size() - 1) { |
| Type valueType = typeConverter.convertTypeAsNonPointer(argType); |
| Value size = |
| emitc_builders::sizeOf(builder, loc, TypeAttr::get(valueType)); |
| |
| uint8Ptr = |
| emitc::AddOp::create(builder, |
| /*location=*/loc, /*type=*/bytePtrType, |
| /*operands=*/ArrayRef<Value>{uint8Ptr, size}) |
| .getResult(); |
| } |
| } |
| return success(); |
| } |
| |
| LogicalResult unpackResultBuffer(ArrayRef<Type> resultTypes, |
| mlir::emitc::FuncOp &funcOp, |
| TypedValue<emitc::LValueType> call, |
| OpBuilder &builder, Location loc) const { |
| if (resultTypes.empty()) { |
| return success(); |
| } |
| |
| auto ctx = builder.getContext(); |
| |
| // The last N arguments are the results. |
| size_t resultOffset = funcOp.getNumArguments() - resultTypes.size(); |
| |
| auto results = emitc::MemberOp::create( |
| builder, loc, |
| /*type=*/ |
| emitc::LValueType::get(emitc::OpaqueType::get(ctx, "iree_byte_span_t")), |
| /*memberName=*/"results", |
| /*operand=*/call); |
| |
| Type bytePtrType = |
| emitc::PointerType::get(builder.getIntegerType(8, false)); |
| auto uint8Ptr = emitc_builders::structMember(builder, loc, |
| /*type=*/bytePtrType, |
| /*memberName=*/"data", |
| /*operand=*/results); |
| |
| for (size_t i = 0; i < resultTypes.size(); i++) { |
| BlockArgument arg = funcOp.getArgument(i + resultOffset); |
| Type argType = arg.getType(); |
| assert(!isa<IREE::VM::RefType>(argType)); |
| |
| if (argType == emitc::PointerType::get( |
| emitc::OpaqueType::get(ctx, "iree_vm_ref_t"))) { |
| Type refPtrType = emitc::PointerType::get( |
| emitc::OpaqueType::get(ctx, "iree_vm_ref_t")); |
| Value refPtr = emitc::CastOp::create(builder, |
| /*location=*/loc, |
| /*type=*/refPtrType, |
| /*operand=*/uint8Ptr) |
| .getResult(); |
| |
| emitc::CallOpaqueOp::create(builder, |
| /*location=*/loc, |
| /*type=*/TypeRange{}, |
| /*callee=*/"iree_vm_ref_move", |
| /*operands=*/ArrayRef<Value>{refPtr, arg}); |
| } else { |
| Type valueType = cast<emitc::PointerType>(argType).getPointee(); |
| Value size = |
| emitc_builders::sizeOf(builder, loc, TypeAttr::get(valueType)); |
| |
| // memcpy(arg, uint8Ptr, size); |
| emitc_builders::memcpy(builder, loc, arg, uint8Ptr, size); |
| } |
| |
| // Skip the addition in the last iteration. |
| if (i < resultTypes.size() - 1) { |
| Type valueType = cast<emitc::PointerType>(argType).getPointee(); |
| Value size = |
| emitc_builders::sizeOf(builder, loc, TypeAttr::get(valueType)); |
| uint8Ptr = |
| emitc::AddOp::create(builder, |
| /*location=*/loc, |
| /*type=*/bytePtrType, |
| /*operands=*/ArrayRef<Value>{uint8Ptr, size}) |
| .getResult(); |
| } |
| } |
| return success(); |
| } |
| |
| LogicalResult createCall(TypedValue<emitc::LValueType> call, |
| TypedValue<emitc::LValueType> import, Value stack, |
| OpBuilder &builder, Location loc) const { |
| auto ctx = builder.getContext(); |
| |
| // RETURN_IF_ERROR(import->module->begin_call(import->module, stack, call)); |
| auto im = emitc::MemberOfPtrOp::create( |
| builder, loc, |
| /*type=*/ |
| emitc::LValueType::get(emitc::PointerType::get( |
| emitc::OpaqueType::get(ctx, "iree_vm_module_t"))), |
| /*memberName=*/"module", |
| /*operand=*/import); |
| auto importModule = dyn_cast<TypedValue<emitc::LValueType>>(im.getResult()); |
| if (!importModule) { |
| return failure(); |
| } |
| |
| // EmitC can't emit function pointers, so we need to fallback to a typedef |
| // currently. This and the `EMITC_CALL_INDIRECT` macro can be replaced with |
| // a new `emitc.call_indirect` op once it has been added upstream. |
| emitc::OpaqueType type = emitc::OpaqueType::get(ctx, "begin_call_t"); |
| |
| auto bc = |
| emitc::MemberOfPtrOp::create(builder, loc, emitc::LValueType::get(type), |
| "begin_call", importModule) |
| .getResult(); |
| auto beginCall = dyn_cast<TypedValue<emitc::LValueType>>(bc); |
| if (!beginCall) { |
| return failure(); |
| } |
| returnIfError( |
| /*rewriter=*/builder, |
| /*location=*/loc, |
| /*callee=*/"EMITC_CALL_INDIRECT", |
| /*args=*/{}, |
| /*operands=*/ |
| ArrayRef<Value>{emitc_builders::asRValue(builder, loc, beginCall), |
| emitc_builders::asRValue(builder, loc, importModule), |
| stack, emitc_builders::asRValue(builder, loc, call)}, |
| typeConverter.analysis); |
| |
| return success(); |
| } |
| |
| // A span count of -1 means a non variadic call |
| SmallVector<Type> flattenInputTypes(IREE::VM::ImportOp importOp, |
| DenseIntElementsAttr segmentSizes, |
| OpBuilder &builder) const { |
| assert(!segmentSizes || |
| (importOp.getNumArguments() == segmentSizes.size())); |
| |
| SmallVector<Type> result; |
| auto expandType = [&result](Type type) { |
| if (auto tupleType = dyn_cast<TupleType>(type)) { |
| for (Type inner : tupleType) { |
| result.push_back(inner); |
| } |
| } else { |
| result.push_back(type); |
| } |
| }; |
| |
| for (size_t i = 0; i < importOp.getNumArguments(); i++) { |
| Type type = importOp.getFunctionType().getInput(i); |
| |
| if (importOp.isFuncArgumentVariadic(i)) { |
| assert(segmentSizes && "segmentSizes must not be nullptr"); |
| APInt segmentSize = *(segmentSizes.begin() + i); |
| int64_t size = segmentSize.getSExtValue(); |
| result.push_back(builder.getI32Type()); |
| |
| assert(size >= 0); |
| for (int j = 0; j < size; j++) { |
| expandType(type); |
| } |
| } else { |
| expandType(type); |
| } |
| } |
| |
| return result; |
| } |
| |
| SmallVector<IREE::VM::CallVariadicOp> |
| getCallers(IREE::VM::ImportOp &importOp) const { |
| SmallVector<IREE::VM::CallVariadicOp> result; |
| |
| auto moduleOp = |
| importOp.getOperation()->getParentOfType<IREE::VM::ModuleOp>(); |
| |
| moduleOp.walk([&result, &importOp](Operation *op) { |
| if (auto callOp = dyn_cast<IREE::VM::CallVariadicOp>(op)) { |
| if (importOp == lookupSymbolRef<IREE::VM::ImportOp>( |
| callOp.getOperation(), "callee")) { |
| result.push_back(callOp); |
| } |
| } |
| }); |
| |
| return result; |
| } |
| |
| IREE::VM::EmitCTypeConverter &typeConverter; |
| SmallVector<std::string> &importShims; |
| }; |
| |
| template <typename OpTy> |
| class CallOpConversion : public EmitCConversionPattern<OpTy> { |
| using Adaptor = typename OpTy::Adaptor; |
| using EmitCConversionPattern<OpTy>::EmitCConversionPattern; |
| |
| LogicalResult |
| matchAndRewrite(OpTy op, Adaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| mlir::emitc::FuncOp funcOp = |
| lookupSymbolRef<mlir::emitc::FuncOp>(op.getOperation(), "callee"); |
| IREE::VM::ImportOp importOp = |
| lookupSymbolRef<IREE::VM::ImportOp>(op.getOperation(), "callee"); |
| |
| if (!funcOp && !importOp) |
| return op.emitError() << "lookup of callee failed"; |
| |
| if (funcOp && importOp) |
| return op.emitError() << "lookup of callee ambiguous"; |
| |
| const bool isImported = importOp != nullptr; |
| |
| return isImported ? rewriteImportedCall(op.getOperation(), adaptor, |
| rewriter, importOp) |
| : rewriteInternalCall(op.getOperation(), adaptor, |
| rewriter, funcOp); |
| } |
| |
| LogicalResult rewriteInternalCall(Operation *op, Adaptor adaptor, |
| ConversionPatternRewriter &rewriter, |
| mlir::emitc::FuncOp funcOp) const { |
| auto loc = op->getLoc(); |
| |
| SmallVector<Value> updatedOperands; |
| SmallVector<Value> resultOperands; |
| |
| auto parentFuncOp = op->getParentOfType<mlir::emitc::FuncOp>(); |
| |
| const BlockArgument stackArg = |
| parentFuncOp.getArgument(CCONV_ARGUMENT_STACK); |
| const BlockArgument moduleArg = |
| parentFuncOp.getArgument(CCONV_ARGUMENT_MODULE); |
| const BlockArgument moduleStateArg = |
| parentFuncOp.getArgument(CCONV_ARGUMENT_MODULE_STATE); |
| |
| updatedOperands = {stackArg, moduleArg, moduleStateArg}; |
| |
| if (failed(updateOperands(op, nullptr, rewriter, updatedOperands, |
| resultOperands))) { |
| return failure(); |
| }; |
| |
| returnIfError( |
| /*rewriter=*/rewriter, /*location=*/loc, /*callee=*/funcOp, |
| /*operands=*/updatedOperands, this->getModuleAnalysis()); |
| |
| emitc_builders::asRValues(rewriter, loc, resultOperands); |
| rewriter.replaceOp(op, resultOperands); |
| |
| return success(); |
| } |
| |
| LogicalResult rewriteImportedCall(Operation *op, Adaptor adaptor, |
| ConversionPatternRewriter &rewriter, |
| IREE::VM::ImportOp importOp) const { |
| auto ctx = op->getContext(); |
| auto loc = op->getLoc(); |
| |
| SmallVector<Value> updatedOperands; |
| SmallVector<Value> resultOperands; |
| |
| auto moduleOp = |
| importOp.getOperation()->getParentOfType<IREE::VM::ModuleOp>(); |
| |
| int importOrdinal = importOp.getOrdinal()->getZExtValue(); |
| |
| auto funcOp = op->getParentOfType<mlir::emitc::FuncOp>(); |
| |
| const BlockArgument stackArg = funcOp.getArgument(CCONV_ARGUMENT_STACK); |
| const BlockArgument stateArg = |
| funcOp.getArgument(CCONV_ARGUMENT_MODULE_STATE); |
| auto stateArgLValue = emitc_builders::asLValue(rewriter, loc, stateArg); |
| |
| auto imports = |
| cast<TypedValue<emitc::PointerType>>(emitc_builders::structPtrMember( |
| rewriter, loc, |
| /*type=*/ |
| emitc::PointerType::get( |
| emitc::OpaqueType::get(ctx, "iree_vm_function_t")), |
| /*memberName=*/"imports", |
| /*operand=*/stateArgLValue)); |
| |
| auto import = emitc_builders::arrayElementAddress( |
| rewriter, loc, /*index=*/importOrdinal, /*operand=*/imports); |
| |
| updatedOperands = {stackArg, import}; |
| |
| std::optional<std::string> funcName; |
| if (auto variadicOp = dyn_cast<IREE::VM::CallVariadicOp>(op)) { |
| funcName = buildVariadicFunctionName(moduleOp, importOp, |
| variadicOp.getSegmentSizes()); |
| } else { |
| funcName = buildFunctionName(moduleOp, importOp); |
| } |
| |
| if (failed(updateOperands(op, importOp, rewriter, updatedOperands, |
| resultOperands))) { |
| return failure(); |
| } |
| |
| if (!funcName.has_value()) |
| return op->emitError() << "Couldn't build name to imported function"; |
| |
| auto callee = moduleOp.lookupSymbol<mlir::emitc::FuncOp>(funcName.value()); |
| if (callee == nullptr) { |
| return op->emitError() |
| << "Couldn't find function with name `" << funcName.value() << "`"; |
| } |
| |
| returnIfError(rewriter, loc, callee, updatedOperands, |
| this->getModuleAnalysis()); |
| |
| emitc_builders::asRValues(rewriter, loc, resultOperands); |
| rewriter.replaceOp(op, resultOperands); |
| |
| return success(); |
| } |
| |
| LogicalResult updateOperands(Operation *op, IREE::VM::ImportOp importOp, |
| ConversionPatternRewriter &rewriter, |
| SmallVector<Value> &updatedOperands, |
| SmallVector<Value> &resultOperands) const { |
| auto loc = op->getLoc(); |
| |
| OperandRange operands = op->getOperands(); |
| |
| int operandIndex = 0; |
| int numInputs = |
| importOp ? importOp.getFunctionType().getNumInputs() : operands.size(); |
| for (int i = 0; i < numInputs; i++) { |
| if (importOp && importOp.isFuncArgumentVariadic(i)) { |
| auto variadicCallOp = cast<IREE::VM::CallVariadicOp>(op); |
| APInt segment = *(variadicCallOp.getSegmentSizes().begin() + i); |
| int64_t size = segment.getSExtValue(); |
| |
| Value segmentSize = emitc::ConstantOp::create( |
| rewriter, |
| /*location=*/loc, |
| /*resultType=*/rewriter.getI32Type(), |
| /*value=*/rewriter.getI32IntegerAttr(size)) |
| .getResult(); |
| updatedOperands.push_back(segmentSize); |
| |
| Type type = importOp.getFunctionType().getInput(i); |
| int numOps = isa<TupleType>(type) ? cast<TupleType>(type).size() : 1; |
| for (int i = 0; i < size; i++) { |
| for (int j = 0; j < numOps; j++) { |
| FailureOr<Value> updatedOperand = |
| updateOperand(operands[operandIndex], rewriter, loc); |
| if (failed(updatedOperand)) { |
| return failure(); |
| } |
| updatedOperands.push_back(updatedOperand.value()); |
| operandIndex++; |
| } |
| } |
| } else { |
| FailureOr<Value> updatedOperand = |
| updateOperand(operands[operandIndex], rewriter, loc); |
| if (failed(updatedOperand)) { |
| return failure(); |
| } |
| updatedOperands.push_back(updatedOperand.value()); |
| operandIndex++; |
| } |
| } |
| |
| // Create a variable for every result and a pointer to it as output |
| // parameter to the call. |
| for (OpResult result : op->getResults()) { |
| if (isa<IREE::VM::RefType>(result.getType())) { |
| Value ref = this->getModuleAnalysis().lookupRef(result); |
| |
| resultOperands.push_back(ref); |
| updatedOperands.push_back(ref); |
| } else { |
| auto resultValue = |
| emitc_builders::allocateVariable(rewriter, loc, result.getType()); |
| Value resultPtr = emitc_builders::addressOf(rewriter, loc, resultValue); |
| |
| resultOperands.push_back(resultValue); |
| updatedOperands.push_back(resultPtr); |
| } |
| } |
| return success(); |
| } |
| |
| FailureOr<Value> updateOperand(Value operand, OpBuilder &builder, |
| Location loc) const { |
| auto ctx = builder.getContext(); |
| |
| assert(operand.getType() != emitc::PointerType::get(emitc::OpaqueType::get( |
| ctx, "iree_vm_ref_t"))); |
| if (!isa<IREE::VM::RefType>(operand.getType())) { |
| return operand; |
| } |
| |
| Value operandRef = this->getModuleAnalysis().lookupRef(operand); |
| |
| auto [ref, refPtr] = emitc_builders::allocZeroInitializedVar( |
| builder, loc, emitc::OpaqueType::get(ctx, "iree_vm_ref_t")); |
| |
| emitc::CallOpaqueOp::create( |
| builder, |
| /*location=*/loc, |
| /*type=*/TypeRange{}, |
| /*callee=*/"iree_vm_ref_assign", |
| /*operands=*/ArrayRef<Value>{operandRef, refPtr}); |
| |
| return refPtr; |
| } |
| }; |
| |
| template <typename OpTy> |
| class CompareRefOpConversion : public EmitCConversionPattern<OpTy> { |
| using Adaptor = typename OpTy::Adaptor; |
| using EmitCConversionPattern<OpTy>::EmitCConversionPattern; |
| |
| public: |
| CompareRefOpConversion(const TypeConverter &typeConverter, |
| MLIRContext *context, StringRef funcName) |
| : EmitCConversionPattern<OpTy>(typeConverter, context), |
| funcName(funcName) {} |
| |
| private: |
| LogicalResult |
| matchAndRewrite(OpTy cmpOp, Adaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| auto ctx = cmpOp.getContext(); |
| auto loc = cmpOp.getLoc(); |
| |
| auto funcOp = |
| cmpOp.getOperation()->template getParentOfType<mlir::emitc::FuncOp>(); |
| |
| const auto typeConverter = |
| this->template getTypeConverter<IREE::VM::EmitCTypeConverter>(); |
| |
| auto analysis = this->getModuleAnalysis().lookupFunction(funcOp); |
| if (failed(analysis)) { |
| return cmpOp.emitError() << "parent func op not found in cache."; |
| } |
| |
| bool moveLhs = |
| analysis.value().get().isMove(cmpOp.getLhs(), cmpOp.getOperation()); |
| bool moveRhs = |
| analysis.value().get().isMove(cmpOp.getRhs(), cmpOp.getOperation()); |
| |
| Value refLhs = this->getModuleAnalysis().lookupRef(cmpOp.getLhs()); |
| Value refRhs = this->getModuleAnalysis().lookupRef(cmpOp.getRhs()); |
| |
| rewriter.replaceOpWithNewOp<emitc::CallOpaqueOp>( |
| /*op=*/cmpOp, |
| /*type=*/cmpOp.getType(), |
| /*callee=*/funcName, |
| /*operands=*/ArrayRef<Value>{refLhs, refRhs}); |
| |
| if (moveLhs) { |
| emitc_builders::ireeVmRefRelease(rewriter, loc, refLhs); |
| } |
| |
| // NOTE: If lhs and rhs alias we call release twice on the same |
| // argument. |
| if (moveRhs) { |
| emitc_builders::ireeVmRefRelease(rewriter, loc, refRhs); |
| } |
| |
| return success(); |
| } |
| |
| StringRef funcName; |
| }; |
| |
| class CompareRefNotZeroOpConversion |
| : public EmitCConversionPattern<IREE::VM::CmpNZRefOp> { |
| using Adaptor = IREE::VM::CmpNZRefOp::Adaptor; |
| using EmitCConversionPattern<IREE::VM::CmpNZRefOp>::EmitCConversionPattern; |
| |
| LogicalResult |
| matchAndRewrite(IREE::VM::CmpNZRefOp cmpOp, Adaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| auto loc = cmpOp.getLoc(); |
| |
| auto funcOp = cmpOp.getOperation()->getParentOfType<mlir::emitc::FuncOp>(); |
| |
| auto &funcAnalysis = getModuleAnalysis().lookupFunction(funcOp); |
| |
| bool move = funcAnalysis.isMove(cmpOp.getOperand(), cmpOp.getOperation()); |
| |
| Value ref = getModuleAnalysis().lookupRef(cmpOp.getOperand()); |
| |
| rewriter.replaceOpWithNewOp<emitc::CallOpaqueOp>( |
| /*op=*/cmpOp, |
| /*type=*/cmpOp.getType(), |
| /*callee=*/"vm_cmp_nz_ref", |
| /*operands=*/ArrayRef<Value>{ref}); |
| |
| if (move) { |
| emitc_builders::ireeVmRefRelease(rewriter, loc, ref); |
| } |
| |
| return success(); |
| } |
| }; |
| |
| class SelectRefOpConversion |
| : public EmitCConversionPattern<IREE::VM::SelectRefOp> { |
| using Adaptor = typename IREE::VM::SelectRefOp::Adaptor; |
| using EmitCConversionPattern<IREE::VM::SelectRefOp>::EmitCConversionPattern; |
| |
| LogicalResult |
| matchAndRewrite(IREE::VM::SelectRefOp selectOp, Adaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| auto ctx = selectOp.getContext(); |
| auto loc = selectOp.getLoc(); |
| |
| auto moduleOp = |
| selectOp.getOperation()->template getParentOfType<IREE::VM::ModuleOp>(); |
| auto funcOp = selectOp.getOperation() |
| ->template getParentOfType<mlir::emitc::FuncOp>(); |
| auto &funcAnalysis = getModuleAnalysis().lookupFunction(funcOp); |
| |
| const BlockArgument moduleArg = funcOp.getArgument(CCONV_ARGUMENT_MODULE); |
| |
| auto resultTypePtr = |
| createVmTypeDefPtr(rewriter, loc, this->getModuleAnalysis(), moduleOp, |
| moduleArg, selectOp.getType()); |
| if (!resultTypePtr.has_value()) { |
| return selectOp->emitError() << "generating iree_vm_type_def_t* failed"; |
| } |
| auto resultTypeAsRef = |
| emitc::CallOpaqueOp::create( |
| rewriter, |
| /*location=*/loc, |
| /*type=*/emitc::OpaqueType::get(ctx, "iree_vm_ref_type_t"), |
| /*callee=*/"iree_vm_type_def_as_ref", |
| /*operands=*/ArrayRef<Value>{resultTypePtr.value()}) |
| .getResult(0); |
| |
| bool moveTrue = |
| funcAnalysis.isMove(selectOp.getTrueValue(), selectOp.getOperation()); |
| bool moveFalse = |
| funcAnalysis.isMove(selectOp.getFalseValue(), selectOp.getOperation()); |
| |
| Value refTrue = |
| this->getModuleAnalysis().lookupRef(selectOp.getTrueValue()); |
| Value refFalse = |
| this->getModuleAnalysis().lookupRef(selectOp.getFalseValue()); |
| Value refResult = this->getModuleAnalysis().lookupRef(selectOp.getResult()); |
| |
| Type boolType = rewriter.getI1Type(); |
| auto condition = IREE::VM::CmpNZI32Op::create( |
| rewriter, loc, rewriter.getI32Type(), selectOp.getCondition()); |
| auto conditionI1 = emitc::CastOp::create(rewriter, |
| /*location=*/loc, |
| /*type=*/boolType, |
| /*operand=*/condition.getResult()); |
| |
| auto *continueBlock = |
| rewriter.splitBlock(selectOp->getBlock(), Block::iterator(selectOp)); |
| |
| Block *trueBlock = nullptr; |
| { |
| OpBuilder::InsertionGuard guard(rewriter); |
| trueBlock = rewriter.createBlock(continueBlock); |
| returnIfError( |
| /*rewriter=*/rewriter, |
| /*location=*/loc, |
| /*callee=*/"iree_vm_ref_retain_or_move_checked", |
| /*args=*/ |
| ArrayAttr::get( |
| ctx, {rewriter.getBoolAttr(moveTrue), rewriter.getIndexAttr(0), |
| rewriter.getIndexAttr(1), rewriter.getIndexAttr(2)}), |
| /*operands=*/ArrayRef<Value>{refTrue, resultTypeAsRef, refResult}, |
| this->getModuleAnalysis()); |
| IREE::VM::BranchOp::create(rewriter, loc, continueBlock); |
| } |
| |
| Block *falseBlock = nullptr; |
| { |
| OpBuilder::InsertionGuard guard(rewriter); |
| falseBlock = rewriter.createBlock(continueBlock); |
| returnIfError( |
| /*rewriter=*/rewriter, |
| /*location=*/loc, |
| /*callee=*/"iree_vm_ref_retain_or_move_checked", |
| /*args=*/ |
| ArrayAttr::get( |
| ctx, {rewriter.getBoolAttr(moveFalse), rewriter.getIndexAttr(0), |
| rewriter.getIndexAttr(1), rewriter.getIndexAttr(2)}), |
| /*operands=*/ArrayRef<Value>{refFalse, resultTypeAsRef, refResult}, |
| this->getModuleAnalysis()); |
| IREE::VM::BranchOp::create(rewriter, loc, continueBlock); |
| } |
| |
| rewriter.setInsertionPointAfterValue(conditionI1); |
| mlir::cf::CondBranchOp::create(rewriter, loc, conditionI1.getResult(), |
| trueBlock, falseBlock); |
| rewriter.replaceOp(selectOp, refResult); |
| |
| return success(); |
| } |
| }; |
| |
| template <typename OpTy> |
| class ConstOpConversion : public EmitCConversionPattern<OpTy> { |
| using Adaptor = typename OpTy::Adaptor; |
| using EmitCConversionPattern<OpTy>::EmitCConversionPattern; |
| |
| LogicalResult |
| matchAndRewrite(OpTy constOp, Adaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| rewriter.replaceOpWithNewOp<emitc::ConstantOp>(constOp, constOp.getType(), |
| constOp.getValue()); |
| return success(); |
| } |
| }; |
| |
| template <typename OpTy> |
| class ConstZeroOpConversion : public EmitCConversionPattern<OpTy> { |
| using Adaptor = typename OpTy::Adaptor; |
| using EmitCConversionPattern<OpTy>::EmitCConversionPattern; |
| |
| LogicalResult |
| matchAndRewrite(OpTy constZeroOp, Adaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| auto type = constZeroOp.getType(); |
| |
| Attribute value = rewriter.getZeroAttr(type); |
| |
| rewriter.replaceOpWithNewOp<emitc::ConstantOp>(constZeroOp, type, value); |
| return success(); |
| } |
| }; |
| |
| class ConstRefZeroOpConversion |
| : public EmitCConversionPattern<IREE::VM::ConstRefZeroOp> { |
| using Adaptor = IREE::VM::ConstRefZeroOp::Adaptor; |
| using EmitCConversionPattern< |
| IREE::VM::ConstRefZeroOp>::EmitCConversionPattern; |
| |
| LogicalResult |
| matchAndRewrite(IREE::VM::ConstRefZeroOp constRefZeroOp, Adaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| auto loc = constRefZeroOp.getLoc(); |
| |
| Value ref = getModuleAnalysis().lookupRef(constRefZeroOp.getResult()); |
| |
| emitc_builders::ireeVmRefRelease(rewriter, loc, ref); |
| |
| rewriter.replaceOp(constRefZeroOp, ref); |
| |
| return success(); |
| } |
| }; |
| |
| class ConstRefRodataOpConversion |
| : public EmitCConversionPattern<IREE::VM::ConstRefRodataOp> { |
| using Adaptor = IREE::VM::ConstRefRodataOp::Adaptor; |
| using EmitCConversionPattern< |
| IREE::VM::ConstRefRodataOp>::EmitCConversionPattern; |
| |
| LogicalResult |
| matchAndRewrite(IREE::VM::ConstRefRodataOp constRefRodataOp, Adaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| auto ctx = constRefRodataOp.getContext(); |
| auto loc = constRefRodataOp.getLoc(); |
| |
| auto rodataOp = lookupSymbolRef<IREE::VM::RodataOp>( |
| constRefRodataOp.getOperation(), "rodata"); |
| if (!rodataOp) { |
| return constRefRodataOp.emitError() << "Unable to find RodataOp"; |
| } |
| |
| auto funcOp = |
| constRefRodataOp.getOperation()->getParentOfType<mlir::emitc::FuncOp>(); |
| |
| const BlockArgument stateArg = |
| funcOp.getArgument(CCONV_ARGUMENT_MODULE_STATE); |
| auto stateArgLValue = emitc_builders::asLValue(rewriter, loc, stateArg); |
| |
| auto rodataBuffersPtr = |
| cast<TypedValue<emitc::PointerType>>(emitc_builders::structPtrMember( |
| rewriter, loc, |
| /*type=*/ |
| emitc::PointerType::get( |
| emitc::OpaqueType::get(ctx, "iree_vm_buffer_t")), |
| /*memberName=*/"rodata_buffers", |
| /*operand=*/stateArgLValue)); |
| |
| auto byteBufferPtrOp = emitc_builders::arrayElementAddress( |
| rewriter, loc, /*index=*/rodataOp.getOrdinal()->getZExtValue(), |
| /*operand=*/rodataBuffersPtr); |
| |
| auto typeIdOp = emitc::CallOpaqueOp::create( |
| rewriter, |
| /*location=*/loc, |
| /*type=*/emitc::OpaqueType::get(ctx, "iree_vm_ref_type_t"), |
| /*callee=*/"iree_vm_buffer_type", |
| /*operands=*/ArrayRef<Value>{}); |
| |
| Value ref = getModuleAnalysis().lookupRef(constRefRodataOp.getResult()); |
| |
| returnIfError( |
| /*rewriter=*/rewriter, |
| /*location=*/loc, |
| /*callee=*/"iree_vm_ref_wrap_retain", |
| /*args=*/ArrayAttr{}, |
| /*operands=*/ |
| ArrayRef<Value>{byteBufferPtrOp, typeIdOp.getResult(0), ref}, |
| getModuleAnalysis()); |
| |
| rewriter.replaceOp(constRefRodataOp, ref); |
| |
| return success(); |
| } |
| }; |
| |
| class BranchOpConversion : public EmitCConversionPattern<IREE::VM::BranchOp> { |
| using Adaptor = IREE::VM::BranchOp::Adaptor; |
| using EmitCConversionPattern<IREE::VM::BranchOp>::EmitCConversionPattern; |
| |
| LogicalResult |
| matchAndRewrite(IREE::VM::BranchOp op, Adaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| auto loc = op.getLoc(); |
| |
| assert(op.getOperands().size() == adaptor.getOperands().size()); |
| |
| auto isNotRefOperand = [](Value operand) { |
| return !isa<IREE::VM::RefType>(operand.getType()); |
| }; |
| |
| SmallVector<Value> nonRefOperands; |
| for (Value operand : op.getOperands()) { |
| if (isNotRefOperand(operand)) { |
| nonRefOperands.push_back(operand); |
| } |
| } |
| |
| Block *dest = op.getDest(); |
| |
| // If we don't have ref block arguments, we can convert the operation |
| // directly. |
| if (adaptor.getOperands().size() == nonRefOperands.size()) { |
| rewriter.replaceOpWithNewOp<mlir::cf::BranchOp>(op, op.getDest(), |
| op.getOperands()); |
| return success(); |
| } |
| |
| Block *destDispatch; |
| { |
| auto funcOp = |
| op.getOperation()->template getParentOfType<mlir::emitc::FuncOp>(); |
| auto &funcAnalysis = getModuleAnalysis().lookupFunction(funcOp); |
| auto &signatureConversion = funcAnalysis.lookupBlockConversion(dest); |
| |
| OpBuilder::InsertionGuard guard(rewriter); |
| destDispatch = rewriter.createBlock(dest); |
| |
| IRMapping refMapping; |
| for (auto [index, operand] : llvm::enumerate(op.getOperands())) { |
| if (isNotRefOperand(operand)) { |
| continue; |
| } |
| |
| Value blockArgRef = |
| signatureConversion.getInputMapping(index)->replacementValues[0]; |
| |
| assert(isa<IREE::VM::RefType>(operand.getType())); |
| assert(isa<emitc::PointerType>(blockArgRef.getType())); |
| |
| Value operandRef = getModuleAnalysis().lookupRef(operand); |
| |
| refMapping.map(operandRef, blockArgRef); |
| } |
| if (failed(retainOrMoveRefs(rewriter, loc, refMapping, |
| /*isMove=*/false))) { |
| return op.emitError() << "moving of multiple refs failed"; |
| } |
| mlir::cf::BranchOp::create(rewriter, loc, op.getDest(), nonRefOperands); |
| } |
| |
| rewriter.replaceOpWithNewOp<mlir::cf::BranchOp>(op, destDispatch); |
| |
| return success(); |
| } |
| }; |
| |
| // Basic block arguments are emitted as variable assignments in EmitC. Because |
| // of that we need to treat ref operands separately here. We remove ref |
| // arguments from the basic blocks and use the ref C API to set the ref |
| // variables. The generated IR looks roughly as follows: |
| |
| // clang-format off |
| // vm.cond_br %cond, ^bb1(%ref : !vm.ref<?>, %int : i32), ^bb2(%ref : !vm.ref<?>, %int : i32) |
| // ^bb1(%ref_arg_1 : !vm.ref<?>, %int_arg : i32): |
| // ... |
| // ^bb2(%ref_arg_2 : !vm.ref<?>, %int_arg : i32): |
| // ... |
| // => |
| // cond_br %cond, ^bb1_dispatch, ^bb2_dispatch |
| // ^bb1_dispatch: |
| // // populate the variable corresponding to ordinal(%ref_arg_1) |
| // br ^bb1(%int : i32) |
| // ^bb2_dispatch: |
| // // populate the variable corresponding to ordinal(%ref_arg_2) |
| // br ^bb2(%int : i32) |
| // ^bb1(%int_arg : i32): |
| // ... |
| // ^bb2(%int_arg : i32): |
| // ... |
| // clang-format on |
| class CondBranchOpConversion |
| : public EmitCConversionPattern<IREE::VM::CondBranchOp> { |
| using Adaptor = IREE::VM::CondBranchOp::Adaptor; |
| using EmitCConversionPattern<IREE::VM::CondBranchOp>::EmitCConversionPattern; |
| |
| LogicalResult |
| matchAndRewrite(IREE::VM::CondBranchOp op, Adaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| auto loc = op.getLoc(); |
| |
| assert(op.getOperands().size() == adaptor.getOperands().size()); |
| |
| auto isNotRefOperand = [](Value operand) { |
| return !isa<IREE::VM::RefType>(operand.getType()); |
| }; |
| |
| SmallVector<Value> nonRefOperands; |
| for (Value operand : op.getOperands()) { |
| if (isNotRefOperand(operand)) { |
| nonRefOperands.push_back(operand); |
| } |
| } |
| |
| Block *trueDest = op.getTrueDest(); |
| Block *falseDest = op.getFalseDest(); |
| |
| Type boolType = rewriter.getI1Type(); |
| |
| auto condition = IREE::VM::CmpNZI32Op::create( |
| rewriter, loc, rewriter.getI32Type(), op.getCondition()); |
| auto conditionI1 = emitc::CastOp::create(rewriter, |
| /*location=*/loc, |
| /*type=*/boolType, |
| /*operand=*/condition.getResult()); |
| |
| // If we don't have ref block arguments, we can convert the operation |
| // directly. |
| if (adaptor.getOperands().size() == nonRefOperands.size()) { |
| rewriter.replaceOpWithNewOp<mlir::cf::CondBranchOp>( |
| op, conditionI1.getResult(), op.getTrueDest(), op.getTrueOperands(), |
| op.getFalseDest(), op.getFalseOperands()); |
| return success(); |
| } |
| |
| Block *trueDestDispatch; |
| { |
| OpBuilder::InsertionGuard guard(rewriter); |
| trueDestDispatch = rewriter.createBlock(trueDest); |
| |
| // Let the BranchOpConversion handle ref block arguments. |
| IREE::VM::BranchOp::create(rewriter, loc, op.getTrueDest(), |
| op.getTrueOperands()); |
| } |
| |
| Block *falseDestDispatch; |
| { |
| OpBuilder::InsertionGuard guard(rewriter); |
| falseDestDispatch = rewriter.createBlock(falseDest); |
| |
| // Let the BranchOpConversion handle ref block arguments. |
| IREE::VM::BranchOp::create(rewriter, loc, op.getFalseDest(), |
| op.getFalseOperands()); |
| } |
| |
| rewriter.replaceOpWithNewOp<mlir::cf::CondBranchOp>( |
| op, conditionI1.getResult(), trueDestDispatch, falseDestDispatch); |
| |
| return success(); |
| } |
| }; |
| |
| // EmitC does not support cf.switch so we turn the branch table into a long |
| // sequence of conditional branches. |
| class BranchTableOpConversion |
| : public EmitCConversionPattern<IREE::VM::BranchTableOp> { |
| using Adaptor = IREE::VM::BranchTableOp::Adaptor; |
| using EmitCConversionPattern<IREE::VM::BranchTableOp>::EmitCConversionPattern; |
| |
| LogicalResult |
| matchAndRewrite(IREE::VM::BranchTableOp op, Adaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| auto caseDestinations = op.getCaseDestinations(); |
| SmallVector<Block *> caseBlocks; |
| { |
| OpBuilder::InsertionGuard guard(rewriter); |
| auto *nextBlock = rewriter.getInsertionBlock()->getNextNode(); |
| for (size_t i = 0; i < caseDestinations.size(); ++i) |
| caseBlocks.push_back(rewriter.createBlock(nextBlock)); |
| caseBlocks.push_back(rewriter.createBlock(nextBlock)); // default |
| } |
| IREE::VM::BranchOp::create(rewriter, op.getLoc(), caseBlocks.front()); |
| for (size_t i = 0; i < caseDestinations.size(); ++i) { |
| rewriter.setInsertionPointToStart(caseBlocks[i]); |
| Value cmp = IREE::VM::CmpEQI32Op::create( |
| rewriter, op.getLoc(), rewriter.getI32Type(), adaptor.getIndex(), |
| IREE::VM::ConstI32Op::create(rewriter, op.getLoc(), i)); |
| auto caseOperands = adaptor.getCaseOperands(); |
| IREE::VM::CondBranchOp::create(rewriter, op.getLoc(), cmp, |
| caseDestinations[i], caseOperands[i], |
| caseBlocks[i + 1], ValueRange{}); |
| } |
| rewriter.setInsertionPointToStart(caseBlocks.back()); |
| rewriter.replaceOpWithNewOp<IREE::VM::BranchOp>( |
| op, op.getDefaultDestination(), adaptor.getDefaultOperands()); |
| return success(); |
| } |
| }; |
| |
| class ReturnOpConversion : public EmitCConversionPattern<IREE::VM::ReturnOp> { |
| using Adaptor = IREE::VM::ReturnOp::Adaptor; |
| using EmitCConversionPattern<IREE::VM::ReturnOp>::EmitCConversionPattern; |
| |
| LogicalResult |
| matchAndRewrite(IREE::VM::ReturnOp op, Adaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| auto loc = op.getLoc(); |
| |
| auto funcOp = op.getOperation()->getParentOfType<mlir::emitc::FuncOp>(); |
| |
| // The result variables are the last N arguments of the function. |
| unsigned int firstOutputArgumentIndex = |
| funcOp.getNumArguments() - op.getOperands().size(); |
| |
| IRMapping refMapping; |
| for (const auto &pair : llvm::enumerate(op.getOperands())) { |
| Value operand = pair.value(); |
| size_t index = pair.index(); |
| |
| unsigned int argumentIndex = firstOutputArgumentIndex + index; |
| BlockArgument resultArgument = funcOp.getArgument(argumentIndex); |
| |
| if (isa<IREE::VM::RefType>(operand.getType())) { |
| assert(operand.getType() != |
| emitc::PointerType::get( |
| emitc::OpaqueType::get(op.getContext(), "iree_vm_ref_t"))); |
| |
| Value operandRef = getModuleAnalysis().lookupRef(operand); |
| |
| refMapping.map(operandRef, resultArgument); |
| } else { |
| emitc::CallOpaqueOp::create( |
| rewriter, |
| /*location=*/loc, |
| /*type=*/TypeRange{}, |
| /*callee=*/"EMITC_DEREF_ASSIGN_VALUE", |
| /*operands=*/ArrayRef<Value>{resultArgument, operand}); |
| } |
| } |
| |
| if (failed(retainOrMoveRefs(rewriter, loc, refMapping, /*isMove=*/true))) { |
| return op.emitError() << "moving of multiple refs failed"; |
| } |
| releaseRefs(rewriter, loc, funcOp, getModuleAnalysis()); |
| |
| auto status = emitc_builders::ireeOkStatus(rewriter, loc); |
| |
| rewriter.replaceOpWithNewOp<mlir::emitc::ReturnOp>(op, status); |
| |
| return success(); |
| } |
| }; |
| |
| class ImportResolvedOpConversion |
| : public EmitCConversionPattern<IREE::VM::ImportResolvedOp> { |
| using Adaptor = IREE::VM::ImportResolvedOp::Adaptor; |
| using EmitCConversionPattern< |
| IREE::VM::ImportResolvedOp>::EmitCConversionPattern; |
| |
| private: |
| LogicalResult |
| matchAndRewrite(IREE::VM::ImportResolvedOp op, Adaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| auto ctx = op.getContext(); |
| auto loc = op.getLoc(); |
| |
| IREE::VM::ImportOp importOp = |
| lookupSymbolRef<IREE::VM::ImportOp>(op.getOperation(), "import"); |
| int importOrdinal = importOp.getOrdinal()->getZExtValue(); |
| |
| auto funcOp = op->getParentOfType<mlir::emitc::FuncOp>(); |
| |
| const BlockArgument stateArg = |
| funcOp.getArgument(CCONV_ARGUMENT_MODULE_STATE); |
| auto stateArgLValue = emitc_builders::asLValue(rewriter, loc, stateArg); |
| |
| auto imports = |
| cast<TypedValue<emitc::PointerType>>(emitc_builders::structPtrMember( |
| rewriter, loc, |
| /*type=*/ |
| emitc::PointerType::get( |
| emitc::OpaqueType::get(ctx, "iree_vm_function_t")), |
| /*memberName=*/"imports", |
| /*operand=*/stateArgLValue)); |
| |
| auto import = emitc_builders::arrayElementAddress( |
| rewriter, loc, /*index=*/importOrdinal, /*operand=*/imports); |
| auto importLValue = emitc_builders::asLValue(rewriter, loc, import); |
| |
| // (iree_vm_function_t*)->module |
| auto importModule = emitc_builders::structPtrMember( |
| rewriter, loc, |
| /*type=*/ |
| emitc::PointerType::get( |
| emitc::OpaqueType::get(ctx, "iree_vm_module_t")), |
| /*memberName=*/"module", |
| /*operand=*/importLValue); |
| |
| Type boolType = rewriter.getIntegerType(1); |
| auto conditionI1 = emitc::LogicalNotOp::create(rewriter, |
| /*location=*/loc, |
| /*type=*/boolType, |
| /*operands=*/importModule) |
| .getResult(); |
| auto invConditionI1 = emitc::LogicalNotOp::create(rewriter, |
| /*location=*/loc, |
| /*type=*/boolType, |
| /*operands=*/conditionI1) |
| .getResult(); |
| |
| auto i32Type = rewriter.getIntegerType(32); |
| auto conditionI32 = emitc::CastOp::create(rewriter, |
| /*location=*/loc, |
| /*type=*/i32Type, |
| /*operand=*/invConditionI1); |
| |
| rewriter.replaceOp(op, {conditionI32.getResult()}); |
| |
| return success(); |
| } |
| }; |
| |
| class FailOpConversion : public EmitCConversionPattern<IREE::VM::FailOp> { |
| using Adaptor = IREE::VM::FailOp::Adaptor; |
| using EmitCConversionPattern<IREE::VM::FailOp>::EmitCConversionPattern; |
| |
| LogicalResult |
| matchAndRewrite(IREE::VM::FailOp op, Adaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| auto ctx = op.getContext(); |
| auto loc = op.getLoc(); |
| |
| Block *block = rewriter.getInsertionBlock(); |
| Region *parentRegion = block->getParent(); |
| Block *passthroughBlock; |
| { |
| OpBuilder::InsertionGuard guard(rewriter); |
| passthroughBlock = |
| rewriter.createBlock(parentRegion, parentRegion->end()); |
| |
| auto funcOp = op.getOperation()->getParentOfType<mlir::emitc::FuncOp>(); |
| |
| releaseRefs(rewriter, loc, funcOp, getModuleAnalysis()); |
| |
| auto status = emitc_builders::ireeOkStatus(rewriter, loc); |
| |
| mlir::emitc::ReturnOp::create(rewriter, loc, status); |
| } |
| Block *failureBlock; |
| { |
| OpBuilder::InsertionGuard guard(rewriter); |
| failureBlock = rewriter.createBlock(parentRegion, parentRegion->end()); |
| |
| auto funcOp = op.getOperation()->getParentOfType<mlir::emitc::FuncOp>(); |
| |
| releaseRefs(rewriter, loc, funcOp, getModuleAnalysis()); |
| |
| Value message = emitc_builders::ireeMakeCstringView( |
| rewriter, loc, op.getMessage().value_or("").str()); |
| auto messageLValue = emitc_builders::asLValue(rewriter, loc, message); |
| Type type = emitc::OpaqueType::get(ctx, "iree_host_size_t"); |
| Value messageSize = |
| emitc_builders::structMember(rewriter, loc, |
| /*type=*/type, |
| /*memberName=*/"size", |
| /*operand=*/messageLValue); |
| |
| auto messageSizeIntOp = |
| emitc::CastOp::create(rewriter, |
| /*location=*/loc, |
| /*type=*/emitc::OpaqueType::get(ctx, "int"), |
| /*operand=*/messageSize); |
| |
| Type charPtr = |
| emitc::PointerType::get(emitc::OpaqueType::get(ctx, "const char")); |
| auto messageDataOp = |
| emitc_builders::structMember(rewriter, loc, |
| /*type=*/charPtr, |
| /*memberName=*/"data", |
| /*operand=*/messageLValue); |
| |
| auto status = emitc::CallOpaqueOp::create( |
| rewriter, |
| /*location=*/loc, |
| /*type=*/emitc::OpaqueType::get(ctx, "iree_status_t"), |
| /*callee=*/"iree_status_allocate_f", |
| /*operands=*/ |
| ArrayRef<Value>{messageSizeIntOp.getResult(), messageDataOp}, |
| /*args=*/ |
| ArrayAttr::get( |
| ctx, |
| {emitc::OpaqueAttr::get(ctx, "IREE_STATUS_FAILED_PRECONDITION"), |
| emitc::OpaqueAttr::get(ctx, "\"<vm>\""), |
| rewriter.getI32IntegerAttr(0), |
| emitc::OpaqueAttr::get(ctx, "\"%.*s\""), |
| rewriter.getIndexAttr(0), rewriter.getIndexAttr(1)})); |
| |
| mlir::emitc::ReturnOp::create(rewriter, loc, status.getResult(0)); |
| } |
| |
| Type boolType = rewriter.getIntegerType(1); |
| auto condition = emitc::CastOp::create(rewriter, |
| /*location=*/loc, |
| /*type=*/boolType, |
| /*operand=*/op.getStatus()); |
| |
| rewriter.replaceOpWithNewOp<mlir::cf::CondBranchOp>( |
| op, condition.getResult(), failureBlock, passthroughBlock); |
| |
| return success(); |
| } |
| }; |
| |
| template <typename OpTy, typename GlobalOpTy> |
| class GlobalLoadOpConversion : public EmitCConversionPattern<OpTy> { |
| using Adaptor = typename OpTy::Adaptor; |
| using EmitCConversionPattern<OpTy>::EmitCConversionPattern; |
| |
| public: |
| GlobalLoadOpConversion(const TypeConverter &typeConverter, |
| MLIRContext *context, StringRef funcName) |
| : EmitCConversionPattern<OpTy>(typeConverter, context), |
| funcName(funcName) {} |
| |
| private: |
| LogicalResult |
| matchAndRewrite(OpTy loadOp, Adaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| auto loc = loadOp.getLoc(); |
| |
| GlobalOpTy globalOp = |
| lookupSymbolRef<GlobalOpTy>(loadOp.getOperation(), "global"); |
| if (!globalOp) { |
| return loadOp.emitError() << "Unable to find GlobalOp"; |
| } |
| |
| auto funcOp = |
| loadOp.getOperation()->template getParentOfType<mlir::emitc::FuncOp>(); |
| |
| const BlockArgument stateArg = |
| funcOp.getArgument(CCONV_ARGUMENT_MODULE_STATE); |
| auto stateArgLValue = emitc_builders::asLValue(rewriter, loc, stateArg); |
| |
| auto rwDataPtr = emitc_builders::structPtrMember( |
| rewriter, loc, |
| /*type=*/emitc::PointerType::get(rewriter.getIntegerType(8, false)), |
| /*memberName=*/"rwdata", |
| /*operand=*/stateArgLValue); |
| |
| rewriter.replaceOpWithNewOp<emitc::CallOpaqueOp>( |
| /*op=*/loadOp, |
| /*type=*/loadOp.getOperation()->getResultTypes(), |
| /*callee=*/funcName, |
| /*operands=*/ArrayRef<Value>{rwDataPtr}, |
| /*args=*/ |
| rewriter.getArrayAttr( |
| {rewriter.getIndexAttr(0), |
| rewriter.getUI32IntegerAttr(static_cast<uint32_t>( |
| globalOp.getOrdinal()->getZExtValue()))})); |
| |
| return success(); |
| } |
| |
| StringRef funcName; |
| }; |
| |
| template <typename OpTy> |
| class GlobalLoadStoreRefOpConversion : public EmitCConversionPattern<OpTy> { |
| using Adaptor = typename OpTy::Adaptor; |
| using EmitCConversionPattern<OpTy>::EmitCConversionPattern; |
| |
| LogicalResult |
| matchAndRewrite(OpTy op, Adaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| if (isa<IREE::VM::GlobalLoadRefOp>(op)) { |
| return rewriteOp(op.getOperation(), adaptor, rewriter, true); |
| } else if (isa<IREE::VM::GlobalStoreRefOp>(op)) { |
| return rewriteOp(op.getOperation(), adaptor, rewriter, false); |
| } |
| |
| return op.emitError() << "op must be one of `vm.global.load.ref` or " |
| "`vm.global.store.ref`"; |
| } |
| |
| LogicalResult rewriteOp(Operation *op, Adaptor adaptor, |
| ConversionPatternRewriter &rewriter, |
| bool isLoad) const { |
| auto ctx = op->getContext(); |
| auto loc = op->getLoc(); |
| |
| IREE::VM::GlobalRefOp globalOp = |
| lookupSymbolRef<IREE::VM::GlobalRefOp>(op, "global"); |
| if (!globalOp) { |
| return op->emitError() << "Unable to find GlobalOp"; |
| } |
| |
| auto globalOrdinal = globalOp.getOrdinal()->getZExtValue(); |
| |
| auto funcOp = op->getParentOfType<mlir::emitc::FuncOp>(); |
| |
| auto &funcAnalysis = this->getModuleAnalysis().lookupFunction(funcOp); |
| |
| Value localValue = isLoad ? op->getResult(0) : op->getOperand(0); |
| |
| Value localRef = this->getModuleAnalysis().lookupRef(localValue); |
| |
| const BlockArgument stateArg = |
| funcOp.getArgument(CCONV_ARGUMENT_MODULE_STATE); |
| auto stateArgLValue = emitc_builders::asLValue(rewriter, loc, stateArg); |
| |
| auto refs = |
| cast<TypedValue<emitc::PointerType>>(emitc_builders::structPtrMember( |
| rewriter, loc, |
| /*type=*/ |
| emitc::PointerType::get( |
| emitc::OpaqueType::get(ctx, "iree_vm_ref_t")), |
| /*memberName=*/"refs", |
| /*operand=*/stateArgLValue)); |
| |
| auto stateRef = emitc_builders::arrayElementAddress( |
| rewriter, loc, /*index=*/globalOrdinal, /*operand=*/refs); |
| |
| auto moduleOp = op->getParentOfType<IREE::VM::ModuleOp>(); |
| auto parentFuncOp = op->getParentOfType<mlir::emitc::FuncOp>(); |
| const BlockArgument moduleArg = |
| parentFuncOp.getArgument(CCONV_ARGUMENT_MODULE); |
| |
| Type elementType = localValue.getType(); |
| |
| auto elementTypePtr = |
| createVmTypeDefPtr(rewriter, op->getLoc(), this->getModuleAnalysis(), |
| moduleOp, moduleArg, elementType); |
| |
| if (!elementTypePtr.has_value()) { |
| return op->emitError() << "generating iree_vm_type_def_t* failed"; |
| } |
| |
| auto typedefAsRef = |
| emitc::CallOpaqueOp::create( |
| rewriter, |
| /*location=*/loc, |
| /*type=*/emitc::OpaqueType::get(ctx, "iree_vm_ref_type_t"), |
| /*callee=*/"iree_vm_type_def_as_ref", |
| /*operands=*/ArrayRef<Value>{elementTypePtr.value()}) |
| .getResult(0); |
| |
| Value srcRef = isLoad ? stateRef : localRef; |
| Value destRef = isLoad ? localRef : stateRef; |
| |
| bool move = funcAnalysis.isMove(localValue, op); |
| |
| returnIfError( |
| /*rewriter=*/rewriter, |
| /*location=*/loc, |
| /*callee=*/"iree_vm_ref_retain_or_move_checked", |
| /*args=*/ |
| ArrayAttr::get(ctx, |
| {rewriter.getBoolAttr(move), rewriter.getIndexAttr(0), |
| rewriter.getIndexAttr(1), rewriter.getIndexAttr(2)}), |
| /*operands=*/ArrayRef<Value>{srcRef, typedefAsRef, destRef}, |
| this->getModuleAnalysis()); |
| |
| if (isLoad) { |
| rewriter.replaceOp(op, localRef); |
| } else { |
| rewriter.eraseOp(op); |
| } |
| |
| return success(); |
| } |
| }; |
| |
| template <typename OpTy, typename GlobalOpTy> |
| class GlobalStoreOpConversion : public EmitCConversionPattern<OpTy> { |
| using Adaptor = typename OpTy::Adaptor; |
| using EmitCConversionPattern<OpTy>::EmitCConversionPattern; |
| |
| public: |
| GlobalStoreOpConversion(const TypeConverter &typeConverter, |
| MLIRContext *context, StringRef funcName) |
| : EmitCConversionPattern<OpTy>(typeConverter, context), |
| funcName(funcName) {} |
| |
| private: |
| LogicalResult |
| matchAndRewrite(OpTy storeOp, Adaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| auto loc = storeOp.getLoc(); |
| |
| GlobalOpTy globalOp = |
| lookupSymbolRef<GlobalOpTy>(storeOp.getOperation(), "global"); |
| if (!globalOp) { |
| return storeOp.emitError() << "Unable to find GlobalOp"; |
| } |
| |
| auto funcOp = |
| storeOp.getOperation()->template getParentOfType<mlir::emitc::FuncOp>(); |
| |
| const BlockArgument stateArg = |
| funcOp.getArgument(CCONV_ARGUMENT_MODULE_STATE); |
| auto stateArgLValue = emitc_builders::asLValue(rewriter, loc, stateArg); |
| |
| auto rwDataPtr = emitc_builders::structPtrMember( |
| rewriter, loc, |
| /*type=*/emitc::PointerType::get(rewriter.getIntegerType(8, false)), |
| /*memberName=*/"rwdata", /*operand=*/stateArgLValue); |
| |
| rewriter.replaceOpWithNewOp<emitc::CallOpaqueOp>( |
| /*op=*/storeOp, |
| /*type=*/storeOp.getOperation()->getResultTypes(), |
| /*callee=*/funcName, |
| /*operands=*/ArrayRef<Value>{rwDataPtr, storeOp.getValue()}, |
| /*args=*/ |
| rewriter.getArrayAttr( |
| {rewriter.getIndexAttr(0), |
| rewriter.getUI32IntegerAttr( |
| static_cast<uint32_t>(globalOp.getOrdinal()->getZExtValue())), |
| rewriter.getIndexAttr(1)})); |
| |
| return success(); |
| } |
| |
| StringRef funcName; |
| }; |
| |
| // Convert vm operations with wrapped containers to multiple emitc opaque_calls. |
| // The wrapping ref pointers are first dereferenced and the results are used as |
| // the arguments of the specified function name. |
| template <typename OpTy> |
| class ContainerOpConversion : public EmitCConversionPattern<OpTy> { |
| using Adaptor = typename OpTy::Adaptor; |
| using EmitCConversionPattern<OpTy>::EmitCConversionPattern; |
| |
| public: |
| ContainerOpConversion(const TypeConverter &typeConverter, |
| MLIRContext *context, StringRef funcName, |
| DenseSet<size_t> refArgumentIndices, bool failable) |
| : EmitCConversionPattern<OpTy>(typeConverter, context), |
| funcName(funcName), refArgumentIndices(refArgumentIndices), |
| failable(failable) {} |
| |
| private: |
| LogicalResult |
| matchAndRewrite(OpTy op, Adaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| auto ctx = op.getContext(); |
| auto loc = op.getLoc(); |
| |
| SmallVector<Value> unwrappedOperands; |
| for (const auto &operand : llvm::enumerate(adaptor.getOperands())) { |
| if (refArgumentIndices.contains(operand.index())) { |
| Type originalType = |
| op.getOperation()->getOperand(operand.index()).getType(); |
| |
| assert(isa<IREE::VM::RefType>(originalType) && "expected ref type"); |
| |
| Type objectType = cast<IREE::VM::RefType>(originalType).getObjectType(); |
| |
| std::optional<std::pair<StringRef, StringRef>> vmNames = |
| TypeSwitch<Type, std::optional<std::pair<StringRef, StringRef>>>( |
| objectType) |
| .Case<IREE::VM::ListType>([&](auto t) { |
| return std::make_pair(StringRef("iree_vm_list_t"), |
| StringRef("iree_vm_list_deref")); |
| }) |
| .template Case<IREE::VM::BufferType>([&](auto t) { |
| return std::make_pair(StringRef("iree_vm_buffer_t"), |
| StringRef("iree_vm_buffer_deref")); |
| }) |
| .Default([](Type) { return std::nullopt; }); |
| |
| if (!vmNames.has_value()) { |
| return op.emitOpError() << "object type not handled"; |
| } |
| |
| StringRef vmType = std::get<0>(vmNames.value()); |
| StringRef vmDerefCallee = std::get<1>(vmNames.value()); |
| |
| Value refValue = |
| emitc_builders::contentsOf(rewriter, loc, operand.value()); |
| auto derefOp = failContainerNull( |
| /*rewriter=*/rewriter, |
| /*location=*/loc, |
| /*type=*/ |
| emitc::PointerType::get(emitc::OpaqueType::get(ctx, vmType)), |
| /*callee=*/vmDerefCallee, |
| /*args=*/ArrayAttr{}, |
| /*operands=*/ArrayRef<Value>{refValue}, this->getModuleAnalysis()); |
| unwrappedOperands.push_back(derefOp.getResult(0)); |
| } else { |
| unwrappedOperands.push_back(operand.value()); |
| } |
| } |
| |
| SmallVector<Value> resultOperands; |
| if (failed(patchOperands(op, adaptor, rewriter, unwrappedOperands, |
| resultOperands))) { |
| return op.emitError() << "failed to patch operands"; |
| } |
| |
| if (failable) { |
| returnIfError( |
| /*rewriter=*/rewriter, |
| /*location=*/loc, |
| /*callee=*/funcName, |
| /*args=*/ArrayAttr{}, |
| /*operands=*/ArrayRef<Value>(unwrappedOperands), |
| this->getModuleAnalysis()); |
| |
| emitc_builders::asRValues(rewriter, loc, resultOperands); |
| rewriter.replaceOp(op, resultOperands); |
| } else { |
| rewriter.replaceOpWithNewOp<emitc::CallOpaqueOp>( |
| /*op=*/op, |
| /*type=*/op.getOperation()->getResultTypes(), |
| /*callee=*/funcName, |
| /*operands=*/ArrayRef<Value>(unwrappedOperands)); |
| } |
| |
| return success(); |
| } |
| |
| LogicalResult patchOperands(Operation *op, Adaptor adaptor, |
| ConversionPatternRewriter &rewriter, |
| SmallVector<Value> &operands, |
| SmallVector<Value> &results) const { |
| if (failable) { |
| if (failed(createOutOperands(op, rewriter, operands, results))) { |
| return failure(); |
| } |
| } |
| return success(); |
| } |
| |
| LogicalResult createOutOperands(Operation *op, |
| ConversionPatternRewriter &rewriter, |
| SmallVector<Value> &operands, |
| SmallVector<Value> &results) const { |
| auto loc = op->getLoc(); |
| |
| for (OpResult result : op->getResults()) { |
| if (isa<IREE::VM::RefType>(result.getType())) { |
| Value ref = this->getModuleAnalysis().lookupRef(result); |
| |
| results.push_back(ref); |
| operands.push_back(ref); |
| } else { |
| Type type = result.getType(); |
| auto resultValue = emitc_builders::allocateVariable( |
| rewriter, loc, type, rewriter.getZeroAttr(type)); |
| Value resultPtr = emitc_builders::addressOf(rewriter, loc, resultValue); |
| |
| results.push_back(resultValue); |
| operands.push_back(resultPtr); |
| } |
| } |
| return success(); |
| } |
| |
| StringRef funcName; |
| |
| // The indices of the wrapped arguments. |
| DenseSet<size_t> refArgumentIndices; |
| |
| // Whether the function call can fail, i.e. it returns an iree_status_t. |
| bool failable; |
| }; |
| |
| template <typename OpTy> |
| class ContainerAllocOpConversion : public EmitCConversionPattern<OpTy> { |
| using Adaptor = typename OpTy::Adaptor; |
| using EmitCConversionPattern<OpTy>::EmitCConversionPattern; |
| |
| // Bundle function and type names based on the container type |
| struct CNames { |
| std::string type; |
| std::string typeId; |
| std::string constructor; |
| }; |
| |
| LogicalResult |
| matchAndRewrite(OpTy op, Adaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| auto ctx = op.getContext(); |
| auto loc = op.getLoc(); |
| |
| Type objectType = cast<IREE::VM::RefType>(op.getType()).getObjectType(); |
| std::optional<Type> elementType = extractElementType(ctx, objectType); |
| std::optional<CNames> cNames = extractCNames(op); |
| |
| if (!elementType.has_value() || !cNames.has_value()) { |
| return op.emitError() << "unknown container type"; |
| } |
| |
| auto container = emitc_builders::allocateVariable( |
| rewriter, loc, |
| emitc::PointerType::get( |
| emitc::OpaqueType::get(ctx, cNames.value().type)), |
| {"NULL"}); |
| |
| Value containerPtr = emitc_builders::addressOf(rewriter, loc, container); |
| |
| auto funcOp = |
| op.getOperation()->template getParentOfType<mlir::emitc::FuncOp>(); |
| |
| const BlockArgument stateArg = |
| funcOp.getArgument(CCONV_ARGUMENT_MODULE_STATE); |
| auto stateArgLValue = emitc_builders::asLValue(rewriter, loc, stateArg); |
| |
| auto allocator = emitc_builders::structPtrMember( |
| rewriter, loc, |
| /*type=*/emitc::OpaqueType::get(ctx, "iree_allocator_t"), |
| /*memberName=*/"allocator", |
| /*operand=*/stateArgLValue); |
| |
| std::optional<SmallVector<Value>> operands = getOperands( |
| op, adaptor, rewriter, elementType.value(), containerPtr, allocator); |
| |
| if (!operands.has_value()) { |
| return op.emitError() << "failed to build operands"; |
| } |
| |
| returnIfError( |
| /*rewriter=*/rewriter, |
| /*location=*/loc, |
| /*callee=*/cNames.value().constructor, |
| /*args=*/ArrayAttr{}, |
| /*operands=*/ArrayRef<Value>{operands.value()}, |
| this->getModuleAnalysis()); |
| |
| Value ref = this->getModuleAnalysis().lookupRef(op.getResult()); |
| |
| Value refType = |
| emitc::CallOpaqueOp::create( |
| rewriter, |
| /*location=*/loc, |
| /*type=*/emitc::OpaqueType::get(ctx, "iree_vm_ref_type_t"), |
| /*callee=*/cNames.value().typeId, |
| /*operands=*/ArrayRef<Value>{}) |
| .getResult(0); |
| |
| auto containerRValue = emitc_builders::asRValue(rewriter, loc, container); |
| returnIfError( |
| /*rewriter=*/rewriter, |
| /*location=*/loc, |
| /*callee=*/"iree_vm_ref_wrap_assign", |
| /*args=*/ArrayAttr{}, |
| /*operands=*/ArrayRef<Value>{containerRValue, refType, ref}, |
| this->getModuleAnalysis()); |
| |
| rewriter.replaceOp(op, ref); |
| |
| return success(); |
| } |
| |
| std::optional<Type> extractElementType(MLIRContext *ctx, Type t) const { |
| if (auto listType = dyn_cast<IREE::VM::ListType>(t)) { |
| return listType.getElementType(); |
| } else if (auto bufferType = dyn_cast<IREE::VM::BufferType>(t)) { |
| return NoneType::get(ctx); |
| } |
| return std::nullopt; |
| } |
| |
| std::optional<CNames> extractCNames(OpTy op) const { |
| if (isa<IREE::VM::ListAllocOp>(op)) { |
| return CNames{"iree_vm_list_t", "iree_vm_list_type", |
| "iree_vm_list_create"}; |
| } else if (isa<IREE::VM::BufferAllocOp>(op)) { |
| return CNames{"iree_vm_buffer_t", "iree_vm_buffer_type", |
| "iree_vm_buffer_create"}; |
| } else if (isa<IREE::VM::BufferCloneOp>(op)) { |
| return CNames{"iree_vm_buffer_t", "iree_vm_buffer_type", |
| "iree_vm_buffer_clone"}; |
| } |
| return std::nullopt; |
| } |
| |
| std::optional<SmallVector<Value>> |
| getOperands(IREE::VM::ListAllocOp op, Adaptor adaptor, |
| ConversionPatternRewriter &rewriter, Type elementType, |
| Value containerPtr, Value allocator) const { |
| |
| auto moduleOp = op.getOperation()->getParentOfType<IREE::VM::ModuleOp>(); |
| auto parentFuncOp = |
| op.getOperation()->getParentOfType<mlir::emitc::FuncOp>(); |
| const BlockArgument moduleArg = |
| parentFuncOp.getArgument(CCONV_ARGUMENT_MODULE); |
| auto elementTypePtr = |
| createVmTypeDefPtr(rewriter, op.getLoc(), this->getModuleAnalysis(), |
| moduleOp, moduleArg, elementType); |
| |
| if (!elementTypePtr.has_value()) { |
| return std::nullopt; |
| } |
| |
| Value capacity = adaptor.getOperands()[0]; |
| SmallVector<Value> result = {elementTypePtr.value(), capacity, allocator, |
| containerPtr}; |
| return result; |
| } |
| |
| std::optional<SmallVector<Value>> |
| getOperands(IREE::VM::BufferAllocOp op, Adaptor adaptor, |
| ConversionPatternRewriter &rewriter, Type elementType, |
| Value containerPtr, Value allocator) const { |
| auto ctx = op.getContext(); |
| auto loc = op.getLoc(); |
| Value access = |
| emitc::ConstantOp::create( |
| rewriter, |
| /*location=*/loc, |
| /*resultType=*/ |
| emitc::OpaqueType::get(ctx, "iree_vm_buffer_access_t"), |
| /*value=*/ |
| emitc::OpaqueAttr::get(ctx, "IREE_VM_BUFFER_ACCESS_MUTABLE | " |
| "IREE_VM_BUFFER_ACCESS_ORIGIN_GUEST")) |
| .getResult(); |
| Value length = adaptor.getOperands()[0]; |
| Value alignment = adaptor.getOperands()[1]; |
| |
| SmallVector<Value> result = {access, length, alignment, allocator, |
| containerPtr}; |
| return result; |
| } |
| |
| std::optional<SmallVector<Value>> |
| getOperands(IREE::VM::BufferCloneOp op, Adaptor adaptor, |
| ConversionPatternRewriter &rewriter, Type elementType, |
| Value containerPtr, Value allocator) const { |
| auto ctx = op.getContext(); |
| auto loc = op.getLoc(); |
| |
| Value access = |
| emitc::ConstantOp::create( |
| rewriter, |
| /*location=*/loc, |
| /*resultType=*/ |
| emitc::OpaqueType::get(ctx, "iree_vm_buffer_access_t"), |
| /*value=*/ |
| emitc::OpaqueAttr::get(ctx, "IREE_VM_BUFFER_ACCESS_MUTABLE | " |
| "IREE_VM_BUFFER_ACCESS_ORIGIN_GUEST")) |
| .getResult(); |
| |
| Value refPtr = adaptor.getOperands()[0]; |
| Value refValue = emitc_builders::contentsOf(rewriter, loc, refPtr); |
| |
| Value source = |
| failContainerNull( |
| /*rewriter=*/rewriter, |
| /*location=*/loc, |
| /*type=*/ |
| emitc::PointerType::get( |
| emitc::OpaqueType::get(ctx, "iree_vm_buffer_t")), |
| /*callee=*/"iree_vm_buffer_deref", |
| /*args=*/ArrayAttr{}, |
| /*operands=*/ArrayRef<Value>{refValue}, this->getModuleAnalysis()) |
| .getResult(0); |
| |
| Value offset = adaptor.getOperands()[1]; |
| Value length = adaptor.getOperands()[2]; |
| Value alignment = adaptor.getOperands()[3]; |
| |
| SmallVector<Value> result = {access, source, offset, length, |
| alignment, allocator, containerPtr}; |
| return result; |
| } |
| }; |
| |
| template <typename OpTy> |
| class ListGetOpConversion : public EmitCConversionPattern<OpTy> { |
| using Adaptor = typename OpTy::Adaptor; |
| using EmitCConversionPattern<OpTy>::EmitCConversionPattern; |
| |
| LogicalResult |
| matchAndRewrite(OpTy getOp, Adaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| auto ctx = getOp.getContext(); |
| auto loc = getOp.getLoc(); |
| |
| std::optional<StringRef> valueTypeEnum; |
| std::optional<StringRef> valueExtractor; |
| |
| std::tie(valueTypeEnum, valueExtractor) = |
| TypeSwitch<Operation *, std::pair<std::optional<StringRef>, |
| std::optional<StringRef>>>( |
| getOp.getOperation()) |
| .Case<IREE::VM::ListGetI32Op>([&](auto op) { |
| return std::make_pair(StringRef("IREE_VM_VALUE_TYPE_I32"), |
| StringRef("iree_vm_value_get_i32")); |
| }) |
| .template Case<IREE::VM::ListGetI64Op>([&](auto op) { |
| return std::make_pair(StringRef("IREE_VM_VALUE_TYPE_I64"), |
| StringRef("iree_vm_value_get_i64")); |
| }) |
| .Default([](Operation *) { |
| return std::make_pair(std::nullopt, std::nullopt); |
| }); |
| |
| if (!valueTypeEnum.has_value() || !valueExtractor.has_value()) { |
| return getOp.emitOpError() << "element type not handled"; |
| } |
| |
| auto value = emitc_builders::allocateVariable( |
| rewriter, loc, emitc::OpaqueType::get(ctx, "iree_vm_value_t")); |
| |
| Value valuePtr = emitc_builders::addressOf(rewriter, loc, value); |
| |
| Value refValue = |
| emitc_builders::contentsOf(rewriter, loc, adaptor.getOperands()[0]); |
| |
| auto listDerefOp = failContainerNull( |
| /*rewriter=*/rewriter, |
| /*location=*/loc, |
| /*type=*/ |
| emitc::PointerType::get(emitc::OpaqueType::get(ctx, "iree_vm_list_t")), |
| /*callee=*/"iree_vm_list_deref", |
| /*args=*/ArrayAttr{}, |
| /*operands=*/ArrayRef<Value>{refValue}, this->getModuleAnalysis()); |
| |
| returnIfError( |
| /*rewriter=*/rewriter, |
| /*location=*/loc, |
| /*callee=*/"iree_vm_list_get_value_as", |
| /*args=*/ |
| ArrayAttr::get(ctx, {rewriter.getIndexAttr(0), rewriter.getIndexAttr(1), |
| emitc::OpaqueAttr::get(ctx, valueTypeEnum.value()), |
| rewriter.getIndexAttr(2)}), |
| /*operands=*/ |
| ArrayRef<Value>{listDerefOp.getResult(0), getOp.getIndex(), valuePtr}, |
| this->getModuleAnalysis()); |
| |
| rewriter.replaceOpWithNewOp<emitc::CallOpaqueOp>( |
| /*op=*/getOp, |
| /*type=*/getOp.getType(), |
| /*callee=*/valueExtractor.value(), |
| /*operands=*/ArrayRef<Value>{valuePtr}); |
| |
| return success(); |
| } |
| }; |
| |
| class ListGetRefOpConversion |
| : public EmitCConversionPattern<IREE::VM::ListGetRefOp> { |
| using Adaptor = IREE::VM::ListGetRefOp::Adaptor; |
| using EmitCConversionPattern<IREE::VM::ListGetRefOp>::EmitCConversionPattern; |
| |
| LogicalResult |
| matchAndRewrite(IREE::VM::ListGetRefOp getOp, Adaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| auto ctx = getOp.getContext(); |
| auto loc = getOp.getLoc(); |
| |
| Value listRefValue = |
| emitc_builders::contentsOf(rewriter, loc, adaptor.getOperands()[0]); |
| |
| auto listDerefOp = failContainerNull( |
| /*rewriter=*/rewriter, |
| /*location=*/loc, |
| /*type=*/ |
| emitc::PointerType::get(emitc::OpaqueType::get(ctx, "iree_vm_list_t")), |
| /*callee=*/"iree_vm_list_deref", |
| /*args=*/ArrayAttr{}, |
| /*operands=*/ArrayRef<Value>{listRefValue}, getModuleAnalysis()); |
| |
| Value ref = getModuleAnalysis().lookupRef(getOp.getResult()); |
| |
| returnIfError( |
| /*rewriter=*/rewriter, |
| /*location=*/loc, |
| /*callee=*/"iree_vm_list_get_ref_retain", |
| /*args=*/ArrayAttr{}, |
| /*operands=*/ |
| ArrayRef<Value>{listDerefOp.getResult(0), getOp.getIndex(), ref}, |
| getModuleAnalysis()); |
| |
| auto moduleOp = getOp.getOperation()->getParentOfType<IREE::VM::ModuleOp>(); |
| auto parentFuncOp = |
| getOp.getOperation()->getParentOfType<mlir::emitc::FuncOp>(); |
| const BlockArgument moduleArg = |
| parentFuncOp.getArgument(CCONV_ARGUMENT_MODULE); |
| |
| Type elementType = getOp.getResult().getType(); |
| |
| auto elementTypePtr = |
| createVmTypeDefPtr(rewriter, getOp.getLoc(), getModuleAnalysis(), |
| moduleOp, moduleArg, elementType); |
| |
| if (!elementTypePtr.has_value()) { |
| return getOp.emitError() << "generating iree_vm_type_def_t* failed"; |
| } |
| |
| // Build the following expression: |
| // (ref->type != IREE_VM_REF_TYPE_NULL && |
| // (iree_vm_type_def_is_value(type_def) || ref->type != |
| // iree_vm_type_def_as_ref(type_def))) |
| Value invalidType; |
| { |
| auto refLValue = emitc_builders::asLValue(rewriter, loc, ref); |
| // ref->type |
| auto refType = emitc_builders::structPtrMember( |
| rewriter, loc, |
| /*type=*/emitc::OpaqueType::get(ctx, "iree_vm_ref_type_t"), |
| /*memberName=*/"type", |
| /*operand=*/refLValue); |
| |
| // IREE_VM_REF_TYPE_NULL |
| auto refTypeNull = |
| emitc::ConstantOp::create( |
| rewriter, |
| /*location=*/loc, |
| /*resultType=*/ |
| emitc::OpaqueType::get(ctx, "iree_vm_ref_type_t"), |
| /*value=*/ |
| emitc::OpaqueAttr::get(ctx, "IREE_VM_REF_TYPE_NULL")) |
| .getResult(); |
| |
| // ref->type != IREE_VM_REF_TYPE_NULL |
| auto refTypeIsNotNull = |
| emitc::CmpOp::create(rewriter, |
| /*location=*/loc, |
| /*type=*/rewriter.getI1Type(), |
| /*predicate*/ emitc::CmpPredicate::ne, |
| /*lhs*/ refType, |
| /*rhs*/ refTypeNull) |
| .getResult(); |
| |
| // (iree_vm_type_def_is_value(type_def) |
| auto typedefIsValue = |
| emitc::CallOpaqueOp::create( |
| rewriter, |
| /*location=*/loc, |
| /*type=*/rewriter.getI1Type(), |
| /*callee=*/"iree_vm_type_def_is_value", |
| /*operands=*/ArrayRef<Value>{elementTypePtr.value()}) |
| .getResult(0); |
| |
| // iree_vm_type_def_as_ref(type_def) |
| auto typedefAsRef = |
| emitc::CallOpaqueOp::create( |
| rewriter, |
| /*location=*/loc, |
| /*type=*/emitc::OpaqueType::get(ctx, "iree_vm_ref_type_t"), |
| /*callee=*/"iree_vm_type_def_as_ref", |
| /*operands=*/ArrayRef<Value>{elementTypePtr.value()}) |
| .getResult(0); |
| |
| // ref->type != iree_vm_type_def_as_ref(type_def) |
| auto refTypesDontMatch = |
| emitc::CmpOp::create(rewriter, |
| /*location=*/loc, |
| /*type=*/rewriter.getI1Type(), |
| /*predicate*/ emitc::CmpPredicate::ne, |
| /*lhs*/ refType, |
| /*rhs*/ typedefAsRef) |
| .getResult(); |
| |
| auto invalidRefType = |
| emitc::LogicalOrOp::create(rewriter, |
| /*location=*/loc, |
| /*type=*/rewriter.getI1Type(), |
| /*lhs*/ typedefIsValue, |
| /*rhs*/ refTypesDontMatch) |
| .getResult(); |
| |
| invalidType = emitc::LogicalAndOp::create(rewriter, |
| /*location=*/loc, |
| /*type=*/rewriter.getI1Type(), |
| /*lhs*/ refTypeIsNotNull, |
| /*rhs*/ invalidRefType) |
| .getResult(); |
| } |
| |
| // Start by splitting the block into two. The part before will contain |
| // the condition, and the part after will contain the continuation |
| // point. |
| Block *condBlock = rewriter.getInsertionBlock(); |
| Block::iterator opPosition = rewriter.getInsertionPoint(); |
| Block *continuationBlock = condBlock->splitBlock(opPosition); |
| |
| // Create a new block for the target of the failure. |
| Block *failureBlock; |
| { |
| OpBuilder::InsertionGuard guard(rewriter); |
| Region *parentRegion = condBlock->getParent(); |
| failureBlock = rewriter.createBlock(parentRegion, parentRegion->end()); |
| |
| emitc_builders::ireeVmRefRelease(rewriter, loc, ref); |
| |
| mlir::cf::BranchOp::create(rewriter, loc, continuationBlock); |
| } |
| |
| rewriter.setInsertionPointToEnd(condBlock); |
| cf::CondBranchOp::create(rewriter, loc, invalidType, failureBlock, |
| continuationBlock); |
| |
| rewriter.replaceOp(getOp, ref); |
| |
| return success(); |
| } |
| }; |
| |
| template <typename OpTy> |
| class ListSetOpConversion : public EmitCConversionPattern<OpTy> { |
| using Adaptor = typename OpTy::Adaptor; |
| using EmitCConversionPattern<OpTy>::EmitCConversionPattern; |
| |
| LogicalResult |
| matchAndRewrite(OpTy setOp, Adaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| auto ctx = setOp.getContext(); |
| auto loc = setOp.getLoc(); |
| |
| std::optional<StringRef> valueConstructor = |
| TypeSwitch<Operation *, std::optional<StringRef>>(setOp.getOperation()) |
| .Case<IREE::VM::ListSetI32Op>( |
| [&](auto op) { return StringRef("iree_vm_value_make_i32"); }) |
| .template Case<IREE::VM::ListSetI64Op>( |
| [&](auto op) { return StringRef("iree_vm_value_make_i64"); }) |
| .Default([](Operation *) { return std::nullopt; }); |
| |
| if (!valueConstructor.has_value()) { |
| return setOp.emitOpError() << " not handled"; |
| } |
| |
| auto valueOp = emitc::CallOpaqueOp::create( |
| rewriter, |
| /*location=*/loc, |
| /*type=*/emitc::OpaqueType::get(ctx, "iree_vm_value_t"), |
| /*callee=*/valueConstructor.value(), |
| /*operands=*/ArrayRef<Value>{setOp.getValue()}); |
| |
| auto value = emitc_builders::asLValue(rewriter, loc, valueOp.getResult(0)); |
| |
| Value valuePtr = emitc_builders::addressOf(rewriter, loc, value); |
| |
| Value refValue = |
| emitc_builders::contentsOf(rewriter, loc, adaptor.getOperands()[0]); |
| |
| auto listDerefOp = failContainerNull( |
| /*rewriter=*/rewriter, |
| /*location=*/loc, |
| /*type=*/ |
| emitc::PointerType::get(emitc::OpaqueType::get(ctx, "iree_vm_list_t")), |
| /*callee=*/"iree_vm_list_deref", |
| /*args=*/ArrayAttr{}, |
| /*operands=*/ArrayRef<Value>{refValue}, this->getModuleAnalysis()); |
| |
| returnIfError( |
| /*rewriter=*/rewriter, |
| /*location=*/loc, |
| /*callee=*/"iree_vm_list_set_value", |
| /*args=*/ArrayAttr{}, |
| /*operands=*/ |
| ArrayRef<Value>{listDerefOp.getResult(0), setOp.getIndex(), valuePtr}, |
| this->getModuleAnalysis()); |
| |
| rewriter.eraseOp(setOp); |
| |
| return success(); |
| } |
| }; |
| |
| class ListSetRefOpConversion |
| : public EmitCConversionPattern<IREE::VM::ListSetRefOp> { |
| using Adaptor = IREE::VM::ListSetRefOp::Adaptor; |
| using EmitCConversionPattern<IREE::VM::ListSetRefOp>::EmitCConversionPattern; |
| |
| LogicalResult |
| matchAndRewrite(IREE::VM::ListSetRefOp setOp, Adaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| auto ctx = setOp.getContext(); |
| auto loc = setOp.getLoc(); |
| |
| Value refValue = |
| emitc_builders::contentsOf(rewriter, loc, adaptor.getOperands()[0]); |
| |
| auto listDerefOp = failContainerNull( |
| /*rewriter=*/rewriter, |
| /*location=*/loc, |
| /*type=*/ |
| emitc::PointerType::get(emitc::OpaqueType::get(ctx, "iree_vm_list_t")), |
| /*callee=*/"iree_vm_list_deref", |
| /*args=*/ArrayAttr{}, |
| /*operands=*/ArrayRef<Value>{refValue}, getModuleAnalysis()); |
| |
| auto funcOp = setOp.getOperation()->getParentOfType<mlir::emitc::FuncOp>(); |
| |
| auto &funcAnalysis = getModuleAnalysis().lookupFunction(funcOp); |
| |
| bool move = funcAnalysis.isMove(setOp.getValue(), setOp.getOperation()); |
| |
| StringRef callee = |
| move ? "iree_vm_list_set_ref_move" : "iree_vm_list_set_ref_retain"; |
| |
| returnIfError( |
| /*rewriter=*/rewriter, |
| /*location=*/loc, |
| /*callee=*/callee, |
| /*args=*/ArrayAttr{}, |
| /*operands=*/ |
| ArrayRef<Value>{listDerefOp.getResult(0), setOp.getIndex(), |
| adaptor.getValue()}, |
| getModuleAnalysis()); |
| |
| rewriter.eraseOp(setOp); |
| |
| return success(); |
| } |
| }; |
| } // namespace |
| |
| void populateVMToEmitCPatterns(ConversionTarget &conversionTarget, |
| IREE::VM::EmitCTypeConverter &typeConverter, |
| RewritePatternSet &patterns) { |
| auto context = patterns.getContext(); |
| populateUtilConversionPatterns(context, conversionTarget, typeConverter, |
| patterns); |
| |
| // Patterns |
| #define ADD_GENERIC_PATTERN(Op, FuncName) \ |
| patterns.add<GenericOpConversion<Op>>(typeConverter, context, FuncName) |
| #define ADD_CONTAINER_PATTERN(Op, FuncName, IndexSet, Failable) \ |
| patterns.add<ContainerOpConversion<Op>>(typeConverter, context, FuncName, \ |
| IndexSet, Failable); |
| #define ADD_GLOBAL_LOAD_PATTERN(Op, GlobalOp, FuncName) \ |
| patterns.add<GlobalLoadOpConversion<Op, GlobalOp>>(typeConverter, context, \ |
| FuncName); |
| #define ADD_GLOBAL_STORE_PATTERN(Op, GlobalOp, FuncName) \ |
| patterns.add<GlobalStoreOpConversion<Op, GlobalOp>>(typeConverter, context, \ |
| FuncName); |
| |
| // argument free patterns |
| // clang-format off |
| patterns.add< |
| BranchOpConversion, |
| CallOpConversion<IREE::VM::CallOp>, |
| CallOpConversion<IREE::VM::CallVariadicOp>, |
| CompareRefNotZeroOpConversion, |
| SelectRefOpConversion, |
| CondBranchOpConversion, |
| BranchTableOpConversion, |
| ConstOpConversion<IREE::VM::ConstF32Op>, |
| ConstOpConversion<IREE::VM::ConstF64Op>, |
| ConstOpConversion<IREE::VM::ConstI32Op>, |
| ConstOpConversion<IREE::VM::ConstI64Op>, |
| ConstRefRodataOpConversion, |
| ConstRefZeroOpConversion, |
| ConstZeroOpConversion<IREE::VM::ConstF32ZeroOp>, |
| ConstZeroOpConversion<IREE::VM::ConstF64ZeroOp>, |
| ConstZeroOpConversion<IREE::VM::ConstI32ZeroOp>, |
| ConstZeroOpConversion<IREE::VM::ConstI64ZeroOp>, |
| ContainerAllocOpConversion<IREE::VM::BufferAllocOp>, |
| ContainerAllocOpConversion<IREE::VM::BufferCloneOp>, |
| ContainerAllocOpConversion<IREE::VM::ListAllocOp>, |
| DeleteOpConversion<IREE::VM::GlobalF32Op>, |
| DeleteOpConversion<IREE::VM::GlobalI32Op>, |
| DeleteOpConversion<IREE::VM::GlobalI64Op>, |
| DeleteOpConversion<IREE::VM::GlobalRefOp>, |
| ExportOpConversion, |
| FailOpConversion, |
| FuncOpConversion, |
| GlobalLoadStoreRefOpConversion<IREE::VM::GlobalLoadRefOp>, |
| GlobalLoadStoreRefOpConversion<IREE::VM::GlobalStoreRefOp>, |
| ImportResolvedOpConversion, |
| ListGetOpConversion<IREE::VM::ListGetI32Op>, |
| ListGetOpConversion<IREE::VM::ListGetI64Op>, |
| ListGetRefOpConversion, |
| ListSetOpConversion<IREE::VM::ListSetI32Op>, |
| ListSetOpConversion<IREE::VM::ListSetI64Op>, |
| ListSetRefOpConversion, |
| ReturnOpConversion |
| >(typeConverter, context); |
| // clang-format on |
| |
| // generic conversions |
| ADD_GENERIC_PATTERN(IREE::VM::AbsF32Op, "vm_abs_f32"); |
| ADD_GENERIC_PATTERN(IREE::VM::AbsI32Op, "vm_abs_i32"); |
| ADD_GENERIC_PATTERN(IREE::VM::AbsI64Op, "vm_abs_i64"); |
| ADD_GENERIC_PATTERN(IREE::VM::AddF32Op, "vm_add_f32"); |
| ADD_GENERIC_PATTERN(IREE::VM::AddI32Op, "vm_add_i32"); |
| ADD_GENERIC_PATTERN(IREE::VM::AddI64Op, "vm_add_i64"); |
| ADD_GENERIC_PATTERN(IREE::VM::AndI32Op, "vm_and_i32"); |
| ADD_GENERIC_PATTERN(IREE::VM::AndI64Op, "vm_and_i64"); |
| ADD_GENERIC_PATTERN(IREE::VM::Atan2F32Op, "vm_atan2_f32"); |
| ADD_GENERIC_PATTERN(IREE::VM::AtanF32Op, "vm_atan_f32"); |
| ADD_GENERIC_PATTERN(IREE::VM::BitcastF32I32Op, "vm_bitcast_f32i32"); |
| ADD_GENERIC_PATTERN(IREE::VM::BitcastI32F32Op, "vm_bitcast_i32f32"); |
| ADD_GENERIC_PATTERN(IREE::VM::CastF32SI32Op, "vm_cast_f32si32"); |
| ADD_GENERIC_PATTERN(IREE::VM::CastF32SI64Op, "vm_cast_f32si64"); |
| ADD_GENERIC_PATTERN(IREE::VM::CastF32UI32Op, "vm_cast_f32ui32"); |
| ADD_GENERIC_PATTERN(IREE::VM::CastF32UI64Op, "vm_cast_f32ui64"); |
| ADD_GENERIC_PATTERN(IREE::VM::CastSI32F32Op, "vm_cast_si32f32"); |
| ADD_GENERIC_PATTERN(IREE::VM::CastSI64F32Op, "vm_cast_si64f32"); |
| ADD_GENERIC_PATTERN(IREE::VM::CastUI32F32Op, "vm_cast_ui32f32"); |
| ADD_GENERIC_PATTERN(IREE::VM::CastUI64F32Op, "vm_cast_ui64f32"); |
| ADD_GENERIC_PATTERN(IREE::VM::CeilF32Op, "vm_ceil_f32"); |
| ADD_GENERIC_PATTERN(IREE::VM::CmpEQF32OOp, "vm_cmp_eq_f32o"); |
| ADD_GENERIC_PATTERN(IREE::VM::CmpEQF32UOp, "vm_cmp_eq_f32u"); |
| ADD_GENERIC_PATTERN(IREE::VM::CmpEQI32Op, "vm_cmp_eq_i32"); |
| ADD_GENERIC_PATTERN(IREE::VM::CmpEQI64Op, "vm_cmp_eq_i64"); |
| ADD_GENERIC_PATTERN(IREE::VM::CmpEQRefOp, "vm_cmp_eq_ref"); |
| ADD_GENERIC_PATTERN(IREE::VM::CmpLTEF32OOp, "vm_cmp_lte_f32o"); |
| ADD_GENERIC_PATTERN(IREE::VM::CmpLTEF32UOp, "vm_cmp_lte_f32u"); |
| ADD_GENERIC_PATTERN(IREE::VM::CmpLTF32OOp, "vm_cmp_lt_f32o"); |
| ADD_GENERIC_PATTERN(IREE::VM::CmpLTF32UOp, "vm_cmp_lt_f32u"); |
| ADD_GENERIC_PATTERN(IREE::VM::CmpLTI32SOp, "vm_cmp_lt_i32s"); |
| ADD_GENERIC_PATTERN(IREE::VM::CmpLTI32UOp, "vm_cmp_lt_i32u"); |
| ADD_GENERIC_PATTERN(IREE::VM::CmpLTI64SOp, "vm_cmp_lt_i64s"); |
| ADD_GENERIC_PATTERN(IREE::VM::CmpLTI64UOp, "vm_cmp_lt_i64u"); |
| ADD_GENERIC_PATTERN(IREE::VM::CmpNaNF32Op, "vm_cmp_nan_f32"); |
| ADD_GENERIC_PATTERN(IREE::VM::CmpNEF32OOp, "vm_cmp_ne_f32o"); |
| ADD_GENERIC_PATTERN(IREE::VM::CmpNEF32UOp, "vm_cmp_ne_f32u"); |
| ADD_GENERIC_PATTERN(IREE::VM::CmpNEI32Op, "vm_cmp_ne_i32"); |
| ADD_GENERIC_PATTERN(IREE::VM::CmpNEI64Op, "vm_cmp_ne_i64"); |
| ADD_GENERIC_PATTERN(IREE::VM::CmpNERefOp, "vm_cmp_ne_ref"); |
| ADD_GENERIC_PATTERN(IREE::VM::CmpNZI32Op, "vm_cmp_nz_i32"); |
| ADD_GENERIC_PATTERN(IREE::VM::CmpNZI64Op, "vm_cmp_nz_i64"); |
| ADD_GENERIC_PATTERN(IREE::VM::CosF32Op, "vm_cos_f32"); |
| ADD_GENERIC_PATTERN(IREE::VM::CtlzI32Op, "vm_ctlz_i32"); |
| ADD_GENERIC_PATTERN(IREE::VM::CtlzI64Op, "vm_ctlz_i64"); |
| ADD_GENERIC_PATTERN(IREE::VM::DivF32Op, "vm_div_f32"); |
| ADD_GENERIC_PATTERN(IREE::VM::DivI32SOp, "vm_div_i32s"); |
| ADD_GENERIC_PATTERN(IREE::VM::DivI32UOp, "vm_div_i32u"); |
| ADD_GENERIC_PATTERN(IREE::VM::DivI64SOp, "vm_div_i64s"); |
| ADD_GENERIC_PATTERN(IREE::VM::DivI64UOp, "vm_div_i64u"); |
| ADD_GENERIC_PATTERN(IREE::VM::ErfF32Op, "vm_erf_f32"); |
| ADD_GENERIC_PATTERN(IREE::VM::Exp2F32Op, "vm_exp2_f32"); |
| ADD_GENERIC_PATTERN(IREE::VM::ExpF32Op, "vm_exp_f32"); |
| ADD_GENERIC_PATTERN(IREE::VM::ExpM1F32Op, "vm_expm1_f32"); |
| ADD_GENERIC_PATTERN(IREE::VM::ExtI8I32SOp, "vm_ext_i8i32s"); |
| ADD_GENERIC_PATTERN(IREE::VM::ExtI8I32UOp, "vm_ext_i8i32u"); |
| ADD_GENERIC_PATTERN(IREE::VM::ExtI16I32SOp, "vm_ext_i16i32s"); |
| ADD_GENERIC_PATTERN(IREE::VM::ExtI16I32UOp, "vm_ext_i16i32u"); |
| ADD_GENERIC_PATTERN(IREE::VM::ExtI32I64SOp, "vm_ext_i32i64s"); |
| ADD_GENERIC_PATTERN(IREE::VM::ExtI32I64UOp, "vm_ext_i32i64u"); |
| ADD_GENERIC_PATTERN(IREE::VM::FloorF32Op, "vm_floor_f32"); |
| ADD_GENERIC_PATTERN(IREE::VM::FMAF32Op, "vm_fma_f32"); |
| ADD_GENERIC_PATTERN(IREE::VM::FMAI32Op, "vm_fma_i32"); |
| ADD_GENERIC_PATTERN(IREE::VM::FMAI64Op, "vm_fma_i64"); |
| ADD_GENERIC_PATTERN(IREE::VM::Log1pF32Op, "vm_log1p_f32"); |
| ADD_GENERIC_PATTERN(IREE::VM::Log2F32Op, "vm_log2_f32"); |
| ADD_GENERIC_PATTERN(IREE::VM::Log10F32Op, "vm_log10_f32"); |
| ADD_GENERIC_PATTERN(IREE::VM::LogF32Op, "vm_log_f32"); |
| ADD_GENERIC_PATTERN(IREE::VM::MaxF32Op, "vm_max_f32"); |
| ADD_GENERIC_PATTERN(IREE::VM::MaxI32SOp, "vm_max_i32s"); |
| ADD_GENERIC_PATTERN(IREE::VM::MaxI32UOp, "vm_max_i32u"); |
| ADD_GENERIC_PATTERN(IREE::VM::MaxI64SOp, "vm_max_i64s"); |
| ADD_GENERIC_PATTERN(IREE::VM::MaxI64UOp, "vm_max_i64u"); |
| ADD_GENERIC_PATTERN(IREE::VM::MinF32Op, "vm_min_f32"); |
| ADD_GENERIC_PATTERN(IREE::VM::MinI32SOp, "vm_min_i32s"); |
| ADD_GENERIC_PATTERN(IREE::VM::MinI32UOp, "vm_min_i32u"); |
| ADD_GENERIC_PATTERN(IREE::VM::MinI64SOp, "vm_min_i64s"); |
| ADD_GENERIC_PATTERN(IREE::VM::MinI64UOp, "vm_min_i64u"); |
| ADD_GENERIC_PATTERN(IREE::VM::MulF32Op, "vm_mul_f32"); |
| ADD_GENERIC_PATTERN(IREE::VM::MulI32Op, "vm_mul_i32"); |
| ADD_GENERIC_PATTERN(IREE::VM::MulI64Op, "vm_mul_i64"); |
| ADD_GENERIC_PATTERN(IREE::VM::NegF32Op, "vm_neg_f32"); |
| ADD_GENERIC_PATTERN(IREE::VM::NotI32Op, "vm_not_i32"); |
| ADD_GENERIC_PATTERN(IREE::VM::NotI64Op, "vm_not_i64"); |
| ADD_GENERIC_PATTERN(IREE::VM::OrI32Op, "vm_or_i32"); |
| ADD_GENERIC_PATTERN(IREE::VM::OrI64Op, "vm_or_i64"); |
| ADD_GENERIC_PATTERN(IREE::VM::PowF32Op, "vm_pow_f32"); |
| ADD_GENERIC_PATTERN(IREE::VM::RemF32Op, "vm_rem_f32"); |
| ADD_GENERIC_PATTERN(IREE::VM::RemI32SOp, "vm_rem_i32s"); |
| ADD_GENERIC_PATTERN(IREE::VM::RemI32UOp, "vm_rem_i32u"); |
| ADD_GENERIC_PATTERN(IREE::VM::RemI64SOp, "vm_rem_i64s"); |
| ADD_GENERIC_PATTERN(IREE::VM::RemI64UOp, "vm_rem_i64u"); |
| ADD_GENERIC_PATTERN(IREE::VM::RoundF32EvenOp, "vm_round_f32_even"); |
| ADD_GENERIC_PATTERN(IREE::VM::RoundF32Op, "vm_round_f32"); |
| ADD_GENERIC_PATTERN(IREE::VM::RsqrtF32Op, "vm_rsqrt_f32"); |
| ADD_GENERIC_PATTERN(IREE::VM::SelectF32Op, "vm_select_f32"); |
| ADD_GENERIC_PATTERN(IREE::VM::SelectI32Op, "vm_select_i32"); |
| ADD_GENERIC_PATTERN(IREE::VM::SelectI64Op, "vm_select_i64"); |
| ADD_GENERIC_PATTERN(IREE::VM::ShlI32Op, "vm_shl_i32"); |
| ADD_GENERIC_PATTERN(IREE::VM::ShlI64Op, "vm_shl_i64"); |
| ADD_GENERIC_PATTERN(IREE::VM::ShrI32SOp, "vm_shr_i32s"); |
| ADD_GENERIC_PATTERN(IREE::VM::ShrI32UOp, "vm_shr_i32u"); |
| ADD_GENERIC_PATTERN(IREE::VM::ShrI64SOp, "vm_shr_i64s"); |
| ADD_GENERIC_PATTERN(IREE::VM::ShrI64UOp, "vm_shr_i64u"); |
| ADD_GENERIC_PATTERN(IREE::VM::SinF32Op, "vm_sin_f32"); |
| ADD_GENERIC_PATTERN(IREE::VM::SqrtF32Op, "vm_sqrt_f32"); |
| ADD_GENERIC_PATTERN(IREE::VM::SubF32Op, "vm_sub_f32"); |
| ADD_GENERIC_PATTERN(IREE::VM::SubI32Op, "vm_sub_i32"); |
| ADD_GENERIC_PATTERN(IREE::VM::SubI64Op, "vm_sub_i64"); |
| ADD_GENERIC_PATTERN(IREE::VM::TanhF32Op, "vm_tanh_f32"); |
| ADD_GENERIC_PATTERN(IREE::VM::TruncI32I8Op, "vm_trunc_i32i8"); |
| ADD_GENERIC_PATTERN(IREE::VM::TruncI32I16Op, "vm_trunc_i32i16"); |
| ADD_GENERIC_PATTERN(IREE::VM::TruncI64I32Op, "vm_trunc_i64i32"); |
| ADD_GENERIC_PATTERN(IREE::VM::XorI32Op, "vm_xor_i32"); |
| ADD_GENERIC_PATTERN(IREE::VM::XorI64Op, "vm_xor_i64"); |
| |
| // containers wrapped in ref types |
| ADD_CONTAINER_PATTERN(IREE::VM::BufferCompareOp, "vm_buffer_compare", |
| DenseSet<size_t>({0, 2}), true); |
| ADD_CONTAINER_PATTERN(IREE::VM::BufferCopyOp, "iree_vm_buffer_copy_bytes", |
| DenseSet<size_t>({0, 2}), true); |
| ADD_CONTAINER_PATTERN(IREE::VM::BufferFillF32Op, "vm_buffer_fill_f32", |
| DenseSet<size_t>({0}), true); |
| ADD_CONTAINER_PATTERN(IREE::VM::BufferFillF64Op, "vm_buffer_fill_f64", |
| DenseSet<size_t>({0}), true); |
| ADD_CONTAINER_PATTERN(IREE::VM::BufferFillI8Op, "vm_buffer_fill_i8", |
| DenseSet<size_t>({0}), true); |
| ADD_CONTAINER_PATTERN(IREE::VM::BufferFillI16Op, "vm_buffer_fill_i16", |
| DenseSet<size_t>({0}), true); |
| ADD_CONTAINER_PATTERN(IREE::VM::BufferFillI32Op, "vm_buffer_fill_i32", |
| DenseSet<size_t>({0}), true); |
| ADD_CONTAINER_PATTERN(IREE::VM::BufferFillI64Op, "vm_buffer_fill_i64", |
| DenseSet<size_t>({0}), true); |
| ADD_CONTAINER_PATTERN(IREE::VM::BufferLengthOp, "iree_vm_buffer_length", |
| DenseSet<size_t>({0}), false); |
| ADD_CONTAINER_PATTERN(IREE::VM::BufferLoadF32Op, "vm_buffer_load_f32", |
| DenseSet<size_t>({0}), true); |
| ADD_CONTAINER_PATTERN(IREE::VM::BufferLoadF64Op, "vm_buffer_load_f64", |
| DenseSet<size_t>({0}), true); |
| ADD_CONTAINER_PATTERN(IREE::VM::BufferLoadI8SOp, "vm_buffer_load_i8s", |
| DenseSet<size_t>({0}), true); |
| ADD_CONTAINER_PATTERN(IREE::VM::BufferLoadI8UOp, "vm_buffer_load_i8u", |
| DenseSet<size_t>({0}), true); |
| ADD_CONTAINER_PATTERN(IREE::VM::BufferLoadI16SOp, "vm_buffer_load_i16s", |
| DenseSet<size_t>({0}), true); |
| ADD_CONTAINER_PATTERN(IREE::VM::BufferLoadI16UOp, "vm_buffer_load_i16u", |
| DenseSet<size_t>({0}), true); |
| ADD_CONTAINER_PATTERN(IREE::VM::BufferLoadI32Op, "vm_buffer_load_i32", |
| DenseSet<size_t>({0}), true); |
| ADD_CONTAINER_PATTERN(IREE::VM::BufferLoadI64Op, "vm_buffer_load_i64", |
| DenseSet<size_t>({0}), true); |
| ADD_CONTAINER_PATTERN(IREE::VM::BufferStoreF32Op, "vm_buffer_store_f32", |
| DenseSet<size_t>({0}), true); |
| ADD_CONTAINER_PATTERN(IREE::VM::BufferStoreF64Op, "vm_buffer_store_f64", |
| DenseSet<size_t>({0}), true); |
| ADD_CONTAINER_PATTERN(IREE::VM::BufferStoreI8Op, "vm_buffer_store_i8", |
| DenseSet<size_t>({0}), true); |
| ADD_CONTAINER_PATTERN(IREE::VM::BufferStoreI16Op, "vm_buffer_store_i16", |
| DenseSet<size_t>({0}), true); |
| ADD_CONTAINER_PATTERN(IREE::VM::BufferStoreI32Op, "vm_buffer_store_i32", |
| DenseSet<size_t>({0}), true); |
| ADD_CONTAINER_PATTERN(IREE::VM::BufferStoreI64Op, "vm_buffer_store_i64", |
| DenseSet<size_t>({0}), true); |
| ADD_CONTAINER_PATTERN(IREE::VM::BufferHashOp, "iree_vm_buffer_hash", |
| DenseSet<size_t>({0}), true); |
| ADD_CONTAINER_PATTERN(IREE::VM::ListReserveOp, "iree_vm_list_reserve", |
| DenseSet<size_t>({0}), true); |
| ADD_CONTAINER_PATTERN(IREE::VM::ListResizeOp, "iree_vm_list_resize", |
| DenseSet<size_t>({0}), true); |
| ADD_CONTAINER_PATTERN(IREE::VM::ListSizeOp, "iree_vm_list_size", |
| DenseSet<size_t>({0}), false); |
| // global patterns |
| ADD_GLOBAL_LOAD_PATTERN(IREE::VM::GlobalLoadF32Op, IREE::VM::GlobalF32Op, |
| "vm_global_load_f32"); |
| ADD_GLOBAL_LOAD_PATTERN(IREE::VM::GlobalLoadI32Op, IREE::VM::GlobalI32Op, |
| "vm_global_load_i32"); |
| ADD_GLOBAL_LOAD_PATTERN(IREE::VM::GlobalLoadI64Op, IREE::VM::GlobalI64Op, |
| "vm_global_load_i64"); |
| ADD_GLOBAL_STORE_PATTERN(IREE::VM::GlobalStoreF32Op, IREE::VM::GlobalF32Op, |
| "vm_global_store_f32"); |
| ADD_GLOBAL_STORE_PATTERN(IREE::VM::GlobalStoreI32Op, IREE::VM::GlobalI32Op, |
| "vm_global_store_i32"); |
| ADD_GLOBAL_STORE_PATTERN(IREE::VM::GlobalStoreI64Op, IREE::VM::GlobalI64Op, |
| "vm_global_store_i64"); |
| |
| #undef ADD_GENERIC_PATTERN |
| #undef ADD_CONTAINER_PATTERN |
| #undef ADD_GLOBAL_LOAD_PATTERN |
| #undef ADD_GLOBAL_STORE_PATTERN |
| } |
| |
| namespace IREE::VM { |
| |
| namespace { |
| |
| // A pass converting IREE VM operations into the EmitC dialect. |
| // vm.func ops get converted to emitc.func with the calling convention used by |
| // EmitC. Each function gets three additional arguments a `iree_vm_stack_t*` as |
| // well as two module specific struct pointers (`{module_name}_t*` and |
| // `{module_name}_state_t`). These are followed by the original function |
| // arguments and out arguments for the vm.func results. The result type of the |
| // function is `iree_status_t`. Ref types are always passed as pointers. |
| // |
| // Examples: |
| // () -> () => (iree_vm_stack_t*, module_t*, module_state_t*) -> iree_status_t |
| // |
| // (i) -> () => (iree_vm_stack_t*, module_t*, module_state_t*, int32_t) -> |
| // iree_status_t |
| // |
| // (r) -> () => (iree_vm_stack_t*, module_t*, module_state_t*, iree_vm_ref_t*) |
| // -> iree_status_t |
| // |
| // () -> (r) => (iree_vm_stack_t*, module_t*, module_state_t*, iree_vm_ref_t*) |
| // -> iree_status_t |
| // |
| // (iir) -> (ri) => (iree_vm_stack_t*, module_t*, module_state_t*, int32_t, |
| // int32_t, iree_vm_ref_t*, iree_vm_ref_t*, int32_t*) -> |
| // iree_status_t |
| class ConvertVMToEmitCPass |
| : public PassWrapper<ConvertVMToEmitCPass, |
| OperationPass<IREE::VM::ModuleOp>> { |
| public: |
| MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ConvertVMToEmitCPass) |
| |
| void getDependentDialects(DialectRegistry ®istry) const override { |
| registry.insert<mlir::BuiltinDialect, mlir::cf::ControlFlowDialect, |
| mlir::emitc::EmitCDialect, IREE::Util::UtilDialect>(); |
| } |
| |
| StringRef getArgument() const override { return "iree-convert-vm-to-emitc"; } |
| |
| StringRef getDescription() const override { |
| return "Convert VM Ops to the EmitC dialect"; |
| } |
| |
| void runOnOperation() override { |
| IREE::VM::ModuleOp moduleOp = getOperation(); |
| |
| ConversionTarget target(getContext()); |
| EmitCTypeConverter typeConverter(moduleOp); |
| |
| // Convert vm.func ops to emitc.func with the calling convention used by |
| // EmitC. We convert these upfront to make sure vm.call ops always |
| // reference emitc.func ops with the correct calling convention during the |
| // conversion. |
| SmallVector<IREE::VM::FuncOp> funcsToRemove; |
| for (auto funcOp : moduleOp.getOps<IREE::VM::FuncOp>()) { |
| if (failed(convertFuncOp(funcOp, typeConverter))) { |
| return signalPassFailure(); |
| } |
| funcsToRemove.push_back(funcOp); |
| } |
| |
| for (auto &funcOp : funcsToRemove) { |
| funcOp.erase(); |
| } |
| |
| SmallVector<std::string> importShims; |
| |
| // The conversion of `call/call.variadic` ops on imported functions expects |
| // import ops to be rewritten to compiler generated shim functions. To |
| // ensure this we only rewrite `import` ops first. |
| ImportOpConverter importOpConverter(typeConverter, importShims); |
| for (auto importOp : moduleOp.getOps<IREE::VM::ImportOp>()) { |
| if (failed(importOpConverter(importOp))) { |
| return signalPassFailure(); |
| } |
| } |
| |
| RewritePatternSet patterns(&getContext()); |
| populateVMToEmitCPatterns(target, typeConverter, patterns); |
| |
| target.addLegalDialect<emitc::EmitCDialect, mlir::BuiltinDialect, |
| mlir::cf::ControlFlowDialect>(); |
| |
| target.addDynamicallyLegalOp<mlir::emitc::FuncOp>( |
| [&](mlir::emitc::FuncOp op) { |
| return typeConverter.isSignatureLegal(op.getFunctionType()) && |
| typeConverter.isLegal(&op.getFunctionBody()); |
| }); |
| |
| // Structural ops |
| target.addLegalOp<IREE::VM::ModuleOp>(); |
| target.addLegalOp<IREE::VM::ModuleTerminatorOp>(); |
| target.addLegalOp<IREE::VM::ImportOp>(); |
| |
| // This op is needed in the printer to emit an array holding the data. |
| target.addLegalOp<IREE::VM::RodataOp>(); |
| |
| if (failed(applyFullConversion(moduleOp, target, std::move(patterns)))) { |
| return signalPassFailure(); |
| } |
| |
| if (failed(createModuleStructure(moduleOp, typeConverter))) { |
| return signalPassFailure(); |
| } |
| } |
| }; |
| |
| } // namespace |
| |
| std::unique_ptr<OperationPass<IREE::VM::ModuleOp>> |
| createConvertVMToEmitCPass() { |
| return std::make_unique<ConvertVMToEmitCPass>(); |
| } |
| |
| } // namespace IREE::VM |
| |
| static PassRegistration<IREE::VM::ConvertVMToEmitCPass> pass; |
| |
| } // namespace mlir::iree_compiler |