blob: 0ecbcceda6315187ad9f23627a0c1eb0f53e1b49 [file] [log] [blame]
// Copyright 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef IREE_BINDINGS_PYTHON_PYIREE_VM_H_
#define IREE_BINDINGS_PYTHON_PYIREE_VM_H_
#include "absl/types/optional.h"
#include "bindings/python/pyiree/binding.h"
#include "bindings/python/pyiree/host_types.h"
#include "iree/base/api.h"
#include "iree/vm/api.h"
#include "iree/vm/bytecode_module.h"
#include "iree/vm/variant_list.h"
namespace iree {
namespace python {
class FunctionAbi;
//------------------------------------------------------------------------------
// Retain/release bindings
//------------------------------------------------------------------------------
template <>
struct ApiPtrAdapter<iree_vm_instance_t> {
static void Retain(iree_vm_instance_t* b) { iree_vm_instance_retain(b); }
static void Release(iree_vm_instance_t* b) { iree_vm_instance_release(b); }
};
template <>
struct ApiPtrAdapter<iree_vm_context_t> {
static void Retain(iree_vm_context_t* b) { iree_vm_context_retain(b); }
static void Release(iree_vm_context_t* b) { iree_vm_context_release(b); }
};
template <>
struct ApiPtrAdapter<iree_vm_module_t> {
static void Retain(iree_vm_module_t* b) { iree_vm_module_retain(b); }
static void Release(iree_vm_module_t* b) { iree_vm_module_release(b); }
};
template <>
struct ApiPtrAdapter<iree_vm_invocation_t> {
static void Retain(iree_vm_invocation_t* b) { iree_vm_invocation_retain(b); }
static void Release(iree_vm_invocation_t* b) {
iree_vm_invocation_release(b);
}
};
//------------------------------------------------------------------------------
// VmVariantList
//------------------------------------------------------------------------------
class VmVariantList {
public:
VmVariantList() : list_(nullptr) {}
~VmVariantList() {
if (list_) {
CheckApiStatus(iree_vm_variant_list_free(list_), "Error freeing list");
}
}
VmVariantList(VmVariantList&& other) {
list_ = other.list_;
other.list_ = nullptr;
}
VmVariantList& operator=(const VmVariantList&) = delete;
VmVariantList(const VmVariantList&) = delete;
static VmVariantList Create(iree_host_size_t capacity) {
iree_vm_variant_list_t* list;
CheckApiStatus(
iree_vm_variant_list_alloc(capacity, IREE_ALLOCATOR_SYSTEM, &list),
"Error allocating variant list");
return VmVariantList(list);
}
iree_host_size_t size() const { return iree_vm_variant_list_size(list_); }
iree_vm_variant_list_t* raw_ptr() { return list_; }
const iree_vm_variant_list_t* raw_ptr() const { return list_; }
void AppendNullRef() {
CheckApiStatus(iree_vm_variant_list_append_null_ref(raw_ptr()),
"Error appending to list");
}
std::string DebugString() const;
private:
VmVariantList(iree_vm_variant_list_t* list) : list_(list) {}
iree_vm_variant_list_t* list_;
};
//------------------------------------------------------------------------------
// ApiRefCounted types
//------------------------------------------------------------------------------
class VmInstance : public ApiRefCounted<VmInstance, iree_vm_instance_t> {
public:
static VmInstance Create();
};
class VmModule : public ApiRefCounted<VmModule, iree_vm_module_t> {
public:
static VmModule FromFlatbufferBlob(
std::shared_ptr<OpaqueBlob> flatbuffer_blob);
absl::optional<iree_vm_function_t> LookupFunction(
const std::string& name, iree_vm_function_linkage_t linkage);
std::string name() const {
auto name_sv = iree_vm_module_name(raw_ptr());
return std::string(name_sv.data, name_sv.size);
}
};
class VmContext : public ApiRefCounted<VmContext, iree_vm_context_t> {
public:
// Creates a context, optionally with modules, which will make the context
// static, disallowing further module registration (and may be more
// efficient).
static VmContext Create(VmInstance* instance,
absl::optional<std::vector<VmModule*>> modules);
// Registers additional modules. Only valid for non static contexts (i.e.
// those created without modules.
void RegisterModules(std::vector<VmModule*> modules);
// Unique id for this context.
int context_id() const { return iree_vm_context_id(raw_ptr()); }
// Synchronously invokes the given function.
void Invoke(iree_vm_function_t f, VmVariantList& inputs,
VmVariantList& outputs);
// Creates a function ABI suitable for marshalling function inputs/results.
std::unique_ptr<FunctionAbi> CreateFunctionAbi(
HalDevice& device, std::shared_ptr<HostTypeFactory> host_type_factory,
iree_vm_function_t f);
};
class VmInvocation : public ApiRefCounted<VmInvocation, iree_vm_invocation_t> {
};
void SetupVmBindings(pybind11::module m);
} // namespace python
} // namespace iree
#endif // IREE_BINDINGS_PYTHON_PYIREE_VM_H_