| // Copyright 2019 Google LLC |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // https://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| |
| #include "iree/compiler/Dialect/VM/Target/Bytecode/BytecodeModuleTarget.h" |
| |
| #include <algorithm> |
| |
| #include "flatbuffers/flatbuffers.h" |
| #include "flatbuffers/minireflect.h" |
| #include "iree/compiler/Dialect/IREE/IR/IREEOps.h" |
| #include "iree/compiler/Dialect/IREE/IR/IREETypes.h" |
| #include "iree/compiler/Dialect/IREE/Transforms/Passes.h" |
| #include "iree/compiler/Dialect/VM/Analysis/RegisterAllocation.h" |
| #include "iree/compiler/Dialect/VM/Analysis/ValueLiveness.h" |
| #include "iree/compiler/Dialect/VM/IR/VMDialect.h" |
| #include "iree/compiler/Dialect/VM/IR/VMOps.h" |
| #include "iree/compiler/Dialect/VM/Target/Bytecode/BytecodeEncoder.h" |
| #include "iree/compiler/Dialect/VM/Target/Bytecode/ConstantEncoder.h" |
| #include "iree/compiler/Dialect/VM/Transforms/Passes.h" |
| #include "iree/schemas/bytecode_module_def_generated.h" |
| #include "llvm/Support/ErrorHandling.h" |
| #include "mlir/IR/Attributes.h" |
| #include "mlir/IR/Diagnostics.h" |
| #include "mlir/IR/Module.h" |
| #include "mlir/IR/PatternMatch.h" |
| #include "mlir/IR/Visitors.h" |
| #include "mlir/Pass/Pass.h" |
| #include "mlir/Pass/PassManager.h" |
| #include "mlir/Transforms/DialectConversion.h" |
| #include "mlir/Transforms/Passes.h" |
| #include "mlir/Translation.h" |
| |
| namespace mlir { |
| namespace iree_compiler { |
| namespace IREE { |
| namespace VM { |
| |
| namespace { |
| |
| using flatbuffers::FlatBufferBuilder; |
| using flatbuffers::Offset; |
| using flatbuffers::Vector; |
| |
| struct ModuleCounts { |
| int importFuncs = 0; |
| int exportFuncs = 0; |
| int internalFuncs = 0; |
| size_t globalBytes = 0; |
| int globalRefs = 0; |
| int rodatas = 0; |
| int rwdatas = 0; |
| }; |
| |
| struct TypeDef { |
| Type type; |
| std::string full_name; |
| }; |
| |
| } // namespace |
| |
| // Computes symbol counts within the given |moduleOp|. |
| // These counts, including the global byte reservation count, are expected to |
| // match the actual values during serialization. |
| // |
| // Preconditions: |
| // - OrdinalAllocationPass has run on the module |
| // - All ordinals start from 0 and are contiguous |
| static ModuleCounts computeModuleSymbolCounts(IREE::VM::ModuleOp moduleOp) { |
| ModuleCounts counts; |
| for (auto &op : moduleOp.getBlock().getOperations()) { |
| if (auto funcOp = dyn_cast<IREE::VM::FuncOp>(op)) { |
| ++counts.internalFuncs; |
| } else if (isa<IREE::VM::ExportOp>(op)) { |
| ++counts.exportFuncs; |
| } else if (isa<IREE::VM::ImportOp>(op)) { |
| ++counts.importFuncs; |
| } else if (isa<IREE::VM::RodataOp>(op)) { |
| ++counts.rodatas; |
| } else if (isa<IREE::VM::GlobalRefOp>(op)) { |
| ++counts.globalRefs; |
| } else if (auto globalOp = dyn_cast<VMGlobalOp>(op)) { |
| counts.globalBytes = |
| std::max(counts.globalBytes, |
| globalOp.getOrdinal() + globalOp.getStorageSize()); |
| } |
| } |
| return counts; |
| } |
| |
| // Finds all types in the module and builds a type table mapping the index in |
| // the vector to the type represented by the type ordinal. |
| static std::vector<TypeDef> buildTypeTable(IREE::VM::ModuleOp moduleOp) { |
| llvm::DenseMap<Type, std::string> typeMap; |
| std::function<void(Type)> tryInsertType; |
| tryInsertType = [&](Type type) { |
| if (auto refPtrType = type.dyn_cast<IREE::VM::RefType>()) { |
| type = refPtrType.getObjectType(); |
| } |
| if (typeMap.count(type)) return; |
| std::string str; |
| llvm::raw_string_ostream sstream(str); |
| type.print(sstream); |
| sstream.flush(); |
| typeMap.try_emplace(type, str); |
| if (auto listType = type.dyn_cast<IREE::VM::ListType>()) { |
| if (listType.getElementType()) { |
| tryInsertType(listType.getElementType()); |
| } |
| } |
| }; |
| for (auto funcOp : moduleOp.getBlock().getOps<IREE::VM::FuncOp>()) { |
| funcOp.walk([&](Operation *op) { |
| for (auto type : op->getOperandTypes()) tryInsertType(type); |
| for (auto type : op->getResultTypes()) tryInsertType(type); |
| }); |
| } |
| |
| std::vector<TypeDef> table; |
| table.reserve(typeMap.size()); |
| for (const auto &typeString : typeMap) { |
| table.push_back(TypeDef{typeString.first, typeString.second}); |
| } |
| llvm::sort( |
| table, +[](const TypeDef &lhs, const TypeDef &rhs) { |
| // Always sort builtins above custom types. |
| if (lhs.full_name[0] != '!' && rhs.full_name[0] == '!') { |
| return true; |
| } else if (lhs.full_name[0] == '!' && rhs.full_name[0] != '!') { |
| return false; |
| } |
| return lhs.full_name.compare(rhs.full_name) < 0; |
| }); |
| return table; |
| } |
| |
| // Canonicalizes the module to its final form prior to emission. |
| // This verifies that we only have ops we can serialize and performs any of the |
| // required transformations (such as debug op stripping). |
| static LogicalResult canonicalizeModule(BytecodeTargetOptions targetOptions, |
| IREE::VM::ModuleOp moduleOp) { |
| OwningRewritePatternList patterns; |
| ConversionTarget target(*moduleOp.getContext()); |
| target.addLegalDialect<IREE::VM::VMDialect>(); |
| target.addLegalOp<IREE::DoNotOptimizeOp>(); |
| |
| // Add all VM canonicalization patterns and mark pseudo-ops illegal. |
| auto *context = moduleOp.getContext(); |
| for (auto *op : context->getRegisteredOperations()) { |
| // Non-serializable ops must be removed prior to serialization. |
| if (op->hasTrait<OpTrait::IREE::VM::PseudoOp>()) { |
| op->getCanonicalizationPatterns(patterns, context); |
| target.setOpAction(OperationName(op->name, context), |
| ConversionTarget::LegalizationAction::Illegal); |
| } |
| |
| // Debug ops must not be present when stripping. |
| // TODO(benvanik): add RemoveDisabledDebugOp pattern. |
| if (op->hasTrait<OpTrait::IREE::VM::DebugOnly>() && |
| targetOptions.stripDebugOps) { |
| target.setOpAction(OperationName(op->name, context), |
| ConversionTarget::LegalizationAction::Illegal); |
| } |
| } |
| |
| if (failed(applyFullConversion(moduleOp, target, std::move(patterns)))) { |
| return moduleOp.emitError() << "unable to fully apply conversion to module"; |
| } |
| |
| PassManager passManager(context); |
| mlir::applyPassManagerCLOptions(passManager); |
| auto &modulePasses = passManager.nest<IREE::VM::ModuleOp>(); |
| |
| if (targetOptions.optimize) { |
| // TODO(benvanik): does this run until it quiesces? |
| modulePasses.addPass(mlir::createInlinerPass()); |
| modulePasses.addPass(mlir::createCSEPass()); |
| modulePasses.addPass(mlir::createCanonicalizerPass()); |
| } |
| |
| modulePasses.addPass(createDropCompilerHintsPass()); |
| |
| // Mark up the module with ordinals for each top-level op (func, etc). |
| // This will make it easier to correlate the MLIR textual output to the |
| // binary output. |
| // We don't want any more modifications after this point as they could |
| // invalidate the ordinals. |
| modulePasses.addPass(IREE::VM::createOrdinalAllocationPass()); |
| |
| if (failed(passManager.run(moduleOp.getParentOfType<mlir::ModuleOp>()))) { |
| return moduleOp.emitError() << "failed during transform passes"; |
| } |
| |
| return success(); |
| } |
| |
| // Returns a vector of tables of type T or None if |contents| is empty. |
| template <typename T> |
| static Optional<Offset<Vector<Offset<T>>>> createOptionalVector( |
| const std::vector<Offset<T>> &contents, FlatBufferBuilder &fbb) { |
| if (contents.empty()) return llvm::None; |
| return fbb.CreateVector(contents); |
| } |
| template <typename T> |
| static Optional<Offset<Vector<T>>> createOptionalVector( |
| const std::vector<T> &contents, FlatBufferBuilder &fbb) { |
| if (contents.empty()) return llvm::None; |
| return fbb.CreateVector(contents); |
| } |
| |
| // Encodes a type (or a tuple of nested types) to a calling convention string. |
| // |
| // Examples: |
| // i32 -> i |
| // !vm.ref<...> -> r |
| // tuple<i32, i64> -> iI |
| static LogicalResult encodeCallingConventionType(Operation *op, Type type, |
| SmallVectorImpl<char> &s) { |
| if (auto refPtrType = type.dyn_cast<IREE::VM::RefType>()) { |
| s.push_back('r'); |
| return success(); |
| } else if (auto intType = type.dyn_cast<IntegerType>()) { |
| switch (intType.getIntOrFloatBitWidth()) { |
| default: |
| case 32: |
| s.push_back('i'); |
| return success(); |
| case 64: |
| s.push_back('I'); |
| return success(); |
| } |
| } else if (auto tupleType = type.dyn_cast<TupleType>()) { |
| // Flatten tuple (so tuple<i32, i64> -> `...iI...`). |
| SmallVector<Type, 4> flattenedTypes; |
| tupleType.getFlattenedTypes(flattenedTypes); |
| for (auto elementType : flattenedTypes) { |
| if (failed(encodeCallingConventionType(op, elementType, s))) { |
| return op->emitError() |
| << "unsupported external calling convention tuple element type " |
| << elementType; |
| } |
| } |
| return success(); |
| } |
| return op->emitError() << "unsupported external calling convention type " |
| << type; |
| } |
| |
| static LogicalResult encodeVariadicCallingConventionType( |
| Operation *op, Type type, SmallVectorImpl<char> &s) { |
| s.push_back('['); |
| auto result = encodeCallingConventionType(op, type, s); |
| s.push_back(']'); |
| return result; |
| } |
| |
| // Generates a string encoding the function type for defining the |
| // FunctionSignatureDef::calling_convention field for import functions. |
| // |
| // This differs from makeCallingConventionString in that it supports variadic |
| // arguments. Ideally we'd combine the two, but we only have this additional |
| // metadata on IREE::VM::ImportOp. |
| static Optional<std::string> makeImportCallingConventionString( |
| IREE::VM::ImportOp importOp) { |
| auto functionType = importOp.getType(); |
| if (functionType.getNumInputs() == 0 && functionType.getNumResults() == 0) { |
| return std::string{}; // Valid but empty. |
| } |
| |
| SmallVector<char, 8> s = {'0'}; |
| for (int i = 0; i < functionType.getNumInputs(); ++i) { |
| if (importOp.isFuncArgumentVariadic(i)) { |
| if (failed(encodeVariadicCallingConventionType( |
| importOp, functionType.getInput(i), s))) { |
| return None; |
| } |
| } else { |
| if (failed(encodeCallingConventionType(importOp, functionType.getInput(i), |
| s))) { |
| return None; |
| } |
| } |
| } |
| if (functionType.getNumResults() > 0) { |
| s.push_back('.'); |
| for (int i = 0; i < functionType.getNumResults(); ++i) { |
| if (failed(encodeCallingConventionType(importOp, |
| functionType.getResult(i), s))) { |
| return None; |
| } |
| } |
| } |
| return std::string(s.data(), s.size()); |
| } |
| |
| // Generates a string encoding the function type for defining the |
| // FunctionSignatureDef::calling_convention field for internal/export functions. |
| static Optional<std::string> makeCallingConventionString( |
| IREE::VM::FuncOp funcOp) { |
| auto functionType = funcOp.getType(); |
| if (functionType.getNumInputs() == 0 && functionType.getNumResults() == 0) { |
| return std::string{}; // Valid but empty. |
| } |
| |
| SmallVector<char, 8> s = {'0'}; |
| for (int i = 0; i < functionType.getNumInputs(); ++i) { |
| if (failed( |
| encodeCallingConventionType(funcOp, functionType.getInput(i), s))) { |
| return None; |
| } |
| } |
| if (functionType.getNumResults() > 0) { |
| s.push_back('.'); |
| for (int i = 0; i < functionType.getNumResults(); ++i) { |
| if (failed(encodeCallingConventionType(funcOp, functionType.getResult(i), |
| s))) { |
| return None; |
| } |
| } |
| } |
| return std::string(s.data(), s.size()); |
| } |
| |
| // Populates common fields for FunctionSignatureDefs of all function types. |
| static void populateFunctionSignatureDef(FunctionType functionType, |
| llvm::DenseMap<Type, int> &typeTable, |
| iree::vm::FunctionSignatureDefT &fsd) { |
| for (auto type : functionType.getInputs()) { |
| if (auto refPtrType = type.dyn_cast<IREE::VM::RefType>()) { |
| type = refPtrType.getObjectType(); |
| } |
| fsd.argument_types.push_back(typeTable.lookup(type)); |
| } |
| for (auto type : functionType.getResults()) { |
| if (auto refPtrType = type.dyn_cast<IREE::VM::RefType>()) { |
| type = refPtrType.getObjectType(); |
| } |
| fsd.result_types.push_back(typeTable.lookup(type)); |
| } |
| } |
| |
| // Returns a serialized function signature. |
| static Offset<iree::vm::FunctionSignatureDef> makeImportFunctionSignatureDef( |
| IREE::VM::ImportOp importOp, llvm::DenseMap<Type, int> &typeTable, |
| FlatBufferBuilder &fbb) { |
| // Common attributes. |
| iree::vm::FunctionSignatureDefT fsd; |
| populateFunctionSignatureDef(importOp.getType(), typeTable, fsd); |
| |
| // Generate the signature calling convention string based on types. |
| auto cconv = makeImportCallingConventionString(importOp); |
| if (!cconv.hasValue()) return {}; |
| fsd.calling_convention = cconv.getValue(); |
| |
| return iree::vm::FunctionSignatureDef::Pack(fbb, &fsd); |
| } |
| |
| // Returns a serialized function signature. |
| static Offset<iree::vm::FunctionSignatureDef> makeExportFunctionSignatureDef( |
| IREE::VM::ExportOp exportOp, IREE::VM::FuncOp funcOp, |
| llvm::DenseMap<Type, int> &typeTable, FlatBufferBuilder &fbb) { |
| // Common attributes. |
| iree::vm::FunctionSignatureDefT fsd; |
| populateFunctionSignatureDef(funcOp.getType(), typeTable, fsd); |
| |
| // Generate the signature calling convention string based on types. |
| auto cconv = makeCallingConventionString(funcOp); |
| if (!cconv.hasValue()) return {}; |
| fsd.calling_convention = cconv.getValue(); |
| |
| return iree::vm::FunctionSignatureDef::Pack(fbb, &fsd); |
| } |
| |
| // Returns a serialized function signature. |
| static Offset<iree::vm::FunctionSignatureDef> makeInternalFunctionSignatureDef( |
| IREE::VM::FuncOp funcOp, llvm::DenseMap<Type, int> &typeTable, |
| FlatBufferBuilder &fbb) { |
| // Common attributes. |
| iree::vm::FunctionSignatureDefT fsd; |
| populateFunctionSignatureDef(funcOp.getType(), typeTable, fsd); |
| |
| // Generate the signature calling convention string based on types. |
| // TODO(benvanik): only do this on exports. The runtime currently looks on |
| // internal functions, though, so we have to have it here. |
| auto cconv = makeCallingConventionString(funcOp); |
| if (!cconv.hasValue()) return {}; |
| fsd.calling_convention = cconv.getValue(); |
| |
| // Reflection attributes. |
| // TODO(benvanik): move these to exports (or remove entirely). |
| if (auto reflectionAttrs = |
| funcOp.getAttrOfType<DictionaryAttr>("iree.reflection")) { |
| llvm::SmallVector<Offset<iree::vm::ReflectionAttrDef>, 4> |
| reflectionAttrItems; |
| for (auto reflectionAttr : reflectionAttrs) { |
| auto key = reflectionAttr.first.strref(); |
| auto value = reflectionAttr.second.dyn_cast<StringAttr>(); |
| if (!value || key.empty()) continue; |
| auto rattr = std::make_unique<iree::vm::ReflectionAttrDefT>(); |
| rattr->key = key.str(); |
| rattr->value = value.getValue().str(); |
| fsd.reflection_attrs.push_back(std::move(rattr)); |
| } |
| } |
| |
| return iree::vm::FunctionSignatureDef::Pack(fbb, &fsd); |
| } |
| |
| // Builds a complete BytecodeModuleDef FlatBuffer object in |fbb|. |
| // The order of the encoding is ordered to ensure that all metadata is at the |
| // front of the resulting buffer. Large read-only data and bytecode blobs always |
| // fill the end of the file meaning that when memory-mapping the file most will |
| // not need to be paged in to do the initial module preparation. |
| // |
| // To keep the actual BytecodeModuleDef and resulting parsing code simple a lot |
| // has been packed into the top-level table. This results in a messier function |
| // here during serialization but a much more trivial (and cache-friendly) |
| // representation at runtime. |
| static Offset<iree::vm::BytecodeModuleDef> buildFlatBufferModule( |
| BytecodeTargetOptions targetOptions, IREE::VM::ModuleOp moduleOp, |
| FlatBufferBuilder &fbb) { |
| SymbolTable symbolTable(moduleOp); |
| auto symbolCounts = computeModuleSymbolCounts(moduleOp); |
| |
| // Find all structural ops in the module. |
| std::vector<IREE::VM::ImportOp> importFuncOps; |
| std::vector<IREE::VM::ExportOp> exportFuncOps; |
| std::vector<IREE::VM::FuncOp> internalFuncOps; |
| std::vector<IREE::VM::RodataOp> rodataOps; |
| importFuncOps.resize(symbolCounts.importFuncs); |
| exportFuncOps.resize(symbolCounts.exportFuncs); |
| internalFuncOps.resize(symbolCounts.internalFuncs); |
| rodataOps.resize(symbolCounts.rodatas); |
| for (auto &op : moduleOp.getBlock().getOperations()) { |
| if (auto funcOp = dyn_cast<IREE::VM::FuncOp>(op)) { |
| internalFuncOps[funcOp.ordinal().getValue().getLimitedValue()] = funcOp; |
| } else if (auto exportOp = dyn_cast<IREE::VM::ExportOp>(op)) { |
| exportFuncOps[exportOp.ordinal().getValue().getLimitedValue()] = exportOp; |
| } else if (auto importOp = dyn_cast<IREE::VM::ImportOp>(op)) { |
| importFuncOps[importOp.ordinal().getValue().getLimitedValue()] = importOp; |
| } else if (auto rodataOp = dyn_cast<IREE::VM::RodataOp>(op)) { |
| rodataOps[rodataOp.ordinal().getValue().getLimitedValue()] = rodataOp; |
| } |
| } |
| |
| // Serialize read-only data first so that it ends up at the end of the file. |
| // This is where large things like parameters live and we don't want that to |
| // get paged in until it is needed. |
| std::vector<Offset<Vector<uint8_t>>> rodataContentOffsets; |
| rodataContentOffsets.reserve(rodataOps.size()); |
| for (auto rodataOp : rodataOps) { |
| auto dataOffset = |
| serializeConstant(rodataOp.getLoc(), rodataOp.value(), fbb); |
| if (dataOffset.IsNull()) { |
| rodataOp.emitOpError() << "failed to encode"; |
| return {}; |
| } |
| rodataContentOffsets.push_back(dataOffset); |
| } |
| |
| // Find all types in the module to build the type table. |
| // Note that we don't emit it yet as we want to keep it near the top of the |
| // file (which, in FlatBuffers, is written last). |
| auto typeTable = buildTypeTable(moduleOp); |
| llvm::DenseMap<Type, int> typeOrdinalMap; |
| for (auto typeDef : llvm::enumerate(typeTable)) { |
| typeOrdinalMap[typeDef.value().type] = typeDef.index(); |
| } |
| |
| // Serialize function bytecode one at a time and then merge at the end. |
| std::vector<std::vector<uint8_t>> bytecodeDataParts; |
| std::vector<iree::vm::FunctionDescriptor> functionDescriptors; |
| bytecodeDataParts.reserve(internalFuncOps.size()); |
| functionDescriptors.reserve(internalFuncOps.size()); |
| size_t totalBytecodeLength = 0; |
| for (auto funcOp : internalFuncOps) { |
| auto encodedFunction = |
| BytecodeEncoder::encodeFunction(funcOp, typeOrdinalMap, symbolTable); |
| if (!encodedFunction) { |
| funcOp.emitError() << "failed to encode function bytecode"; |
| return {}; |
| } |
| functionDescriptors.push_back(iree::vm::FunctionDescriptor( |
| totalBytecodeLength, encodedFunction->bytecodeData.size(), |
| encodedFunction->i32RegisterCount, encodedFunction->refRegisterCount)); |
| totalBytecodeLength += encodedFunction->bytecodeData.size(); |
| bytecodeDataParts.push_back(std::move(encodedFunction->bytecodeData)); |
| } |
| // TODO(benvanik): compression? deduping? |
| uint8_t *bytecodeDataPtr = nullptr; |
| auto bytecodeDataOffset = fbb.CreateUninitializedVector<uint8_t>( |
| totalBytecodeLength, &bytecodeDataPtr); |
| size_t currentBytecodeOffset = 0; |
| for (const auto &it : llvm::enumerate(internalFuncOps)) { |
| int ordinal = it.index(); |
| auto data = std::move(bytecodeDataParts[ordinal]); |
| std::memcpy(bytecodeDataPtr + currentBytecodeOffset, data.data(), |
| data.size()); |
| currentBytecodeOffset += data.size(); |
| } |
| |
| // Serialize metadata that should be near the front of the file. |
| std::vector<Offset<iree::vm::RodataSegmentDef>> rodataSegmentOffsets; |
| rodataSegmentOffsets.reserve(rodataOps.size()); |
| for (auto rodataContentOffset : rodataContentOffsets) { |
| iree::vm::RodataSegmentDefBuilder rsd(fbb); |
| rsd.add_data(rodataContentOffset); |
| rodataSegmentOffsets.push_back(rsd.Finish()); |
| } |
| std::vector<Offset<iree::vm::RwdataSegmentDef>> rwdataSegmentOffsets; |
| std::vector<Offset<iree::vm::TypeDef>> typeOffsets; |
| typeOffsets.reserve(typeTable.size()); |
| for (auto &typeDef : typeTable) { |
| auto nameOffset = fbb.CreateString(typeDef.full_name); |
| iree::vm::TypeDefBuilder tdb(fbb); |
| tdb.add_full_name(nameOffset); |
| typeOffsets.push_back(tdb.Finish()); |
| } |
| std::vector<Offset<iree::vm::ImportFunctionDef>> importFuncOffsets; |
| importFuncOffsets.reserve(importFuncOps.size()); |
| for (auto importOp : importFuncOps) { |
| auto nameOffset = fbb.CreateString(importOp.getName().str()); |
| auto signatureOffset = |
| makeImportFunctionSignatureDef(importOp, typeOrdinalMap, fbb); |
| iree::vm::ImportFunctionDefBuilder ifd(fbb); |
| ifd.add_full_name(nameOffset); |
| ifd.add_signature(signatureOffset); |
| importFuncOffsets.push_back(ifd.Finish()); |
| } |
| std::vector<Offset<iree::vm::ExportFunctionDef>> exportFuncOffsets; |
| exportFuncOffsets.reserve(exportFuncOps.size()); |
| for (auto exportOp : exportFuncOps) { |
| auto nameOffset = fbb.CreateString(exportOp.export_name().str()); |
| auto funcOp = symbolTable.lookup<IREE::VM::FuncOp>(exportOp.function_ref()); |
| auto signatureOffset = |
| makeExportFunctionSignatureDef(exportOp, funcOp, typeOrdinalMap, fbb); |
| iree::vm::ExportFunctionDefBuilder efd(fbb); |
| efd.add_local_name(nameOffset); |
| efd.add_signature(signatureOffset); |
| efd.add_internal_ordinal(funcOp.ordinal().getValue().getLimitedValue()); |
| exportFuncOffsets.push_back(efd.Finish()); |
| } |
| std::vector<Offset<iree::vm::InternalFunctionDef>> internalFuncOffsets; |
| if (!targetOptions.stripSymbols) { |
| internalFuncOffsets.reserve(internalFuncOps.size()); |
| for (auto funcOp : internalFuncOps) { |
| auto nameOffset = fbb.CreateString(funcOp.getName().str()); |
| auto signatureOffset = |
| makeInternalFunctionSignatureDef(funcOp, typeOrdinalMap, fbb); |
| iree::vm::InternalFunctionDefBuilder ifd(fbb); |
| ifd.add_local_name(nameOffset); |
| ifd.add_signature(signatureOffset); |
| internalFuncOffsets.push_back(ifd.Finish()); |
| } |
| } |
| |
| auto functionDescriptorsOffset = |
| fbb.CreateVectorOfStructs(functionDescriptors); |
| auto rodataSegmentsOffset = createOptionalVector(rodataSegmentOffsets, fbb); |
| auto rwdataSegmentsOffset = createOptionalVector(rwdataSegmentOffsets, fbb); |
| auto internalFuncsOffset = fbb.CreateVector(internalFuncOffsets); |
| auto exportFuncsOffset = fbb.CreateVector(exportFuncOffsets); |
| auto importFuncsOffset = createOptionalVector(importFuncOffsets, fbb); |
| auto typesOffset = fbb.CreateVector(typeOffsets); |
| |
| Optional<Offset<iree::vm::ModuleStateDef>> moduleStateDef; |
| if (symbolCounts.globalBytes || symbolCounts.globalRefs) { |
| iree::vm::ModuleStateDefBuilder msd(fbb); |
| msd.add_global_bytes_capacity(symbolCounts.globalBytes); |
| msd.add_global_ref_count(symbolCounts.globalRefs); |
| moduleStateDef = msd.Finish(); |
| } |
| |
| auto nameOffset = fbb.CreateString( |
| moduleOp.sym_name().empty() ? "module" : moduleOp.sym_name().str()); |
| |
| iree::vm::BytecodeModuleDefBuilder bmd(fbb); |
| bmd.add_name(nameOffset); |
| bmd.add_types(typesOffset); |
| if (importFuncsOffset) { |
| bmd.add_imported_functions(importFuncsOffset.getValue()); |
| } |
| bmd.add_exported_functions(exportFuncsOffset); |
| bmd.add_internal_functions(internalFuncsOffset); |
| if (moduleStateDef) { |
| bmd.add_module_state(moduleStateDef.getValue()); |
| } |
| if (rwdataSegmentsOffset) { |
| bmd.add_rwdata_segments(rwdataSegmentsOffset.getValue()); |
| } |
| if (rodataSegmentsOffset) { |
| bmd.add_rodata_segments(rodataSegmentsOffset.getValue()); |
| } |
| bmd.add_function_descriptors(functionDescriptorsOffset); |
| bmd.add_bytecode_data(bytecodeDataOffset); |
| return bmd.Finish(); |
| } |
| |
| LogicalResult translateModuleToBytecode(IREE::VM::ModuleOp moduleOp, |
| BytecodeTargetOptions targetOptions, |
| llvm::raw_ostream &output) { |
| if (failed(canonicalizeModule(targetOptions, moduleOp))) { |
| return moduleOp.emitError() |
| << "failed to canonicalize vm.module to a serializable form"; |
| } |
| |
| if (targetOptions.outputFormat == BytecodeOutputFormat::kAnnotatedMlirText) { |
| // Run register allocation now and put the info in the IR so it's printed. |
| for (auto funcOp : moduleOp.getBlock().getOps<IREE::VM::FuncOp>()) { |
| if (!funcOp.empty()) { |
| if (failed(ValueLiveness::annotateIR(funcOp))) { |
| return funcOp.emitError() << "liveness analysis failed"; |
| } else if (failed(RegisterAllocation::annotateIR(funcOp))) { |
| return funcOp.emitError() << "register allocation failed"; |
| } |
| } |
| } |
| } |
| |
| if (targetOptions.outputFormat == BytecodeOutputFormat::kMlirText || |
| targetOptions.outputFormat == BytecodeOutputFormat::kAnnotatedMlirText) { |
| // Use the standard MLIR text printer. |
| moduleOp.getOperation()->print(output); |
| output << "\n"; |
| return success(); |
| } |
| |
| // NOTE: we order things so that all of the metadata is close to the start of |
| // the module header in memory. This ensures that when we map the file only |
| // the first few pages need to be accessed to get the metadata and the rest |
| // can be large bulk data. |
| FlatBufferBuilder fbb; |
| auto moduleDef = buildFlatBufferModule(targetOptions, moduleOp, fbb); |
| if (moduleDef.IsNull()) { |
| return moduleOp.emitError() |
| << "failed to build FlatBuffer BytecodeModuleDef"; |
| } |
| |
| iree::vm::FinishBytecodeModuleDefBuffer(fbb, moduleDef); |
| const uint8_t *flatbufferBytes = fbb.GetBufferPointer(); |
| size_t flatbufferByteSize = fbb.GetSize(); |
| |
| switch (targetOptions.outputFormat) { |
| case BytecodeOutputFormat::kFlatBufferBinary: |
| output.write(reinterpret_cast<const char *>(flatbufferBytes), |
| flatbufferByteSize); |
| break; |
| case BytecodeOutputFormat::kFlatBufferText: { |
| flatbuffers::ToStringVisitor toStringVisitor("\n", false, " ", false); |
| flatbuffers::IterateFlatBuffer(flatbufferBytes, |
| iree::vm::BytecodeModuleDefTypeTable(), |
| &toStringVisitor); |
| output << toStringVisitor.s << "\n"; |
| break; |
| } |
| default: |
| llvm_unreachable("unimplemented output format"); |
| } |
| |
| output.flush(); |
| return success(); |
| } |
| |
| LogicalResult translateModuleToBytecode(mlir::ModuleOp outerModuleOp, |
| BytecodeTargetOptions targetOptions, |
| llvm::raw_ostream &output) { |
| auto moduleOps = outerModuleOp.getOps<IREE::VM::ModuleOp>(); |
| if (moduleOps.empty()) { |
| return outerModuleOp.emitError() |
| << "outer module does not contain a vm.module op"; |
| } |
| return translateModuleToBytecode(*moduleOps.begin(), targetOptions, output); |
| } |
| |
| } // namespace VM |
| } // namespace IREE |
| } // namespace iree_compiler |
| } // namespace mlir |