blob: a969159c8322a9e52757c5a03284e0121e9a17ff [file] [log] [blame]
// Copyright 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "bindings/python/pyiree/vm.h"
#include "absl/strings/str_cat.h"
#include "absl/types/optional.h"
#include "bindings/python/pyiree/function_abi.h"
#include "bindings/python/pyiree/status_utils.h"
#include "iree/base/api.h"
#include "iree/modules/hal/hal_module.h"
#include "iree/vm/invocation.h"
#include "iree/vm/module.h"
namespace iree {
namespace python {
namespace {
VmModule CreateHalModule(HalDevice* device) {
iree_vm_module_t* module;
CheckApiStatus(
iree_hal_module_create(device->raw_ptr(), IREE_ALLOCATOR_SYSTEM, &module),
"Error creating hal module");
return VmModule::CreateRetained(module);
}
} // namespace
//------------------------------------------------------------------------------
// VmInstance
//------------------------------------------------------------------------------
VmInstance VmInstance::Create() {
iree_vm_instance_t* instance;
auto status = iree_vm_instance_create(IREE_ALLOCATOR_SYSTEM, &instance);
CheckApiStatus(status, "Error creating instance");
return VmInstance::CreateRetained(instance);
}
//------------------------------------------------------------------------------
// VmContext
//------------------------------------------------------------------------------
VmContext VmContext::Create(VmInstance* instance,
absl::optional<std::vector<VmModule*>> modules) {
iree_vm_context_t* context;
if (!modules) {
// Simple create with open allowed modules.
auto status = iree_vm_context_create(instance->raw_ptr(),
IREE_ALLOCATOR_SYSTEM, &context);
CheckApiStatus(status, "Error creating vm context");
} else {
// Closed set of modules.
absl::InlinedVector<iree_vm_module_t*, 8> module_handles;
module_handles.resize(modules->size());
for (size_t i = 0, e = module_handles.size(); i < e; ++i) {
module_handles[i] = (*modules)[i]->raw_ptr();
}
auto status = iree_vm_context_create_with_modules(
instance->raw_ptr(), module_handles.data(), module_handles.size(),
IREE_ALLOCATOR_SYSTEM, &context);
CheckApiStatus(status, "Error creating vm context with modules");
}
CHECK(context);
return VmContext::CreateRetained(context);
}
void VmContext::RegisterModules(std::vector<VmModule*> modules) {
absl::InlinedVector<iree_vm_module_t*, 8> module_handles;
module_handles.resize(modules.size());
for (size_t i = 0, e = module_handles.size(); i < e; ++i) {
module_handles[i] = modules[i]->raw_ptr();
}
auto status = iree_vm_context_register_modules(raw_ptr(), &module_handles[0],
module_handles.size());
CheckApiStatus(status, "Error registering modules");
}
std::unique_ptr<FunctionAbi> VmContext::CreateFunctionAbi(
HalDevice& device, std::shared_ptr<HostTypeFactory> host_type_factory,
iree_vm_function_t f) {
// Resolve attrs.
absl::InlinedVector<std::pair<iree_string_view_t, iree_string_view_t>, 4>
attrs;
for (int i = 0;; ++i) {
attrs.push_back({});
auto status = iree_vm_get_function_reflection_attr(
f, i, &attrs.back().first, &attrs.back().second);
if (iree_status_is_not_found(status)) {
attrs.pop_back();
break;
}
CheckApiStatus(status, "Error getting reflection attr");
}
auto attr_lookup =
[&attrs](absl::string_view key) -> absl::optional<absl::string_view> {
for (const auto& attr : attrs) {
absl::string_view found_key(attr.first.data, attr.first.size);
absl::string_view found_value(attr.second.data, attr.second.size);
if (found_key == key) return found_value;
}
return absl::nullopt;
};
return FunctionAbi::Create(device, std::move(host_type_factory), attr_lookup);
}
void VmContext::Invoke(iree_vm_function_t f, VmVariantList& inputs,
VmVariantList& outputs) {
CheckApiStatus(iree_vm_invoke(raw_ptr(), f, nullptr, inputs.raw_ptr(),
outputs.raw_ptr(), IREE_ALLOCATOR_SYSTEM),
"Error invoking function");
}
//------------------------------------------------------------------------------
// VmModule
//------------------------------------------------------------------------------
VmModule VmModule::FromFlatbufferBlob(
std::shared_ptr<OpaqueBlob> flatbuffer_blob) {
iree_vm_module_t* module;
auto deallocator = OpaqueBlob::CreateDeallocator(flatbuffer_blob);
auto status = iree_vm_bytecode_module_create(
{static_cast<const uint8_t*>(flatbuffer_blob->data()),
flatbuffer_blob->size()},
deallocator, IREE_ALLOCATOR_SYSTEM, &module);
if (!iree_status_is_ok(status)) {
deallocator.free(deallocator.self, nullptr);
}
CheckApiStatus(status, "Error creating vm module from flatbuffer");
return VmModule::CreateRetained(module);
}
absl::optional<iree_vm_function_t> VmModule::LookupFunction(
const std::string& name, iree_vm_function_linkage_t linkage) {
iree_vm_function_t f;
auto status = iree_vm_module_lookup_function_by_name(
raw_ptr(), linkage, {name.data(), name.size()}, &f);
if (iree_status_is_not_found(status)) {
return absl::nullopt;
}
CheckApiStatus(status, "Error looking up function");
return f;
}
//------------------------------------------------------------------------------
// VmVariantList
//------------------------------------------------------------------------------
std::string VmVariantList::DebugString() const {
// The variant list API requires mutability, so we const cast to it internally
// so we can maintain a const DebugString() for callers.
auto mutable_this = const_cast<VmVariantList*>(this);
std::string s;
absl::StrAppend(&s, "<VmVariantList(", size(), "): [");
for (iree_host_size_t i = 0, e = size(); i < e; ++i) {
iree_vm_variant_t* variant =
iree_vm_variant_list_get(mutable_this->raw_ptr(), i);
if (i > 0) absl::StrAppend(&s, ", ");
if (IREE_VM_VARIANT_IS_VALUE(variant)) {
absl::StrAppend(&s, variant->i32);
} else if (IREE_VM_VARIANT_IS_REF(variant)) {
// Pretty print a subset of ABI impacting known types.
if (iree_hal_buffer_isa(&variant->ref)) {
auto* hal_buffer = iree_hal_buffer_deref(&variant->ref);
assert(hal_buffer);
absl::StrAppend(&s, "HalBuffer(",
iree_hal_buffer_byte_length(hal_buffer), ")");
} else {
absl::StrAppend(&s, "Unknown(", variant->ref_type, ")");
}
} else {
absl::StrAppend(&s, "None");
}
}
absl::StrAppend(&s, "]>");
return s;
}
void SetupVmBindings(pybind11::module m) {
IREE_CHECK_OK(iree_vm_register_builtin_types());
IREE_CHECK_OK(iree_hal_module_register_types());
// Built-in module creation.
m.def("create_hal_module", &CreateHalModule);
py::enum_<iree_vm_function_linkage_t>(m, "Linkage")
.value("INTERNAL", IREE_VM_FUNCTION_LINKAGE_INTERNAL)
.value("IMPORT", IREE_VM_FUNCTION_LINKAGE_IMPORT)
.value("EXPORT", IREE_VM_FUNCTION_LINKAGE_EXPORT)
.export_values();
// Mutation and inspection of the variant list is mostly opaque to python.
py::class_<VmVariantList>(m, "VmVariantList")
.def(py::init(&VmVariantList::Create))
.def_property_readonly("size", &VmVariantList::size)
.def("__repr__", &VmVariantList::DebugString);
py::class_<iree_vm_function_t>(m, "VmFunction")
.def_readonly("linkage", &iree_vm_function_t::linkage)
.def_readonly("ordinal", &iree_vm_function_t::ordinal);
py::class_<VmInstance>(m, "VmInstance").def(py::init(&VmInstance::Create));
py::class_<VmContext>(m, "VmContext")
.def(py::init(&VmContext::Create), py::arg("instance"),
py::arg("modules") = absl::optional<std::vector<VmModule*>>())
.def("register_modules", &VmContext::RegisterModules)
.def_property_readonly("context_id", &VmContext::context_id)
.def("create_function_abi", &VmContext::CreateFunctionAbi,
py::arg("device"), py::arg("host_type_factory"), py::arg("f"))
.def("invoke", &VmContext::Invoke);
py::class_<VmModule>(m, "VmModule")
.def_static("from_flatbuffer", &VmModule::FromFlatbufferBlob)
.def_property_readonly("name", &VmModule::name)
.def("lookup_function", &VmModule::LookupFunction, py::arg("name"),
py::arg("linkage") = IREE_VM_FUNCTION_LINKAGE_EXPORT);
}
} // namespace python
} // namespace iree