| // Copyright 2019 Google LLC |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // https://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| |
| #ifndef IREE_VM_MODULE_ABI_CC_H_ |
| #define IREE_VM_MODULE_ABI_CC_H_ |
| |
| #include <cstring> |
| #include <memory> |
| |
| #include "absl/strings/str_cat.h" |
| #include "absl/types/span.h" |
| #include "iree/base/api.h" |
| #include "iree/base/api_util.h" |
| #include "iree/base/status.h" |
| #include "iree/vm/module.h" |
| #include "iree/vm/module_abi_packing.h" |
| #include "iree/vm/stack.h" |
| |
| #ifndef __cplusplus |
| #error "This header is meant for use with C++ module implementations." |
| #endif // __cplusplus |
| |
| namespace iree { |
| namespace vm { |
| |
| // A native module as exported to the VM dynamic module linking API. |
| // This allows easy wrapping of C++ module implementations and removes a |
| // majority of the boilerplate required with marshaling args/results out/in of |
| // the VM via the ABI. |
| // |
| // Functions are defined on the State type as member functions returning either |
| // Status or StatusOr. Arguments are passed as primitive types (int32_t), |
| // wrapped ref objects (vm::ref<my_type_t>&), or some nesting of std::array, |
| // std::tuple, and absl::Span to match fixed-length arrays of the same type, |
| // tuples of mixed types, or dynamic arrays (variadic arguments). Results may be |
| // returned as either their type or an std::tuple/std::array of types. |
| // |
| // Usage: |
| // // Per-context module state that must only be thread-compatible. |
| // // Define |
| // struct MyState final { |
| // StatusOr<std::tuple<int32_t, int32_t>> MyMethod1(vm::ref<my_type_t> t); |
| // }; |
| // |
| // // Table of functions mapped to their name in the IR. |
| // static const vm::NativeFunction<MyState> kMyFunctions[] = { |
| // vm::MakeNativeFunction("my_method_1", &MyState::MyMethod1), |
| // }; |
| // |
| // // The outer module wrapper shared across contexts. |
| // // Must be thread-safe. |
| // struct MyModule : public NativeModule<MyState> { |
| // StatusOr<std::unique_ptr<MyState>> CreateState(iree_allocator_t) { |
| // // You could pass in thread-safe shared resources to MyState. |
| // return std::make_unique<MyState>(); |
| // } |
| // }; |
| // |
| // // Creates the module and exposes it as a C interface. |
| // // Ownership transfers to the caller. |
| // iree_vm_module_t* create_my_module(iree_allocator_t allocator) { |
| // return std::make_unique<MyModule>("my_module", allocator, |
| // absl::MakeConstSpan(kCustomModuleFunctions)).release()->interface(); |
| // } |
| template <typename State> |
| class NativeModule { |
| public: |
| NativeModule(const char* name, iree_allocator_t allocator, |
| absl::Span<const NativeFunction<State>> dispatch_table) |
| : name_(name), allocator_(allocator), dispatch_table_(dispatch_table) { |
| CHECK_OK(FromApiStatus(iree_vm_module_init(&interface_, this), IREE_LOC)); |
| interface_.destroy = NativeModule::ModuleDestroy; |
| interface_.name = NativeModule::ModuleName; |
| interface_.signature = NativeModule::ModuleSignature; |
| interface_.get_function = NativeModule::ModuleGetFunction; |
| interface_.lookup_function = NativeModule::ModuleLookupFunction; |
| interface_.alloc_state = NativeModule::ModuleAllocState; |
| interface_.free_state = NativeModule::ModuleFreeState; |
| interface_.resolve_import = NativeModule::ModuleResolveImport; |
| interface_.execute = NativeModule::ModuleExecute; |
| } |
| |
| virtual ~NativeModule() = default; |
| |
| // C API module interface bound to this NativeModule instance. |
| iree_vm_module_t* interface() { return &interface_; } |
| |
| protected: |
| // Creates a new per-context module State holder. |
| virtual StatusOr<std::unique_ptr<State>> CreateState( |
| iree_allocator_t allocator) = 0; |
| |
| private: |
| static NativeModule* FromModulePointer(void* self) { |
| return reinterpret_cast<NativeModule*>(self); |
| } |
| static State* FromStatePointer(void* self) { |
| return reinterpret_cast<State*>(self); |
| } |
| |
| static iree_status_t ModuleDestroy(void* self) { |
| delete FromModulePointer(self); |
| return IREE_STATUS_OK; |
| } |
| |
| static iree_string_view_t ModuleName(void* self) { |
| auto* module = FromModulePointer(self); |
| return iree_make_cstring_view(module->name_); |
| } |
| |
| static iree_vm_module_signature_t ModuleSignature(void* self) { |
| auto* module = FromModulePointer(self); |
| iree_vm_module_signature_t signature = {0}; |
| signature.import_function_count = 0; |
| signature.export_function_count = module->dispatch_table_.size(); |
| signature.internal_function_count = 0; |
| return signature; |
| } |
| |
| static iree_status_t ModuleGetFunction( |
| void* self, iree_vm_function_linkage_t linkage, int32_t ordinal, |
| iree_vm_function_t* out_function, iree_string_view_t* out_name, |
| iree_vm_function_signature_t* out_signature) { |
| if (out_function) { |
| std::memset(out_function, 0, sizeof(*out_function)); |
| } |
| if (out_name) { |
| out_name->data = nullptr; |
| out_name->size = 0; |
| } |
| if (out_signature) { |
| std::memset(out_signature, 0, sizeof(*out_signature)); |
| } |
| auto* module = FromModulePointer(self); |
| if (ordinal < 0 || ordinal > module->dispatch_table_.size()) { |
| return IREE_STATUS_INVALID_ARGUMENT; |
| } |
| if (out_function) { |
| out_function->module = module->interface(); |
| out_function->linkage = IREE_VM_FUNCTION_LINKAGE_EXPORT; |
| out_function->ordinal = ordinal; |
| } |
| if (out_name) { |
| const auto& dispatch_function = module->dispatch_table_[ordinal]; |
| *out_name = iree_make_cstring_view(dispatch_function.name); |
| } |
| return IREE_STATUS_OK; |
| } |
| |
| static iree_status_t ModuleLookupFunction(void* self, |
| iree_vm_function_linkage_t linkage, |
| iree_string_view_t name, |
| iree_vm_function_t* out_function) { |
| if (!out_function) return IREE_STATUS_INVALID_ARGUMENT; |
| std::memset(out_function, 0, sizeof(*out_function)); |
| if (!name.data || !name.size) return IREE_STATUS_INVALID_ARGUMENT; |
| |
| auto* module = FromModulePointer(self); |
| out_function->module = module->interface(); |
| out_function->linkage = IREE_VM_FUNCTION_LINKAGE_EXPORT; |
| for (int i = 0; i < module->dispatch_table_.size(); ++i) { |
| if (iree_string_view_compare( |
| name, iree_make_cstring_view(module->dispatch_table_[i].name)) == |
| 0) { |
| out_function->ordinal = i; |
| return IREE_STATUS_OK; |
| } |
| } |
| return IREE_STATUS_NOT_FOUND; |
| } |
| |
| static iree_status_t ModuleAllocState( |
| void* self, iree_allocator_t allocator, |
| iree_vm_module_state_t** out_module_state) { |
| if (!out_module_state) return IREE_STATUS_INVALID_ARGUMENT; |
| *out_module_state = nullptr; |
| |
| auto* module = FromModulePointer(self); |
| auto module_state_or = module->CreateState(allocator); |
| if (!module_state_or.ok()) { |
| return ToApiStatus(module_state_or.status()); |
| } |
| auto module_state = std::move(module_state_or).ValueOrDie(); |
| |
| *out_module_state = |
| reinterpret_cast<iree_vm_module_state_t*>(module_state.release()); |
| return IREE_STATUS_OK; |
| } |
| |
| static iree_status_t ModuleFreeState(void* self, |
| iree_vm_module_state_t* module_state) { |
| if (!module_state) return IREE_STATUS_INVALID_ARGUMENT; |
| delete FromStatePointer(module_state); |
| return IREE_STATUS_OK; |
| } |
| |
| static iree_status_t ModuleResolveImport(void* self, |
| iree_vm_module_state_t* module_state, |
| int32_t ordinal, |
| iree_vm_function_t function) { |
| // C++ API does not yet support imports. |
| return IREE_STATUS_FAILED_PRECONDITION; |
| } |
| |
| static iree_status_t ModuleExecute(void* self, iree_vm_stack_t* stack, |
| iree_vm_stack_frame_t* frame, |
| iree_vm_execution_result_t* out_result) { |
| if (!out_result) return IREE_STATUS_INVALID_ARGUMENT; |
| std::memset(out_result, 0, sizeof(*out_result)); |
| if (!stack || !frame) return IREE_STATUS_INVALID_ARGUMENT; |
| int32_t ordinal = frame->function.ordinal; |
| auto* module = FromModulePointer(self); |
| if (ordinal < 0 || ordinal > module->dispatch_table_.size()) { |
| return IREE_STATUS_INVALID_ARGUMENT; |
| } |
| const auto& info = module->dispatch_table_[ordinal]; |
| auto* state = FromStatePointer(frame->module_state); |
| auto status = info.call(info.ptr, state, stack, frame, out_result); |
| if (!status.ok()) { |
| status = iree::Annotate( |
| status, |
| absl::StrCat("while executing ", module->name_, ".", info.name)); |
| return ToApiStatus(status); |
| } |
| return IREE_STATUS_OK; |
| } |
| |
| const char* name_; |
| const iree_allocator_t allocator_; |
| iree_vm_module_t interface_; |
| |
| const absl::Span<const NativeFunction<State>> dispatch_table_; |
| }; |
| |
| } // namespace vm |
| } // namespace iree |
| |
| #endif // IREE_VM_MODULE_ABI_CC_H_ |