| // Copyright 2020 The IREE Authors |
| // |
| // Licensed under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| |
| #include "iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.h" |
| |
| #include "iree/compiler/Dialect/Util/Conversion/ConversionPatterns.h" |
| #include "iree/compiler/Dialect/Util/IR/UtilDialect.h" |
| #include "iree/compiler/Dialect/Util/IR/UtilOps.h" |
| #include "iree/compiler/Dialect/VM/IR/VMOps.h" |
| #include "iree/compiler/Dialect/VM/Utils/CallingConvention.h" |
| #include "llvm/ADT/TypeSwitch.h" |
| #include "mlir/Dialect/EmitC/IR/EmitC.h" |
| #include "mlir/Dialect/StandardOps/IR/Ops.h" |
| #include "mlir/IR/Builders.h" |
| #include "mlir/IR/BuiltinDialect.h" |
| #include "mlir/IR/BuiltinOps.h" |
| #include "mlir/IR/Matchers.h" |
| #include "mlir/Pass/Pass.h" |
| #include "mlir/Transforms/DialectConversion.h" |
| |
| namespace mlir { |
| namespace iree_compiler { |
| |
| namespace { |
| |
| // TODO(simon-camp/marbre): Use this function throughout the conversions. |
| Optional<std::string> getCType(Type type) { |
| if (auto iType = type.dyn_cast<IntegerType>()) { |
| switch (iType.getWidth()) { |
| case 32: |
| case 64: |
| return std::string("int") + std::to_string(iType.getWidth()) + |
| std::string("_t"); |
| } |
| } |
| |
| if (auto fType = type.dyn_cast<FloatType>()) { |
| switch (fType.getWidth()) { |
| case 32: |
| return std::string("float"); |
| case 64: |
| return std::string("double"); |
| } |
| } |
| |
| if (auto oType = type.dyn_cast<emitc::OpaqueType>()) { |
| return std::string(oType.getValue()); |
| } |
| |
| if (type.isa<IREE::VM::RefType>()) { |
| return std::string("iree_vm_ref_t*"); |
| } |
| |
| return None; |
| } |
| |
| LogicalResult convertFuncOp(IREE::VM::FuncOp funcOp, |
| VMAnalysisCache &vmAnalysisCache) { |
| auto ctx = funcOp.getContext(); |
| auto loc = funcOp.getLoc(); |
| |
| OpBuilder builder(funcOp); |
| |
| auto moduleOp = funcOp.getOperation()->getParentOfType<IREE::VM::ModuleOp>(); |
| |
| FunctionType funcType = funcOp.getType(); |
| std::string name = |
| std::string(moduleOp.getName()) + "_" + std::string(funcOp.getName()); |
| std::string moduleTypeName = (moduleOp.getName() + "_t*").str(); |
| std::string moduleStateTypeName = (moduleOp.getName() + "_state_t*").str(); |
| |
| Type stackType = emitc::OpaqueType::get(ctx, "iree_vm_stack_t*"); |
| Type moduleType = emitc::OpaqueType::get(ctx, moduleTypeName); |
| Type moduleStateType = emitc::OpaqueType::get(ctx, moduleStateTypeName); |
| |
| SmallVector<Type, 3> inputTypes = {stackType, moduleType, moduleStateType}; |
| SmallVector<Type, 1> outputTypes; |
| |
| for (auto &inputType : funcType.getInputs()) { |
| inputTypes.push_back(inputType); |
| } |
| |
| for (auto &resultType : funcType.getResults()) { |
| Optional<std::string> cType = getCType(resultType); |
| if (!cType.hasValue()) { |
| return funcOp.emitError() << "unable to emit C type"; |
| } |
| std::string cPtrType; |
| // We pass refs as iree_vm_ref_t* regardless of whether it is an in or out |
| // parameter |
| if (resultType.isa<IREE::VM::RefType>()) { |
| cPtrType = cType.getValue(); |
| } else { |
| cPtrType = cType.getValue() + std::string("*"); |
| } |
| Type type = emitc::OpaqueType::get(ctx, cPtrType); |
| inputTypes.push_back(type); |
| outputTypes.push_back(type); |
| } |
| |
| auto newFuncType = mlir::FunctionType::get( |
| ctx, {inputTypes}, {emitc::OpaqueType::get(ctx, "iree_status_t")}); |
| |
| auto newFuncOp = builder.create<mlir::FuncOp>(loc, name, newFuncType); |
| newFuncOp.getOperation()->setAttr("emitc.static", UnitAttr::get(ctx)); |
| |
| Optional<std::string> callingConvention = makeCallingConventionString(funcOp); |
| |
| // Annotate new function with calling convention string which gets used in |
| // the CModuleTarget. |
| newFuncOp.getOperation()->setAttr( |
| "vm.calling_convention", |
| StringAttr::get(ctx, callingConvention.getValue())); |
| |
| // This call shold be equivalent to rewriter.inlineRegionBefore() |
| newFuncOp.getBody().getBlocks().splice(newFuncOp.end(), |
| funcOp.getBody().getBlocks()); |
| |
| Block &entryBlock = newFuncOp.getBlocks().front(); |
| |
| entryBlock.insertArgument(static_cast<unsigned>(0), stackType); |
| entryBlock.insertArgument(static_cast<unsigned>(1), moduleType); |
| entryBlock.insertArgument(static_cast<unsigned>(2), moduleStateType); |
| |
| entryBlock.addArguments(outputTypes); |
| |
| auto ptr = vmAnalysisCache.find(funcOp.getOperation()); |
| if (ptr == vmAnalysisCache.end()) { |
| return funcOp.emitError() << "parent func op not found in cache."; |
| } |
| |
| // Add constant ops for local refs |
| const int numRefArgs = llvm::count_if(inputTypes, [](Type inputType) { |
| return inputType.isa<IREE::VM::RefType>(); |
| }); |
| const int numRefs = ptr->second.getNumRefRegisters() - numRefArgs; |
| |
| builder.setInsertionPointToStart(&entryBlock); |
| |
| for (int i = 0; i < numRefs; i++) { |
| auto refOp = builder.create<emitc::ConstantOp>( |
| /*location=*/loc, |
| /*resultType=*/emitc::OpaqueType::get(ctx, "iree_vm_ref_t"), |
| /*value=*/emitc::OpaqueAttr::get(ctx, "")); |
| |
| // Mark local refs so that we can release them before a return operation. |
| // Here we rely on the fact that the register allocation maps arguments in |
| // the first slots. |
| refOp.getOperation()->setAttr("ref_ordinal", |
| builder.getIndexAttr(i + numRefArgs)); |
| |
| auto refPtrOp = builder.create<emitc::ApplyOp>( |
| /*location=*/loc, |
| /*type=*/emitc::OpaqueType::get(ctx, "iree_vm_ref_t*"), |
| /*applicableOperator=*/StringAttr::get(ctx, "&"), |
| /*operand=*/refOp.getResult()); |
| |
| auto refSizeOp = builder.create<emitc::CallOp>( |
| /*location=*/loc, |
| /*type=*/builder.getI32Type(), |
| /*callee=*/StringAttr::get(ctx, "sizeof"), |
| /*args=*/ArrayAttr{}, |
| /*templateArgs=*/ArrayAttr{}, |
| /*operands=*/ArrayRef<Value>{refOp.getResult()}); |
| |
| builder.create<emitc::CallOp>( |
| /*location=*/loc, |
| /*type=*/TypeRange{}, |
| /*callee=*/StringAttr::get(ctx, "memset"), |
| /*args=*/ |
| ArrayAttr::get(ctx, |
| {builder.getIndexAttr(0), builder.getUI32IntegerAttr(0), |
| builder.getIndexAttr(1)}), |
| /*templateArgs=*/ArrayAttr{}, |
| /*operands=*/ |
| ArrayRef<Value>{refPtrOp.getResult(), refSizeOp.getResult(0)}); |
| } |
| |
| vmAnalysisCache.insert( |
| std::make_pair(newFuncOp.getOperation(), std::move(ptr->second))); |
| |
| if (failed(funcOp.replaceAllSymbolUses(name, moduleOp))) |
| return funcOp.emitError() << "unable to update symbol name in module"; |
| |
| return success(); |
| } |
| |
| Optional<std::string> buildFunctionName(IREE::VM::ModuleOp &moduleOp, |
| IREE::VM::ImportOp &importOp) { |
| auto callingConvention = makeImportCallingConventionString(importOp); |
| if (!callingConvention.hasValue()) { |
| return None; |
| } |
| return std::string("call_") + callingConvention.getValue() + "_import"; |
| } |
| |
| template <typename SrcOpTy> |
| Optional<emitc::ApplyOp> createVmTypeDefPtr(ConversionPatternRewriter &rewriter, |
| SrcOpTy srcOp, Type elementType) { |
| auto ctx = srcOp.getContext(); |
| auto loc = srcOp.getLoc(); |
| |
| // TODO(simon-camp): Cleanup this up |
| StringRef elementTypeConstructor; |
| std::string elementTypeConstructorArg; |
| if (auto intType = elementType.template dyn_cast<IntegerType>()) { |
| unsigned int bitWidth = intType.getIntOrFloatBitWidth(); |
| elementTypeConstructor = "iree_vm_type_def_make_value_type"; |
| elementTypeConstructorArg = |
| std::string("IREE_VM_VALUE_TYPE_I") + std::to_string(bitWidth); |
| } else if (auto refType = |
| elementType.template dyn_cast<IREE::VM::RefType>()) { |
| auto objType = refType.getObjectType(); |
| |
| elementTypeConstructor = "iree_vm_type_def_make_ref_type"; |
| if (objType.template isa<IREE::VM::BufferType>()) { |
| elementTypeConstructorArg = "iree_vm_buffer_type_id()"; |
| } else if (objType.template isa<IREE::VM::ListType>()) { |
| elementTypeConstructorArg = "iree_vm_list_type_id()"; |
| } else { |
| srcOp.emitError() << "Unhandled ref object type " << objType; |
| return None; |
| } |
| } else if (auto opaqueType = |
| elementType.template dyn_cast<IREE::VM::OpaqueType>()) { |
| elementTypeConstructor = "iree_vm_type_def_make_variant_type"; |
| elementTypeConstructorArg = ""; |
| } else { |
| srcOp.emitError() << "Unhandled element type " << elementType; |
| return None; |
| } |
| |
| auto elementTypeOp = rewriter.template create<emitc::CallOp>( |
| /*location=*/loc, |
| /*type=*/emitc::OpaqueType::get(ctx, "iree_vm_type_def_t"), |
| /*callee=*/StringAttr::get(ctx, elementTypeConstructor), |
| /*args=*/ |
| ArrayAttr::get(ctx, |
| {emitc::OpaqueAttr::get(ctx, elementTypeConstructorArg)}), |
| /*templateArgs=*/ArrayAttr{}, |
| /*operands=*/ArrayRef<Value>{}); |
| |
| auto elementTypePtrOp = rewriter.template create<emitc::ApplyOp>( |
| /*location=*/loc, |
| /*result=*/emitc::OpaqueType::get(ctx, "iree_vm_type_def_t*"), |
| /*applicableOperator=*/StringAttr::get(ctx, "&"), |
| /*operand=*/elementTypeOp.getResult(0)); |
| |
| return elementTypePtrOp; |
| } |
| |
| Optional<Value> findRef(mlir::FuncOp &parentFuncOp, |
| VMAnalysisCache &vmAnalysisCache, Value refResult) { |
| assert(refResult.getType().isa<IREE::VM::RefType>()); |
| |
| auto ptr = vmAnalysisCache.find(parentFuncOp.getOperation()); |
| if (ptr == vmAnalysisCache.end()) { |
| parentFuncOp.emitError() << "parent func op not found in cache."; |
| return None; |
| } |
| |
| int32_t ordinal = ptr->second.getRefRegisterOrdinal(refResult); |
| |
| for (auto constantOp : parentFuncOp.getOps<emitc::ConstantOp>()) { |
| Operation *op = constantOp.getOperation(); |
| if (!op->hasAttr("ref_ordinal")) continue; |
| if (op->getAttr("ref_ordinal") |
| .cast<IntegerAttr>() |
| .getValue() |
| .getZExtValue() == ordinal) { |
| return constantOp.getResult(); |
| } |
| } |
| |
| return None; |
| } |
| |
| void releaseLocalRefs(OpBuilder &builder, Location location, |
| mlir::FuncOp funcOp) { |
| auto ctx = funcOp.getContext(); |
| |
| // Release local refs |
| for (auto constantOp : funcOp.getOps<emitc::ConstantOp>()) { |
| Operation *op = constantOp.getOperation(); |
| if (!op->hasAttr("ref_ordinal")) continue; |
| |
| auto refPtrOp = builder.create<emitc::ApplyOp>( |
| /*location=*/location, |
| /*type=*/emitc::OpaqueType::get(ctx, "iree_vm_ref_t*"), |
| /*applicableOperator=*/StringAttr::get(ctx, "&"), |
| /*operand=*/constantOp.getResult()); |
| |
| builder.create<emitc::CallOp>( |
| /*location=*/location, |
| /*type=*/TypeRange{}, |
| /*callee=*/StringAttr::get(ctx, "iree_vm_ref_release"), |
| /*args=*/ArrayAttr{}, |
| /*templateArgs=*/ArrayAttr{}, |
| /*operands=*/ArrayRef<Value>{refPtrOp.getResult()}); |
| } |
| } |
| |
| /// Generate an emitc.call op with one result and split the current block into a |
| /// continuation and failure block based on the truthiness of the result |
| /// value, i.e. a truthy value branches to the continuation block when |
| /// `negateCondition` is false. |
| emitc::CallOp failableCall( |
| OpBuilder &builder, Location location, Type type, StringAttr callee, |
| ArrayAttr args, ArrayAttr templateArgs, ArrayRef<Value> operands, |
| const std::function<void(emitc::CallOp &)> &failureBlockBuilder, |
| bool negateCondition = false) { |
| auto ctx = builder.getContext(); |
| |
| auto callOp = builder.create<emitc::CallOp>( |
| /*location=*/location, |
| /*type=*/type, |
| /*callee=*/callee, |
| /*args=*/args, |
| /*templateArgs=*/templateArgs, |
| /*operands=*/operands); |
| |
| Type boolType = builder.getIntegerType(1); |
| |
| auto conditionI1 = builder.create<emitc::CallOp>( |
| /*location=*/location, |
| /*type=*/boolType, |
| /*callee=*/StringAttr::get(ctx, "EMITC_CAST"), |
| /*args=*/ |
| ArrayAttr::get(ctx, {builder.getIndexAttr(0), TypeAttr::get(boolType)}), |
| /*templateArgs=*/ArrayAttr{}, |
| /*operands=*/ArrayRef<Value>{callOp.getResult(0)}); |
| |
| // Start by splitting the block into two. The part before will contain the |
| // condition, and the part after will contain the continuation point. |
| Block *condBlock = builder.getInsertionBlock(); |
| Block::iterator opPosition = builder.getInsertionPoint(); |
| Block *continuationBlock = condBlock->splitBlock(opPosition); |
| |
| // Create a new block for the target of the failure. |
| Block *failureBlock; |
| { |
| OpBuilder::InsertionGuard guard(builder); |
| Region *parentRegion = condBlock->getParent(); |
| failureBlock = builder.createBlock(parentRegion, parentRegion->end()); |
| |
| failureBlockBuilder(callOp); |
| } |
| |
| builder.setInsertionPointToEnd(condBlock); |
| auto branchOp = builder.create<CondBranchOp>( |
| location, conditionI1.getResult(0), |
| negateCondition ? failureBlock : continuationBlock, |
| negateCondition ? continuationBlock : failureBlock); |
| |
| builder.setInsertionPointToStart(continuationBlock); |
| |
| return callOp; |
| } |
| |
| emitc::CallOp returnIfError(OpBuilder &builder, Location location, |
| StringAttr callee, ArrayAttr args, |
| ArrayAttr templateArgs, ArrayRef<Value> operands) { |
| auto blockBuilder = [&builder, &location](emitc::CallOp &callOp) { |
| auto ctx = builder.getContext(); |
| |
| Block *block = builder.getBlock(); |
| mlir::FuncOp funcOp = cast<mlir::FuncOp>(block->getParentOp()); |
| |
| releaseLocalRefs(builder, location, funcOp); |
| |
| builder.create<mlir::ReturnOp>(location, callOp.getResult(0)); |
| }; |
| |
| auto ctx = builder.getContext(); |
| Type type = emitc::OpaqueType::get(ctx, "iree_status_t"); |
| return failableCall(builder, location, type, callee, args, templateArgs, |
| operands, blockBuilder, /*negateResult=*/true); |
| } |
| |
| emitc::CallOp failListNull(OpBuilder &builder, Location location, Type type, |
| StringAttr callee, ArrayAttr args, |
| ArrayAttr templateArgs, ArrayRef<Value> operands) { |
| auto blockBuilder = [&builder, &location](emitc::CallOp &callOp) { |
| auto ctx = builder.getContext(); |
| |
| Block *block = builder.getBlock(); |
| mlir::FuncOp funcOp = cast<mlir::FuncOp>(block->getParentOp()); |
| |
| releaseLocalRefs(builder, location, funcOp); |
| |
| auto statusOp = builder.create<emitc::CallOp>( |
| /*location=*/location, |
| /*type=*/emitc::OpaqueType::get(ctx, "iree_status_t"), |
| /*callee=*/StringAttr::get(ctx, "iree_make_status"), |
| /*args=*/ |
| ArrayAttr::get( |
| ctx, {emitc::OpaqueAttr::get(ctx, "IREE_STATUS_INVALID_ARGUMENT")}), |
| /*templateArgs=*/ArrayAttr{}, |
| /*operands=*/ArrayRef<Value>{}); |
| |
| builder.create<mlir::ReturnOp>(location, statusOp.getResult(0)); |
| }; |
| |
| return failableCall(builder, location, type, callee, args, templateArgs, |
| operands, blockBuilder); |
| } |
| |
| /// Generate a mlir.call op with one result and split the current block into a |
| /// continuation and failure block based on the truthiness of the result |
| /// value, i.e. a truthy value branches to the continuation block when |
| /// `negateCondition` is false. |
| mlir::CallOp failableCall( |
| OpBuilder &builder, Location location, mlir::FuncOp &callee, |
| ArrayRef<Value> operands, |
| const std::function<void(mlir::CallOp &)> &failureBlockBuilder, |
| bool negateCondition = false) { |
| auto ctx = builder.getContext(); |
| |
| auto callOp = builder.create<mlir::CallOp>( |
| /*location=*/location, |
| /*callee=*/callee, |
| /*operands=*/operands); |
| |
| Type boolType = builder.getIntegerType(1); |
| |
| auto conditionI1 = builder.create<emitc::CallOp>( |
| /*location=*/location, |
| /*type=*/boolType, |
| /*callee=*/StringAttr::get(ctx, "EMITC_CAST"), |
| /*args=*/ |
| ArrayAttr::get(ctx, {builder.getIndexAttr(0), TypeAttr::get(boolType)}), |
| /*templateArgs=*/ArrayAttr{}, |
| /*operands=*/ArrayRef<Value>{callOp.getResult(0)}); |
| |
| // Start by splitting the block into two. The part before will contain the |
| // condition, and the part after will contain the continuation point. |
| Block *condBlock = builder.getInsertionBlock(); |
| Block::iterator opPosition = builder.getInsertionPoint(); |
| Block *continuationBlock = condBlock->splitBlock(opPosition); |
| |
| // Create a new block for the target of the failure. |
| Block *failureBlock; |
| { |
| OpBuilder::InsertionGuard guard(builder); |
| Region *parentRegion = condBlock->getParent(); |
| failureBlock = builder.createBlock(parentRegion, parentRegion->end()); |
| |
| failureBlockBuilder(callOp); |
| } |
| |
| builder.setInsertionPointToEnd(condBlock); |
| auto branchOp = builder.create<CondBranchOp>( |
| location, conditionI1.getResult(0), |
| negateCondition ? failureBlock : continuationBlock, |
| negateCondition ? continuationBlock : failureBlock); |
| |
| builder.setInsertionPoint(continuationBlock, opPosition); |
| |
| return callOp; |
| } |
| |
| mlir::CallOp returnIfError(OpBuilder &builder, Location location, |
| mlir::FuncOp &callee, ArrayRef<Value> operands) { |
| auto blockBuilder = [&builder, &location](mlir::CallOp &callOp) { |
| auto ctx = builder.getContext(); |
| |
| Block *block = builder.getBlock(); |
| mlir::FuncOp funcOp = cast<mlir::FuncOp>(block->getParentOp()); |
| |
| releaseLocalRefs(builder, location, funcOp); |
| |
| builder.create<mlir::ReturnOp>(location, callOp.getResult(0)); |
| }; |
| |
| return failableCall(builder, location, callee, operands, blockBuilder, |
| /*negateResult=*/true); |
| } |
| |
| LogicalResult createAPIFunctions(IREE::VM::ModuleOp moduleOp) { |
| auto ctx = moduleOp.getContext(); |
| auto loc = moduleOp.getLoc(); |
| |
| OpBuilder builder(moduleOp); |
| builder.setInsertionPoint(moduleOp.getBody()->getTerminator()); |
| |
| std::string moduleName{moduleOp.getName()}; |
| |
| // destroy |
| { |
| OpBuilder::InsertionGuard guard(builder); |
| |
| auto funcType = mlir::FunctionType::get( |
| ctx, {emitc::OpaqueType::get(ctx, "void*")}, {}); |
| |
| auto funcOp = |
| builder.create<mlir::FuncOp>(loc, moduleName + "_destroy", funcType); |
| |
| funcOp.getOperation()->setAttr("emitc.static", UnitAttr::get(ctx)); |
| |
| Block *entryBlock = funcOp.addEntryBlock(); |
| |
| builder.setInsertionPointToStart(entryBlock); |
| |
| std::string moduleTypeName = moduleName + "_t*"; |
| |
| auto castedModuleOp = builder.create<emitc::CallOp>( |
| /*location=*/loc, |
| /*type=*/emitc::OpaqueType::get(ctx, moduleTypeName), |
| /*callee=*/StringAttr::get(ctx, "EMITC_CAST"), |
| /*args=*/ |
| ArrayAttr::get(ctx, {builder.getIndexAttr(0), |
| emitc::OpaqueAttr::get(ctx, moduleTypeName)}), |
| /*templateArgs=*/ArrayAttr{}, |
| /*operands=*/ArrayRef<Value>{funcOp.getArgument(0)}); |
| |
| auto allocatorOp = builder.create<emitc::CallOp>( |
| /*location=*/loc, |
| /*type=*/emitc::OpaqueType::get(ctx, "iree_allocator_t"), |
| /*callee=*/StringAttr::get(ctx, "EMITC_STRUCT_PTR_MEMBER"), |
| /*args=*/ |
| ArrayAttr::get(ctx, {builder.getIndexAttr(0), |
| emitc::OpaqueAttr::get(ctx, "allocator")}), |
| /*templateArgs=*/ArrayAttr{}, |
| /*operands=*/ArrayRef<Value>{castedModuleOp.getResult(0)}); |
| |
| builder.create<emitc::CallOp>( |
| /*location=*/loc, |
| /*type=*/TypeRange{}, |
| /*callee=*/StringAttr::get(ctx, "iree_allocator_free"), |
| /*args=*/ArrayAttr{}, |
| /*templateArgs=*/ArrayAttr{}, |
| /*operands=*/ |
| ArrayRef<Value>{allocatorOp.getResult(0), castedModuleOp.getResult(0)}); |
| |
| builder.create<mlir::ReturnOp>(loc); |
| } |
| |
| // alloc_state |
| { |
| OpBuilder::InsertionGuard guard(builder); |
| |
| auto funcType = mlir::FunctionType::get( |
| ctx, |
| {emitc::OpaqueType::get(ctx, "void*"), |
| emitc::OpaqueType::get(ctx, "iree_allocator_t"), |
| emitc::OpaqueType::get(ctx, "iree_vm_module_state_t**")}, |
| {emitc::OpaqueType::get(ctx, "iree_status_t")}); |
| |
| auto funcOp = builder.create<mlir::FuncOp>(loc, moduleName + "_alloc_state", |
| funcType); |
| |
| funcOp.getOperation()->setAttr("emitc.static", UnitAttr::get(ctx)); |
| |
| Block *entryBlock = funcOp.addEntryBlock(); |
| |
| builder.setInsertionPointToStart(entryBlock); |
| |
| std::string moduleStateTypeName = moduleName + "_state_t*"; |
| |
| auto stateOp = builder.create<emitc::ConstantOp>( |
| /*location=*/loc, |
| /*resultType=*/emitc::OpaqueType::get(ctx, moduleStateTypeName), |
| /*value=*/emitc::OpaqueAttr::get(ctx, "NULL")); |
| |
| auto stateSize = builder.create<emitc::CallOp>( |
| /*location=*/loc, |
| /*type=*/emitc::OpaqueType::get(ctx, "iree_host_size_t"), |
| /*callee=*/StringAttr::get(ctx, "sizeof"), |
| /*args=*/ |
| ArrayAttr::get(ctx, |
| {emitc::OpaqueAttr::get(ctx, moduleName + "_state_t")}), |
| /*templateArgs=*/ArrayAttr{}, |
| /*operands=*/ArrayRef<Value>{}); |
| |
| auto statePtr = builder.template create<emitc::ApplyOp>( |
| /*location=*/loc, |
| /*result=*/emitc::OpaqueType::get(ctx, moduleStateTypeName + "*"), |
| /*applicableOperator=*/StringAttr::get(ctx, "&"), |
| /*operand=*/stateOp.getResult()); |
| |
| auto voidPtr = builder.create<emitc::CallOp>( |
| /*location=*/loc, |
| /*type=*/emitc::OpaqueType::get(ctx, "void**"), |
| /*callee=*/StringAttr::get(ctx, "EMITC_CAST"), |
| /*args=*/ |
| ArrayAttr::get(ctx, {builder.getIndexAttr(0), |
| emitc::OpaqueAttr::get(ctx, "void**")}), |
| /*templateArgs=*/ArrayAttr{}, |
| /*operands=*/ArrayRef<Value>{statePtr.getResult()}); |
| |
| returnIfError( |
| builder, loc, StringAttr::get(ctx, "iree_allocator_malloc"), {}, {}, |
| {funcOp.getArgument(1), stateSize.getResult(0), voidPtr.getResult(0)}); |
| |
| builder.create<emitc::CallOp>( |
| /*location=*/loc, |
| /*type=*/TypeRange{}, |
| /*callee=*/StringAttr::get(ctx, "memset"), |
| /*args=*/ |
| ArrayAttr::get(ctx, |
| {builder.getIndexAttr(0), builder.getUI32IntegerAttr(0), |
| builder.getIndexAttr(1)}), |
| /*templateArgs=*/ArrayAttr{}, |
| /*operands=*/ |
| ArrayRef<Value>{stateOp.getResult(), stateSize.getResult(0)}); |
| |
| builder.create<emitc::CallOp>( |
| /*location=*/loc, |
| /*type=*/TypeRange{}, |
| /*callee=*/StringAttr::get(ctx, "EMITC_STRUCT_PTR_MEMBER_ASSIGN"), |
| /*args=*/ |
| ArrayAttr::get(ctx, {builder.getIndexAttr(0), |
| emitc::OpaqueAttr::get(ctx, "allocator"), |
| builder.getIndexAttr(1)}), |
| /*templateArgs=*/ArrayAttr{}, |
| /*operands=*/ |
| ArrayRef<Value>{stateOp.getResult(), funcOp.getArgument(1)}); |
| |
| // Initialize buffers |
| for (auto rodataOp : moduleOp.getOps<IREE::VM::RodataOp>()) { |
| auto ordinal = rodataOp.ordinal().getValue().getZExtValue(); |
| |
| std::string bufferName = moduleName + "_" + rodataOp.getName().str(); |
| |
| auto bufferVoid = builder.create<emitc::CallOp>( |
| /*location=*/loc, |
| /*type=*/emitc::OpaqueType::get(ctx, "void*"), |
| /*callee=*/StringAttr::get(ctx, "EMITC_CAST"), |
| /*args=*/ |
| ArrayAttr::get(ctx, {emitc::OpaqueAttr::get(ctx, bufferName), |
| emitc::OpaqueAttr::get(ctx, "void*")}), |
| /*templateArgs=*/ArrayAttr{}, |
| /*operands=*/ArrayRef<Value>{}); |
| |
| auto bufferSize = builder.create<emitc::CallOp>( |
| /*location=*/loc, |
| /*type=*/emitc::OpaqueType::get(ctx, "iree_host_size_t"), |
| /*callee=*/StringAttr::get(ctx, "sizeof"), |
| /*args=*/ |
| ArrayAttr::get(ctx, {emitc::OpaqueAttr::get(ctx, bufferName)}), |
| /*templateArgs=*/ArrayAttr{}, |
| /*operands=*/ArrayRef<Value>{}); |
| |
| auto byteSpan = builder.create<emitc::CallOp>( |
| /*location=*/loc, |
| /*type=*/emitc::OpaqueType::get(ctx, "iree_byte_span_t"), |
| /*callee=*/StringAttr::get(ctx, "iree_make_byte_span"), |
| /*args=*/ArrayAttr{}, |
| /*templateArgs=*/ArrayAttr{}, |
| /*operands=*/ |
| ArrayRef<Value>{bufferVoid.getResult(0), bufferSize.getResult(0)}); |
| |
| auto allocator = builder.create<emitc::CallOp>( |
| /*location=*/loc, |
| /*type=*/emitc::OpaqueType::get(ctx, "iree_allocator_t"), |
| /*callee=*/StringAttr::get(ctx, "iree_allocator_null"), |
| /*args=*/ArrayAttr{}, |
| /*templateArgs=*/ArrayAttr{}, |
| /*operands=*/ |
| ArrayRef<Value>{}); |
| |
| auto buffers = builder.create<emitc::CallOp>( |
| /*location=*/loc, |
| /*type=*/emitc::OpaqueType::get(ctx, "iree_vm_buffer_t*"), |
| /*callee=*/StringAttr::get(ctx, "EMITC_STRUCT_PTR_MEMBER"), |
| /*args=*/ |
| ArrayAttr::get(ctx, {builder.getIndexAttr(0), |
| emitc::OpaqueAttr::get(ctx, "rodata_buffers")}), |
| /*templateArgs=*/ArrayAttr{}, |
| /*operands=*/ArrayRef<Value>{stateOp.getResult()}); |
| |
| auto buffer = builder.create<emitc::CallOp>( |
| /*location=*/loc, |
| /*type=*/emitc::OpaqueType::get(ctx, "iree_vm_buffer_t*"), |
| /*callee=*/StringAttr::get(ctx, "EMITC_ARRAY_ELEMENT_ADDRESS"), |
| /*args=*/ |
| ArrayAttr::get(ctx, {builder.getIndexAttr(0), |
| builder.getUI32IntegerAttr(ordinal)}), |
| /*templateArgs=*/ArrayAttr{}, |
| /*operands=*/ArrayRef<Value>{buffers.getResult(0)}); |
| |
| builder.create<emitc::CallOp>( |
| /*location=*/loc, |
| /*type=*/TypeRange{}, |
| /*callee=*/StringAttr::get(ctx, "iree_vm_buffer_initialize"), |
| /*args=*/ |
| ArrayAttr::get(ctx, {emitc::OpaqueAttr::get( |
| ctx, "IREE_VM_BUFFER_ACCESS_ORIGIN_MODULE"), |
| builder.getIndexAttr(0), builder.getIndexAttr(1), |
| builder.getIndexAttr(2)}), |
| /*templateArgs=*/ArrayAttr{}, |
| /*operands=*/ |
| ArrayRef<Value>{byteSpan.getResult(0), allocator.getResult(0), |
| buffer.getResult(0)}); |
| } |
| |
| auto baseStateOp = builder.create<emitc::CallOp>( |
| /*location=*/loc, |
| /*type=*/emitc::OpaqueType::get(ctx, "iree_vm_module_state_t*"), |
| /*callee=*/StringAttr::get(ctx, "EMITC_CAST"), |
| /*args=*/ |
| ArrayAttr::get( |
| ctx, {builder.getIndexAttr(0), |
| emitc::OpaqueAttr::get(ctx, "iree_vm_module_state_t*")}), |
| /*templateArgs=*/ArrayAttr{}, |
| /*operands=*/ArrayRef<Value>{stateOp.getResult()}); |
| |
| builder.create<emitc::CallOp>( |
| /*location=*/loc, |
| /*type=*/TypeRange{}, |
| /*callee=*/StringAttr::get(ctx, "EMITC_DEREF_ASSIGN_VALUE"), |
| /*args=*/ArrayAttr{}, |
| /*templateArgs=*/ArrayAttr{}, |
| /*operands=*/ |
| ArrayRef<Value>{funcOp.getArgument(2), baseStateOp.getResult(0)}); |
| |
| auto status = builder.create<emitc::CallOp>( |
| /*location=*/loc, |
| /*type=*/emitc::OpaqueType::get(ctx, "iree_status_t"), |
| /*callee=*/StringAttr::get(ctx, "iree_ok_status"), |
| /*args=*/ArrayAttr{}, |
| /*templateArgs=*/ArrayAttr{}, |
| /*operands=*/ArrayRef<Value>{}); |
| |
| builder.create<mlir::ReturnOp>(loc, status.getResult(0)); |
| } |
| |
| // free_state |
| { |
| OpBuilder::InsertionGuard guard(builder); |
| |
| auto funcType = mlir::FunctionType::get( |
| ctx, |
| {emitc::OpaqueType::get(ctx, "void*"), |
| emitc::OpaqueType::get(ctx, "iree_vm_module_state_t*")}, |
| {}); |
| |
| auto funcOp = |
| builder.create<mlir::FuncOp>(loc, moduleName + "_free_state", funcType); |
| |
| funcOp.getOperation()->setAttr("emitc.static", UnitAttr::get(ctx)); |
| |
| Block *entryBlock = funcOp.addEntryBlock(); |
| |
| builder.setInsertionPointToStart(entryBlock); |
| |
| std::string moduleStateTypeName = moduleName + "_state_t*"; |
| |
| auto stateOp = builder.create<emitc::CallOp>( |
| /*location=*/loc, |
| /*type=*/emitc::OpaqueType::get(ctx, moduleStateTypeName), |
| /*callee=*/StringAttr::get(ctx, "EMITC_CAST"), |
| /*args=*/ |
| ArrayAttr::get(ctx, {builder.getIndexAttr(0), |
| emitc::OpaqueAttr::get(ctx, moduleStateTypeName)}), |
| /*templateArgs=*/ArrayAttr{}, |
| /*operands=*/ArrayRef<Value>{funcOp.getArgument(1)}); |
| |
| auto allocatorOp = builder.create<emitc::CallOp>( |
| /*location=*/loc, |
| /*type=*/emitc::OpaqueType::get(ctx, "iree_allocator_t"), |
| /*callee=*/StringAttr::get(ctx, "EMITC_STRUCT_PTR_MEMBER"), |
| /*args=*/ |
| ArrayAttr::get(ctx, {builder.getIndexAttr(0), |
| emitc::OpaqueAttr::get(ctx, "allocator")}), |
| /*templateArgs=*/ArrayAttr{}, |
| /*operands=*/ArrayRef<Value>{stateOp.getResult(0)}); |
| |
| builder.create<emitc::CallOp>( |
| /*location=*/loc, |
| /*type=*/TypeRange{}, |
| /*callee=*/StringAttr::get(ctx, "iree_allocator_free"), |
| /*args=*/ArrayAttr{}, |
| /*templateArgs=*/ArrayAttr{}, |
| /*operands=*/ |
| ArrayRef<Value>{allocatorOp.getResult(0), stateOp.getResult(0)}); |
| |
| builder.create<mlir::ReturnOp>(loc); |
| } |
| |
| // resolve_import |
| { |
| OpBuilder::InsertionGuard guard(builder); |
| |
| auto funcType = mlir::FunctionType::get( |
| ctx, |
| { |
| emitc::OpaqueType::get(ctx, "void*"), |
| emitc::OpaqueType::get(ctx, "iree_vm_module_state_t*"), |
| emitc::OpaqueType::get(ctx, "iree_host_size_t"), |
| emitc::OpaqueType::get(ctx, "const iree_vm_function_t*"), |
| emitc::OpaqueType::get(ctx, "const iree_vm_function_signature_t*"), |
| }, |
| {emitc::OpaqueType::get(ctx, "iree_status_t")}); |
| |
| auto funcOp = builder.create<mlir::FuncOp>( |
| loc, moduleName + "_resolve_import", funcType); |
| |
| funcOp.getOperation()->setAttr("emitc.static", UnitAttr::get(ctx)); |
| |
| Block *entryBlock = funcOp.addEntryBlock(); |
| |
| builder.setInsertionPointToStart(entryBlock); |
| |
| std::string moduleStateTypeName = moduleName + "_state_t*"; |
| |
| auto stateOp = builder.create<emitc::CallOp>( |
| /*location=*/loc, |
| /*type=*/emitc::OpaqueType::get(ctx, moduleStateTypeName), |
| /*callee=*/StringAttr::get(ctx, "EMITC_CAST"), |
| /*args=*/ |
| ArrayAttr::get(ctx, {builder.getIndexAttr(0), |
| emitc::OpaqueAttr::get(ctx, moduleStateTypeName)}), |
| /*templateArgs=*/ArrayAttr{}, |
| /*operands=*/ArrayRef<Value>{funcOp.getArgument(1)}); |
| |
| auto imports = builder.create<emitc::CallOp>( |
| /*location=*/loc, |
| /*type=*/emitc::OpaqueType::get(ctx, "iree_vm_function_t*"), |
| /*callee=*/StringAttr::get(ctx, "EMITC_STRUCT_PTR_MEMBER"), |
| /*args=*/ |
| ArrayAttr::get(ctx, {builder.getIndexAttr(0), |
| emitc::OpaqueAttr::get(ctx, "imports")}), |
| /*templateArgs=*/ArrayAttr{}, |
| /*operands=*/ArrayRef<Value>{stateOp.getResult(0)}); |
| |
| auto import = builder.create<emitc::CallOp>( |
| /*location=*/loc, |
| /*type=*/emitc::OpaqueType::get(ctx, "iree_vm_function_t*"), |
| /*callee=*/StringAttr::get(ctx, "EMITC_ARRAY_ELEMENT_ADDRESS"), |
| /*args=*/ArrayAttr{}, |
| /*templateArgs=*/ArrayAttr{}, |
| /*operands=*/ |
| ArrayRef<Value>{imports.getResult(0), funcOp.getArgument(2)}); |
| |
| builder.create<emitc::CallOp>( |
| /*location=*/loc, |
| /*type=*/TypeRange{}, |
| /*callee=*/StringAttr::get(ctx, "EMITC_DEREF_ASSIGN_PTR"), |
| /*args=*/ArrayAttr{}, |
| /*templateArgs=*/ArrayAttr{}, |
| /*operands=*/ |
| ArrayRef<Value>{import.getResult(0), funcOp.getArgument(3)}); |
| |
| auto status = builder.create<emitc::CallOp>( |
| /*location=*/loc, |
| /*type=*/emitc::OpaqueType::get(ctx, "iree_status_t"), |
| /*callee=*/StringAttr::get(ctx, "iree_ok_status"), |
| /*args=*/ArrayAttr{}, |
| /*templateArgs=*/ArrayAttr{}, |
| /*operands=*/ArrayRef<Value>{}); |
| |
| builder.create<mlir::ReturnOp>(loc, status.getResult(0)); |
| } |
| |
| // create |
| { |
| OpBuilder::InsertionGuard guard(builder); |
| |
| auto funcType = mlir::FunctionType::get( |
| ctx, |
| {emitc::OpaqueType::get(ctx, "iree_allocator_t"), |
| emitc::OpaqueType::get(ctx, "iree_vm_module_t**")}, |
| {emitc::OpaqueType::get(ctx, "iree_status_t")}); |
| |
| auto funcOp = |
| builder.create<mlir::FuncOp>(loc, moduleName + "_create", funcType); |
| |
| funcOp.getOperation()->setAttr("emitc.static", UnitAttr::get(ctx)); |
| |
| // This function needs an iree_vm_native_module_descriptor_t that is emitted |
| // by the CModuleTarget at the moment. So we add a marker to this function |
| // and delay the printing of it. |
| funcOp.getOperation()->setAttr("vm.emit_at_end", UnitAttr::get(ctx)); |
| |
| Block *entryBlock = funcOp.addEntryBlock(); |
| |
| builder.setInsertionPointToStart(entryBlock); |
| |
| std::string moduleTypeName = moduleName + "_t*"; |
| |
| auto module = builder.create<emitc::ConstantOp>( |
| /*location=*/loc, |
| /*resultType=*/emitc::OpaqueType::get(ctx, moduleTypeName), |
| /*value=*/emitc::OpaqueAttr::get(ctx, "NULL")); |
| |
| auto moduleSize = builder.create<emitc::CallOp>( |
| /*location=*/loc, |
| /*type=*/emitc::OpaqueType::get(ctx, "iree_host_size_t"), |
| /*callee=*/StringAttr::get(ctx, "sizeof"), |
| /*args=*/ |
| ArrayAttr::get(ctx, {emitc::OpaqueAttr::get(ctx, moduleName + "_t")}), |
| /*templateArgs=*/ArrayAttr{}, |
| /*operands=*/ArrayRef<Value>{}); |
| |
| auto modulePtr = builder.template create<emitc::ApplyOp>( |
| /*location=*/loc, |
| /*result=*/emitc::OpaqueType::get(ctx, moduleTypeName + "*"), |
| /*applicableOperator=*/StringAttr::get(ctx, "&"), |
| /*operand=*/module.getResult()); |
| |
| auto voidPtr = builder.create<emitc::CallOp>( |
| /*location=*/loc, |
| /*type=*/emitc::OpaqueType::get(ctx, "void**"), |
| /*callee=*/StringAttr::get(ctx, "EMITC_CAST"), |
| /*args=*/ |
| ArrayAttr::get(ctx, {builder.getIndexAttr(0), |
| emitc::OpaqueAttr::get(ctx, "void**")}), |
| /*templateArgs=*/ArrayAttr{}, |
| /*operands=*/ArrayRef<Value>{modulePtr.getResult()}); |
| |
| returnIfError( |
| builder, loc, StringAttr::get(ctx, "iree_allocator_malloc"), {}, {}, |
| {funcOp.getArgument(0), moduleSize.getResult(0), voidPtr.getResult(0)}); |
| |
| builder.create<emitc::CallOp>( |
| /*location=*/loc, |
| /*type=*/TypeRange{}, |
| /*callee=*/StringAttr::get(ctx, "memset"), |
| /*args=*/ |
| ArrayAttr::get(ctx, |
| {builder.getIndexAttr(0), builder.getUI32IntegerAttr(0), |
| builder.getIndexAttr(1)}), |
| /*templateArgs=*/ArrayAttr{}, |
| /*operands=*/ |
| ArrayRef<Value>{module.getResult(), moduleSize.getResult(0)}); |
| |
| builder.create<emitc::CallOp>( |
| /*location=*/loc, |
| /*type=*/TypeRange{}, |
| /*callee=*/StringAttr::get(ctx, "EMITC_STRUCT_PTR_MEMBER_ASSIGN"), |
| /*args=*/ |
| ArrayAttr::get(ctx, {builder.getIndexAttr(0), |
| emitc::OpaqueAttr::get(ctx, "allocator"), |
| builder.getIndexAttr(1)}), |
| /*templateArgs=*/ArrayAttr{}, |
| /*operands=*/ |
| ArrayRef<Value>{module.getResult(), funcOp.getArgument(0)}); |
| |
| auto vmModule = builder.create<emitc::ConstantOp>( |
| /*location=*/loc, |
| /*resultType=*/emitc::OpaqueType::get(ctx, "iree_vm_module_t"), |
| /*value=*/emitc::OpaqueAttr::get(ctx, "")); |
| |
| auto vmModulePtr = builder.template create<emitc::ApplyOp>( |
| /*location=*/loc, |
| /*result=*/emitc::OpaqueType::get(ctx, "iree_vm_module_t*"), |
| /*applicableOperator=*/StringAttr::get(ctx, "&"), |
| /*operand=*/vmModule.getResult()); |
| |
| auto vmInitializeStatus = builder.create<emitc::CallOp>( |
| /*location=*/loc, |
| /*type=*/emitc::OpaqueType::get(ctx, "iree_status_t"), |
| /*callee=*/StringAttr::get(ctx, "iree_vm_module_initialize"), |
| /*args=*/ArrayAttr{}, |
| /*templateArgs=*/ArrayAttr{}, |
| /*operands=*/ |
| ArrayRef<Value>{vmModulePtr.getResult(), module.getResult()}); |
| |
| Type boolType = builder.getIntegerType(1); |
| |
| auto vmInitializeIsOk = builder.create<emitc::CallOp>( |
| /*location=*/loc, |
| /*type=*/boolType, |
| /*callee=*/StringAttr::get(ctx, "iree_status_is_ok"), |
| /*args=*/ArrayAttr{}, |
| /*templateArgs=*/ArrayAttr{}, |
| /*operands=*/ |
| ArrayRef<Value>{vmInitializeStatus.getResult(0)}); |
| |
| // Start by splitting the block into two. The part before will contain the |
| // condition, and the part after will contain the continuation point. |
| Block *condBlock = builder.getInsertionBlock(); |
| Block::iterator opPosition = builder.getInsertionPoint(); |
| Block *continuationBlock = condBlock->splitBlock(opPosition); |
| |
| // Create a new block for the target of the failure. |
| Block *failureBlock; |
| { |
| OpBuilder::InsertionGuard guard(builder); |
| Region *parentRegion = condBlock->getParent(); |
| failureBlock = builder.createBlock(parentRegion, parentRegion->end()); |
| |
| builder.create<emitc::CallOp>( |
| /*location=*/loc, |
| /*type=*/TypeRange{}, |
| /*callee=*/StringAttr::get(ctx, "iree_allocator_free"), |
| /*args=*/ArrayAttr{}, |
| /*templateArgs=*/ArrayAttr{}, |
| /*operands=*/ |
| ArrayRef<Value>{funcOp.getArgument(0), module.getResult()}); |
| |
| builder.create<mlir::ReturnOp>(loc, vmInitializeStatus.getResult(0)); |
| } |
| |
| builder.setInsertionPointToEnd(condBlock); |
| |
| builder.create<CondBranchOp>(loc, vmInitializeIsOk.getResult(0), |
| continuationBlock, failureBlock); |
| |
| builder.setInsertionPointToStart(continuationBlock); |
| |
| // Set function pointers |
| for (std::string funcName : |
| {"destroy", "alloc_state", "free_state", "resolve_import"}) { |
| builder.create<emitc::CallOp>( |
| /*location=*/loc, |
| /*type=*/TypeRange{}, |
| /*callee=*/StringAttr::get(ctx, "EMITC_STRUCT_MEMBER_ASSIGN"), |
| /*args=*/ |
| ArrayAttr::get( |
| ctx, |
| {builder.getIndexAttr(0), emitc::OpaqueAttr::get(ctx, funcName), |
| emitc::OpaqueAttr::get(ctx, moduleName + "_" + funcName)}), |
| /*templateArgs=*/ArrayAttr{}, |
| /*operands=*/ |
| ArrayRef<Value>{vmModule.getResult()}); |
| } |
| |
| std::string descriptoPtr = "&" + moduleName + "_descriptor_"; |
| |
| auto status = builder.create<emitc::CallOp>( |
| /*location=*/loc, |
| /*type=*/emitc::OpaqueType::get(ctx, "iree_status_t"), |
| /*callee=*/StringAttr::get(ctx, "iree_vm_native_module_create"), |
| /*args=*/ |
| ArrayAttr::get(ctx, {builder.getIndexAttr(0), |
| emitc::OpaqueAttr::get(ctx, descriptoPtr), |
| builder.getIndexAttr(1), builder.getIndexAttr(2)}), |
| /*templateArgs=*/ArrayAttr{}, |
| /*operands=*/ |
| ArrayRef<Value>{vmModulePtr.getResult(), funcOp.getArgument(0), |
| funcOp.getArgument(1)}); |
| |
| builder.create<mlir::ReturnOp>(loc, status.getResult(0)); |
| } |
| |
| return success(); |
| } |
| |
| SmallVector<Attribute, 4> indexSequence(int64_t n, MLIRContext *ctx) { |
| return llvm::to_vector<4>( |
| llvm::map_range(llvm::seq<int64_t>(0, n), [&ctx](int64_t i) -> Attribute { |
| return IntegerAttr::get(IndexType::get(ctx), i); |
| })); |
| } |
| |
| template <typename AccessOpTy, typename ResultOpTy> |
| ResultOpTy lookupSymbolRef(AccessOpTy accessOp, StringRef attrName) { |
| FlatSymbolRefAttr globalAttr = |
| accessOp.getOperation()->template getAttrOfType<FlatSymbolRefAttr>( |
| attrName); |
| ResultOpTy globalOp = |
| accessOp.getOperation() |
| ->template getParentOfType<IREE::VM::ModuleOp>() |
| .template lookupSymbol<ResultOpTy>(globalAttr.getValue()); |
| return globalOp; |
| } |
| |
| // Convert vm operations to emitc calls. The resultiong call has the ops |
| // operands as arguments followed by an argument for every attribute. |
| template <typename SrcOpTy> |
| class GenericOpConversion : public OpConversionPattern<SrcOpTy> { |
| using OpConversionPattern<SrcOpTy>::OpConversionPattern; |
| |
| public: |
| GenericOpConversion(MLIRContext *context, StringRef funcName) |
| : OpConversionPattern<SrcOpTy>(context), funcName(funcName) {} |
| |
| private: |
| LogicalResult matchAndRewrite( |
| SrcOpTy op, ArrayRef<Value> operands, |
| ConversionPatternRewriter &rewriter) const override { |
| auto ctx = op.getContext(); |
| |
| auto type = op.getOperation()->getResultTypes(); |
| StringAttr callee = StringAttr::get(ctx, funcName); |
| |
| // Default to an empty args attribute, which results in the operands being |
| // printed as the arguments to the function call. |
| ArrayAttr args; |
| ArrayAttr templateArgs; |
| |
| // If the operation has attributes, we need to explicitely build the args |
| // attribute of the emitc call op. This consists of index attributes for |
| // the operands, followed by the source op attributes themselves. |
| if (op->getAttrs().size() > 0) { |
| SmallVector<Attribute, 4> args_ = |
| indexSequence(operands.size(), op.getContext()); |
| |
| for (NamedAttribute attr : op->getAttrs()) { |
| args_.push_back(attr.second); |
| } |
| |
| args = rewriter.getArrayAttr(args_); |
| } |
| |
| rewriter.replaceOpWithNewOp<emitc::CallOp>(op, type, callee, args, |
| templateArgs, operands); |
| |
| return success(); |
| } |
| |
| StringRef funcName; |
| }; |
| |
| class FuncOpConversion : public OpConversionPattern<mlir::FuncOp> { |
| public: |
| using OpConversionPattern<mlir::FuncOp>::OpConversionPattern; |
| |
| FuncOpConversion(TypeConverter &typeConverter, MLIRContext *context, |
| VMAnalysisCache &vmAnalysisCache) |
| : OpConversionPattern<mlir::FuncOp>(typeConverter, context), |
| vmAnalysisCache(vmAnalysisCache) {} |
| |
| private: |
| LogicalResult matchAndRewrite( |
| mlir::FuncOp funcOp, ArrayRef<Value> operands, |
| ConversionPatternRewriter &rewriter) const override { |
| TypeConverter::SignatureConversion signatureConverter( |
| funcOp.getType().getNumInputs()); |
| TypeConverter typeConverter; |
| for (const auto &arg : llvm::enumerate(funcOp.getArgumentTypes())) { |
| Type convertedType = getTypeConverter()->convertType(arg.value()); |
| signatureConverter.addInputs(arg.index(), convertedType); |
| } |
| |
| rewriter.applySignatureConversion(&funcOp.getBody(), signatureConverter); |
| |
| // Creates a new function with the updated signature. |
| rewriter.updateRootInPlace(funcOp, [&] { |
| funcOp.setType( |
| rewriter.getFunctionType(signatureConverter.getConvertedTypes(), |
| funcOp.getType().getResults())); |
| }); |
| return success(); |
| } |
| |
| VMAnalysisCache &vmAnalysisCache; |
| }; |
| |
| class CallOpConversion : public OpConversionPattern<IREE::VM::CallOp> { |
| public: |
| using OpConversionPattern<IREE::VM::CallOp>::OpConversionPattern; |
| |
| CallOpConversion(TypeConverter &typeConverter, MLIRContext *context, |
| VMAnalysisCache &vmAnalysisCache) |
| : OpConversionPattern<IREE::VM::CallOp>(typeConverter, context), |
| vmAnalysisCache(vmAnalysisCache) {} |
| |
| private: |
| LogicalResult matchAndRewrite( |
| IREE::VM::CallOp op, ArrayRef<Value> operands, |
| ConversionPatternRewriter &rewriter) const override { |
| mlir::FuncOp funcOp = |
| lookupSymbolRef<IREE::VM::CallOp, mlir::FuncOp>(op, "callee"); |
| IREE::VM::ImportOp importOp = |
| lookupSymbolRef<IREE::VM::CallOp, IREE::VM::ImportOp>(op, "callee"); |
| |
| if (!funcOp && !importOp) |
| return op.emitError() << "lookup of callee failed"; |
| |
| if (funcOp && importOp) |
| return op.emitError() << "lookup of callee ambiguous"; |
| |
| const bool isImported = importOp != nullptr; |
| |
| return isImported ? rewriteImportedCall(op, operands, rewriter, importOp) |
| : rewriteInternalCall(op, operands, rewriter, funcOp); |
| } |
| |
| LogicalResult rewriteInternalCall(IREE::VM::CallOp op, |
| ArrayRef<Value> operands, |
| ConversionPatternRewriter &rewriter, |
| mlir::FuncOp funcOp) const { |
| auto ctx = op.getContext(); |
| auto loc = op.getLoc(); |
| |
| SmallVector<Value, 4> updatedOperands; |
| SmallVector<Value, 4> resultOperands; |
| |
| auto parentFuncOp = op.getOperation()->getParentOfType<mlir::FuncOp>(); |
| |
| BlockArgument stackArg = parentFuncOp.getArgument(0); |
| BlockArgument moduleArg = parentFuncOp.getArgument(1); |
| BlockArgument moduleStateArg = parentFuncOp.getArgument(2); |
| |
| updatedOperands = {stackArg, moduleArg, moduleStateArg}; |
| |
| if (failed(updateOperands(op, operands, rewriter, updatedOperands, |
| resultOperands))) { |
| return failure(); |
| }; |
| |
| auto callOp = returnIfError( |
| /*rewriter=*/rewriter, |
| /*location=*/loc, |
| /*callee=*/funcOp, |
| /*operands=*/updatedOperands); |
| |
| if (failed(updateResults(op, resultOperands))) { |
| return failure(); |
| } |
| |
| rewriter.eraseOp(op); |
| |
| return success(); |
| } |
| |
| LogicalResult rewriteImportedCall(IREE::VM::CallOp op, |
| ArrayRef<Value> operands, |
| ConversionPatternRewriter &rewriter, |
| IREE::VM::ImportOp importOp) const { |
| auto ctx = op.getContext(); |
| auto loc = op.getLoc(); |
| |
| SmallVector<Value, 4> updatedOperands; |
| SmallVector<Value, 4> resultOperands; |
| |
| auto moduleOp = |
| importOp.getOperation()->getParentOfType<IREE::VM::ModuleOp>(); |
| |
| Optional<std::string> funcName = buildFunctionName(moduleOp, importOp); |
| |
| if (!funcName.hasValue()) |
| return op.emitError() << "Couldn't build name to imported function"; |
| |
| int importOrdinal = importOp.ordinal().getValue().getZExtValue(); |
| |
| auto funcOp = op.getOperation()->template getParentOfType<mlir::FuncOp>(); |
| BlockArgument stackArg = funcOp.getArgument(0); |
| BlockArgument stateArg = funcOp.getArgument(2); |
| |
| auto imports = rewriter.create<emitc::CallOp>( |
| /*location=*/loc, |
| /*type=*/emitc::OpaqueType::get(ctx, "iree_vm_function_t*"), |
| /*callee=*/StringAttr::get(ctx, "EMITC_STRUCT_PTR_MEMBER"), |
| /*args=*/ |
| ArrayAttr::get(ctx, {rewriter.getIndexAttr(0), |
| emitc::OpaqueAttr::get(ctx, "imports")}), |
| /*templateArgs=*/ArrayAttr{}, |
| /*operands=*/ArrayRef<Value>{stateArg}); |
| |
| auto import = rewriter.create<emitc::CallOp>( |
| /*location=*/loc, |
| /*type=*/emitc::OpaqueType::get(ctx, "iree_vm_function_t*"), |
| /*callee=*/StringAttr::get(ctx, "EMITC_ARRAY_ELEMENT_ADDRESS"), |
| /*args=*/ |
| ArrayAttr::get(ctx, {rewriter.getIndexAttr(0), |
| rewriter.getUI32IntegerAttr(importOrdinal)}), |
| /*templateArgs=*/ArrayAttr{}, |
| /*operands=*/ArrayRef<Value>{imports.getResult(0)}); |
| |
| updatedOperands = {stackArg, import.getResult(0)}; |
| |
| if (failed(updateOperands(op, operands, rewriter, updatedOperands, |
| resultOperands))) { |
| return failure(); |
| } |
| |
| auto callOp = returnIfError( |
| /*rewriter=*/rewriter, |
| /*location=*/loc, |
| /*callee=*/StringAttr::get(ctx, funcName.getValue()), |
| /*args=*/ArrayAttr{}, |
| /*templateArgs=*/ArrayAttr{}, |
| /*operands=*/updatedOperands); |
| |
| if (failed(updateResults(op, resultOperands))) { |
| return failure(); |
| } |
| |
| rewriter.eraseOp(op); |
| |
| return success(); |
| } |
| |
| LogicalResult updateOperands(IREE::VM::CallOp op, ArrayRef<Value> operands, |
| ConversionPatternRewriter &rewriter, |
| SmallVector<Value, 4> &updatedOperands, |
| SmallVector<Value, 4> &resultOperands) const { |
| auto ctx = op.getContext(); |
| auto loc = op.getLoc(); |
| |
| auto funcOp = op.getOperation()->template getParentOfType<mlir::FuncOp>(); |
| |
| for (const Value &operand : operands) { |
| updatedOperands.push_back(operand); |
| } |
| |
| // Create a variable for every non-ref result and a pointer to it as output |
| // parameter to the call. |
| for (OpResult result : op.getResults()) { |
| emitc::ConstantOp resultOp; |
| |
| if (result.getType().isa<IREE::VM::RefType>()) { |
| auto ref = findRef(funcOp, vmAnalysisCache, result); |
| |
| if (!ref.hasValue()) { |
| return failure(); |
| } |
| |
| // Keep track of the replaced value in the analysis to keep the value |
| // liveness working. |
| auto ptr = vmAnalysisCache.find(funcOp.getOperation()); |
| if (ptr == vmAnalysisCache.end()) { |
| return op.emitError() << "parent func op not found in cache."; |
| } |
| |
| Optional<std::string> cType = getCType(ref.getValue().getType()); |
| if (!cType.hasValue()) { |
| return op.emitError() << "unable to emit C type"; |
| } |
| |
| std::string cPtrType = cType.getValue() + std::string("*"); |
| auto resultPtrOp = rewriter.create<emitc::ApplyOp>( |
| /*location=*/loc, |
| /*type=*/emitc::OpaqueType::get(ctx, cPtrType), |
| /*applicableOperator=*/StringAttr::get(ctx, "&"), |
| /*operand=*/ref.getValue()); |
| |
| ptr->second.remapValue(result, resultPtrOp.getResult()); |
| |
| resultOperands.push_back(resultPtrOp.getResult()); |
| updatedOperands.push_back(resultPtrOp.getResult()); |
| } else { |
| resultOp = rewriter.create<emitc::ConstantOp>( |
| /*location=*/loc, |
| /*resultType=*/result.getType(), |
| /*value=*/emitc::OpaqueAttr::get(ctx, "")); |
| |
| Optional<std::string> cType = getCType(resultOp.getType()); |
| if (!cType.hasValue()) { |
| return op.emitError() << "unable to emit C type"; |
| } |
| |
| std::string cPtrType = cType.getValue() + std::string("*"); |
| auto resultPtrOp = rewriter.create<emitc::ApplyOp>( |
| /*location=*/loc, |
| /*type=*/emitc::OpaqueType::get(ctx, cPtrType), |
| /*applicableOperator=*/StringAttr::get(ctx, "&"), |
| /*operand=*/resultOp.getResult()); |
| |
| resultOperands.push_back(resultOp.getResult()); |
| updatedOperands.push_back(resultPtrOp.getResult()); |
| } |
| } |
| return success(); |
| } |
| |
| LogicalResult updateResults(IREE::VM::CallOp op, |
| SmallVector<Value, 4> &resultOperands) const { |
| for (auto &pair : llvm::enumerate(op.getResults())) { |
| size_t index = pair.index(); |
| OpResult result = pair.value(); |
| result.replaceAllUsesWith(resultOperands[index]); |
| } |
| |
| return success(); |
| } |
| |
| VMAnalysisCache &vmAnalysisCache; |
| }; |
| |
| template <typename CmpOpTy> |
| class CompareRefOpConversion : public OpConversionPattern<CmpOpTy> { |
| public: |
| using OpConversionPattern<CmpOpTy>::OpConversionPattern; |
| |
| CompareRefOpConversion(MLIRContext *context, StringRef funcName, |
| VMAnalysisCache &vmAnalysisCache) |
| : OpConversionPattern<CmpOpTy>(context), |
| funcName(funcName), |
| vmAnalysisCache(vmAnalysisCache) {} |
| |
| private: |
| LogicalResult matchAndRewrite( |
| CmpOpTy cmpOp, ArrayRef<Value> operands, |
| ConversionPatternRewriter &rewriter) const override { |
| auto ctx = cmpOp.getContext(); |
| auto loc = cmpOp.getLoc(); |
| |
| auto funcOp = |
| cmpOp.getOperation()->template getParentOfType<mlir::FuncOp>(); |
| auto ptr = vmAnalysisCache.find(funcOp.getOperation()); |
| if (ptr == vmAnalysisCache.end()) { |
| return cmpOp.emitError() << "parent func op not found in cache."; |
| } |
| |
| bool moveLhs = |
| ptr->second.isLastValueUse(cmpOp.lhs(), cmpOp.getOperation()); |
| bool moveRhs = |
| ptr->second.isLastValueUse(cmpOp.rhs(), cmpOp.getOperation()); |
| |
| rewriter.replaceOpWithNewOp<emitc::CallOp>( |
| /*op=*/cmpOp, |
| /*type=*/cmpOp.getType(), |
| /*callee=*/StringAttr::get(ctx, funcName), |
| /*args=*/ArrayAttr{}, |
| /*templateArgs=*/ArrayAttr{}, |
| /*operands=*/operands); |
| |
| if (moveLhs) { |
| rewriter.create<emitc::CallOp>( |
| /*location=*/loc, |
| /*type=*/TypeRange{}, |
| /*callee=*/StringAttr::get(ctx, "iree_vm_ref_release"), |
| /*args=*/ArrayAttr{}, |
| /*templateArgs=*/ArrayAttr{}, |
| /*operands=*/ArrayRef<Value>{operands[0]}); |
| } |
| |
| // NOTE: If lhs and rhs alias we call release twice on the same |
| // argument. |
| if (moveRhs) { |
| rewriter.create<emitc::CallOp>( |
| /*location=*/loc, |
| /*type=*/TypeRange{}, |
| /*callee=*/StringAttr::get(ctx, "iree_vm_ref_release"), |
| /*args=*/ArrayAttr{}, |
| /*templateArgs=*/ArrayAttr{}, |
| /*operands=*/ArrayRef<Value>{operands[1]}); |
| } |
| |
| return success(); |
| } |
| |
| StringRef funcName; |
| VMAnalysisCache &vmAnalysisCache; |
| }; |
| |
| class CompareRefNotZeroOpConversion |
| : public OpConversionPattern<IREE::VM::CmpNZRefOp> { |
| using OpConversionPattern<IREE::VM::CmpNZRefOp>::OpConversionPattern; |
| |
| public: |
| CompareRefNotZeroOpConversion(MLIRContext *context, |
| VMAnalysisCache &vmAnalysisCache) |
| : OpConversionPattern<IREE::VM::CmpNZRefOp>(context), |
| vmAnalysisCache(vmAnalysisCache) {} |
| |
| private: |
| LogicalResult matchAndRewrite( |
| IREE::VM::CmpNZRefOp cmpOp, ArrayRef<Value> operands, |
| ConversionPatternRewriter &rewriter) const override { |
| auto ctx = cmpOp.getContext(); |
| auto loc = cmpOp.getLoc(); |
| |
| auto funcOp = cmpOp.getOperation()->getParentOfType<mlir::FuncOp>(); |
| auto ptr = vmAnalysisCache.find(funcOp.getOperation()); |
| if (ptr == vmAnalysisCache.end()) { |
| return cmpOp.emitError() << "parent func op not found in cache."; |
| } |
| |
| bool move = |
| ptr->second.isLastValueUse(cmpOp.operand(), cmpOp.getOperation()); |
| |
| rewriter.replaceOpWithNewOp<emitc::CallOp>( |
| /*op=*/cmpOp, |
| /*type=*/cmpOp.getType(), |
| /*callee=*/StringAttr::get(ctx, "vm_cmp_nz_ref"), |
| /*args=*/ArrayAttr{}, |
| /*templateArgs=*/ArrayAttr{}, |
| /*operands=*/operands); |
| |
| if (move) { |
| rewriter.create<emitc::CallOp>( |
| /*location=*/loc, |
| /*type=*/TypeRange{}, |
| /*callee=*/StringAttr::get(ctx, "iree_vm_ref_release"), |
| /*args=*/ArrayAttr{}, |
| /*templateArgs=*/ArrayAttr{}, |
| /*operands=*/ArrayRef<Value>{operands[0]}); |
| } |
| |
| return success(); |
| } |
| |
| VMAnalysisCache &vmAnalysisCache; |
| }; |
| |
| template <typename ConstOpTy> |
| class ConstOpConversion : public OpRewritePattern<ConstOpTy> { |
| public: |
| using OpRewritePattern<ConstOpTy>::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(ConstOpTy constOp, |
| PatternRewriter &rewriter) const final { |
| rewriter.replaceOpWithNewOp<emitc::ConstantOp>(constOp, constOp.getType(), |
| constOp.value()); |
| return success(); |
| } |
| }; |
| |
| template <typename ConstZeroOpTy> |
| class ConstZeroOpConversion : public OpRewritePattern<ConstZeroOpTy> { |
| public: |
| using OpRewritePattern<ConstZeroOpTy>::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(ConstZeroOpTy constZeroOp, |
| PatternRewriter &rewriter) const final { |
| auto type = constZeroOp.getType(); |
| Attribute value; |
| |
| if (type.template isa<IntegerType>()) { |
| value = rewriter.getIntegerAttr(type, 0); |
| } else if (type.template isa<FloatType>()) { |
| value = rewriter.getFloatAttr(type, 0.0); |
| } else { |
| return failure(); |
| } |
| |
| rewriter.replaceOpWithNewOp<emitc::ConstantOp>(constZeroOp, type, value); |
| return success(); |
| } |
| }; |
| |
| class ConstRefZeroOpConversion |
| : public OpRewritePattern<IREE::VM::ConstRefZeroOp> { |
| public: |
| using OpRewritePattern<IREE::VM::ConstRefZeroOp>::OpRewritePattern; |
| |
| ConstRefZeroOpConversion(MLIRContext *context, |
| VMAnalysisCache &vmAnalysisCache) |
| : OpRewritePattern<IREE::VM::ConstRefZeroOp>(context), |
| vmAnalysisCache(vmAnalysisCache) {} |
| |
| LogicalResult matchAndRewrite(IREE::VM::ConstRefZeroOp constRefZeroOp, |
| PatternRewriter &rewriter) const final { |
| auto ctx = constRefZeroOp.getContext(); |
| auto loc = constRefZeroOp.getLoc(); |
| |
| auto funcOp = |
| constRefZeroOp.getOperation()->getParentOfType<mlir::FuncOp>(); |
| |
| auto ref = findRef(funcOp, vmAnalysisCache, constRefZeroOp.getResult()); |
| |
| if (!ref.hasValue()) { |
| return failure(); |
| } |
| |
| auto refPtrOp = rewriter.replaceOpWithNewOp<emitc::ApplyOp>( |
| /*op=*/constRefZeroOp, |
| /*type=*/emitc::OpaqueType::get(ctx, "iree_vm_ref_t*"), |
| /*applicableOperator=*/StringAttr::get(ctx, "&"), |
| /*operand=*/ref.getValue()); |
| |
| rewriter.create<emitc::CallOp>( |
| /*location=*/loc, |
| /*type=*/TypeRange{}, |
| /*callee=*/StringAttr::get(ctx, "iree_vm_ref_release"), |
| /*args=*/ArrayAttr{}, |
| /*templateArgs=*/ArrayAttr{}, |
| /*operands=*/ArrayRef<Value>{refPtrOp.result()}); |
| return success(); |
| } |
| |
| VMAnalysisCache &vmAnalysisCache; |
| }; |
| |
| class ConstRefRodataOpConversion |
| : public OpConversionPattern<IREE::VM::ConstRefRodataOp> { |
| public: |
| using OpConversionPattern<IREE::VM::ConstRefRodataOp>::OpConversionPattern; |
| |
| ConstRefRodataOpConversion(MLIRContext *context, |
| VMAnalysisCache &vmAnalysisCache) |
| : OpConversionPattern<IREE::VM::ConstRefRodataOp>(context), |
| vmAnalysisCache(vmAnalysisCache) {} |
| |
| LogicalResult matchAndRewrite( |
| IREE::VM::ConstRefRodataOp constRefRodataOp, ArrayRef<Value> operands, |
| ConversionPatternRewriter &rewriter) const final { |
| auto ctx = constRefRodataOp.getContext(); |
| auto loc = constRefRodataOp.getLoc(); |
| |
| auto rodataOp = |
| lookupSymbolRef<IREE::VM::ConstRefRodataOp, IREE::VM::RodataOp>( |
| constRefRodataOp, "rodata"); |
| if (!rodataOp) { |
| return constRefRodataOp.emitError() << "Unable to find RodataOp"; |
| } |
| |
| auto funcOp = constRefRodataOp.getOperation() |
| ->template getParentOfType<mlir::FuncOp>(); |
| |
| BlockArgument stateArg = funcOp.getArgument(2); |
| auto rodataBuffersPtr = rewriter.create<emitc::CallOp>( |
| /*location=*/loc, |
| /*type=*/emitc::OpaqueType::get(ctx, "iree_vm_buffer_t*"), |
| /*callee=*/StringAttr::get(ctx, "EMITC_STRUCT_PTR_MEMBER"), |
| /*args=*/ |
| ArrayAttr::get(ctx, {rewriter.getIndexAttr(0), |
| emitc::OpaqueAttr::get(ctx, "rodata_buffers")}), |
| /*templateArgs=*/ArrayAttr{}, |
| /*operands=*/ArrayRef<Value>{stateArg}); |
| |
| auto byteBufferPtrOp = rewriter.create<emitc::CallOp>( |
| /*location=*/loc, |
| /*type=*/emitc::OpaqueType::get(ctx, "iree_vm_buffer_t*"), |
| /*callee=*/StringAttr::get(ctx, "EMITC_ARRAY_ELEMENT_ADDRESS"), |
| /*args=*/ |
| ArrayAttr::get(ctx, |
| {rewriter.getIndexAttr(0), |
| rewriter.getUI32IntegerAttr(static_cast<uint32_t>( |
| rodataOp.ordinal().getValue().getZExtValue()))}), |
| /*templateArgs=*/ArrayAttr{}, |
| /*operands=*/ArrayRef<Value>{rodataBuffersPtr.getResult(0)}); |
| |
| auto typeIdOp = rewriter.create<emitc::CallOp>( |
| /*location=*/loc, |
| /*type=*/emitc::OpaqueType::get(ctx, "iree_vm_ref_type_t"), |
| /*callee=*/StringAttr::get(ctx, "iree_vm_buffer_type_id"), |
| /*args=*/ArrayAttr{}, |
| /*templateArgs=*/ArrayAttr{}, |
| /*operands=*/ArrayRef<Value>{}); |
| |
| auto ref = findRef(funcOp, vmAnalysisCache, constRefRodataOp.getResult()); |
| |
| if (!ref.hasValue()) return failure(); |
| |
| auto refPtrOp = rewriter.replaceOpWithNewOp<emitc::ApplyOp>( |
| /*op=*/constRefRodataOp, |
| /*type=*/emitc::OpaqueType::get(ctx, "iree_vm_ref_t*"), |
| /*applicableOperator=*/StringAttr::get(ctx, "&"), |
| /*operand=*/ref.getValue()); |
| |
| returnIfError( |
| /*rewriter=*/rewriter, |
| /*location=*/loc, |
| /*callee=*/StringAttr::get(ctx, "iree_vm_ref_wrap_retain"), |
| /*args=*/ArrayAttr{}, |
| /*templateArgs=*/ArrayAttr{}, |
| /*operands=*/ |
| ArrayRef<Value>{byteBufferPtrOp.getResult(0), typeIdOp.getResult(0), |
| refPtrOp.getResult()}); |
| |
| return success(); |
| } |
| |
| VMAnalysisCache &vmAnalysisCache; |
| }; |
| |
| class BranchOpConversion : public OpConversionPattern<IREE::VM::BranchOp> { |
| using OpConversionPattern<IREE::VM::BranchOp>::OpConversionPattern; |
| |
| private: |
| LogicalResult matchAndRewrite( |
| IREE::VM::BranchOp op, ArrayRef<Value> operands, |
| ConversionPatternRewriter &rewriter) const override { |
| auto ctx = op.getContext(); |
| auto loc = op.getLoc(); |
| |
| rewriter.replaceOpWithNewOp<mlir::BranchOp>(op, op.dest(), |
| op.destOperands()); |
| |
| return success(); |
| } |
| }; |
| |
| class CondBranchOpConversion |
| : public OpConversionPattern<IREE::VM::CondBranchOp> { |
| using OpConversionPattern<IREE::VM::CondBranchOp>::OpConversionPattern; |
| |
| private: |
| LogicalResult matchAndRewrite( |
| IREE::VM::CondBranchOp op, ArrayRef<Value> operands, |
| ConversionPatternRewriter &rewriter) const override { |
| auto ctx = op.getContext(); |
| auto loc = op.getLoc(); |
| |
| Type boolType = rewriter.getI1Type(); |
| |
| auto condition = rewriter.create<IREE::VM::CmpNZI32Op>( |
| loc, rewriter.getI32Type(), op.condition()); |
| auto conditionI1 = rewriter.create<emitc::CallOp>( |
| /*location=*/loc, |
| /*type=*/boolType, |
| /*callee=*/StringAttr::get(ctx, "EMITC_CAST"), |
| /*args=*/ |
| ArrayAttr::get(ctx, |
| {rewriter.getIndexAttr(0), TypeAttr::get(boolType)}), |
| /*templateArgs=*/ArrayAttr{}, |
| /*operands=*/ArrayRef<Value>{condition.getResult()}); |
| |
| rewriter.replaceOpWithNewOp<mlir::CondBranchOp>( |
| op, conditionI1.getResult(0), op.trueDest(), op.trueDestOperands(), |
| op.falseDest(), op.falseDestOperands()); |
| |
| return success(); |
| } |
| }; |
| |
| class ReturnOpConversion : public OpConversionPattern<IREE::VM::ReturnOp> { |
| using OpConversionPattern<IREE::VM::ReturnOp>::OpConversionPattern; |
| |
| private: |
| LogicalResult matchAndRewrite( |
| IREE::VM::ReturnOp op, ArrayRef<Value> operands, |
| ConversionPatternRewriter &rewriter) const override { |
| auto ctx = op.getContext(); |
| auto loc = op.getLoc(); |
| |
| auto funcOp = op.getOperation()->getParentOfType<mlir::FuncOp>(); |
| |
| releaseLocalRefs(rewriter, loc, funcOp); |
| |
| // The result variables are the last N arguments of the function. |
| unsigned int firstOutputArgumentIndex = |
| funcOp.getNumArguments() - operands.size(); |
| |
| for (auto &operand : llvm::enumerate(operands)) { |
| unsigned int argumentIndex = firstOutputArgumentIndex + operand.index(); |
| BlockArgument resultArgument = funcOp.getArgument(argumentIndex); |
| |
| auto isRef = [&ctx](Type type) { |
| return type == emitc::OpaqueType::get(ctx, "iree_vm_ref_t*"); |
| }; |
| |
| StringRef assignMacro = isRef(operand.value().getType()) |
| ? "EMITC_ASSIGN_VALUE" |
| : "EMITC_DEREF_ASSIGN_VALUE"; |
| |
| rewriter.create<emitc::CallOp>( |
| /*location=*/loc, |
| /*type=*/TypeRange{}, |
| /*callee=*/StringAttr::get(ctx, assignMacro), |
| /*args=*/ArrayAttr{}, |
| /*templateArgs=*/ArrayAttr{}, |
| /*operands=*/ArrayRef<Value>{resultArgument, operand.value()}); |
| } |
| |
| auto status = rewriter.create<emitc::CallOp>( |
| /*location=*/loc, |
| /*type=*/emitc::OpaqueType::get(ctx, "iree_status_t"), |
| /*callee=*/StringAttr::get(ctx, "iree_ok_status"), |
| /*args=*/ArrayAttr{}, |
| /*templateArgs=*/ArrayAttr{}, |
| /*operands=*/ArrayRef<Value>{}); |
| |
| rewriter.replaceOpWithNewOp<mlir::ReturnOp>(op, status.getResult(0)); |
| |
| return success(); |
| } |
| }; |
| |
| class FailOpConversion : public OpConversionPattern<IREE::VM::FailOp> { |
| using OpConversionPattern<IREE::VM::FailOp>::OpConversionPattern; |
| |
| private: |
| LogicalResult matchAndRewrite( |
| IREE::VM::FailOp op, ArrayRef<Value> operands, |
| ConversionPatternRewriter &rewriter) const override { |
| auto ctx = op.getContext(); |
| auto loc = op.getLoc(); |
| |
| Block *block = rewriter.getInsertionBlock(); |
| Region *parentRegion = block->getParent(); |
| Block *passthroughBlock; |
| { |
| OpBuilder::InsertionGuard guard(rewriter); |
| passthroughBlock = |
| rewriter.createBlock(parentRegion, parentRegion->end()); |
| |
| auto funcOp = op.getOperation()->getParentOfType<mlir::FuncOp>(); |
| |
| releaseLocalRefs(rewriter, loc, funcOp); |
| |
| auto status = rewriter.create<emitc::CallOp>( |
| /*location=*/loc, |
| /*type=*/emitc::OpaqueType::get(ctx, "iree_status_t"), |
| /*callee=*/StringAttr::get(ctx, "iree_ok_status"), |
| /*args=*/ArrayAttr{}, |
| /*templateArgs=*/ArrayAttr{}, |
| /*operands=*/ArrayRef<Value>{}); |
| |
| rewriter.create<mlir::ReturnOp>(loc, status.getResult(0)); |
| } |
| Block *failureBlock; |
| { |
| OpBuilder::InsertionGuard guard(rewriter); |
| failureBlock = rewriter.createBlock(parentRegion, parentRegion->end()); |
| |
| auto funcOp = op.getOperation()->getParentOfType<mlir::FuncOp>(); |
| |
| releaseLocalRefs(rewriter, loc, funcOp); |
| |
| std::string message = std::string("\"") + |
| op.message().getValueOr("").str() + |
| std::string("\""); |
| |
| auto messageOp = rewriter.create<emitc::CallOp>( |
| /*location=*/loc, |
| /*type=*/emitc::OpaqueType::get(ctx, "iree_string_view_t"), |
| /*callee=*/StringAttr::get(ctx, "iree_make_cstring_view"), |
| /*args=*/ |
| ArrayAttr::get(ctx, {emitc::OpaqueAttr::get(ctx, message)}), |
| /*templateArgs=*/ArrayAttr{}, |
| /*operands=*/ArrayRef<Value>{}); |
| |
| auto messageSizeOp = rewriter.create<emitc::CallOp>( |
| /*location=*/loc, |
| /*type=*/emitc::OpaqueType::get(ctx, "iree_host_size_t"), |
| /*callee=*/StringAttr::get(ctx, "EMITC_STRUCT_MEMBER"), |
| /*args=*/ |
| ArrayAttr::get(ctx, {rewriter.getIndexAttr(0), |
| emitc::OpaqueAttr::get(ctx, "size")}), |
| /*templateArgs=*/ArrayAttr{}, |
| /*operands=*/ArrayRef<Value>{messageOp.getResult(0)}); |
| |
| auto messageSizeIntOp = rewriter.create<emitc::CallOp>( |
| /*location=*/loc, |
| /*type=*/emitc::OpaqueType::get(ctx, "int"), |
| /*callee=*/StringAttr::get(ctx, "EMITC_CAST"), |
| /*args=*/ |
| ArrayAttr::get(ctx, {rewriter.getIndexAttr(0), |
| emitc::OpaqueAttr::get(ctx, "int")}), |
| /*templateArgs=*/ArrayAttr{}, |
| /*operands=*/ArrayRef<Value>{messageSizeOp.getResult(0)}); |
| |
| auto messageDataOp = rewriter.create<emitc::CallOp>( |
| /*location=*/loc, |
| /*type=*/emitc::OpaqueType::get(ctx, "const char*"), |
| /*callee=*/StringAttr::get(ctx, "EMITC_STRUCT_MEMBER"), |
| /*args=*/ |
| ArrayAttr::get(ctx, {rewriter.getIndexAttr(0), |
| emitc::OpaqueAttr::get(ctx, "data")}), |
| /*templateArgs=*/ArrayAttr{}, |
| /*operands=*/ArrayRef<Value>{messageOp.getResult(0)}); |
| |
| auto status = rewriter.create<emitc::CallOp>( |
| /*location=*/loc, |
| /*type=*/emitc::OpaqueType::get(ctx, "iree_status_t"), |
| /*callee=*/StringAttr::get(ctx, "iree_status_allocate_f"), |
| /*args=*/ |
| ArrayAttr::get( |
| ctx, |
| {emitc::OpaqueAttr::get(ctx, "IREE_STATUS_FAILED_PRECONDITION"), |
| emitc::OpaqueAttr::get(ctx, "\"<vm>\""), |
| rewriter.getI32IntegerAttr(0), |
| emitc::OpaqueAttr::get(ctx, "\"%.*s\""), |
| rewriter.getIndexAttr(0), rewriter.getIndexAttr(1)}), |
| /*templateArgs=*/ArrayAttr{}, |
| /*operands=*/ |
| ArrayRef<Value>{messageSizeIntOp.getResult(0), |
| messageDataOp.getResult(0)}); |
| |
| rewriter.create<mlir::ReturnOp>(loc, status.getResult(0)); |
| } |
| |
| rewriter.replaceOpWithNewOp<IREE::VM::CondBranchOp>( |
| op, op.status(), failureBlock, passthroughBlock); |
| |
| return success(); |
| } |
| }; |
| |
| template <typename LoadOpTy, typename GlobalOpTy> |
| class GlobalLoadOpConversion : public OpConversionPattern<LoadOpTy> { |
| using OpConversionPattern<LoadOpTy>::OpConversionPattern; |
| |
| public: |
| GlobalLoadOpConversion(MLIRContext *context, StringRef funcName) |
| : OpConversionPattern<LoadOpTy>(context), funcName(funcName) {} |
| |
| private: |
| LogicalResult matchAndRewrite( |
| LoadOpTy loadOp, ArrayRef<Value> operands, |
| ConversionPatternRewriter &rewriter) const override { |
| auto ctx = loadOp.getContext(); |
| auto loc = loadOp.getLoc(); |
| |
| GlobalOpTy globalOp = |
| lookupSymbolRef<LoadOpTy, GlobalOpTy>(loadOp, "global"); |
| if (!globalOp) { |
| return loadOp.emitError() << "Unable to find GlobalOp"; |
| } |
| |
| auto funcOp = |
| loadOp.getOperation()->template getParentOfType<mlir::FuncOp>(); |
| |
| BlockArgument stateArg = funcOp.getArgument(2); |
| auto rwDataPtr = rewriter.create<emitc::CallOp>( |
| /*location=*/loc, |
| /*type=*/emitc::OpaqueType::get(ctx, "uint8_t*"), |
| /*callee=*/StringAttr::get(ctx, "EMITC_STRUCT_PTR_MEMBER"), |
| /*args=*/ |
| ArrayAttr::get(ctx, {rewriter.getIndexAttr(0), |
| emitc::OpaqueAttr::get(ctx, "rwdata")}), |
| /*templateArgs=*/ArrayAttr{}, |
| /*operands=*/ArrayRef<Value>{stateArg}); |
| |
| rewriter.replaceOpWithNewOp<emitc::CallOp>( |
| /*op=*/loadOp, |
| /*type=*/loadOp.getOperation()->getResultTypes(), |
| /*callee=*/StringAttr::get(ctx, funcName), |
| /*args=*/ |
| rewriter.getArrayAttr( |
| {rewriter.getIndexAttr(0), |
| rewriter.getUI32IntegerAttr(static_cast<uint32_t>( |
| globalOp.ordinal().getValue().getZExtValue()))}), |
| /*templateArgs=*/ArrayAttr{}, |
| /*operands=*/ArrayRef<Value>{rwDataPtr.getResult(0)}); |
| |
| return success(); |
| } |
| |
| StringRef funcName; |
| }; |
| |
| template <typename StoreOpTy, typename GlobalOpTy> |
| class GlobalStoreOpConversion : public OpConversionPattern<StoreOpTy> { |
| using OpConversionPattern<StoreOpTy>::OpConversionPattern; |
| |
| public: |
| GlobalStoreOpConversion(MLIRContext *context, StringRef funcName) |
| : OpConversionPattern<StoreOpTy>(context), funcName(funcName) {} |
| |
| private: |
| LogicalResult matchAndRewrite( |
| StoreOpTy storeOp, ArrayRef<Value> operands, |
| ConversionPatternRewriter &rewriter) const override { |
| auto ctx = storeOp.getContext(); |
| auto loc = storeOp.getLoc(); |
| |
| GlobalOpTy globalOp = |
| lookupSymbolRef<StoreOpTy, GlobalOpTy>(storeOp, "global"); |
| if (!globalOp) { |
| return storeOp.emitError() << "Unable to find GlobalOp"; |
| } |
| |
| auto funcOp = |
| storeOp.getOperation()->template getParentOfType<mlir::FuncOp>(); |
| |
| BlockArgument stateArg = funcOp.getArgument(2); |
| auto rwDataPtr = rewriter.create<emitc::CallOp>( |
| /*location=*/loc, |
| /*type=*/emitc::OpaqueType::get(ctx, "uint8_t*"), |
| /*callee=*/StringAttr::get(ctx, "EMITC_STRUCT_PTR_MEMBER"), |
| /*args=*/ |
| ArrayAttr::get(ctx, {rewriter.getIndexAttr(0), |
| emitc::OpaqueAttr::get(ctx, "rwdata")}), |
| /*templateArgs=*/ArrayAttr{}, |
| /*operands=*/ArrayRef<Value>{stateArg}); |
| |
| rewriter.replaceOpWithNewOp<emitc::CallOp>( |
| /*op=*/storeOp, |
| /*type=*/storeOp.getOperation()->getResultTypes(), |
| /*callee=*/StringAttr::get(ctx, funcName), |
| /*args=*/ |
| rewriter.getArrayAttr( |
| {rewriter.getIndexAttr(0), |
| rewriter.getUI32IntegerAttr(static_cast<uint32_t>( |
| globalOp.ordinal().getValue().getZExtValue())), |
| rewriter.getIndexAttr(1)}), |
| /*templateArgs=*/ArrayAttr{}, |
| /*operands=*/ArrayRef<Value>{rwDataPtr.getResult(0), operands[0]}); |
| |
| return success(); |
| } |
| |
| StringRef funcName; |
| }; |
| |
| // Convert vm list operations to two emitc calls. The wrapping ref pointer |
| // is first dereferenced and the result is used as the argument of the |
| // specified function name. |
| template <typename SrcOpTy> |
| class ListOpConversion : public OpConversionPattern<SrcOpTy> { |
| using OpConversionPattern<SrcOpTy>::OpConversionPattern; |
| |
| public: |
| ListOpConversion(MLIRContext *context, StringRef funcName, |
| size_t listArgumentIndex, bool failable) |
| : OpConversionPattern<SrcOpTy>(context), |
| funcName(funcName), |
| listArgumentIndex(listArgumentIndex), |
| failable(failable) {} |
| |
| private: |
| LogicalResult matchAndRewrite( |
| SrcOpTy op, ArrayRef<Value> operands, |
| ConversionPatternRewriter &rewriter) const override { |
| auto ctx = op.getContext(); |
| auto loc = op.getLoc(); |
| |
| if (listArgumentIndex >= operands.size()) { |
| return op.emitError() << " index for list argument out of range"; |
| } |
| |
| Value listOperand = operands[listArgumentIndex]; |
| |
| // deref |
| auto refOp = rewriter.create<emitc::ApplyOp>( |
| /*location=*/loc, |
| /*type=*/emitc::OpaqueType::get(ctx, "iree_vm_ref_t"), |
| /*applicableOperator=*/StringAttr::get(ctx, "*"), |
| /*operand=*/listOperand); |
| |
| auto listDerefOp = failListNull( |
| /*rewriter=*/rewriter, |
| /*location=*/loc, |
| /*type=*/emitc::OpaqueType::get(ctx, "iree_vm_list_t*"), |
| /*callee=*/StringAttr::get(ctx, "iree_vm_list_deref"), |
| /*args=*/ArrayAttr{}, |
| /*templateArgs=*/ArrayAttr{}, |
| /*operands=*/ArrayRef<Value>{refOp.getResult()}); |
| |
| // Replace the one list argument (which is wrapped in a ref) with the |
| // unwrapped list. |
| SmallVector<Value, 4> updatedOperands; |
| for (auto &operand : llvm::enumerate(operands)) { |
| if (operand.index() == listArgumentIndex) { |
| updatedOperands.push_back(listDerefOp.getResult(0)); |
| } else { |
| updatedOperands.push_back(operand.value()); |
| } |
| } |
| |
| if (failable) { |
| auto callOp = returnIfError( |
| /*rewriter=*/rewriter, |
| /*location=*/loc, |
| /*callee=*/StringAttr::get(ctx, funcName), |
| /*args=*/ArrayAttr{}, |
| /*templateArgs=*/ArrayAttr{}, |
| /*operands=*/ArrayRef<Value>(updatedOperands)); |
| |
| rewriter.replaceOp(op, ArrayRef<Value>{}); |
| } else { |
| rewriter.replaceOpWithNewOp<emitc::CallOp>( |
| /*op=*/op, |
| /*type=*/op.getOperation()->getResultTypes(), |
| /*callee=*/StringAttr::get(ctx, funcName), |
| /*args=*/ArrayAttr{}, |
| /*templateArgs=*/ArrayAttr{}, |
| /*operands=*/ArrayRef<Value>(updatedOperands)); |
| } |
| |
| return success(); |
| } |
| |
| StringRef funcName; |
| |
| // The index of the list argument. This gets replaced in the conversion. |
| size_t listArgumentIndex; |
| |
| // Whether the function call can fail, i.e. it returns an iree_status_t. |
| bool failable; |
| }; |
| |
| class ListAllocOpConversion |
| : public OpConversionPattern<IREE::VM::ListAllocOp> { |
| public: |
| using OpConversionPattern<IREE::VM::ListAllocOp>::OpConversionPattern; |
| |
| ListAllocOpConversion(TypeConverter &typeConverter, MLIRContext *context, |
| VMAnalysisCache &vmAnalysisCache) |
| : OpConversionPattern<IREE::VM::ListAllocOp>(typeConverter, context), |
| vmAnalysisCache(vmAnalysisCache) {} |
| |
| private: |
| LogicalResult matchAndRewrite( |
| IREE::VM::ListAllocOp allocOp, ArrayRef<Value> operands, |
| ConversionPatternRewriter &rewriter) const override { |
| auto ctx = allocOp.getContext(); |
| auto loc = allocOp.getLoc(); |
| |
| Type convertedType = typeConverter->convertType(allocOp.getType()); |
| |
| if (!convertedType) { |
| return allocOp.emitOpError() << "type conversion failed"; |
| } |
| |
| auto elementType = allocOp.getType() |
| .cast<IREE::VM::RefType>() |
| .getObjectType() |
| .cast<IREE::VM::ListType>() |
| .getElementType(); |
| |
| Optional<emitc::ApplyOp> elementTypePtrOp = |
| createVmTypeDefPtr(rewriter, allocOp, elementType); |
| |
| if (!elementTypePtrOp.hasValue()) { |
| return failure(); |
| } |
| |
| auto listOp = rewriter.create<emitc::ConstantOp>( |
| /*location=*/loc, |
| /*resultType=*/emitc::OpaqueType::get(ctx, "iree_vm_list_t*"), |
| /*value=*/emitc::OpaqueAttr::get(ctx, "NULL")); |
| |
| auto listPtrOp = rewriter.create<emitc::ApplyOp>( |
| /*location=*/loc, |
| /*result=*/emitc::OpaqueType::get(ctx, "iree_vm_list_t**"), |
| /*applicableOperator=*/StringAttr::get(ctx, "&"), |
| /*operand=*/listOp.getResult()); |
| |
| auto funcOp = |
| allocOp.getOperation()->template getParentOfType<mlir::FuncOp>(); |
| |
| BlockArgument stateArg = funcOp.getArgument(2); |
| auto allocatorOp = rewriter.create<emitc::CallOp>( |
| /*location=*/loc, |
| /*type=*/emitc::OpaqueType::get(ctx, "iree_allocator_t"), |
| /*callee=*/StringAttr::get(ctx, "EMITC_STRUCT_PTR_MEMBER"), |
| /*args=*/ |
| ArrayAttr::get(ctx, {rewriter.getIndexAttr(0), |
| emitc::OpaqueAttr::get(ctx, "allocator")}), |
| /*templateArgs=*/ArrayAttr{}, |
| /*operands=*/ArrayRef<Value>{stateArg}); |
| |
| returnIfError( |
| /*rewriter=*/rewriter, |
| /*location=*/loc, |
| /*callee=*/StringAttr::get(ctx, "iree_vm_list_create"), |
| /*args=*/ArrayAttr{}, |
| /*templateArgs=*/ArrayAttr{}, |
| /*operands=*/ |
| ArrayRef<Value>{elementTypePtrOp.getValue().getResult(), operands[0], |
| allocatorOp.getResult(0), listPtrOp.getResult()}); |
| |
| auto ref = findRef(funcOp, vmAnalysisCache, allocOp.getResult()); |
| |
| if (!ref.hasValue()) return failure(); |
| |
| auto refPtrOp = rewriter.replaceOpWithNewOp<emitc::ApplyOp>( |
| /*op=*/allocOp, |
| /*type=*/emitc::OpaqueType::get(ctx, "iree_vm_ref_t*"), |
| /*applicableOperator=*/StringAttr::get(ctx, "&"), |
| /*operand=*/ref.getValue()); |
| |
| auto refTypeOp = rewriter.create<emitc::CallOp>( |
| /*location=*/loc, |
| /*type=*/emitc::OpaqueType::get(ctx, "iree_vm_ref_type_t"), |
| /*callee=*/StringAttr::get(ctx, "iree_vm_list_type_id"), |
| /*args=*/ArrayAttr{}, |
| /*templateArgs=*/ArrayAttr{}, |
| /*operands=*/ArrayRef<Value>{}); |
| |
| returnIfError( |
| /*rewriter=*/rewriter, |
| /*location=*/loc, |
| /*callee=*/StringAttr::get(ctx, "iree_vm_ref_wrap_assign"), |
| /*args=*/ArrayAttr{}, |
| /*templateArgs=*/ArrayAttr{}, |
| /*operands=*/ |
| ArrayRef<Value>{listOp.getResult(), refTypeOp.getResult(0), |
| refPtrOp.getResult()}); |
| |
| return success(); |
| } |
| |
| VMAnalysisCache &vmAnalysisCache; |
| }; |
| |
| template <typename GetOpTy> |
| class ListGetOpConversion : public OpConversionPattern<GetOpTy> { |
| using OpConversionPattern<GetOpTy>::OpConversionPattern; |
| |
| private: |
| LogicalResult matchAndRewrite( |
| GetOpTy getOp, ArrayRef<Value> operands, |
| ConversionPatternRewriter &rewriter) const override { |
| auto ctx = getOp.getContext(); |
| auto loc = getOp.getLoc(); |
| |
| Optional<StringRef> valueTypeEnum; |
| Optional<StringRef> valueExtractor; |
| |
| std::tie(valueTypeEnum, valueExtractor) = |
| TypeSwitch<Operation *, |
| std::pair<Optional<StringRef>, Optional<StringRef>>>( |
| getOp.getOperation()) |
| .Case<IREE::VM::ListGetI32Op>([&](auto op) { |
| return std::make_pair(StringRef("IREE_VM_VALUE_TYPE_I32"), |
| StringRef("iree_vm_value_get_i32")); |
| }) |
| .template Case<IREE::VM::ListGetI64Op>([&](auto op) { |
| return std::make_pair(StringRef("IREE_VM_VALUE_TYPE_I64"), |
| StringRef("iree_vm_value_get_i64")); |
| }) |
| .Default([](Operation *) { return std::make_pair(None, None); }); |
| |
| if (!valueTypeEnum.hasValue() || !valueExtractor.hasValue()) { |
| return getOp.emitOpError() << "element type not handled"; |
| } |
| |
| auto valueOp = rewriter.create<emitc::ConstantOp>( |
| /*location=*/loc, |
| /*resultType=*/emitc::OpaqueType::get(ctx, "iree_vm_value_t"), |
| /*value=*/emitc::OpaqueAttr::get(ctx, "")); |
| |
| auto valuePtrOp = rewriter.create<emitc::ApplyOp>( |
| /*location=*/loc, |
| /*result=*/emitc::OpaqueType::get(ctx, "iree_vm_value_t*"), |
| /*applicableOperator=*/StringAttr::get(ctx, "&"), |
| /*operand=*/valueOp.getResult()); |
| |
| auto refOp = rewriter.create<emitc::ApplyOp>( |
| /*location=*/loc, |
| /*type=*/emitc::OpaqueType::get(ctx, "iree_vm_ref_t"), |
| /*applicableOperator=*/StringAttr::get(ctx, "*"), |
| /*operand=*/operands[0]); |
| |
| auto listDerefOp = failListNull( |
| /*rewriter=*/rewriter, |
| /*location=*/loc, |
| /*type=*/emitc::OpaqueType::get(ctx, "iree_vm_list_t*"), |
| /*callee=*/StringAttr::get(ctx, "iree_vm_list_deref"), |
| /*args=*/ArrayAttr{}, |
| /*templateArgs=*/ArrayAttr{}, |
| /*operands=*/ArrayRef<Value>{refOp.getResult()}); |
| |
| auto getValueOp = returnIfError( |
| /*rewriter=*/rewriter, |
| /*location=*/loc, |
| /*callee=*/StringAttr::get(ctx, "iree_vm_list_get_value_as"), |
| /*args=*/ |
| ArrayAttr::get(ctx, |
| {rewriter.getIndexAttr(0), rewriter.getIndexAttr(1), |
| emitc::OpaqueAttr::get(ctx, valueTypeEnum.getValue()), |
| rewriter.getIndexAttr(2)}), |
| /*templateArgs=*/ArrayAttr{}, |
| /*operands=*/ |
| ArrayRef<Value>{listDerefOp.getResult(0), getOp.index(), |
| valuePtrOp.getResult()}); |
| |
| rewriter.replaceOpWithNewOp<emitc::CallOp>( |
| /*op=*/getOp, |
| /*type=*/getOp.getType(), |
| /*callee=*/StringAttr::get(ctx, valueExtractor.getValue()), |
| /*args=*/ArrayAttr{}, |
| /*templateArgs=*/ArrayAttr{}, |
| /*operands=*/ArrayRef<Value>{valuePtrOp.getResult()}); |
| |
| return success(); |
| } |
| }; |
| |
| class ListGetRefOpConversion |
| : public OpConversionPattern<IREE::VM::ListGetRefOp> { |
| public: |
| using OpConversionPattern<IREE::VM::ListGetRefOp>::OpConversionPattern; |
| |
| ListGetRefOpConversion(MLIRContext *context, VMAnalysisCache &vmAnalysisCache) |
| : OpConversionPattern<IREE::VM::ListGetRefOp>(context), |
| vmAnalysisCache(vmAnalysisCache) {} |
| |
| private: |
| LogicalResult matchAndRewrite( |
| IREE::VM::ListGetRefOp getOp, ArrayRef<Value> operands, |
| ConversionPatternRewriter &rewriter) const override { |
| auto ctx = getOp.getContext(); |
| auto loc = getOp.getLoc(); |
| |
| auto listRefOp = rewriter.create<emitc::ApplyOp>( |
| /*location=*/loc, |
| /*type=*/emitc::OpaqueType::get(ctx, "iree_vm_ref_t"), |
| /*applicableOperator=*/StringAttr::get(ctx, "*"), |
| /*operand=*/operands[0]); |
| |
| auto listDerefOp = failListNull( |
| /*rewriter=*/rewriter, |
| /*location=*/loc, |
| /*type=*/emitc::OpaqueType::get(ctx, "iree_vm_list_t*"), |
| /*callee=*/StringAttr::get(ctx, "iree_vm_list_deref"), |
| /*args=*/ArrayAttr{}, |
| /*templateArgs=*/ArrayAttr{}, |
| /*operands=*/ArrayRef<Value>{listRefOp.getResult()}); |
| |
| auto funcOp = getOp.getOperation()->getParentOfType<mlir::FuncOp>(); |
| |
| auto ref = findRef(funcOp, vmAnalysisCache, getOp.getResult()); |
| |
| if (!ref.hasValue()) return failure(); |
| |
| auto refPtrOp = rewriter.replaceOpWithNewOp<emitc::ApplyOp>( |
| /*op=*/getOp, |
| /*type=*/emitc::OpaqueType::get(ctx, "iree_vm_ref_t*"), |
| /*applicableOperator=*/StringAttr::get(ctx, "&"), |
| /*operand=*/ref.getValue()); |
| |
| returnIfError( |
| /*rewriter=*/rewriter, |
| /*location=*/loc, |
| /*callee=*/StringAttr::get(ctx, "iree_vm_list_get_ref_retain"), |
| /*args=*/ArrayAttr{}, |
| /*templateArgs=*/ArrayAttr{}, |
| /*operands=*/ |
| ArrayRef<Value>{listDerefOp.getResult(0), getOp.index(), |
| refPtrOp.getResult()}); |
| |
| Type elementType = getOp.getResult().getType(); |
| |
| auto elementTypePtrOp = createVmTypeDefPtr(rewriter, getOp, elementType); |
| |
| if (!elementTypePtrOp.hasValue()) { |
| return failure(); |
| } |
| |
| // Build the following expression: |
| // (ref->type != IREE_VM_REF_TYPE_NULL && |
| // (iree_vm_type_def_is_value(type_def) || ref->type != |
| // type_def->ref_type)) |
| emitc::CallOp invalidType; |
| { |
| auto refType = rewriter.create<emitc::CallOp>( |
| /*location=*/loc, |
| /*type=*/ |
| emitc::OpaqueType::get(ctx, "iree_vm_ref_type_t"), |
| /*callee=*/StringAttr::get(ctx, "EMITC_STRUCT_PTR_MEMBER"), |
| /*args=*/ |
| ArrayAttr::get(ctx, {rewriter.getIndexAttr(0), |
| emitc::OpaqueAttr::get(ctx, "type")}), |
| /*templateArgs=*/ArrayAttr{}, |
| /*operands=*/ |
| ArrayRef<Value>{refPtrOp.getResult()}); |
| |
| auto refTypeNull = rewriter.create<emitc::ConstantOp>( |
| /*location=*/loc, |
| /*resultType=*/emitc::OpaqueType::get(ctx, "iree_vm_ref_type_t"), |
| /*value=*/emitc::OpaqueAttr::get(ctx, "IREE_VM_REF_TYPE_NULL")); |
| |
| auto typedefIsValue = rewriter.create<emitc::CallOp>( |
| /*location=*/loc, |
| /*type=*/rewriter.getI1Type(), |
| /*callee=*/StringAttr::get(ctx, "iree_vm_type_def_is_value"), |
| /*args=*/ArrayAttr{}, |
| /*templateArgs=*/ArrayAttr{}, |
| /*operands=*/ |
| ArrayRef<Value>{elementTypePtrOp.getValue().getResult()}); |
| |
| auto typedefRefType = rewriter.create<emitc::CallOp>( |
| /*location=*/loc, |
| /*type=*/ |
| emitc::OpaqueType::get(ctx, "iree_vm_ref_type_t"), |
| /*callee=*/StringAttr::get(ctx, "EMITC_STRUCT_PTR_MEMBER"), |
| /*args=*/ |
| ArrayAttr::get(ctx, {rewriter.getIndexAttr(0), |
| emitc::OpaqueAttr::get(ctx, "ref_type")}), |
| /*templateArgs=*/ArrayAttr{}, |
| /*operands=*/ |
| ArrayRef<Value>{elementTypePtrOp.getValue().getResult()}); |
| |
| auto refTypeIsNotNull = rewriter.create<emitc::CallOp>( |
| /*location=*/loc, |
| /*type=*/rewriter.getI1Type(), |
| /*callee=*/StringAttr::get(ctx, "EMITC_NE"), |
| /*args=*/ArrayAttr{}, |
| /*templateArgs=*/ArrayAttr{}, |
| /*operands=*/ |
| ArrayRef<Value>{refType.getResult(0), refTypeNull.getResult()}); |
| |
| auto refTypesDontMatch = rewriter.create<emitc::CallOp>( |
| /*location=*/loc, |
| /*type=*/rewriter.getI1Type(), |
| /*callee=*/StringAttr::get(ctx, "EMITC_NE"), |
| /*args=*/ArrayAttr{}, |
| /*templateArgs=*/ArrayAttr{}, |
| /*operands=*/ |
| ArrayRef<Value>{refType.getResult(0), typedefRefType.getResult(0)}); |
| |
| auto invalidRefType = rewriter.create<emitc::CallOp>( |
| /*location=*/loc, |
| /*type=*/rewriter.getI1Type(), |
| /*callee=*/StringAttr::get(ctx, "EMITC_OR"), |
| /*args=*/ArrayAttr{}, |
| /*templateArgs=*/ArrayAttr{}, |
| /*operands=*/ |
| ArrayRef<Value>{typedefIsValue.getResult(0), |
| refTypesDontMatch.getResult(0)}); |
| |
| invalidType = rewriter.create<emitc::CallOp>( |
| /*location=*/loc, |
| /*type=*/rewriter.getI1Type(), |
| /*callee=*/StringAttr::get(ctx, "EMITC_AND"), |
| /*args=*/ArrayAttr{}, |
| /*templateArgs=*/ArrayAttr{}, |
| /*operands=*/ |
| ArrayRef<Value>{refTypeIsNotNull.getResult(0), |
| invalidRefType.getResult(0)}); |
| } |
| |
| // Start by splitting the block into two. The part before will contain |
| // the condition, and the part after will contain the continuation |
| // point. |
| Block *condBlock = rewriter.getInsertionBlock(); |
| Block::iterator opPosition = rewriter.getInsertionPoint(); |
| Block *continuationBlock = condBlock->splitBlock(opPosition); |
| |
| // Create a new block for the target of the failure. |
| Block *failureBlock; |
| { |
| OpBuilder::InsertionGuard guard(rewriter); |
| Region *parentRegion = condBlock->getParent(); |
| failureBlock = rewriter.createBlock(parentRegion, parentRegion->end()); |
| |
| rewriter.create<emitc::CallOp>( |
| /*location=*/loc, |
| /*type=*/TypeRange{}, |
| /*callee=*/StringAttr::get(ctx, "iree_vm_ref_release"), |
| /*args=*/ArrayAttr{}, |
| /*templateArgs=*/ArrayAttr{}, |
| /*operands=*/ArrayRef<Value>{refPtrOp.getResult()}); |
| |
| rewriter.create<mlir::BranchOp>(loc, continuationBlock); |
| } |
| |
| rewriter.setInsertionPointToEnd(condBlock); |
| auto branchOp = rewriter.create<CondBranchOp>( |
| loc, invalidType.getResult(0), failureBlock, continuationBlock); |
| |
| return success(); |
| } |
| |
| VMAnalysisCache &vmAnalysisCache; |
| }; |
| |
| template <typename SetOpTy> |
| class ListSetOpConversion : public OpConversionPattern<SetOpTy> { |
| using OpConversionPattern<SetOpTy>::OpConversionPattern; |
| |
| private: |
| LogicalResult matchAndRewrite( |
| SetOpTy setOp, ArrayRef<Value> operands, |
| ConversionPatternRewriter &rewriter) const override { |
| auto ctx = setOp.getContext(); |
| auto loc = setOp.getLoc(); |
| |
| Optional<StringRef> valueConstructor = |
| TypeSwitch<Operation *, Optional<StringRef>>(setOp.getOperation()) |
| .Case<IREE::VM::ListSetI32Op>( |
| [&](auto op) { return StringRef("iree_vm_value_make_i32"); }) |
| .template Case<IREE::VM::ListSetI64Op>( |
| [&](auto op) { return StringRef("iree_vm_value_make_i64"); }) |
| .Default([](Operation *) { return None; }); |
| |
| if (!valueConstructor.hasValue()) { |
| return setOp.emitOpError() << " not handled"; |
| } |
| |
| auto valueOp = rewriter.create<emitc::CallOp>( |
| /*location=*/loc, |
| /*type=*/emitc::OpaqueType::get(ctx, "iree_vm_value_t"), |
| /*callee=*/StringAttr::get(ctx, valueConstructor.getValue()), |
| /*args=*/ArrayAttr{}, |
| /*templateArgs=*/ArrayAttr{}, |
| /*operands=*/ArrayRef<Value>{setOp.value()}); |
| |
| auto valuePtrOp = rewriter.create<emitc::ApplyOp>( |
| /*location=*/loc, |
| /*result=*/emitc::OpaqueType::get(ctx, "iree_vm_value_t*"), |
| /*applicableOperator=*/StringAttr::get(ctx, "&"), |
| /*operand=*/valueOp.getResult(0)); |
| |
| auto refOp = rewriter.create<emitc::ApplyOp>( |
| /*location=*/loc, |
| /*type=*/emitc::OpaqueType::get(ctx, "iree_vm_ref_t"), |
| /*applicableOperator=*/StringAttr::get(ctx, "*"), |
| /*operand=*/operands[0]); |
| |
| auto listDerefOp = failListNull( |
| /*rewriter=*/rewriter, |
| /*location=*/loc, |
| /*type=*/emitc::OpaqueType::get(ctx, "iree_vm_list_t*"), |
| /*callee=*/StringAttr::get(ctx, "iree_vm_list_deref"), |
| /*args=*/ArrayAttr{}, |
| /*templateArgs=*/ArrayAttr{}, |
| /*operands=*/ArrayRef<Value>{refOp.getResult()}); |
| |
| auto callOp = returnIfError( |
| /*rewriter=*/rewriter, |
| /*location=*/loc, |
| /*callee=*/StringAttr::get(ctx, "iree_vm_list_set_value"), |
| /*args=*/ArrayAttr{}, |
| /*templateArgs=*/ArrayAttr{}, |
| /*operands=*/ |
| ArrayRef<Value>{listDerefOp.getResult(0), setOp.index(), |
| valuePtrOp.getResult()}); |
| |
| rewriter.eraseOp(setOp); |
| |
| return success(); |
| } |
| }; |
| |
| class ListSetRefOpConversion |
| : public OpConversionPattern<IREE::VM::ListSetRefOp> { |
| public: |
| using OpConversionPattern<IREE::VM::ListSetRefOp>::OpConversionPattern; |
| |
| ListSetRefOpConversion(MLIRContext *context, VMAnalysisCache &vmAnalysisCache) |
| : OpConversionPattern<IREE::VM::ListSetRefOp>(context), |
| vmAnalysisCache(vmAnalysisCache) {} |
| |
| private: |
| LogicalResult matchAndRewrite( |
| IREE::VM::ListSetRefOp setOp, ArrayRef<Value> operands, |
| ConversionPatternRewriter &rewriter) const override { |
| auto ctx = setOp.getContext(); |
| auto loc = setOp.getLoc(); |
| |
| auto refOp = rewriter.create<emitc::ApplyOp>( |
| /*location=*/loc, |
| /*type=*/emitc::OpaqueType::get(ctx, "iree_vm_ref_t"), |
| /*applicableOperator=*/StringAttr::get(ctx, "*"), |
| /*operand=*/operands[0]); |
| |
| auto listDerefOp = failListNull( |
| /*rewriter=*/rewriter, |
| /*location=*/loc, |
| /*type=*/emitc::OpaqueType::get(ctx, "iree_vm_list_t*"), |
| /*callee=*/StringAttr::get(ctx, "iree_vm_list_deref"), |
| /*args=*/ArrayAttr{}, |
| /*templateArgs=*/ArrayAttr{}, |
| /*operands=*/ArrayRef<Value>{refOp.getResult()}); |
| |
| auto funcOp = setOp.getOperation()->getParentOfType<mlir::FuncOp>(); |
| auto ptr = vmAnalysisCache.find(funcOp.getOperation()); |
| if (ptr == vmAnalysisCache.end()) { |
| return setOp.emitError() << "parent func op not found in cache."; |
| } |
| |
| bool move = ptr->second.isLastValueUse(setOp.value(), setOp.getOperation()); |
| |
| StringRef callee = |
| move ? "iree_vm_list_set_ref_move" : "iree_vm_list_set_ref_retain"; |
| |
| auto callOp = returnIfError( |
| /*rewriter=*/rewriter, |
| /*location=*/loc, |
| /*callee=*/StringAttr::get(ctx, callee), |
| /*args=*/ArrayAttr{}, |
| /*templateArgs=*/ArrayAttr{}, |
| /*operands=*/ |
| ArrayRef<Value>{listDerefOp.getResult(0), setOp.index(), operands[2]}); |
| |
| rewriter.eraseOp(setOp); |
| |
| return success(); |
| } |
| |
| VMAnalysisCache &vmAnalysisCache; |
| }; |
| } // namespace |
| |
| void populateVMToEmitCPatterns(MLIRContext *context, |
| IREE::VM::EmitCTypeConverter &typeConverter, |
| OwningRewritePatternList &patterns, |
| VMAnalysisCache &vmAnalysisCache) { |
| populateUtilConversionPatterns(context, typeConverter, patterns); |
| |
| // CFG |
| patterns.insert<BranchOpConversion>(context); |
| patterns.insert<CallOpConversion>(typeConverter, context, vmAnalysisCache); |
| patterns.insert<CondBranchOpConversion>(context); |
| patterns.insert<FailOpConversion>(context); |
| patterns.insert<FuncOpConversion>(typeConverter, context, vmAnalysisCache); |
| patterns.insert<ReturnOpConversion>(context); |
| |
| // Globals |
| patterns.insert< |
| GlobalLoadOpConversion<IREE::VM::GlobalLoadI32Op, IREE::VM::GlobalI32Op>>( |
| context, "vm_global_load_i32"); |
| patterns.insert<GlobalStoreOpConversion<IREE::VM::GlobalStoreI32Op, |
| IREE::VM::GlobalI32Op>>( |
| context, "vm_global_store_i32"); |
| |
| // Constants |
| patterns.insert<ConstOpConversion<IREE::VM::ConstI32Op>>(context); |
| patterns.insert<ConstZeroOpConversion<IREE::VM::ConstI32ZeroOp>>(context); |
| patterns.insert<ConstRefZeroOpConversion>(context, vmAnalysisCache); |
| patterns.insert<ConstRefRodataOpConversion>(context, vmAnalysisCache); |
| |
| // List ops |
| patterns.insert<ListAllocOpConversion>(typeConverter, context, |
| vmAnalysisCache); |
| patterns.insert<ListOpConversion<IREE::VM::ListReserveOp>>( |
| context, "iree_vm_list_reserve", 0, true); |
| patterns.insert<ListOpConversion<IREE::VM::ListResizeOp>>( |
| context, "iree_vm_list_resize", 0, true); |
| patterns.insert<ListOpConversion<IREE::VM::ListSizeOp>>( |
| context, "iree_vm_list_size", 0, false); |
| patterns.insert<ListGetOpConversion<IREE::VM::ListGetI32Op>>(context); |
| patterns.insert<ListGetRefOpConversion>(context, vmAnalysisCache); |
| patterns.insert<ListSetOpConversion<IREE::VM::ListSetI32Op>>(context); |
| patterns.insert<ListSetRefOpConversion>(context, vmAnalysisCache); |
| |
| // Conditional assignment ops |
| patterns.insert<GenericOpConversion<IREE::VM::SelectI32Op>>(context, |
| "vm_select_i32"); |
| |
| // Native integer arithmetic ops |
| patterns.insert<GenericOpConversion<IREE::VM::AddI32Op>>(context, |
| "vm_add_i32"); |
| patterns.insert<GenericOpConversion<IREE::VM::SubI32Op>>(context, |
| "vm_sub_i32"); |
| patterns.insert<GenericOpConversion<IREE::VM::MulI32Op>>(context, |
| "vm_mul_i32"); |
| patterns.insert<GenericOpConversion<IREE::VM::DivI32SOp>>(context, |
| "vm_div_i32s"); |
| patterns.insert<GenericOpConversion<IREE::VM::DivI32UOp>>(context, |
| "vm_div_i32u"); |
| patterns.insert<GenericOpConversion<IREE::VM::RemI32SOp>>(context, |
| "vm_rem_i32s"); |
| patterns.insert<GenericOpConversion<IREE::VM::RemI32UOp>>(context, |
| "vm_rem_i32u"); |
| patterns.insert<GenericOpConversion<IREE::VM::FMAI32Op>>(context, |
| "vm_fma_i32"); |
| patterns.insert<GenericOpConversion<IREE::VM::NotI32Op>>(context, |
| "vm_not_i32"); |
| patterns.insert<GenericOpConversion<IREE::VM::AndI32Op>>(context, |
| "vm_and_i32"); |
| patterns.insert<GenericOpConversion<IREE::VM::OrI32Op>>(context, "vm_or_i32"); |
| patterns.insert<GenericOpConversion<IREE::VM::XorI32Op>>(context, |
| "vm_xor_i32"); |
| |
| // Casting and type conversion/emulation ops |
| patterns.insert<GenericOpConversion<IREE::VM::TruncI32I8Op>>( |
| context, "vm_trunc_i32i8"); |
| patterns.insert<GenericOpConversion<IREE::VM::TruncI32I16Op>>( |
| context, "vm_trunc_i32i16"); |
| patterns.insert<GenericOpConversion<IREE::VM::ExtI8I32SOp>>(context, |
| "vm_ext_i8i32s"); |
| patterns.insert<GenericOpConversion<IREE::VM::ExtI8I32UOp>>(context, |
| "vm_ext_i8i32u"); |
| patterns.insert<GenericOpConversion<IREE::VM::ExtI16I32SOp>>( |
| context, "vm_ext_i16i32s"); |
| patterns.insert<GenericOpConversion<IREE::VM::ExtI16I32UOp>>( |
| context, "vm_ext_i16i32u"); |
| |
| // Native bitwise shift and rotate ops |
| patterns.insert<GenericOpConversion<IREE::VM::ShlI32Op>>(context, |
| "vm_shl_i32"); |
| patterns.insert<GenericOpConversion<IREE::VM::ShrI32SOp>>(context, |
| "vm_shr_i32s"); |
| patterns.insert<GenericOpConversion<IREE::VM::ShrI32UOp>>(context, |
| "vm_shr_i32u"); |
| |
| // Comparison ops |
| patterns.insert<GenericOpConversion<IREE::VM::CmpEQI32Op>>(context, |
| "vm_cmp_eq_i32"); |
| patterns.insert<GenericOpConversion<IREE::VM::CmpNEI32Op>>(context, |
| "vm_cmp_ne_i32"); |
| patterns.insert<GenericOpConversion<IREE::VM::CmpLTI32SOp>>(context, |
| "vm_cmp_lt_i32s"); |
| patterns.insert<GenericOpConversion<IREE::VM::CmpLTI32UOp>>(context, |
| "vm_cmp_lt_i32u"); |
| patterns.insert<GenericOpConversion<IREE::VM::CmpNZI32Op>>(context, |
| "vm_cmp_nz_i32"); |
| patterns.insert<CompareRefOpConversion<IREE::VM::CmpEQRefOp>>( |
| context, "vm_cmp_eq_ref", vmAnalysisCache); |
| patterns.insert<CompareRefOpConversion<IREE::VM::CmpNERefOp>>( |
| context, "vm_cmp_ne_ref", vmAnalysisCache); |
| patterns.insert<CompareRefNotZeroOpConversion>(context, vmAnalysisCache); |
| |
| // ExtF32: Globals |
| patterns.insert< |
| GlobalLoadOpConversion<IREE::VM::GlobalLoadF32Op, IREE::VM::GlobalF32Op>>( |
| context, "vm_global_load_f32"); |
| patterns.insert<GlobalStoreOpConversion<IREE::VM::GlobalStoreF32Op, |
| IREE::VM::GlobalF32Op>>( |
| context, "vm_global_store_f32"); |
| |
| // ExtF32: Native floating-point constants |
| patterns.insert<ConstOpConversion<IREE::VM::ConstF32Op>>(context); |
| patterns.insert<ConstZeroOpConversion<IREE::VM::ConstF32ZeroOp>>(context); |
| |
| // ExtF32: Conditional assignment |
| patterns.insert<GenericOpConversion<IREE::VM::SelectF32Op>>(context, |
| "vm_select_f32"); |
| |
| // ExtF32: Native floating-point arithmetic |
| patterns.insert<GenericOpConversion<IREE::VM::AddF32Op>>(context, |
| "vm_add_f32"); |
| patterns.insert<GenericOpConversion<IREE::VM::SubF32Op>>(context, |
| "vm_sub_f32"); |
| patterns.insert<GenericOpConversion<IREE::VM::MulF32Op>>(context, |
| "vm_mul_f32"); |
| patterns.insert<GenericOpConversion<IREE::VM::DivF32Op>>(context, |
| "vm_div_f32"); |
| patterns.insert<GenericOpConversion<IREE::VM::RemF32Op>>(context, |
| "vm_rem_f32"); |
| patterns.insert<GenericOpConversion<IREE::VM::FMAF32Op>>(context, |
| "vm_fma_f32"); |
| patterns.insert<GenericOpConversion<IREE::VM::AbsF32Op>>(context, |
| "vm_abs_f32"); |
| patterns.insert<GenericOpConversion<IREE::VM::NegF32Op>>(context, |
| "vm_neg_f32"); |
| patterns.insert<GenericOpConversion<IREE::VM::CeilF32Op>>(context, |
| "vm_ceil_f32"); |
| patterns.insert<GenericOpConversion<IREE::VM::FloorF32Op>>(context, |
| "vm_floor_f32"); |
| |
| patterns.insert<GenericOpConversion<IREE::VM::AtanF32Op>>(context, |
| "vm_atan_f32"); |
| patterns.insert<GenericOpConversion<IREE::VM::Atan2F32Op>>(context, |
| "vm_atan2_f32"); |
| patterns.insert<GenericOpConversion<IREE::VM::CosF32Op>>(context, |
| "vm_cos_f32"); |
| patterns.insert<GenericOpConversion<IREE::VM::SinF32Op>>(context, |
| "vm_sin_f32"); |
| patterns.insert<GenericOpConversion<IREE::VM::ExpF32Op>>(context, |
| "vm_exp_f32"); |
| patterns.insert<GenericOpConversion<IREE::VM::Exp2F32Op>>(context, |
| "vm_exp2_f32"); |
| patterns.insert<GenericOpConversion<IREE::VM::ExpM1F32Op>>(context, |
| "vm_expm1_f32"); |
| patterns.insert<GenericOpConversion<IREE::VM::LogF32Op>>(context, |
| "vm_log_f32"); |
| patterns.insert<GenericOpConversion<IREE::VM::Log10F32Op>>(context, |
| "vm_log10_f32"); |
| patterns.insert<GenericOpConversion<IREE::VM::Log1pF32Op>>(context, |
| "vm_log1p_f32"); |
| patterns.insert<GenericOpConversion<IREE::VM::Log2F32Op>>(context, |
| "vm_log2_f32"); |
| patterns.insert<GenericOpConversion<IREE::VM::PowF32Op>>(context, |
| "vm_pow_f32"); |
| patterns.insert<GenericOpConversion<IREE::VM::RsqrtF32Op>>(context, |
| "vm_rsqrt_f32"); |
| patterns.insert<GenericOpConversion<IREE::VM::SqrtF32Op>>(context, |
| "vm_sqrt_f32"); |
| patterns.insert<GenericOpConversion<IREE::VM::TanhF32Op>>(context, |
| "vm_tanh_f32"); |
| |
| // ExtF32: Casting and type conversion/emulation |
| patterns.insert<GenericOpConversion<IREE::VM::CastSI32F32Op>>( |
| context, "vm_cast_si32f32"); |
| patterns.insert<GenericOpConversion<IREE::VM::CastUI32F32Op>>( |
| context, "vm_cast_ui32f32"); |
| patterns.insert<GenericOpConversion<IREE::VM::CastF32SI32Op>>( |
| context, "vm_cast_f32si32"); |
| patterns.insert<GenericOpConversion<IREE::VM::CastF32UI32Op>>( |
| context, "vm_cast_f32ui32"); |
| patterns.insert<GenericOpConversion<IREE::VM::BitcastI32F32Op>>( |
| context, "vm_bitcast_i32f32"); |
| patterns.insert<GenericOpConversion<IREE::VM::BitcastF32I32Op>>( |
| context, "vm_bitcast_f32i32"); |
| |
| // ExtF32: Comparison ops |
| patterns.insert<GenericOpConversion<IREE::VM::CmpEQF32OOp>>(context, |
| "vm_cmp_eq_f32o"); |
| patterns.insert<GenericOpConversion<IREE::VM::CmpEQF32UOp>>(context, |
| "vm_cmp_eq_f32u"); |
| patterns.insert<GenericOpConversion<IREE::VM::CmpNEF32OOp>>(context, |
| "vm_cmp_ne_f32o"); |
| patterns.insert<GenericOpConversion<IREE::VM::CmpNEF32UOp>>(context, |
| "vm_cmp_ne_f32u"); |
| patterns.insert<GenericOpConversion<IREE::VM::CmpLTF32OOp>>(context, |
| "vm_cmp_lt_f32o"); |
| patterns.insert<GenericOpConversion<IREE::VM::CmpLTF32UOp>>(context, |
| "vm_cmp_lt_f32u"); |
| patterns.insert<GenericOpConversion<IREE::VM::CmpLTEF32OOp>>( |
| context, "vm_cmp_lte_f32o"); |
| patterns.insert<GenericOpConversion<IREE::VM::CmpLTEF32UOp>>( |
| context, "vm_cmp_lte_f32u"); |
| patterns.insert<GenericOpConversion<IREE::VM::CmpNaNF32Op>>(context, |
| "vm_cmp_nan_f32"); |
| |
| // ExtI64: Globals |
| patterns.insert< |
| GlobalLoadOpConversion<IREE::VM::GlobalLoadI64Op, IREE::VM::GlobalI64Op>>( |
| context, "vm_global_load_i64"); |
| patterns.insert<GlobalStoreOpConversion<IREE::VM::GlobalStoreI64Op, |
| IREE::VM::GlobalI64Op>>( |
| context, "vm_global_store_i64"); |
| |
| // ExtI64: Constants |
| patterns.insert<ConstOpConversion<IREE::VM::ConstI64Op>>(context); |
| patterns.insert<ConstZeroOpConversion<IREE::VM::ConstI64ZeroOp>>(context); |
| |
| // ExtI64: List ops |
| patterns.insert<ListGetOpConversion<IREE::VM::ListGetI64Op>>(context); |
| patterns.insert<ListSetOpConversion<IREE::VM::ListSetI64Op>>(context); |
| |
| // ExtI64: Conditional assignment ops |
| patterns.insert<GenericOpConversion<IREE::VM::SelectI64Op>>(context, |
| "vm_select_i64"); |
| // ExtI64: Native integer arithmetic ops |
| patterns.insert<GenericOpConversion<IREE::VM::AddI64Op>>(context, |
| "vm_add_i64"); |
| patterns.insert<GenericOpConversion<IREE::VM::SubI64Op>>(context, |
| "vm_sub_i64"); |
| patterns.insert<GenericOpConversion<IREE::VM::MulI64Op>>(context, |
| "vm_mul_i64"); |
| patterns.insert<GenericOpConversion<IREE::VM::DivI64SOp>>(context, |
| "vm_div_i64s"); |
| patterns.insert<GenericOpConversion<IREE::VM::DivI64UOp>>(context, |
| "vm_div_i64u"); |
| patterns.insert<GenericOpConversion<IREE::VM::RemI64SOp>>(context, |
| "vm_rem_i64s"); |
| patterns.insert<GenericOpConversion<IREE::VM::RemI64UOp>>(context, |
| "vm_rem_i64u"); |
| patterns.insert<GenericOpConversion<IREE::VM::FMAI64Op>>(context, |
| "vm_fma_i64"); |
| patterns.insert<GenericOpConversion<IREE::VM::NotI64Op>>(context, |
| "vm_not_i64"); |
| patterns.insert<GenericOpConversion<IREE::VM::AndI64Op>>(context, |
| "vm_and_i64"); |
| patterns.insert<GenericOpConversion<IREE::VM::OrI64Op>>(context, "vm_or_i64"); |
| patterns.insert<GenericOpConversion<IREE::VM::XorI64Op>>(context, |
| "vm_xor_i64"); |
| |
| // ExtI64: Casting and type conversion/emulation ops |
| patterns.insert<GenericOpConversion<IREE::VM::TruncI64I32Op>>( |
| context, "vm_trunc_i64i32"); |
| patterns.insert<GenericOpConversion<IREE::VM::ExtI32I64SOp>>( |
| context, "vm_ext_i32i64s"); |
| patterns.insert<GenericOpConversion<IREE::VM::ExtI32I64UOp>>( |
| context, "vm_ext_i32i64u"); |
| |
| // ExtI64: Native bitwise shift and rotate ops |
| patterns.insert<GenericOpConversion<IREE::VM::ShlI64Op>>(context, |
| "vm_shl_i64"); |
| patterns.insert<GenericOpConversion<IREE::VM::ShrI64SOp>>(context, |
| "vm_shr_i64s"); |
| patterns.insert<GenericOpConversion<IREE::VM::ShrI64UOp>>(context, |
| "vm_shr_i64u"); |
| |
| // ExtI64: Comparison ops |
| patterns.insert<GenericOpConversion<IREE::VM::CmpEQI64Op>>(context, |
| "vm_cmp_eq_i64"); |
| patterns.insert<GenericOpConversion<IREE::VM::CmpNEI64Op>>(context, |
| "vm_cmp_ne_i64"); |
| patterns.insert<GenericOpConversion<IREE::VM::CmpLTI64SOp>>(context, |
| "vm_cmp_lt_i64s"); |
| patterns.insert<GenericOpConversion<IREE::VM::CmpLTI64UOp>>(context, |
| "vm_cmp_lt_i64u"); |
| patterns.insert<GenericOpConversion<IREE::VM::CmpNZI64Op>>(context, |
| "vm_cmp_nz_i64"); |
| } |
| |
| namespace IREE { |
| namespace VM { |
| |
| namespace { |
| |
| // A pass converting IREE VM operations into the EmitC dialect. |
| // vm.func ops get converted to std.func with the calling convention used by |
| // EmitC. Each function gets three additional arguments a `iree_vm_stack_t*` as |
| // well as two module specific struct pointers (`{module_name}_t*` and |
| // `{module_name}_state_t`). These are followed by the original function |
| // arguments and out arguments for the vm.func results. The result type of the |
| // function is `iree_status_t`. Ref types are always passed as pointers. |
| // |
| // Examples: |
| // () -> () => (iree_vm_stack_t*, module_t*, module_state_t*) -> iree_status_t |
| // |
| // (i) -> () => (iree_vm_stack_t*, module_t*, module_state_t*, int32_t) -> |
| // iree_status_t |
| // |
| // (r) -> () => (iree_vm_stack_t*, module_t*, module_state_t*, iree_vm_ref_t*) |
| // -> iree_status_t |
| // |
| // () -> (r) => (iree_vm_stack_t*, module_t*, module_state_t*, iree_vm_ref_t*) |
| // -> iree_status_t |
| // |
| // (iir) -> (ri) => (iree_vm_stack_t*, module_t*, module_state_t*, int32_t, |
| // int32_t, iree_vm_ref_t*, iree_vm_ref_t*, int32_t*) -> |
| // iree_status_t |
| class ConvertVMToEmitCPass |
| : public PassWrapper<ConvertVMToEmitCPass, |
| OperationPass<IREE::VM::ModuleOp>> { |
| void getDependentDialects(DialectRegistry ®istry) const override { |
| registry.insert<mlir::emitc::EmitCDialect, mlir::BuiltinDialect, |
| mlir::StandardOpsDialect, IREE::Util::UtilDialect>(); |
| } |
| |
| StringRef getArgument() const override { return "iree-convert-vm-to-emitc"; } |
| |
| StringRef getDescription() const override { |
| return "Convert VM Ops to the EmitC dialect"; |
| } |
| |
| void runOnOperation() override { |
| IREE::VM::ModuleOp module = getOperation(); |
| |
| ConversionTarget target(getContext()); |
| EmitCTypeConverter typeConverter; |
| |
| // Run analysis passes |
| VMAnalysisCache vmAnalysisCache; |
| |
| // Convert vm.func ops to std.func with the calling convention used by |
| // EmitC. We convert these upfront to make sure vm.call ops always |
| // reference std.func ops with the correct calling convention during the |
| // conversion. |
| SmallVector<IREE::VM::FuncOp, 4> funcsToRemove; |
| for (auto funcOp : module.getOps<IREE::VM::FuncOp>()) { |
| Operation *op = funcOp.getOperation(); |
| vmAnalysisCache.insert(std::make_pair( |
| op, VMAnalysis{RegisterAllocation(op), ValueLiveness(op)})); |
| |
| if (failed(convertFuncOp(funcOp, vmAnalysisCache))) |
| return signalPassFailure(); |
| funcsToRemove.push_back(funcOp); |
| } |
| |
| for (auto &funcOp : funcsToRemove) funcOp.erase(); |
| |
| // Generate func ops that implement the C API. |
| if (failed(createAPIFunctions(module))) return signalPassFailure(); |
| |
| OwningRewritePatternList patterns(&getContext()); |
| populateVMToEmitCPatterns(&getContext(), typeConverter, patterns, |
| vmAnalysisCache); |
| |
| target.addLegalDialect<emitc::EmitCDialect, mlir::BuiltinDialect, |
| mlir::StandardOpsDialect>(); |
| |
| target.addDynamicallyLegalOp<mlir::FuncOp>([&](mlir::FuncOp op) { |
| return typeConverter.isSignatureLegal(op.getType()); |
| }); |
| |
| target.addDynamicallyLegalOp<IREE::Util::DoNotOptimizeOp>( |
| [&](IREE::Util::DoNotOptimizeOp op) { |
| return typeConverter.isLegal(op.getResultTypes()); |
| }); |
| |
| // Structural ops |
| target.addLegalOp<IREE::VM::ModuleOp>(); |
| target.addLegalOp<IREE::VM::ModuleTerminatorOp>(); |
| target.addLegalOp<IREE::VM::ExportOp>(); |
| target.addLegalOp<IREE::VM::ImportOp>(); |
| |
| // Global ops |
| target.addLegalOp<IREE::VM::GlobalI32Op>(); |
| target.addLegalOp<IREE::VM::GlobalI64Op>(); |
| target.addLegalOp<IREE::VM::GlobalF32Op>(); |
| target.addLegalOp<IREE::VM::RodataOp>(); |
| |
| if (failed(applyFullConversion(module, target, std::move(patterns)))) { |
| return signalPassFailure(); |
| } |
| } |
| }; |
| |
| } // namespace |
| |
| std::unique_ptr<OperationPass<IREE::VM::ModuleOp>> |
| createConvertVMToEmitCPass() { |
| return std::make_unique<ConvertVMToEmitCPass>(); |
| } |
| |
| } // namespace VM |
| } // namespace IREE |
| |
| static PassRegistration<IREE::VM::ConvertVMToEmitCPass> pass; |
| |
| } // namespace iree_compiler |
| } // namespace mlir |