| // Copyright 2019 The IREE Authors |
| // |
| // Licensed under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| |
| #ifndef IREE_BINDINGS_PYTHON_IREE_RT_VM_H_ |
| #define IREE_BINDINGS_PYTHON_IREE_RT_VM_H_ |
| |
| #include <optional> |
| |
| #include "./binding.h" |
| #include "./status_utils.h" |
| #include "iree/base/api.h" |
| #include "iree/vm/api.h" |
| #include "iree/vm/bytecode/module.h" |
| |
| namespace iree { |
| namespace python { |
| |
| class FunctionAbi; |
| |
| //------------------------------------------------------------------------------ |
| // Retain/release bindings |
| //------------------------------------------------------------------------------ |
| |
| template <> |
| struct ApiPtrAdapter<iree_vm_buffer_t> { |
| static void Retain(iree_vm_buffer_t* b) { iree_vm_buffer_retain(b); } |
| static void Release(iree_vm_buffer_t* b) { iree_vm_buffer_release(b); } |
| }; |
| |
| template <> |
| struct ApiPtrAdapter<iree_vm_instance_t> { |
| static void Retain(iree_vm_instance_t* b) { iree_vm_instance_retain(b); } |
| static void Release(iree_vm_instance_t* b) { iree_vm_instance_release(b); } |
| }; |
| |
| template <> |
| struct ApiPtrAdapter<iree_vm_context_t> { |
| static void Retain(iree_vm_context_t* b) { iree_vm_context_retain(b); } |
| static void Release(iree_vm_context_t* b) { iree_vm_context_release(b); } |
| }; |
| |
| template <> |
| struct ApiPtrAdapter<iree_vm_list_t> { |
| static void Retain(iree_vm_list_t* b) { iree_vm_list_retain(b); } |
| static void Release(iree_vm_list_t* b) { iree_vm_list_release(b); } |
| }; |
| |
| template <> |
| struct ApiPtrAdapter<iree_vm_module_t> { |
| static void Retain(iree_vm_module_t* b) { iree_vm_module_retain(b); } |
| static void Release(iree_vm_module_t* b) { iree_vm_module_release(b); } |
| }; |
| |
| template <> |
| struct ApiPtrAdapter<iree_vm_invocation_t> { |
| static void Retain(iree_vm_invocation_t* b) { iree_vm_invocation_retain(b); } |
| static void Release(iree_vm_invocation_t* b) { |
| iree_vm_invocation_release(b); |
| } |
| }; |
| |
| template <> |
| struct ApiPtrAdapter<iree_vm_ref_t> { |
| static void Retain(iree_vm_ref_t* b) { |
| iree_vm_ref_t out_ref; |
| std::memset(&out_ref, 0, sizeof(out_ref)); |
| iree_vm_ref_retain(b, &out_ref); |
| } |
| static void Release(iree_vm_ref_t* b) { iree_vm_ref_release(b); } |
| }; |
| |
| //------------------------------------------------------------------------------ |
| // VmBuffer |
| //------------------------------------------------------------------------------ |
| |
| class VmBuffer : public ApiRefCounted<VmBuffer, iree_vm_buffer_t> {}; |
| |
| //------------------------------------------------------------------------------ |
| // VmVariantList |
| // TODO: Rename to VmList |
| //------------------------------------------------------------------------------ |
| |
| class VmVariantList : public ApiRefCounted<VmVariantList, iree_vm_list_t> { |
| public: |
| static VmVariantList Create(iree_host_size_t capacity) { |
| iree_vm_list_t* list; |
| CheckApiStatus( |
| iree_vm_list_create(iree_vm_make_undefined_type_def(), capacity, |
| iree_allocator_system(), &list), |
| "Error allocating variant list"); |
| return VmVariantList::StealFromRawPtr(list); |
| } |
| |
| iree_host_size_t size() const { return iree_vm_list_size(raw_ptr()); } |
| |
| void AppendNullRef() { |
| iree_vm_ref_t null_ref = {0}; |
| CheckApiStatus(iree_vm_list_push_ref_move(raw_ptr(), &null_ref), |
| "Error appending to list"); |
| } |
| |
| std::string DebugString() const; |
| void PushFloat(double fvalue); |
| void PushInt(int64_t ivalue); |
| void PushList(VmVariantList& other); |
| void PushRef(py::handle ref_or_object); |
| py::object GetAsList(int index); |
| py::object GetAsRef(int index); |
| py::object GetAsObject(int index, py::object clazz); |
| py::object GetVariant(int index); |
| py::object GetAsSerializedTraceValue(int index); |
| }; |
| |
| //------------------------------------------------------------------------------ |
| // VmInstance |
| //------------------------------------------------------------------------------ |
| |
| class VmInstance : public ApiRefCounted<VmInstance, iree_vm_instance_t> { |
| public: |
| static VmInstance Create(); |
| }; |
| |
| //------------------------------------------------------------------------------ |
| // VmModule |
| //------------------------------------------------------------------------------ |
| |
| class VmModule : public ApiRefCounted<VmModule, iree_vm_module_t> { |
| public: |
| static VmModule ResolveModuleDependency(VmInstance* instance, |
| const std::string& name, |
| uint32_t minimum_version); |
| |
| static VmModule WrapBuffer(VmInstance* instance, py::object buffer_obj, |
| py::object destroy_callback, bool close_buffer); |
| static VmModule MMap(VmInstance* instance, std::string filepath, |
| py::object destroy_callback); |
| static VmModule CopyBuffer(VmInstance* instance, py::object buffer_obj); |
| static VmModule FromBuffer(VmInstance* instance, py::object buffer_obj, |
| bool warn_if_copy); |
| |
| std::optional<iree_vm_function_t> LookupFunction( |
| const std::string& name, iree_vm_function_linkage_t linkage); |
| |
| std::string name() const { |
| auto name_sv = iree_vm_module_name(raw_ptr()); |
| return std::string(name_sv.data, name_sv.size); |
| } |
| |
| py::object get_stashed_flatbuffer_blob() { return stashed_flatbuffer_blob; } |
| |
| private: |
| // If the module was created from a FlatBuffer blob, we stash it here. |
| // Note that this buffer will typically be captured here at the Python |
| // instance level and within the deallocator of the native vm module. |
| // Since this child field is destroyed first (before the base class wrapped |
| // vm module), we ensure that there are no dangling references when |
| // that deallocator is called. |
| py::object stashed_flatbuffer_blob = py::none(); |
| }; |
| |
| //------------------------------------------------------------------------------ |
| // VmContext |
| //------------------------------------------------------------------------------ |
| |
| class VmContext : public ApiRefCounted<VmContext, iree_vm_context_t> { |
| public: |
| // Creates a context, optionally with modules, which will make the context |
| // static, disallowing further module registration (and may be more |
| // efficient). |
| static VmContext Create(VmInstance* instance, |
| std::optional<std::vector<VmModule*>>& modules); |
| |
| // Registers additional modules. Only valid for non static contexts (i.e. |
| // those created without modules. |
| void RegisterModules(std::vector<VmModule*> modules); |
| |
| // Unique id for this context. |
| int context_id() const { return iree_vm_context_id(raw_ptr()); } |
| |
| // Synchronously invokes the given function. |
| void Invoke(iree_vm_function_t f, VmVariantList& inputs, |
| VmVariantList& outputs); |
| }; |
| |
| class VmInvocation : public ApiRefCounted<VmInvocation, iree_vm_invocation_t> { |
| }; |
| |
| //------------------------------------------------------------------------------ |
| // VmRef (represents a pointer to an arbitrary reference object). |
| //------------------------------------------------------------------------------ |
| |
| class VmRef { |
| public: |
| //---------------------------------------------------------------------------- |
| // Binds the reference protocol to a VmRefObject bound class. |
| // This defines three attributes: |
| // __iree_vm_type__() |
| // Gets the type from the object. |
| // [readonly property] __iree_vm_ref__ : |
| // Gets a VmRef from the object. |
| // __iree_vm_cast__(ref) : |
| // Dereferences the VmRef to the concrete type. Returns None on cast |
| // failure. |
| // |
| // In addition, a user attribute of "ref" will be added that is an alias of |
| // __iree_vm_ref__. |
| // |
| // An __eq__ method is added which returns true if the python objects refer |
| // to the same vm object. |
| // |
| // The BindRefProtocol() helper is used on a py::class_ defined for a |
| // reference object. It takes some of the C helper functions that are defined |
| // for each type and is generic. |
| //---------------------------------------------------------------------------- |
| static const char* const kTypeAttr; |
| static const char* const kRefAttr; |
| static const char* const kCastAttr; |
| |
| template <typename PyClass, typename TypeFunctor, typename RetainRefFunctor, |
| typename DerefFunctor, typename IsaFunctor> |
| static void BindRefProtocol(PyClass& cls, TypeFunctor type, |
| RetainRefFunctor retain_ref, DerefFunctor deref, |
| IsaFunctor isa) { |
| using WrapperType = typename PyClass::Type; |
| auto ref_lambda = [=](WrapperType& self) { |
| return VmRef::Steal(retain_ref(self.raw_ptr())); |
| }; |
| cls.def_static(VmRef::kTypeAttr, [=]() { return type(); }); |
| cls.def_prop_ro(VmRef::kRefAttr, ref_lambda); |
| cls.def_prop_ro("ref", ref_lambda); |
| cls.def_static(VmRef::kCastAttr, [=](VmRef& ref) -> py::object { |
| if (!isa(ref.ref())) { |
| return py::none(); |
| } |
| return py::cast(WrapperType::BorrowFromRawPtr(deref(ref.ref())), |
| py::rv_policy::move); |
| }); |
| cls.def("__eq__", [](WrapperType& self, WrapperType& other) { |
| return self.raw_ptr() == other.raw_ptr(); |
| }); |
| cls.def("__eq__", |
| [](WrapperType& self, py::object& other) { return false; }); |
| } |
| |
| // Initializes a null ref. |
| VmRef() { std::memset(&ref_, 0, sizeof(ref_)); } |
| VmRef(VmRef&& other) : ref_(other.ref_) { |
| std::memset(&other.ref_, 0, sizeof(other.ref_)); |
| } |
| VmRef(const VmRef&) = delete; |
| VmRef& operator=(const VmRef&) = delete; |
| ~VmRef() { |
| if (ref_.ptr) { |
| iree_vm_ref_release(&ref_); |
| } |
| } |
| |
| // Creates a VmRef from an owned ref, taking the reference count. |
| static VmRef Steal(iree_vm_ref_t ref) { return VmRef(ref); } |
| |
| iree_vm_ref_t& ref() { return ref_; } |
| |
| py::object Deref(py::object ref_object_class, bool optional); |
| bool IsInstance(py::object ref_object_class); |
| |
| std::string ToString(); |
| |
| private: |
| // Initializes with an owned ref. |
| VmRef(iree_vm_ref_t ref) : ref_(ref) {} |
| iree_vm_ref_t ref_; |
| }; |
| |
| void SetupVmBindings(nanobind::module_ m); |
| |
| } // namespace python |
| } // namespace iree |
| |
| #endif // IREE_BINDINGS_PYTHON_IREE_RT_VM_H_ |