| // Copyright 2020 The IREE Authors |
| // |
| // Licensed under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| |
| #include "iree/compiler/Dialect/VM/Target/C/CModuleTarget.h" |
| |
| #include "iree/compiler/Dialect/Util/IR/UtilDialect.h" |
| #include "iree/compiler/Dialect/Util/IR/UtilOps.h" |
| #include "iree/compiler/Dialect/Util/Transforms/Passes.h" |
| #include "iree/compiler/Dialect/VM/Analysis/RegisterAllocation.h" |
| #include "iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.h" |
| #include "iree/compiler/Dialect/VM/Conversion/VMToEmitC/DropExcludedExports.h" |
| #include "iree/compiler/Dialect/VM/Target/C/CppEmitter.h" |
| #include "iree/compiler/Dialect/VM/Transforms/Passes.h" |
| #include "llvm/ADT/TypeSwitch.h" |
| #include "mlir/Pass/PassManager.h" |
| #include "mlir/Transforms/Passes.h" |
| |
| namespace mlir { |
| namespace iree_compiler { |
| namespace IREE { |
| namespace VM { |
| |
| static void printCompilerConfigurationBlock(llvm::raw_ostream &output) { |
| output << "//" << std::string(77, '=') << "\n" |
| << "// compiler configuration\n" |
| << "//" << std::string(77, '=') << "\n\n"; |
| |
| output << "#if defined(IREE_COMPILER_MSVC)\n"; |
| output << "#pragma warning(disable:4102)\n"; |
| output << "#endif // IREE_COMPILER_MSVC\n"; |
| } |
| |
| static void printModuleComment(IREE::VM::ModuleOp &moduleOp, |
| llvm::raw_ostream &output) { |
| output << "//" << std::string(77, '=') << "\n" |
| << "// module \"" << moduleOp.getName() |
| << "\"\n" |
| "//" |
| << std::string(77, '=') << "\n"; |
| } |
| |
| static LogicalResult printRodataBuffers(IREE::VM::ModuleOp &moduleOp, |
| mlir::emitc::CppEmitter &emitter) { |
| llvm::raw_ostream &output = emitter.ostream(); |
| std::string moduleName = moduleOp.getName().str(); |
| |
| for (auto rodataOp : moduleOp.getOps<IREE::VM::RodataOp>()) { |
| auto value = |
| rodataOp.value().dyn_cast<IREE::Util::SerializableAttrInterface>(); |
| assert(value && "expected a serializable rodata value"); |
| SmallVector<char> byteBuffer; |
| if (failed(value.serializeToVector(llvm::support::endianness::little, |
| byteBuffer))) { |
| return rodataOp.emitError() << "error during serialization"; |
| } |
| |
| constexpr size_t kDefaultRodataAlignment = 16; |
| size_t alignment = |
| rodataOp.alignment() |
| ? static_cast<size_t>(rodataOp.alignment().getValue()) |
| : 0; |
| if (alignment == 0) alignment = kDefaultRodataAlignment; |
| |
| std::string bufferName = |
| moduleOp.getName().str() + "_" + rodataOp.getName().str(); |
| |
| output << "iree_alignas(" << alignment << ") static const uint8_t " |
| << bufferName << "[] = {"; |
| llvm::interleaveComma(byteBuffer, output, [&](char value) { |
| output << static_cast<unsigned int>(value); |
| }); |
| output << "};\n"; |
| } |
| |
| output << "\n"; |
| |
| return success(); |
| } |
| |
| static LogicalResult printStructDefinitions(IREE::VM::ModuleOp &moduleOp, |
| mlir::emitc::CppEmitter &emitter) { |
| llvm::raw_ostream &output = emitter.ostream(); |
| std::string moduleName = moduleOp.getName().str(); |
| |
| output << "struct " << moduleName << "_t {\n"; |
| output << "iree_allocator_t allocator;\n"; |
| output << "};\n"; |
| |
| output << "struct " << moduleName << "_state_t {\n"; |
| |
| // Returns |count| or 1 if |count| == 0. |
| // Some compilers (MSVC) don't support zero-length struct fields on the |
| // interior of structs (just VLA at the tail). |
| auto countOrEmpty = [](uint32_t count) { return count ? count : 1; }; |
| |
| auto ordinalCounts = moduleOp.ordinal_counts().getValue(); |
| output << "iree_allocator_t allocator;\n"; |
| output << "uint8_t rwdata[" << countOrEmpty(ordinalCounts.global_bytes()) |
| << "];\n"; |
| output << "iree_vm_ref_t refs[" << countOrEmpty(ordinalCounts.global_refs()) |
| << "];\n"; |
| output << "iree_vm_buffer_t rodata_buffers[" |
| << countOrEmpty(ordinalCounts.rodatas()) << "];\n"; |
| output << "iree_vm_function_t imports[" |
| << countOrEmpty(ordinalCounts.import_funcs()) << "];\n"; |
| output << "};\n"; |
| |
| output << "typedef struct " << moduleName << "_t " << moduleName << "_t;\n"; |
| output << "typedef struct " << moduleName << "_state_t " << moduleName |
| << "_state_t;\n"; |
| |
| output << "\n"; |
| |
| return success(); |
| } |
| |
| static LogicalResult printShim(mlir::FuncOp &funcOp, |
| llvm::raw_ostream &output) { |
| StringAttr callingConvention = funcOp.getOperation() |
| ->getAttr("vm.calling_convention") |
| .cast<StringAttr>(); |
| if (!callingConvention) { |
| return funcOp.emitError("Couldn't find calling convention attribute"); |
| } |
| output << "call_" << callingConvention.getValue() << "_shim"; |
| return success(); |
| } |
| |
| static LogicalResult buildModuleDescriptors(IREE::VM::ModuleOp &moduleOp, |
| mlir::emitc::CppEmitter &emitter) { |
| SymbolTable symbolTable(moduleOp); |
| std::string moduleName = moduleOp.getName().str(); |
| llvm::raw_ostream &output = emitter.ostream(); |
| |
| auto printStringView = [](StringRef s) -> std::string { |
| // We can't use iree_make_string_view because function calls are not allowed |
| // for constant expressions in C. |
| return ("{\"" + s + "\", " + std::to_string(s.size()) + "}").str(); |
| }; |
| |
| // exports |
| SmallVector<IREE::VM::ExportOp, 4> exportOps( |
| moduleOp.getOps<IREE::VM::ExportOp>()); |
| std::string exportName = moduleName + "_exports_"; |
| output << "static const size_t " << exportName |
| << "count_ = " << exportOps.size() << ";\n"; |
| output << "static const iree_vm_native_export_descriptor_t " << exportName |
| << "[] = {\n"; |
| if (exportOps.empty()) { |
| // Empty list placeholder. |
| output << " {0},\n"; |
| } else { |
| // sort export ops |
| llvm::sort(exportOps, [](auto &lhs, auto &rhs) { |
| return lhs.export_name().compare(rhs.export_name()) < 0; |
| }); |
| for (auto exportOp : exportOps) { |
| StringRef funcName = exportOp.function_ref(); |
| auto funcOp = symbolTable.lookup<mlir::FuncOp>(funcName); |
| if (!funcOp) { |
| return exportOp.emitError("Couldn't find referenced FuncOp"); |
| } |
| StringAttr callingConvention = funcOp.getOperation() |
| ->getAttr("vm.calling_convention") |
| .cast<StringAttr>(); |
| if (!callingConvention) { |
| return exportOp.emitError("Couldn't find calling convention attribute"); |
| } |
| |
| // TODO(simon-camp): support function-level reflection attributes |
| output << "{" << printStringView(exportOp.export_name()) << ", " |
| << printStringView(callingConvention.getValue()) |
| << ", 0, NULL},\n"; |
| } |
| } |
| output << "};\n"; |
| output << "\n"; |
| |
| // imports |
| SmallVector<IREE::VM::ImportOp, 4> importOps( |
| moduleOp.getOps<IREE::VM::ImportOp>()); |
| std::string importName = moduleName + "_imports_"; |
| output << "static const size_t " << importName |
| << "count_ = " << importOps.size() << ";\n"; |
| output << "static const iree_vm_native_import_descriptor_t " << importName |
| << "[] = {\n"; |
| if (importOps.empty()) { |
| // Empty list placeholder. |
| output << " {0},\n"; |
| } else { |
| // sort import ops |
| llvm::sort(importOps, [](auto &lhs, auto &rhs) { |
| return lhs.getName().compare(rhs.getName()) < 0; |
| }); |
| for (auto importOp : importOps) { |
| output << "{" << printStringView(importOp.getName()) << "},\n"; |
| } |
| } |
| output << "};\n"; |
| output << "\n"; |
| |
| // functions |
| std::string functionName = moduleName + "_funcs_"; |
| output << "static const size_t " << functionName |
| << "count_ = " << exportOps.size() << ";\n"; |
| output << "static const iree_vm_native_function_ptr_t " << functionName |
| << "[] = {\n"; |
| if (exportOps.empty()) { |
| // Empty list placeholder. |
| output << " {0},\n"; |
| } else { |
| // We only add exported functions to the table, as calls to internal |
| // functions are directly mapped to C function calls of the generated |
| // implementation. |
| for (auto exportOp : exportOps) { |
| StringRef funcName = exportOp.function_ref(); |
| auto funcOp = symbolTable.lookup<mlir::FuncOp>(funcName); |
| if (!funcOp) { |
| return exportOp.emitError("Couldn't find referenced FuncOp"); |
| } |
| output << "{" |
| << "(iree_vm_native_function_shim_t)"; |
| |
| if (failed(printShim(funcOp, output))) { |
| return funcOp.emitError("Error generating shim"); |
| } |
| output << ", " |
| << "(iree_vm_native_function_target_t)" << funcName << "},\n"; |
| } |
| } |
| output << "};\n"; |
| output << "\n"; |
| |
| // module descriptor |
| // TODO(simon-camp): support module-level reflection attributes |
| std::string descriptorName = moduleName + "_descriptor_"; |
| output << "static const iree_vm_native_module_descriptor_t " << descriptorName |
| << " = {\n" |
| << printStringView(moduleName) << ",\n" |
| << importName << "count_,\n" |
| << importName << ",\n" |
| << exportName << "count_,\n" |
| << exportName << ",\n" |
| << functionName << "count_,\n" |
| << functionName << ",\n" |
| << "0,\n" |
| << "NULL,\n" |
| << "};\n"; |
| |
| output << "\n"; |
| return success(); |
| } |
| |
| /// Adapted from BytecodeModuleTarget and extended by C specific passes |
| static LogicalResult canonicalizeModule( |
| IREE::VM::ModuleOp moduleOp, IREE::VM::CTargetOptions targetOptions) { |
| OwningRewritePatternList patterns(moduleOp.getContext()); |
| ConversionTarget target(*moduleOp.getContext()); |
| target.addLegalDialect<IREE::VM::VMDialect>(); |
| target.addLegalOp<IREE::Util::DoNotOptimizeOp>(); |
| |
| // Add all VM canonicalization patterns and mark pseudo-ops illegal. |
| auto *context = moduleOp.getContext(); |
| for (auto *op : context->getRegisteredOperations()) { |
| // Non-serializable ops must be removed prior to serialization. |
| if (op->hasTrait<OpTrait::IREE::VM::PseudoOp>()) { |
| op->getCanonicalizationPatterns(patterns, context); |
| target.setOpAction(OperationName(op->name, context), |
| ConversionTarget::LegalizationAction::Illegal); |
| } |
| |
| // Debug ops must not be present when stripping. |
| // TODO(benvanik): add RemoveDisabledDebugOp pattern. |
| if (op->hasTrait<OpTrait::IREE::VM::DebugOnly>() && |
| targetOptions.stripDebugOps) { |
| target.setOpAction(OperationName(op->name, context), |
| ConversionTarget::LegalizationAction::Illegal); |
| } |
| } |
| |
| if (failed(applyFullConversion(moduleOp, target, std::move(patterns)))) { |
| return moduleOp.emitError() << "unable to fully apply conversion to module"; |
| } |
| |
| PassManager passManager(context); |
| mlir::applyPassManagerCLOptions(passManager); |
| mlir::applyDefaultTimingPassManagerCLOptions(passManager); |
| auto &modulePasses = passManager.nest<IREE::VM::ModuleOp>(); |
| |
| if (targetOptions.optimize) { |
| // TODO(benvanik): does this run until it quiesces? |
| modulePasses.addPass(mlir::createInlinerPass()); |
| modulePasses.addPass(mlir::createCSEPass()); |
| modulePasses.addPass(mlir::createCanonicalizerPass()); |
| } |
| |
| // C target specific pass |
| |
| // Erase exports annotated with 'emitc.exclude'. This makes testing |
| // of partially supported ops easier. For the DCE pass to remove the |
| // referenced function it must be unused and marked private. |
| modulePasses.addPass(createDropExcludedExportsPass()); |
| modulePasses.addPass(mlir::createSymbolDCEPass()); |
| |
| // In the the Bytecode module the order is: |
| // * `createDropCompilerHintsPass()` |
| // * `IREE::VM::createOrdinalAllocationPass()` |
| // Here, we have to reverse the order and run |
| // `createConvertVMToEmitCPass()` inbetween to test the EmitC pass. |
| // Otherwise, the constants get folded by the canonicalizer. |
| |
| // Mark up the module with ordinals for each top-level op (func, etc). |
| // This will make it easier to correlate the MLIR textual output to the |
| // binary output. |
| // We don't want any more modifications after this point as they could |
| // invalidate the ordinals. |
| modulePasses.addPass(IREE::VM::createOrdinalAllocationPass()); |
| |
| // C target specific pass |
| modulePasses.addPass(createConvertVMToEmitCPass()); |
| |
| modulePasses.addPass(IREE::Util::createDropCompilerHintsPass()); |
| |
| if (failed(passManager.run(moduleOp->getParentOfType<mlir::ModuleOp>()))) { |
| return moduleOp.emitError() << "failed during transform passes"; |
| } |
| |
| return success(); |
| } |
| |
| LogicalResult translateModuleToC(IREE::VM::ModuleOp moduleOp, |
| CTargetOptions targetOptions, |
| llvm::raw_ostream &output) { |
| moduleOp.getContext()->getOrLoadDialect<IREE::Util::UtilDialect>(); |
| |
| if (failed(canonicalizeModule(moduleOp, targetOptions))) { |
| return moduleOp.emitError() |
| << "failed to canonicalize vm.module to a serializable form"; |
| } |
| |
| if (targetOptions.outputFormat == COutputFormat::kMlirText) { |
| // Use the standard MLIR text printer. |
| moduleOp.getOperation()->print(output); |
| output << "\n"; |
| return success(); |
| } |
| |
| auto printInclude = [&output](std::string include) { |
| output << "#include \"" << include << "\"\n"; |
| }; |
| |
| printInclude("iree/vm/api.h"); |
| printInclude("iree/vm/ops.h"); |
| printInclude("iree/vm/ops_emitc.h"); |
| printInclude("iree/vm/shims_emitc.h"); |
| output << "\n"; |
| |
| printCompilerConfigurationBlock(output); |
| output << "\n"; |
| |
| printModuleComment(moduleOp, output); |
| output << "\n"; |
| |
| mlir::emitc::CppEmitter emitter(output, /*declareVariablesAtTop=*/true); |
| mlir::emitc::CppEmitter::Scope scope(emitter); |
| |
| if (failed(printRodataBuffers(moduleOp, emitter))) { |
| return failure(); |
| } |
| |
| // build struct definitions |
| if (failed(printStructDefinitions(moduleOp, emitter))) { |
| return failure(); |
| } |
| |
| // translate functions |
| output << "// DECLARE FUNCTIONS\n"; |
| |
| for (auto funcOp : moduleOp.getOps<mlir::FuncOp>()) { |
| Operation *op = funcOp.getOperation(); |
| if (op->hasAttr("emitc.static")) output << "static "; |
| |
| if (failed( |
| emitter.emitTypes(funcOp.getLoc(), funcOp.getType().getResults()))) |
| return failure(); |
| output << " " << funcOp.getName(); |
| |
| output << "("; |
| |
| bool error = false; |
| llvm::interleaveComma( |
| funcOp.getArguments(), output, [&](BlockArgument arg) { |
| if (failed(emitter.emitType(funcOp.getLoc(), arg.getType()))) |
| error = true; |
| }); |
| if (error) return failure(); |
| output << ");\n"; |
| } |
| |
| output << "// DEFINE FUNCTIONS\n"; |
| |
| // Emit code for functions skipping those marked with `vm.emit_at_end`. |
| for (auto funcOp : moduleOp.getOps<mlir::FuncOp>()) { |
| Operation *op = funcOp.getOperation(); |
| if (op->hasAttr("vm.emit_at_end")) continue; |
| if (op->hasAttr("emitc.static")) output << "static "; |
| if (failed(emitter.emitOperation(*funcOp.getOperation(), |
| /*trailingSemicolon=*/false))) |
| return failure(); |
| } |
| |
| output << "\n"; |
| |
| // generate module descriptors |
| if (failed(buildModuleDescriptors(moduleOp, emitter))) { |
| return failure(); |
| } |
| |
| // Emit code for functions marked with `vm.emit_at_end`. |
| for (auto funcOp : moduleOp.getOps<mlir::FuncOp>()) { |
| Operation *op = funcOp.getOperation(); |
| if (!op->hasAttr("vm.emit_at_end")) continue; |
| if (op->hasAttr("emitc.static")) output << "static "; |
| if (failed(emitter.emitOperation(*funcOp.getOperation(), |
| /*trailingSemicolon=*/false))) |
| return failure(); |
| } |
| |
| return success(); |
| } |
| |
| LogicalResult translateModuleToC(mlir::ModuleOp outerModuleOp, |
| CTargetOptions targetOptions, |
| llvm::raw_ostream &output) { |
| auto moduleOps = outerModuleOp.getOps<IREE::VM::ModuleOp>(); |
| if (moduleOps.empty()) { |
| return outerModuleOp.emitError() |
| << "outer module does not contain a vm.module op"; |
| } |
| return translateModuleToC(*moduleOps.begin(), targetOptions, output); |
| } |
| |
| } // namespace VM |
| } // namespace IREE |
| } // namespace iree_compiler |
| } // namespace mlir |