First step of splitting python bindings.
* As discussed on mailing list, introduces pyiree.{compiler,rt,tf.support}
* Separate module shared libraries for compiler and rt
* Does not yet fully separate the TensorFlow compiler from the generic compiler (follow-on)
* Fixes namespace fallout
* This will almost certainly require patching into various platforms and verifying. Please pre-review.
PiperOrigin-RevId: 288361204
diff --git a/bindings/python/pyiree/rt/BUILD b/bindings/python/pyiree/rt/BUILD
new file mode 100644
index 0000000..56d2c1d
--- /dev/null
+++ b/bindings/python/pyiree/rt/BUILD
@@ -0,0 +1,291 @@
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+load(
+ "//bindings/python/pyiree:build_defs.bzl",
+ "NUMPY_DEPS",
+ "PLATFORM_VULKAN_DEPS",
+ "PYBIND_COPTS",
+ "PYBIND_EXTENSION_COPTS",
+ "PYBIND_FEATURES",
+ "iree_py_extension",
+ "pybind_cc_library",
+)
+
+package(
+ default_visibility = ["//visibility:public"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+DRIVER_DEPS = PLATFORM_VULKAN_DEPS + [
+ "//iree/hal/interpreter:interpreter_driver_module",
+ "//iree/hal/vulkan:vulkan_driver_module",
+]
+
+py_library(
+ name = "rt",
+ srcs = [
+ "__init__.py",
+ "system_api.py",
+ ],
+ srcs_version = "PY3",
+ deps = [
+ ":binding",
+ "//bindings/python:pathsetup", # build_cleaner: keep
+ ],
+)
+
+iree_py_extension(
+ name = "binding",
+ srcs = [
+ "initialize_module.cc",
+ ],
+ copts = PYBIND_COPTS + PYBIND_EXTENSION_COPTS,
+ features = PYBIND_FEATURES,
+ linkstatic = 1,
+ # TODO(b/145815906) Get this running in OSS CI.
+ tags = ["nokokoro"],
+ win_def_file = "export.def",
+ deps = DRIVER_DEPS + [
+ ":rt_library",
+ "//bindings/python/pyiree/common",
+ "//iree/base:initializer",
+ "//iree/base:tracing",
+ "@com_google_tracing_framework_cpp//:tracing_framework_bindings_cpp",
+ ],
+)
+
+pybind_cc_library(
+ name = "rt_library",
+ srcs = [
+ "function_abi.cc",
+ "hal.cc",
+ "host_types.cc",
+ "vm.cc",
+ ],
+ hdrs = [
+ "function_abi.h",
+ "hal.h",
+ "host_types.h",
+ "vm.h",
+ ],
+ deps = [
+ "//bindings/python/pyiree/common",
+ "//iree/base:api",
+ "//iree/base:signature_mangle",
+ "//iree/hal:api",
+ "//iree/modules/hal",
+ "//iree/vm",
+ "//iree/vm:bytecode_module",
+ "//iree/vm:invocation",
+ "//iree/vm:module",
+ "//iree/vm:ref",
+ "//iree/vm:variant_list",
+ "@com_google_absl//absl/container:inlined_vector",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:optional",
+ "@com_google_absl//absl/types:span",
+ ],
+)
+
+py_library(
+ name = "system_api",
+ srcs = ["system_api.py"],
+ srcs_version = "PY3",
+ # TODO(b/145815906) Get this running in OSS CI.
+ tags = ["nokokoro"],
+ deps = [
+ ":binding",
+ "//bindings/python:pathsetup", # build_cleaner: keep
+ ],
+)
+
+py_test(
+ name = "function_abi_test",
+ srcs = ["function_abi_test.py"],
+ python_version = "PY3",
+ # TODO(laurenzo): Enable once test does not depend on a real vulkan device.
+ tags = [
+ "noga",
+ "nokokoro",
+ "notap",
+ ],
+ deps = NUMPY_DEPS + [
+ "//bindings/python:pathsetup", # build_cleaner: keep
+ "@absl_py//absl/testing:absltest",
+ "//bindings/python/pyiree/rt",
+ ],
+ # TODO(b/145815906) Get this running in OSS CI.
+)
+
+py_test(
+ name = "hal_test",
+ srcs = ["hal_test.py"],
+ python_version = "PY3",
+ # TODO(b/145815906) Get this running in OSS CI.
+ tags = ["nokokoro"],
+ deps = NUMPY_DEPS + [
+ "//bindings/python:pathsetup", # build_cleaner: keep
+ "@absl_py//absl/testing:absltest",
+ "//bindings/python/pyiree/rt",
+ ],
+)
+
+py_test(
+ name = "system_api_test",
+ srcs = ["system_api_test.py"],
+ python_version = "PY3",
+ tags = [
+ # TODO(laurenzo): Enable once test does not depend on a real vulkan device.
+ "notap",
+ # TODO(b/145815906) Get this running in OSS CI.
+ "noga",
+ "nokokoro",
+ ],
+ deps = NUMPY_DEPS + [
+ ":system_api",
+ "//bindings/python:pathsetup", # build_cleaner: keep
+ "@absl_py//absl/testing:absltest",
+ "//bindings/python/pyiree/compiler",
+ "//bindings/python/pyiree/rt",
+ ],
+)
+
+py_test(
+ name = "vm_test",
+ srcs = ["vm_test.py"],
+ python_version = "PY3",
+ # TODO(laurenzo): Enable once test does not depend on a real vulkan device.
+ tags = [
+ "noga",
+ "nokokoro",
+ "notap",
+ ],
+ deps = NUMPY_DEPS + [
+ "//bindings/python:pathsetup", # build_cleaner: keep
+ "@absl_py//absl/testing:absltest",
+ "//bindings/python/pyiree/compiler",
+ "//bindings/python/pyiree/rt",
+ ],
+)
+
+# py_library(
+# name = "pyiree",
+# srcs = [
+# "__init__.py",
+# ],
+# srcs_version = "PY3",
+# # TODO(b/145815906) Get this running in OSS CI.
+# tags = ["nokokoro"],
+# deps = [
+# ":binding",
+# ":compiler",
+# ":system_api",
+# "//bindings/python:pathsetup", # build_cleaner: keep
+# ] + select({
+# "//iree:enable_tensorflow": [
+# "//bindings/python/pyiree/tf_interop:test_utils",
+# "//bindings/python/pyiree/tf_interop:tf_test_driver",
+# ],
+# "//conditions:default": [
+# ],
+# }),
+# )
+#
+#
+# cc_library(
+# name = "base",
+# srcs = [
+# "compiler.cc",
+# "function_abi.cc",
+# "hal.cc",
+# "host_types.cc",
+# "status_utils.cc",
+# "vm.cc",
+# ],
+# hdrs = [
+# "binding.h",
+# "compiler.h",
+# "function_abi.h",
+# "hal.h",
+# "host_types.h",
+# "status_utils.h",
+# "vm.h",
+# ],
+# copts = DEFAULT_COPTS,
+# features = DEFAULT_FEATURES,
+# # TODO(b/145815906) Get this running in OSS CI.
+# tags = ["nokokoro"],
+# deps = [
+# "@com_google_absl//absl/container:inlined_vector",
+# "@com_google_absl//absl/memory",
+# "@com_google_absl//absl/strings",
+# "@com_google_absl//absl/types:optional",
+# "@com_google_absl//absl/types:span",
+# "@com_google_absl//absl/types:variant",
+# "//iree/compiler/Utils",
+# "//iree/modules/hal",
+# "//iree/vm",
+# "//iree/vm:bytecode_module",
+# "//iree/vm:invocation",
+# "//iree/vm:module",
+# "//iree/vm:ref",
+# "//iree/vm:variant_list",
+# "//third_party/llvm/llvm/projects/google_mlir:IR",
+# "//iree/base:api",
+# "//iree/base:status",
+# "//iree/base:signature_mangle",
+# "//iree/hal:api",
+# "//iree/vm:api",
+# "@llvm-project//llvm:support",
+# "//third_party/llvm/llvm/projects/google_mlir:Parser",
+# "//third_party/llvm/llvm/projects/google_mlir:Pass",
+# "@iree_pybind11//:pybind11",
+# ] + COMPILER_DEPS + DRIVER_DEPS + PYTHON_HEADERS_DEPS,
+# )
+#
+# iree_py_extension(
+# name = "binding",
+# srcs = [
+# "initialize_module.cc",
+# ],
+# copts = DEFAULT_COPTS,
+# features = DEFAULT_FEATURES,
+# linkstatic = 1,
+# # TODO(b/145815906) Get this running in OSS CI.
+# tags = ["nokokoro"],
+# win_def_file = "export.def",
+# deps = [
+# ":base",
+# "//bindings/python/pyiree/tf_interop",
+# "//iree/base:initializer",
+# "//iree/base:tracing",
+# "@com_google_tracing_framework_cpp//:tracing_framework_bindings_cpp",
+# ],
+# )
+#
+# py_test(
+# name = "compiler_test",
+# srcs = ["compiler_test.py"],
+# python_version = "PY3",
+# # TODO(b/145815906) Get this running in OSS CI.
+# tags = ["nokokoro"],
+# deps = [
+# "//bindings/python:pathsetup", # build_cleaner: keep
+# "@absl_py//absl/testing:absltest",
+# "//bindings/python/pyiree",
+# ],
+# )
+#
diff --git a/bindings/python/pyiree/rt/__init__.py b/bindings/python/pyiree/rt/__init__.py
new file mode 100644
index 0000000..b3c5533
--- /dev/null
+++ b/bindings/python/pyiree/rt/__init__.py
@@ -0,0 +1,33 @@
+"""Module init for the python bindings."""
+
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# pylint: disable=g-multiple-import
+# pylint: disable=g-bad-import-order
+# pylint: disable=wildcard-import
+
+from . import binding
+
+# Pull some of the native symbols into the public API.
+# FunctionAbi imports
+from .binding import FunctionAbi
+# Hal imports
+from .binding import BufferUsage, HalBuffer, HalDevice, HalDriver, MemoryAccess, MemoryType, Shape
+# HostTypeFactory imports
+from .binding import HostTypeFactory
+# Vm imports
+from .binding import create_hal_module, Linkage, VmVariantList, VmFunction, VmInstance, VmContext, VmModule
+# SystemApi
+from .system_api import *
diff --git a/bindings/python/pyiree/rt/export.def b/bindings/python/pyiree/rt/export.def
new file mode 100644
index 0000000..85fd7ca
--- /dev/null
+++ b/bindings/python/pyiree/rt/export.def
@@ -0,0 +1,17 @@
+;; Copyright 2019 Google LLC
+;;
+;; Licensed under the Apache License, Version 2.0 (the "License");
+;; you may not use this file except in compliance with the License.
+;; You may obtain a copy of the License at
+;;
+;; https://www.apache.org/licenses/LICENSE-2.0
+;;
+;; Unless required by applicable law or agreed to in writing, software
+;; distributed under the License is distributed on an "AS IS" BASIS,
+;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+;; See the License for the specific language governing permissions and
+;; limitations under the License.
+
+LIBRARY BINDING
+EXPORTS
+ PyInit_binding @1
diff --git a/bindings/python/pyiree/rt/function_abi.cc b/bindings/python/pyiree/rt/function_abi.cc
new file mode 100644
index 0000000..1a3efd0
--- /dev/null
+++ b/bindings/python/pyiree/rt/function_abi.cc
@@ -0,0 +1,418 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "bindings/python/pyiree/rt/function_abi.h"
+
+#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
+#include "absl/types/span.h"
+#include "bindings/python/pyiree/common/status_utils.h"
+#include "bindings/python/pyiree/rt/hal.h"
+#include "bindings/python/pyiree/rt/vm.h"
+#include "iree/base/api.h"
+#include "iree/base/signature_mangle.h"
+#include "iree/hal/api.h"
+#include "iree/modules/hal/hal_module.h"
+#include "iree/vm/ref.h"
+#include "iree/vm/variant_list.h"
+
+namespace iree {
+namespace python {
+
+namespace {
+
+// Python friendly entry-point for creating an instance from a list
+// of attributes. This is not particularly efficient and is primarily
+// for testing. Typically, this will be created directly from a function
+// and the attribute introspection will happen internal to C++.
+std::unique_ptr<FunctionAbi> PyCreateAbi(
+ HalDevice& device, std::shared_ptr<HostTypeFactory> host_type_factory,
+ std::vector<std::pair<std::string, std::string>> attrs) {
+ auto lookup =
+ [&attrs](absl::string_view key) -> absl::optional<absl::string_view> {
+ for (const auto& kv : attrs) {
+ if (kv.first == key) return kv.second;
+ }
+ return absl::nullopt;
+ };
+ return FunctionAbi::Create(device, std::move(host_type_factory), lookup);
+}
+
+VmVariantList PyRawPack(FunctionAbi* self,
+ absl::Span<const FunctionAbi::Description> descs,
+ py::sequence py_args, bool writable) {
+ if (py_args.size() != descs.size()) {
+ throw RaiseValueError("Mismatched pack arity");
+ }
+
+ VmVariantList f_args = VmVariantList::Create(py_args.size());
+ absl::InlinedVector<py::handle, 8> local_py_args(py_args.begin(),
+ py_args.end());
+ self->RawPack(descs, absl::MakeSpan(local_py_args), f_args, writable);
+ return f_args;
+}
+
+VmVariantList PyAllocateResults(FunctionAbi* self, VmVariantList& f_args,
+ bool static_alloc) {
+ auto f_results = VmVariantList::Create(self->raw_result_arity());
+ if (static_alloc) {
+ // For static dispatch, attempt to fully allocate and perform shape
+ // inference.
+ self->AllocateResults(absl::MakeConstSpan(self->raw_config().results),
+ f_args, f_results);
+ }
+ return f_results;
+}
+
+py::object PyRawUnpackResults(FunctionAbi* self, VmVariantList& f_args) {
+ absl::InlinedVector<py::object, 4> py_results;
+ py_results.resize(f_args.size());
+ self->RawUnpack(absl::MakeConstSpan(self->raw_config().results), f_args,
+ absl::MakeSpan(py_results));
+ py::tuple py_result_tuple(py_results.size());
+ for (size_t i = 0, e = py_results.size(); i < e; ++i) {
+ py_result_tuple[i] = std::move(py_results[i]);
+ }
+ return py_result_tuple;
+}
+
+// RAII wrapper for a Py_buffer which calls PyBuffer_Release when it goes
+// out of scope.
+class PyBufferReleaser {
+ public:
+ PyBufferReleaser(Py_buffer& b) : b_(b) {}
+ ~PyBufferReleaser() { PyBuffer_Release(&b_); }
+
+ private:
+ Py_buffer& b_;
+};
+
+pybind11::error_already_set RaiseBufferMismatchError(
+ std::string message, py::handle obj,
+ const RawSignatureParser::Description& desc) {
+ message.append("For argument = ");
+ auto arg_py_str = py::str(obj);
+ auto arg_str = static_cast<std::string>(arg_py_str);
+ message.append(arg_str);
+ message.append(" (expected ");
+ desc.ToString(message);
+ message.append(")");
+ return RaiseValueError(message.c_str());
+}
+
+// Verifies and maps the py buffer shape and layout to the bound argument.
+// Returns false if not compatible.
+void MapBufferAttrs(Py_buffer& py_view,
+ const RawSignatureParser::Description& desc,
+ absl::InlinedVector<int, 2>& dynamic_dims) {
+ // Verify that rank matches.
+ if (py_view.ndim != desc.dims.size()) {
+ throw RaiseBufferMismatchError(
+ absl::StrCat("Mismatched buffer rank (received: ", py_view.ndim,
+ ", expected: ", desc.dims.size(), "): "),
+ py::handle(py_view.obj), desc);
+ }
+
+ // Verify that the item size matches.
+ size_t f_item_size =
+ AbiConstants::kScalarTypeSize[static_cast<int>(desc.buffer.scalar_type)];
+ if (f_item_size != py_view.itemsize) {
+ throw RaiseBufferMismatchError(
+ absl::StrCat("Mismatched buffer item size (received: ",
+ py_view.itemsize, ", expected: ", f_item_size, "): "),
+ py::handle(py_view.obj), desc);
+ }
+
+ // Note: The python buffer format does not map precisely to IREE's type
+ // system, so the below is only advisory for where they do match. Otherwise,
+ // it is basically a bitcast.
+ const char* f_expected_format =
+ kScalarTypePyFormat[static_cast<int>(desc.buffer.scalar_type)];
+ if (f_expected_format != nullptr &&
+ strcmp(f_expected_format, py_view.format) != 0) {
+ throw RaiseBufferMismatchError(
+ absl::StrCat("Mismatched buffer format (received: ", py_view.format,
+ ", expected: ", f_expected_format, "): "),
+ py::handle(py_view.obj), desc);
+ }
+
+ // Verify shape, populating dynamic_dims while looping.
+ for (size_t i = 0; i < py_view.ndim; ++i) {
+ auto py_dim = py_view.shape[i];
+ auto f_dim = desc.dims[i];
+ if (f_dim < 0) {
+ // Dynamic.
+ dynamic_dims.push_back(py_dim);
+ } else if (py_dim != f_dim) {
+ // Mismatch.
+ throw RaiseBufferMismatchError(
+ absl::StrCat("Mismatched buffer dim (received: ", py_dim,
+ ", expected: ", f_dim, "): "),
+ py::handle(py_view.obj), desc);
+ }
+ }
+}
+
+} // namespace
+
+//------------------------------------------------------------------------------
+// FunctionAbi
+//------------------------------------------------------------------------------
+
+std::string FunctionAbi::DebugString() const {
+ RawSignatureParser p;
+ auto s = p.FunctionSignatureToString(raw_config_.signature);
+ if (!s) {
+ return "<FunctionAbi NO_DEBUG_INFO>";
+ }
+ return absl::StrCat("<FunctionAbi ", *s, ">");
+}
+
+std::unique_ptr<FunctionAbi> FunctionAbi::Create(
+ HalDevice& device, std::shared_ptr<HostTypeFactory> host_type_factory,
+ AttributeLookup lookup) {
+ auto abi =
+ absl::make_unique<FunctionAbi>(device, std::move(host_type_factory));
+
+ // Fetch key attributes for the raw ABI.
+ auto raw_version = lookup("fv");
+ auto raw_fsig_str = lookup("f");
+
+ // Validation.
+ if (!raw_fsig_str) {
+ throw RaiseValueError("No raw abi reflection metadata for function");
+ }
+ if (!raw_version || *raw_version != "1") {
+ throw RaiseValueError("Unsupported raw function ABI version");
+ }
+
+ // Parse signature.
+ abi->raw_config().signature = std::string(*raw_fsig_str);
+ RawSignatureParser raw_parser;
+ raw_parser.VisitInputs(*raw_fsig_str,
+ [&abi](const RawSignatureParser::Description& d) {
+ abi->raw_config().inputs.push_back(d);
+ });
+ raw_parser.VisitResults(*raw_fsig_str,
+ [&abi](const RawSignatureParser::Description& d) {
+ abi->raw_config().results.push_back(d);
+ });
+ if (raw_parser.GetError()) {
+ auto message = absl::StrCat(
+ "Error parsing raw ABI signature: ", *raw_parser.GetError(), " ('",
+ *raw_fsig_str, "')");
+ throw RaiseValueError(message.c_str());
+ }
+
+ // TODO(laurenzo): Detect sip ABI and add a translation layer.
+ return abi;
+}
+
+void FunctionAbi::RawPack(absl::Span<const Description> descs,
+ absl::Span<py::handle> py_args, VmVariantList& f_args,
+ bool writable) {
+ if (descs.size() != py_args.size()) {
+ throw RaiseValueError("Mismatched RawPack() input arity");
+ }
+
+ for (size_t i = 0, e = descs.size(); i < e; ++i) {
+ const Description& desc = descs[i];
+ switch (desc.type) {
+ case RawSignatureParser::Type::kBuffer:
+ PackBuffer(desc, py_args[i], f_args, writable);
+ break;
+ case RawSignatureParser::Type::kRefObject:
+ throw RaisePyError(PyExc_NotImplementedError,
+ "Ref objects not yet supported");
+ break;
+ default:
+ throw RaisePyError(PyExc_NotImplementedError,
+ "Unsupported argument type");
+ }
+ }
+}
+
+void FunctionAbi::RawUnpack(absl::Span<const Description> descs,
+ VmVariantList& f_results,
+ absl::Span<py::object> py_results) {
+ if (descs.size() != f_results.size() || descs.size() != py_results.size()) {
+ throw RaiseValueError("Mismatched RawUnpack() result arity");
+ }
+ for (size_t i = 0, e = descs.size(); i < e; ++i) {
+ const Description& desc = descs[i];
+ iree_vm_variant_t* f_result =
+ iree_vm_variant_list_get(f_results.raw_ptr(), i);
+ switch (desc.type) {
+ case RawSignatureParser::Type::kBuffer: {
+ iree_hal_buffer* raw_buffer = iree_hal_buffer_deref(&f_result->ref);
+ if (!raw_buffer) {
+ throw RaiseValueError("Could not deref result buffer (wrong type?)");
+ }
+ HalBuffer buffer = HalBuffer::RetainAndCreate(raw_buffer);
+ // TODO(laurenzo): In the case of dynamic dims, the full dims will
+ // need to be splied together based on known static dims and dynamic
+ // dims from a subsequent result.
+ absl::Span<const int> dims = absl::MakeSpan(desc.dims);
+ py_results[i] = host_type_factory_->CreateImmediateNdarray(
+ desc.buffer.scalar_type, dims, std::move(buffer));
+ break;
+ }
+ case RawSignatureParser::Type::kRefObject:
+ throw RaisePyError(PyExc_NotImplementedError,
+ "Ref objects not yet supported");
+ break;
+ default:
+ throw RaisePyError(PyExc_NotImplementedError,
+ "Unsupported argument type");
+ }
+ }
+}
+
+void FunctionAbi::AllocateResults(absl::Span<const Description> descs,
+ VmVariantList& f_args,
+ VmVariantList& f_results) {
+ if (f_args.size() != raw_config().inputs.size()) {
+ throw RaiseValueError("Mismatched AllocatResults() input arity");
+ }
+
+ for (size_t i = 0, e = descs.size(); i < e; ++i) {
+ const Description& desc = descs[i];
+ iree_device_size_t alloc_size =
+ AbiConstants::kScalarTypeSize[static_cast<int>(
+ desc.buffer.scalar_type)];
+ switch (desc.type) {
+ case RawSignatureParser::Type::kBuffer: {
+ for (auto dim : desc.dims) {
+ if (dim < 0) {
+ // If there is a dynamic dim, fallback to completely func allocated
+ // result. This is the worst case because it will force a
+ // pipeline stall.
+ // TODO(laurenzo): Invoke shape resolution function if available
+ // to allocate full result.
+ f_results.AppendNullRef();
+ }
+ alloc_size *= dim;
+ }
+
+ // Static cases are easy.
+ iree_hal_buffer_t* raw_buffer;
+ CheckApiStatus(iree_hal_allocator_allocate_buffer(
+ device_.allocator(),
+ static_cast<iree_hal_memory_type_t>(
+ IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL |
+ IREE_HAL_MEMORY_TYPE_HOST_VISIBLE),
+ IREE_HAL_BUFFER_USAGE_ALL, alloc_size, &raw_buffer),
+ "Error allocating host visible buffer");
+ iree_vm_ref_t buffer_ref = iree_hal_buffer_move_ref(raw_buffer);
+ CheckApiStatus(iree_vm_variant_list_append_ref_move(f_results.raw_ptr(),
+ &buffer_ref),
+ "Error moving buffer");
+ break;
+ }
+ case RawSignatureParser::Type::kRefObject:
+ throw RaisePyError(PyExc_NotImplementedError,
+ "Ref objects not yet supported");
+ break;
+ default:
+ throw RaisePyError(PyExc_NotImplementedError,
+ "Unsupported argument type");
+ }
+ }
+}
+
+void FunctionAbi::PackBuffer(const RawSignatureParser::Description& desc,
+ py::handle py_arg, VmVariantList& f_args,
+ bool writable) {
+ // Request a view of the buffer (use the raw python C API to avoid some
+ // allocation and copying at the pybind level).
+ Py_buffer py_view;
+ // Note that only C-Contiguous ND-arrays are presently supported, so
+ // only request that via PyBUF_ND. Long term, we should consult an
+ // "oracle" in the runtime to determine the precise required format and
+ // set flags accordingly (and fallback/copy on failure).
+ int flags = PyBUF_FORMAT | PyBUF_ND;
+ if (writable) {
+ flags |= PyBUF_WRITABLE;
+ }
+
+ // Acquire the backing buffer and setup RAII release.
+ if (PyObject_GetBuffer(py_arg.ptr(), &py_view, flags) != 0) {
+ // The GetBuffer call is required to set an appropriate error.
+ throw py::error_already_set();
+ }
+ PyBufferReleaser py_view_releaser(py_view);
+
+ // Whether the py object needs to be retained with the argument.
+ // Should be set to true if directly mapping, false if copied.
+ bool depends_on_pyobject = false;
+
+ // Verify compatibility.
+ absl::InlinedVector<int, 2> dynamic_dims;
+ MapBufferAttrs(py_view, desc, dynamic_dims);
+ if (!dynamic_dims.empty()) {
+ throw RaisePyError(PyExc_NotImplementedError,
+ "Dynamic argument dimensions not implemented");
+ }
+
+ // Allocate a HalBuffer.
+ // This is hard-coded to C-contiguous right now.
+ // TODO(laurenzo): Expand to other layouts as needed.
+ // TODO(laurenzo): Wrap and retain original buffer (depends_on_pyobject=true).
+ iree_hal_buffer_t* raw_buffer;
+ CheckApiStatus(iree_hal_allocator_allocate_buffer(
+ device_.allocator(),
+ static_cast<iree_hal_memory_type_t>(
+ IREE_HAL_MEMORY_TYPE_HOST_LOCAL |
+ IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE),
+ IREE_HAL_BUFFER_USAGE_ALL, py_view.len, &raw_buffer),
+ "Failed to allocate device visible buffer");
+ CheckApiStatus(
+ iree_hal_buffer_write_data(raw_buffer, 0, py_view.buf, py_view.len),
+ "Error writing to input buffer");
+ iree_vm_ref_t buffer_ref = iree_hal_buffer_move_ref(raw_buffer);
+ CheckApiStatus(
+ iree_vm_variant_list_append_ref_move(f_args.raw_ptr(), &buffer_ref),
+ "Error moving buffer");
+
+ // Only capture the reference to the exporting object (incrementing it)
+ // once guaranteed successful.
+ if (depends_on_pyobject) {
+ // Note for future implementation: there needs to be a place to stash
+ // references to be kept alive which back a buffer. This is likely an
+ // additional bag of refs returned from this function, which can then
+ // be attached to an invocation.
+ throw RaisePyError(PyExc_NotImplementedError,
+ "Dependent buffer arguments not implemented");
+ }
+}
+
+void SetupFunctionAbiBindings(pybind11::module m) {
+ py::class_<FunctionAbi, std::unique_ptr<FunctionAbi>>(m, "FunctionAbi")
+ .def(py::init(&PyCreateAbi))
+ .def("__repr__", &FunctionAbi::DebugString)
+ .def_property_readonly("raw_input_arity", &FunctionAbi::raw_input_arity)
+ .def_property_readonly("raw_result_arity", &FunctionAbi::raw_result_arity)
+ .def("raw_pack_inputs",
+ [](FunctionAbi* self, py::sequence py_args) {
+ return PyRawPack(self,
+ absl::MakeConstSpan(self->raw_config().inputs),
+ py_args, false /* writable */);
+ })
+ .def("allocate_results", &PyAllocateResults, py::arg("f_results"),
+ py::arg("static_alloc") = true)
+ .def("raw_unpack_results", &PyRawUnpackResults);
+}
+
+} // namespace python
+} // namespace iree
diff --git a/bindings/python/pyiree/rt/function_abi.h b/bindings/python/pyiree/rt/function_abi.h
new file mode 100644
index 0000000..62b3d0a
--- /dev/null
+++ b/bindings/python/pyiree/rt/function_abi.h
@@ -0,0 +1,117 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_BINDINGS_PYTHON_PYIREE_RT_FUNCTION_ABI_H_
+#define IREE_BINDINGS_PYTHON_PYIREE_RT_FUNCTION_ABI_H_
+
+#include <utility>
+#include <vector>
+
+#include "absl/container/inlined_vector.h"
+#include "absl/strings/string_view.h"
+#include "absl/types/optional.h"
+#include "absl/types/span.h"
+#include "bindings/python/pyiree/common/binding.h"
+#include "bindings/python/pyiree/rt/hal.h"
+#include "bindings/python/pyiree/rt/host_types.h"
+#include "bindings/python/pyiree/rt/vm.h"
+#include "iree/base/signature_mangle.h"
+
+namespace iree {
+namespace python {
+
+// Forward declarations.
+class HalDevice;
+
+// Instantiated with function attributes in order to process inputs/outputs.
+class FunctionAbi {
+ public:
+ using AttributeLookup =
+ std::function<absl::optional<absl::string_view>(absl::string_view)>;
+ FunctionAbi(HalDevice& device,
+ std::shared_ptr<HostTypeFactory> host_type_factory)
+ : device_(HalDevice::RetainAndCreate(device.raw_ptr())),
+ host_type_factory_(std::move(host_type_factory)) {}
+ virtual ~FunctionAbi() = default;
+
+ using Description = RawSignatureParser::Description;
+ using InputDescriptionVector = absl::InlinedVector<Description, 4>;
+ using ResultDescriptionVector = absl::InlinedVector<Description, 1>;
+
+ struct RawConfig {
+ InputDescriptionVector inputs;
+ ResultDescriptionVector results;
+
+ // The following are retained to aid debugging but may be empty if
+ // disabled.
+ std::string signature;
+ };
+
+ // Creates an instance based on the function attributes.
+ static std::unique_ptr<FunctionAbi> Create(
+ HalDevice& device, std::shared_ptr<HostTypeFactory> host_type_factory,
+ AttributeLookup lookup);
+
+ RawConfig& raw_config() { return raw_config_; }
+ int raw_input_arity() const { return raw_config_.inputs.size(); }
+ int raw_result_arity() const { return raw_config_.results.size(); }
+
+ // Raw packing. These always operate on the linear span of raw inputs and
+ // results. Some ABIs perform a higher level of mapping on top of this,
+ // which can be accessed via the non-prefixed Pack/Unpack methods.
+ // Given a span of descriptions, packs the given py_args into the span
+ // of function args. All spans must be of the same size.
+ void RawPack(absl::Span<const Description> descs,
+ absl::Span<py::handle> py_args, VmVariantList& args,
+ bool writable);
+
+ // Raw unpacks f_results into py_results.
+ // Note that this consumes entries in f_results as needed, leaving them
+ // as nullptr.
+ // Ordinarily, this will be invoked along with AllocateResults() but it
+ // is broken out for testing.
+ void RawUnpack(absl::Span<const Description> descs, VmVariantList& f_results,
+ absl::Span<py::object> py_results);
+
+ // Given bound function arguments (from RawPack or equiv) and signature
+ // descriptors, allocates results for the function invocation. For fully
+ // specified result types, this can be done purely by matching up
+ // reflection metadata and an oracle for determining layout. For dynamically
+ // shaped or data-dependent shaped results, the metadata about the function
+ // arguments may be required to generate additional allocation function calls.
+ // Finally, in truly data-dependent cases, some results may not be resolvable
+ // ahead of time, resulting in a nullptr in f_results. In such cases, the
+ // invocation must ensure proper barriers are in place to fully execute the
+ // function prior to delivering results to the user layer.
+ void AllocateResults(absl::Span<const Description> descs,
+ VmVariantList& f_args, VmVariantList& f_results);
+
+ // Gets the string representation.
+ std::string DebugString() const;
+
+ private:
+ void PackBuffer(const RawSignatureParser::Description& desc,
+ py::handle py_arg, VmVariantList& f_args, bool writable);
+
+ HalDevice device_;
+ std::shared_ptr<HostTypeFactory> host_type_factory_;
+ RawConfig raw_config_;
+};
+
+void SetupFunctionAbiBindings(pybind11::module m);
+
+} // namespace python
+} // namespace iree
+
+#endif // IREE_BINDINGS_PYTHON_PYIREE_RT_FUNCTION_ABI_H_
diff --git a/bindings/python/pyiree/rt/function_abi_test.py b/bindings/python/pyiree/rt/function_abi_test.py
new file mode 100644
index 0000000..204c281
--- /dev/null
+++ b/bindings/python/pyiree/rt/function_abi_test.py
@@ -0,0 +1,158 @@
+# Lint as: python3
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tests for the function abi."""
+
+import re
+
+from absl.testing import absltest
+
+import numpy as np
+from pyiree import rt
+
+ATTRS_1ARG_FLOAT32_10X128X64_TO_SINT32_32X8X64_V1 = (
+ ("fv", "1"),
+ # Equiv to:
+ # (Buffer<float32[10x128x64]>) -> (Buffer<sint32[32x8x64]>)
+ ("f", "I15!B11!d10d128d64R15!B11!t6d32d8d64"),
+)
+
+ATTRS_1ARG_FLOAT32_DYNX128X64_TO_SINT32_DYNX8X64_V1 = (
+ ("fv", "1"),
+ # Equiv to:
+ # (Buffer<float32[?x128x64]>) -> (Buffer<sint32[?x8x64]>)
+ ("f", "I15!B11!d-1d128d64R15!B11!t6d-1d8d64"),
+)
+
+
+class HostTypeFactory(absltest.TestCase):
+
+ def test_baseclass(self):
+ htf = rt.HostTypeFactory()
+ print(htf)
+
+
+class FunctionAbiTest(absltest.TestCase):
+
+ @classmethod
+ def setUpClass(cls):
+ super().setUpClass()
+ driver_names = rt.HalDriver.query()
+ print("DRIVER_NAMES =", driver_names)
+ cls.driver = rt.HalDriver.create("vulkan")
+ cls.device = cls.driver.create_default_device()
+
+ def setUp(self):
+ super().setUp()
+ self.htf = rt.HostTypeFactory.get_numpy()
+
+ def test_static_arg_success(self):
+ fabi = rt.FunctionAbi(self.device, self.htf,
+ ATTRS_1ARG_FLOAT32_10X128X64_TO_SINT32_32X8X64_V1)
+ print(fabi)
+ self.assertEqual(
+ "<FunctionAbi (Buffer<float32[10x128x64]>) -> "
+ "(Buffer<sint32[32x8x64]>)>", repr(fabi))
+ self.assertEqual(1, fabi.raw_input_arity)
+ self.assertEqual(1, fabi.raw_result_arity)
+
+ arg = np.zeros((10, 128, 64), dtype=np.float32)
+ packed = fabi.raw_pack_inputs([arg])
+ print(packed)
+ self.assertEqual("<VmVariantList(1): [HalBuffer(327680)]>", repr(packed))
+
+ def test_static_result_success(self):
+ fabi = rt.FunctionAbi(self.device, self.htf,
+ ATTRS_1ARG_FLOAT32_10X128X64_TO_SINT32_32X8X64_V1)
+ arg = np.zeros((10, 128, 64), dtype=np.float32)
+ f_args = fabi.raw_pack_inputs([arg])
+ f_results = fabi.allocate_results(f_args)
+ print(f_results)
+ self.assertEqual("<VmVariantList(1): [HalBuffer(65536)]>", repr(f_results))
+ py_result, = fabi.raw_unpack_results(f_results)
+ self.assertEqual(np.int32, py_result.dtype)
+ self.assertEqual((32, 8, 64), py_result.shape)
+
+ def test_dynamic_alloc_result_success(self):
+ fabi = rt.FunctionAbi(self.device, self.htf,
+ ATTRS_1ARG_FLOAT32_10X128X64_TO_SINT32_32X8X64_V1)
+ arg = np.zeros((10, 128, 64), dtype=np.float32)
+ f_args = fabi.raw_pack_inputs([arg])
+ f_results = fabi.allocate_results(f_args, static_alloc=False)
+ print(f_results)
+ self.assertEqual("<VmVariantList(0): []>", repr(f_results))
+
+ def test_dynamic_arg_success(self):
+ fabi = rt.FunctionAbi(self.device, self.htf,
+ ATTRS_1ARG_FLOAT32_DYNX128X64_TO_SINT32_DYNX8X64_V1)
+ print(fabi)
+ self.assertEqual(
+ "<FunctionAbi (Buffer<float32[?x128x64]>) -> "
+ "(Buffer<sint32[?x8x64]>)>", repr(fabi))
+ self.assertEqual(1, fabi.raw_input_arity)
+ self.assertEqual(1, fabi.raw_result_arity)
+
+ arg = np.zeros((10, 128, 64), dtype=np.float32)
+ with self.assertRaisesRegex(NotImplementedError,
+ "Dynamic argument dimensions not implemented"):
+ unused_packed = fabi.raw_pack_inputs([arg])
+ # TODO(laurenzo): Re-enable the following once implemented.
+ # print(packed)
+ # self.assertEqual(
+ # "<VmVariantList(1): [HalBuffer(327680, dynamic_dims=[10])]>",
+ # repr(packed))
+
+ def test_static_arg_rank_mismatch(self):
+ fabi = rt.FunctionAbi(self.device, self.htf,
+ ATTRS_1ARG_FLOAT32_10X128X64_TO_SINT32_32X8X64_V1)
+ print(fabi)
+ arg = np.zeros((10,), dtype=np.float32)
+ with self.assertRaisesRegex(
+ ValueError,
+ re.escape("Mismatched buffer rank (received: 1, expected: 3)")):
+ fabi.raw_pack_inputs([arg])
+
+ def test_static_arg_eltsize_mismatch(self):
+ fabi = rt.FunctionAbi(self.device, self.htf,
+ ATTRS_1ARG_FLOAT32_10X128X64_TO_SINT32_32X8X64_V1)
+ print(fabi)
+ arg = np.zeros((10, 128, 64), dtype=np.float64)
+ with self.assertRaisesRegex(
+ ValueError,
+ re.escape("Mismatched buffer item size (received: 8, expected: 4)")):
+ fabi.raw_pack_inputs([arg])
+
+ def test_static_arg_dtype_mismatch(self):
+ fabi = rt.FunctionAbi(self.device, self.htf,
+ ATTRS_1ARG_FLOAT32_10X128X64_TO_SINT32_32X8X64_V1)
+ print(fabi)
+ arg = np.zeros((10, 128, 64), dtype=np.int32)
+ with self.assertRaisesRegex(
+ ValueError,
+ re.escape("Mismatched buffer format (received: i, expected: f)")):
+ fabi.raw_pack_inputs([arg])
+
+ def test_static_arg_static_dim_mismatch(self):
+ fabi = rt.FunctionAbi(self.device, self.htf,
+ ATTRS_1ARG_FLOAT32_10X128X64_TO_SINT32_32X8X64_V1)
+ print(fabi)
+ arg = np.zeros((10, 32, 64), dtype=np.float32)
+ with self.assertRaisesRegex(
+ ValueError,
+ re.escape("Mismatched buffer dim (received: 32, expected: 128)")):
+ fabi.raw_pack_inputs([arg])
+
+
+if __name__ == "__main__":
+ absltest.main()
diff --git a/bindings/python/pyiree/rt/hal.cc b/bindings/python/pyiree/rt/hal.cc
new file mode 100644
index 0000000..49d5f58
--- /dev/null
+++ b/bindings/python/pyiree/rt/hal.cc
@@ -0,0 +1,177 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "bindings/python/pyiree/rt/hal.h"
+
+#include "absl/container/inlined_vector.h"
+#include "iree/hal/api.h"
+
+namespace iree {
+namespace python {
+
+namespace {
+
+class HalMappedMemory {
+ public:
+ HalMappedMemory(iree_hal_mapped_memory_t mapped_memory,
+ iree_hal_buffer_view_t* bv)
+ : mapped_memory_(mapped_memory), bv_(bv) {
+ iree_hal_buffer_view_retain(bv_);
+ }
+ ~HalMappedMemory() {
+ if (bv_) {
+ iree_hal_buffer_t* buffer = iree_hal_buffer_view_buffer(bv_);
+ CHECK_EQ(iree_hal_buffer_unmap(buffer, &mapped_memory_), IREE_STATUS_OK);
+ iree_hal_buffer_view_release(bv_);
+ }
+ }
+ HalMappedMemory(HalMappedMemory&& other)
+ : mapped_memory_(other.mapped_memory_), bv_(other.bv_) {
+ other.bv_ = nullptr;
+ }
+
+ static HalMappedMemory Create(HalBufferView& bv) {
+ iree_hal_buffer_t* buffer = iree_hal_buffer_view_buffer(bv.raw_ptr());
+ iree_device_size_t byte_length = iree_hal_buffer_byte_length(buffer);
+ iree_hal_mapped_memory_t mapped_memory;
+ CheckApiStatus(iree_hal_buffer_map(buffer, IREE_HAL_MEMORY_ACCESS_READ,
+ 0 /* element_offset */, byte_length,
+ &mapped_memory),
+ "Could not map memory");
+ return HalMappedMemory(mapped_memory, bv.raw_ptr());
+ }
+
+ py::buffer_info ToBufferInfo() {
+ iree_shape_t shape;
+ CheckApiStatus(iree_hal_buffer_view_shape(bv_, &shape),
+ "Error getting buffer view shape");
+ int8_t element_size = iree_hal_buffer_view_element_size(bv_);
+ absl::InlinedVector<py::ssize_t, IREE_SHAPE_MAX_RANK> dims;
+ dims.resize(shape.rank);
+ for (int i = 0; i < shape.rank; ++i) {
+ dims[i] = shape.dims[i];
+ }
+ absl::InlinedVector<py::ssize_t, IREE_SHAPE_MAX_RANK> strides;
+ strides.resize(shape.rank);
+ if (!strides.empty()) {
+ strides[shape.rank - 1] = element_size;
+ for (int i = shape.rank - 2; i >= 0; --i) {
+ strides[i] = strides[i + 1] * shape.dims[i + 1];
+ }
+ }
+
+ // TODO(laurenzo): We need to figure out how to propagate dtype in the
+ // buffer view.
+ return py::buffer_info(
+ mapped_memory_.contents.data, element_size,
+ py::format_descriptor<float>::format(), // TODO(laurenzo): DTYPE!
+ shape.rank, dims, strides);
+ }
+
+ private:
+ iree_hal_mapped_memory_t mapped_memory_;
+ iree_hal_buffer_view_t* bv_;
+};
+
+} // namespace
+
+//------------------------------------------------------------------------------
+// HalDriver
+//------------------------------------------------------------------------------
+
+std::vector<std::string> HalDriver::Query() {
+ iree_string_view_t* driver_names;
+ iree_host_size_t driver_count;
+ CheckApiStatus(iree_hal_driver_registry_query_available_drivers(
+ IREE_ALLOCATOR_SYSTEM, &driver_names, &driver_count),
+ "Error querying drivers");
+
+ std::vector<std::string> drivers;
+ drivers.resize(driver_count);
+ for (iree_host_size_t i = 0; i < driver_count; ++i) {
+ drivers[i] = std::string(driver_names[i].data, driver_names[i].size);
+ }
+ free(driver_names);
+ return drivers;
+}
+
+HalDriver HalDriver::Create(const std::string& driver_name) {
+ iree_hal_driver_t* driver;
+ CheckApiStatus(iree_hal_driver_registry_create_driver(
+ {driver_name.data(), driver_name.size()},
+ IREE_ALLOCATOR_SYSTEM, &driver),
+ "Error creating driver");
+ return HalDriver::CreateRetained(driver);
+}
+
+HalDevice HalDriver::CreateDefaultDevice() {
+ iree_hal_device_t* device;
+ CheckApiStatus(iree_hal_driver_create_default_device(
+ raw_ptr(), IREE_ALLOCATOR_SYSTEM, &device),
+ "Error creating default device");
+ return HalDevice::CreateRetained(device);
+}
+
+void SetupHalBindings(pybind11::module m) {
+ // Enums.
+ py::enum_<iree_hal_memory_type_t>(m, "MemoryType")
+ .value("NONE", IREE_HAL_MEMORY_TYPE_NONE)
+ .value("TRANSIENT", IREE_HAL_MEMORY_TYPE_TRANSIENT)
+ .value("HOST_VISIBLE", IREE_HAL_MEMORY_TYPE_HOST_VISIBLE)
+ .value("HOST_COHERENT", IREE_HAL_MEMORY_TYPE_HOST_COHERENT)
+ .value("HOST_CACHED", IREE_HAL_MEMORY_TYPE_HOST_CACHED)
+ .value("HOST_LOCAL", IREE_HAL_MEMORY_TYPE_HOST_LOCAL)
+ .value("DEVICE_VISIBLE", IREE_HAL_MEMORY_TYPE_HOST_VISIBLE)
+ .value("DEVICE_LOCAL", IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL)
+ .export_values();
+ py::enum_<iree_hal_buffer_usage_t>(m, "BufferUsage")
+ .value("NONE", IREE_HAL_BUFFER_USAGE_NONE)
+ .value("CONSTANT", IREE_HAL_BUFFER_USAGE_CONSTANT)
+ .value("TRANSFER", IREE_HAL_BUFFER_USAGE_TRANSFER)
+ .value("MAPPING", IREE_HAL_BUFFER_USAGE_MAPPING)
+ .value("DISPATCH", IREE_HAL_BUFFER_USAGE_DISPATCH)
+ .value("ALL", IREE_HAL_BUFFER_USAGE_ALL)
+ .export_values();
+ py::enum_<iree_hal_memory_access_t>(m, "MemoryAccess")
+ .value("NONE", IREE_HAL_MEMORY_ACCESS_NONE)
+ .value("READ", IREE_HAL_MEMORY_ACCESS_READ)
+ .value("WRITE", IREE_HAL_MEMORY_ACCESS_WRITE)
+ .value("DISCARD", IREE_HAL_MEMORY_ACCESS_DISCARD)
+ .value("DISCARD_WRITE", IREE_HAL_MEMORY_ACCESS_DISCARD_WRITE)
+ .value("ALL", IREE_HAL_MEMORY_ACCESS_ALL)
+ .export_values();
+
+ py::class_<HalDevice>(m, "HalDevice");
+ py::class_<HalDriver>(m, "HalDriver")
+ .def_static("query", &HalDriver::Query)
+ .def_static("create", &HalDriver::Create, py::arg("driver_name"))
+ .def("create_default_device", &HalDriver::CreateDefaultDevice);
+
+ py::class_<HalShape>(m, "Shape").def(py::init(&HalShape::FromIntVector));
+ py::class_<HalBufferView>(m, "BufferView")
+ .def("map", HalMappedMemory::Create);
+ py::class_<HalMappedMemory>(m, "MappedMemory", py::buffer_protocol())
+ .def_buffer(&HalMappedMemory::ToBufferInfo);
+ py::class_<HalBuffer>(m, "HalBuffer")
+ .def_static("allocate_heap", &HalBuffer::AllocateHeapBuffer,
+ py::arg("memory_type"), py::arg("usage"),
+ py::arg("allocation_size"))
+ .def("fill_zero", &HalBuffer::FillZero, py::arg("byte_offset"),
+ py::arg("byte_length"))
+ .def("create_view", &HalBuffer::CreateView, py::arg("shape"),
+ py::arg("element_size"));
+}
+
+} // namespace python
+} // namespace iree
diff --git a/bindings/python/pyiree/rt/hal.h b/bindings/python/pyiree/rt/hal.h
new file mode 100644
index 0000000..918193a
--- /dev/null
+++ b/bindings/python/pyiree/rt/hal.h
@@ -0,0 +1,136 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_BINDINGS_PYTHON_PYIREE_RT_HAL_H_
+#define IREE_BINDINGS_PYTHON_PYIREE_RT_HAL_H_
+
+#include "bindings/python/pyiree/common/binding.h"
+#include "bindings/python/pyiree/common/status_utils.h"
+#include "iree/hal/api.h"
+
+namespace iree {
+namespace python {
+
+//------------------------------------------------------------------------------
+// Retain/release bindings
+//------------------------------------------------------------------------------
+
+template <>
+struct ApiPtrAdapter<iree_hal_driver_t> {
+ static void Retain(iree_hal_driver_t* d) { iree_hal_driver_retain(d); }
+ static void Release(iree_hal_driver_t* d) { iree_hal_driver_release(d); }
+};
+
+template <>
+struct ApiPtrAdapter<iree_hal_device_t> {
+ static void Retain(iree_hal_device_t* d) { iree_hal_device_retain(d); }
+ static void Release(iree_hal_device_t* d) { iree_hal_device_release(d); }
+};
+
+template <>
+struct ApiPtrAdapter<iree_hal_buffer_t> {
+ static void Retain(iree_hal_buffer_t* b) { iree_hal_buffer_retain(b); }
+ static void Release(iree_hal_buffer_t* b) { iree_hal_buffer_release(b); }
+};
+
+template <>
+struct ApiPtrAdapter<iree_hal_buffer_view_t> {
+ static void Retain(iree_hal_buffer_view_t* bv) {
+ iree_hal_buffer_view_retain(bv);
+ }
+ static void Release(iree_hal_buffer_view_t* bv) {
+ iree_hal_buffer_view_release(bv);
+ }
+};
+
+//------------------------------------------------------------------------------
+// ApiRefCounted types
+//------------------------------------------------------------------------------
+
+class HalDevice : public ApiRefCounted<HalDevice, iree_hal_device_t> {
+ public:
+ iree_hal_allocator_t* allocator() {
+ return iree_hal_device_allocator(raw_ptr());
+ }
+};
+
+class HalDriver : public ApiRefCounted<HalDriver, iree_hal_driver_t> {
+ public:
+ static std::vector<std::string> Query();
+ static HalDriver Create(const std::string& driver_name);
+
+ HalDevice CreateDefaultDevice();
+};
+
+struct HalShape {
+ public:
+ static HalShape FromIntVector(std::vector<int32_t> indices) {
+ if (indices.size() > IREE_SHAPE_MAX_RANK) {
+ throw RaiseValueError("Shape exceeded maximum rank");
+ }
+ HalShape s;
+ s.s.rank = indices.size();
+ for (size_t i = 0, e = indices.size(); i < e; ++i) {
+ s.s.dims[i] = indices[i];
+ }
+ return s;
+ }
+
+ iree_shape_t s;
+};
+
+class HalBufferView
+ : public ApiRefCounted<HalBufferView, iree_hal_buffer_view_t> {
+ public:
+};
+
+class HalBuffer : public ApiRefCounted<HalBuffer, iree_hal_buffer_t> {
+ public:
+ static HalBuffer AllocateHeapBuffer(int32_t memory_type, int32_t usage,
+ iree_host_size_t allocation_size) {
+ iree_hal_buffer_t* buffer = nullptr;
+ CheckApiStatus(
+ iree_hal_heap_buffer_allocate(
+ static_cast<iree_hal_memory_type_t>(memory_type),
+ static_cast<iree_hal_buffer_usage_t>(usage), allocation_size,
+ IREE_ALLOCATOR_SYSTEM, IREE_ALLOCATOR_SYSTEM, &buffer),
+ "Error allocating heap buffer");
+ return HalBuffer::CreateRetained(buffer);
+ }
+
+ iree_device_size_t byte_length() const {
+ return iree_hal_buffer_byte_length(raw_ptr());
+ }
+
+ void FillZero(iree_device_size_t byte_offset,
+ iree_device_size_t byte_length) {
+ CheckApiStatus(iree_hal_buffer_zero(raw_ptr(), byte_offset, byte_length),
+ "Error zero filling buffer");
+ }
+
+ HalBufferView CreateView(HalShape& shape, size_t element_size) {
+ iree_hal_buffer_view_t* bv;
+ CheckApiStatus(iree_hal_buffer_view_create(raw_ptr(), shape.s, element_size,
+ IREE_ALLOCATOR_SYSTEM, &bv),
+ "Error creating buffer view");
+ return HalBufferView::CreateRetained(bv);
+ }
+};
+
+void SetupHalBindings(pybind11::module m);
+
+} // namespace python
+} // namespace iree
+
+#endif // IREE_BINDINGS_PYTHON_PYIREE_RT_HAL_H_
diff --git a/bindings/python/pyiree/rt/hal_test.py b/bindings/python/pyiree/rt/hal_test.py
new file mode 100644
index 0000000..b7ab59b
--- /dev/null
+++ b/bindings/python/pyiree/rt/hal_test.py
@@ -0,0 +1,56 @@
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl.testing import absltest
+
+import numpy as np
+from pyiree import rt
+
+
+class HalTest(absltest.TestCase):
+
+ def testEnums(self):
+ print("MemoryType =", rt.MemoryType)
+ print("HOST_VISIBLE =", int(rt.MemoryType.HOST_VISIBLE))
+
+ def testAllocateHeap(self):
+ b = rt.HalBuffer.allocate_heap(
+ memory_type=int(rt.MemoryType.HOST_LOCAL),
+ usage=int(rt.BufferUsage.ALL),
+ allocation_size=4096)
+ self.assertIsNot(b, None)
+ b.fill_zero(0, 4096)
+ shape = rt.Shape([1, 1024])
+ unused_bv = b.create_view(shape, 4)
+
+ def testStrideCalculation(self):
+ b = rt.HalBuffer.allocate_heap(
+ memory_type=int(rt.MemoryType.HOST_LOCAL),
+ usage=int(rt.BufferUsage.ALL),
+ allocation_size=4096)
+ self.assertIsNot(b, None)
+ b.fill_zero(0, 4096)
+ shape = rt.Shape([16, 1, 8, 4, 2])
+ bv = b.create_view(shape, 4)
+ self.assertEqual(
+ np.array(bv.map()).strides,
+ (1 * 8 * 4 * 2 * 4, 8 * 4 * 2 * 4, 4 * 2 * 4, 2 * 4, 4))
+
+
+if __name__ == "__main__":
+ absltest.main()
diff --git a/bindings/python/pyiree/rt/host_types.cc b/bindings/python/pyiree/rt/host_types.cc
new file mode 100644
index 0000000..8be929c
--- /dev/null
+++ b/bindings/python/pyiree/rt/host_types.cc
@@ -0,0 +1,181 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "bindings/python/pyiree/rt/host_types.h"
+
+#include "absl/container/inlined_vector.h"
+#include "absl/memory/memory.h"
+#include "bindings/python/pyiree/common/status_utils.h"
+#include "bindings/python/pyiree/rt/hal.h"
+#include "iree/base/signature_mangle.h"
+#include "pybind11/numpy.h"
+
+namespace iree {
+namespace python {
+
+const std::array<const char*, static_cast<unsigned>(
+ AbiConstants::ScalarType::kMaxScalarType) +
+ 1>
+ kScalarTypePyFormat = {
+ "f", // kIeeeFloat32 = 0,
+ nullptr, // kIeeeFloat16 = 1,
+ "d", // kIeeeFloat64 = 2,
+ nullptr, // kGoogleBfloat16 = 3,
+ "b", // kSint8 = 4,
+ "h", // kSint16 = 5,
+ "i", // kSint32 = 6,
+ "q", // kSint64 = 7,
+ "c", // kUint8 = 8,
+ "H", // kUint16 = 9,
+ "I", // kUint32 = 10,
+ "Q", // kUint64 = 11,
+};
+static_assert(kScalarTypePyFormat.size() ==
+ AbiConstants::kScalarTypeSize.size(),
+ "Mismatch kScalarTypePyFormat");
+
+namespace {
+
+class PyMappedMemory {
+ public:
+ struct Description {
+ size_t element_size;
+ const char* format;
+ absl::InlinedVector<py::ssize_t, 4> dims;
+ absl::InlinedVector<py::ssize_t, 4> strides;
+
+ static Description ForNdarray(AbiConstants::ScalarType scalar_type,
+ absl::Span<const int> dims) {
+ unsigned scalar_type_i = static_cast<unsigned>(scalar_type);
+ if (scalar_type_i >
+ static_cast<unsigned>(AbiConstants::ScalarType::kMaxScalarType)) {
+ throw RaiseValueError("Illegal ScalarType");
+ }
+
+ Description d;
+ d.element_size = AbiConstants::kScalarTypeSize[scalar_type_i];
+ d.format = kScalarTypePyFormat[scalar_type_i];
+ if (!d.format) {
+ throw RaisePyError(PyExc_NotImplementedError,
+ "Unimplemented ScalarType");
+ }
+ if (!dims.empty()) {
+ d.dims.resize(dims.size());
+ d.strides.resize(dims.size());
+
+ for (size_t i = 0, e = dims.size(); i < e; ++i) {
+ d.dims[i] = dims[i];
+ }
+ d.strides[dims.size() - 1] = d.element_size;
+ for (int i = dims.size() - 2; i >= 0; --i) {
+ d.strides[i] = d.strides[i + 1] * dims[i + 1];
+ }
+ }
+ return d;
+ }
+ };
+
+ PyMappedMemory(Description desc, iree_hal_mapped_memory_t mapped_memory,
+ HalBuffer buffer)
+ : desc_(std::move(desc)),
+ mapped_memory_(mapped_memory),
+ buf_(std::move(buffer)) {}
+ ~PyMappedMemory() {
+ if (buf_) {
+ CheckApiStatus(iree_hal_buffer_unmap(buf_.raw_ptr(), &mapped_memory_),
+ "Error unmapping memory");
+ }
+ }
+ PyMappedMemory(PyMappedMemory&& other)
+ : mapped_memory_(other.mapped_memory_), buf_(std::move(other.buf_)) {}
+
+ const Description& desc() const { return desc_; }
+
+ static std::unique_ptr<PyMappedMemory> Read(Description desc,
+ HalBuffer buffer) {
+ iree_device_size_t byte_length =
+ iree_hal_buffer_byte_length(buffer.raw_ptr());
+ iree_hal_mapped_memory_t mapped_memory;
+ CheckApiStatus(iree_hal_buffer_map(
+ buffer.raw_ptr(), IREE_HAL_MEMORY_ACCESS_READ,
+ 0 /* element_offset */, byte_length, &mapped_memory),
+ "Could not map memory");
+ return absl::make_unique<PyMappedMemory>(std::move(desc), mapped_memory,
+ std::move(buffer));
+ }
+
+ py::buffer_info ToBufferInfo() {
+ // TODO(laurenzo): py::buffer_info is a heavy-weight way to get the
+ // buffer. See about implementing the lower level buffer protocol.
+ // Unfortunately, this part of the pybind C++ API is all defined in terms
+ // of std::vector, making it less efficient than necessary.
+ return py::buffer_info(mapped_memory_.contents.data, desc_.element_size,
+ desc_.format, desc_.dims.size(), desc_.dims,
+ desc_.strides);
+ }
+
+ private:
+ Description desc_;
+ iree_hal_mapped_memory_t mapped_memory_;
+ HalBuffer buf_;
+};
+
+class NumpyHostTypeFactory : public HostTypeFactory {
+ py::object CreateImmediateNdarray(AbiConstants::ScalarType element_type,
+ absl::Span<const int> dims,
+ HalBuffer buffer) override {
+ auto mapped_memory = PyMappedMemory::Read(
+ PyMappedMemory::Description::ForNdarray(element_type, dims),
+ std::move(buffer));
+ // Since an immediate ndarray was requested, we can just return a native
+ // ndarray directly (versus a proxy that needs to lazily map on access).
+ auto buffer_info = mapped_memory->ToBufferInfo();
+ auto py_mapped_memory = py::cast(mapped_memory.release(),
+ py::return_value_policy::take_ownership);
+ return py::array(py::dtype(buffer_info), buffer_info.shape,
+ buffer_info.strides, buffer_info.ptr,
+ std::move(py_mapped_memory) /* base */);
+ }
+};
+
+} // namespace
+
+//------------------------------------------------------------------------------
+// HostTypeFactory
+//------------------------------------------------------------------------------
+
+std::shared_ptr<HostTypeFactory> HostTypeFactory::GetNumpyFactory() {
+ static auto global_instance = std::make_shared<NumpyHostTypeFactory>();
+ return global_instance;
+}
+
+py::object HostTypeFactory::CreateImmediateNdarray(
+ AbiConstants::ScalarType element_type, absl::Span<const int> dims,
+ HalBuffer buffer) {
+ throw RaisePyError(PyExc_NotImplementedError,
+ "CreateImmediateNdarray not implemented");
+}
+
+void SetupHostTypesBindings(pybind11::module m) {
+ py::class_<HostTypeFactory, std::shared_ptr<HostTypeFactory>>(
+ m, "HostTypeFactory")
+ .def(py::init<>())
+ .def_static("get_numpy", &HostTypeFactory::GetNumpyFactory);
+ py::class_<PyMappedMemory, std::unique_ptr<PyMappedMemory>>(
+ m, "PyMappedMemory", py::buffer_protocol())
+ .def_buffer(&PyMappedMemory::ToBufferInfo);
+}
+
+} // namespace python
+} // namespace iree
diff --git a/bindings/python/pyiree/rt/host_types.h b/bindings/python/pyiree/rt/host_types.h
new file mode 100644
index 0000000..00a9c42
--- /dev/null
+++ b/bindings/python/pyiree/rt/host_types.h
@@ -0,0 +1,56 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_BINDINGS_PYTHON_PYIREE_RT_HOST_TYPES_H_
+#define IREE_BINDINGS_PYTHON_PYIREE_RT_HOST_TYPES_H_
+
+#include <array>
+
+#include "absl/types/span.h"
+#include "bindings/python/pyiree/common/binding.h"
+#include "bindings/python/pyiree/rt/hal.h"
+#include "iree/base/signature_mangle.h"
+
+namespace iree {
+namespace python {
+
+extern const std::array<
+ const char*,
+ static_cast<unsigned>(AbiConstants::ScalarType::kMaxScalarType) + 1>
+ kScalarTypePyFormat;
+
+class HostTypeFactory {
+ public:
+ virtual ~HostTypeFactory() = default;
+
+ // Creates a default implementation which interops with numpy.
+ static std::shared_ptr<HostTypeFactory> GetNumpyFactory();
+
+ // Creates a C-contiguous ndarray of the given element_type/dims and backed
+ // by the given buffer. The resulting array has no synchronization and is
+ // available for use immediately.
+ virtual py::object CreateImmediateNdarray(
+ AbiConstants::ScalarType element_type, absl::Span<const int> dims,
+ HalBuffer buffer);
+
+ // TODO(laurenzo): Add a CreateDelayedNdarray() which is conditioned on
+ // a semaphore. This is actually what should be used for async results.
+};
+
+void SetupHostTypesBindings(pybind11::module m);
+
+} // namespace python
+} // namespace iree
+
+#endif // IREE_BINDINGS_PYTHON_PYIREE_RT_HOST_TYPES_H_
diff --git a/bindings/python/pyiree/rt/initialize_module.cc b/bindings/python/pyiree/rt/initialize_module.cc
new file mode 100644
index 0000000..a5dea72
--- /dev/null
+++ b/bindings/python/pyiree/rt/initialize_module.cc
@@ -0,0 +1,118 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <mutex> // NOLINT
+
+#include "bindings/python/pyiree/common/binding.h"
+#include "bindings/python/pyiree/common/status_utils.h"
+#include "bindings/python/pyiree/rt/function_abi.h"
+#include "bindings/python/pyiree/rt/hal.h"
+#include "bindings/python/pyiree/rt/host_types.h"
+#include "bindings/python/pyiree/rt/vm.h"
+#include "iree/base/initializer.h"
+#include "iree/base/tracing.h"
+#include "wtf/event.h"
+#include "wtf/macros.h"
+
+namespace iree {
+namespace python {
+
+namespace {
+
+// Wrapper around wtf::ScopedEvent to make it usable as a python context
+// object.
+class PyScopedEvent {
+ public:
+ PyScopedEvent(std::string name_spec)
+ : scoped_event_(InternEvent(std::move(name_spec))) {}
+
+ bool Enter() {
+ if (scoped_event_) {
+ scoped_event_->Enter();
+ return true;
+ }
+ return false;
+ }
+
+ void Exit(py::args args) {
+ if (scoped_event_) scoped_event_->Leave();
+ }
+
+ private:
+ static ::wtf::ScopedEvent<>* InternEvent(std::string name_spec) {
+ if (!::wtf::kMasterEnable) return nullptr;
+ std::lock_guard<std::mutex> lock(mu_);
+ auto it = scoped_event_intern_.find(name_spec);
+ if (it == scoped_event_intern_.end()) {
+ // Name spec must live forever.
+ std::string* dup_name_spec = new std::string(std::move(name_spec));
+ // So must the event.
+ auto scoped_event = new ::wtf::ScopedEvent<>(dup_name_spec->c_str());
+ scoped_event_intern_.insert(std::make_pair(*dup_name_spec, scoped_event));
+ return scoped_event;
+ } else {
+ return it->second;
+ }
+ }
+
+ static std::mutex mu_;
+ static std::unordered_map<std::string, ::wtf::ScopedEvent<>*>
+ scoped_event_intern_;
+ ::wtf::ScopedEvent<>* scoped_event_;
+};
+
+std::mutex PyScopedEvent::mu_;
+std::unordered_map<std::string, ::wtf::ScopedEvent<>*>
+ PyScopedEvent::scoped_event_intern_;
+
+void SetupTracingBindings(pybind11::module m) {
+ m.def("enable_thread", []() { WTF_AUTO_THREAD_ENABLE(); });
+ m.def("is_available", []() { return IsTracingAvailable(); });
+ m.def(
+ "flush",
+ [](absl::optional<std::string> explicit_trace_path) {
+ absl::optional<absl::string_view> sv_path;
+ if (explicit_trace_path) sv_path = explicit_trace_path;
+ FlushTrace(explicit_trace_path);
+ },
+ py::arg("explicit_trace_path") = absl::optional<absl::string_view>());
+ m.def(
+ "autoflush",
+ [](float period) { StartTracingAutoFlush(absl::Seconds(period)); },
+ py::arg("period") = 5.0f);
+ m.def("stop", []() { StopTracing(); });
+
+ py::class_<PyScopedEvent>(m, "ScopedEvent")
+ .def(py::init<std::string>())
+ .def("__enter__", &PyScopedEvent::Enter)
+ .def("__exit__", &PyScopedEvent::Exit);
+}
+
+} // namespace
+
+PYBIND11_MODULE(binding, m) {
+ IREE_RUN_MODULE_INITIALIZERS();
+
+ m.doc() = "IREE Binding Backend Helpers";
+ SetupFunctionAbiBindings(m);
+ SetupHostTypesBindings(m);
+ SetupHalBindings(m);
+ SetupVmBindings(m);
+
+ auto tracing_m = m.def_submodule("tracing", "IREE tracing api");
+ SetupTracingBindings(tracing_m);
+}
+
+} // namespace python
+} // namespace iree
diff --git a/bindings/python/pyiree/rt/system_api.py b/bindings/python/pyiree/rt/system_api.py
new file mode 100644
index 0000000..e4089a9
--- /dev/null
+++ b/bindings/python/pyiree/rt/system_api.py
@@ -0,0 +1,262 @@
+# Lint as: python3
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Top-level python system API.
+
+This facility layers on top of the underlying binding native facilities and
+exposes them in a way that allows general operation against contexts, modules
+and functions.
+"""
+
+# pylint: disable=protected-access
+# pylint: disable=unused-argument
+# pylint: disable=g-explicit-length-test
+
+__all__ = ["load_module", "load_modules", "Config", "SystemContext"]
+
+import os
+import sys
+
+from typing import Optional, Sequence, Tuple
+
+from . import binding as _binding
+
+# Typing aliases (largely used for documentation).
+AnyModule = _binding.VmModule
+
+# Environment key for a comma-delimitted list of drivers to try to load.
+PREFERRED_DRIVER_ENV_KEY = "IREE_DEFAULT_DRIVER"
+
+# Default value for IREE_DRIVER
+DEFAULT_IREE_DRIVER_VALUE = "vulkan,interpreter"
+
+
+def _create_default_iree_driver(
+ driver_names: Optional[Sequence[str]] = None) -> _binding.HalDriver:
+ """Returns a default driver based on environment settings."""
+ # TODO(laurenzo): Ideally this should take a module and join any explicitly
+ # provided driver list with environmental constraints and what the module
+ # was compiled for.
+ if driver_names is None:
+ # Read from environment.
+ driver_names = os.environ.get(PREFERRED_DRIVER_ENV_KEY)
+ if driver_names is None:
+ driver_names = DEFAULT_IREE_DRIVER_VALUE
+ driver_names = driver_names.split(",")
+ available_driver_names = _binding.HalDriver.query()
+ driver_exceptions = {}
+ for driver_name in driver_names:
+ if driver_name not in available_driver_names:
+ print(
+ "Could not create driver %s (not registered)" % driver_name,
+ file=sys.stderr)
+ continue
+ try:
+ driver = _binding.HalDriver.create(driver_name)
+ # TODO(laurenzo): Remove these prints to stderr (for now, more information
+ # is better and there is no better way to report it yet).
+ except Exception as ex: # pylint: disable=broad-except
+ print(
+ "Could not create default driver %s: %r" % (driver_name, ex),
+ file=sys.stderr)
+ driver_exceptions[driver_name] = ex
+ print("Created IREE driver %s: %r" % (driver_name, driver), file=sys.stderr)
+ return driver
+
+ # All failed.
+ raise RuntimeError("Could not create any requested driver "
+ "%r (available=%r) : %r" %
+ (driver_names, available_driver_names, driver_exceptions))
+
+
+class Config:
+ """System configuration."""
+
+ driver: _binding.HalDriver
+ device: _binding.HalDevice
+ vm_instance: _binding.VmInstance
+ host_type_factory: _binding.HostTypeFactory
+ default_modules: Tuple[AnyModule]
+
+ def __init__(self, driver_name: Optional[str] = None):
+ self.vm_instance = _binding.VmInstance()
+ self.driver = _create_default_iree_driver(
+ driver_name.split(",") if driver_name is not None else None)
+ self.device = self.driver.create_default_device()
+ hal_module = _binding.create_hal_module(self.device)
+ self.host_type_factory = _binding.HostTypeFactory.get_numpy()
+ self.default_modules = (hal_module,)
+
+
+_global_config = None
+
+
+def _get_global_config():
+ global _global_config
+ if _global_config is None:
+ _global_config = Config()
+ return _global_config
+
+
+class BoundFunction:
+ """Wraps a VmFunction, VmContext and ABI into a pythonic function."""
+
+ def __init__(self, context: "SystemContext",
+ vm_function: _binding.VmFunction):
+ self._context = context
+ self._vm_function = vm_function
+ self._abi = context.create_function_abi(vm_function)
+
+ def __call__(self, *args):
+ # NOTE: This is just doing sync dispatch right now. In the future,
+ # this should default to async and potentially have some kind of policy
+ # flag that can allow it to be overriden.
+ inputs = self._abi.raw_pack_inputs(args)
+ results = self._abi.allocate_results(inputs, static_alloc=False)
+ self._context._vm_context.invoke(self._vm_function, inputs, results)
+ unpacked_results = self._abi.raw_unpack_results(results)
+ # TODO(laurenzo): When switching from 'raw' to structured pack/unpack,
+ # the ABI should take care of this one-arg special case.
+ if len(unpacked_results) == 1:
+ return unpacked_results[0]
+ elif len(unpacked_results) == 0:
+ return None
+ else:
+ return unpacked_results
+
+ def __repr__(self):
+ return "<BoundFunction %r (%r)>" % (
+ self._abi,
+ self._vm_function,
+ )
+
+
+class BoundModule:
+ """Wraps a VmModule with its context and provides nice python accessors.
+
+ Resolves item access (["foo"]) as function resolution.
+ """
+
+ def __init__(self, context: "SystemContext", vm_module: AnyModule):
+ self._context = context
+ self._vm_module = vm_module
+ self._lazy_functions = dict()
+
+ @property
+ def name(self):
+ return self._vm_module.name
+
+ def __getattr__(self, name):
+ try:
+ return self[name]
+ except KeyError:
+ raise AttributeError(name)
+
+ def __getitem__(self, name):
+ vm_function = self._lazy_functions.get(name)
+ if vm_function is not None:
+ return vm_function
+
+ vm_function = self._vm_module.lookup_function(name)
+ if vm_function is None:
+ raise KeyError("Function '%s' not found in module '%s'" %
+ (name, self.name))
+ bound_function = BoundFunction(self._context, vm_function)
+ self._lazy_functions[name] = bound_function
+ return bound_function
+
+ def __repr__(self):
+ return "<BoundModule %r>" % (self._vm_module,)
+
+
+class Modules(dict):
+ """Provides nice python accessors for a dict of modules."""
+
+ def __getattr__(self, name):
+ try:
+ return self[name]
+ except KeyError:
+ raise AttributeError(name)
+
+
+class SystemContext:
+ """Global system."""
+
+ def __init__(self, modules=None, config: Optional[Config] = None):
+ self._config = config if config is not None else _get_global_config()
+ print("SystemContext driver=%r" % self._config.driver, file=sys.stderr)
+ self._is_dynamic = modules is None
+ if not self._is_dynamic:
+ init_modules = self._config.default_modules + tuple(modules)
+ else:
+ init_modules = None
+
+ self._vm_context = _binding.VmContext(
+ instance=self._config.vm_instance, modules=init_modules)
+
+ if self._is_dynamic:
+ self._vm_context.register_modules(self._config.default_modules)
+ self._modules = Modules([
+ (m.name, BoundModule(self, m)) for m in self._config.default_modules
+ ])
+ else:
+ self._modules = Modules([
+ (m.name, BoundModule(self, m)) for m in init_modules
+ ])
+
+ @property
+ def is_dynamic(self) -> bool:
+ return self._is_dynamic
+
+ @property
+ def config(self) -> Config:
+ return self._config
+
+ @property
+ def instance(self) -> _binding.VmInstance:
+ return self._instance
+
+ @property
+ def modules(self) -> Modules:
+ return self._modules
+
+ def create_function_abi(self, f: _binding.VmFunction) -> _binding.FunctionAbi:
+ return self._vm_context.create_function_abi(self._config.device,
+ self._config.host_type_factory,
+ f)
+
+ def add_modules(self, modules):
+ assert self._is_dynamic, "Cannot 'add_module' on a static context"
+ for m in modules:
+ name = m.name
+ if name in self._modules:
+ raise ValueError("Attempt to register duplicate module: '%s'" % (name,))
+ self._modules[m.name] = BoundModule(self, m)
+ self._vm_context.register_modules(modules)
+
+ def add_module(self, module):
+ self.add_modules((module,))
+
+
+def load_modules(*modules, config: Optional[Config] = None):
+ """Loads modules into a new or shared context and returns them."""
+ context = SystemContext(modules=modules, config=config)
+ context_modules = context.modules
+ bound_modules = [context_modules[m.name] for m in modules]
+ return bound_modules
+
+
+def load_module(module, **kwargs):
+ """Loads a module into a new or shared context and returns them."""
+ return load_modules(module, **kwargs)[0]
diff --git a/bindings/python/pyiree/rt/system_api_test.py b/bindings/python/pyiree/rt/system_api_test.py
new file mode 100644
index 0000000..adb70bb
--- /dev/null
+++ b/bindings/python/pyiree/rt/system_api_test.py
@@ -0,0 +1,104 @@
+# Lint as: python3
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# pylint: disable=unused-variable
+
+import re
+
+from absl.testing import absltest
+import numpy as np
+from pyiree import compiler
+from pyiree import rt
+
+
+def create_simple_mul_module():
+ ctx = compiler.Context()
+ input_module = ctx.parse_asm("""
+ module @arithmetic {
+ func @simple_mul(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32>
+ attributes { iree.module.export } {
+ %0 = "xla_hlo.mul"(%arg0, %arg1) {name = "mul.1"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
+ return %0 : tensor<4xf32>
+ }
+ }
+ """)
+ binary = input_module.compile()
+ m = rt.VmModule.from_flatbuffer(binary)
+ return m
+
+
+class SystemApiTest(absltest.TestCase):
+
+ def test_non_existing_driver(self):
+ with self.assertRaisesRegex(RuntimeError,
+ "Could not create any requested driver"):
+ config = rt.Config("nothere1,nothere2")
+
+ def test_subsequent_driver(self):
+ config = rt.Config("nothere1,interpreter")
+
+ def test_empty_dynamic(self):
+ ctx = rt.SystemContext()
+ self.assertTrue(ctx.is_dynamic)
+ self.assertIn("hal", ctx.modules)
+ self.assertEqual(ctx.modules.hal.name, "hal")
+
+ def test_empty_static(self):
+ ctx = rt.SystemContext(modules=())
+ self.assertFalse(ctx.is_dynamic)
+ self.assertIn("hal", ctx.modules)
+ self.assertEqual(ctx.modules.hal.name, "hal")
+
+ def test_custom_dynamic(self):
+ ctx = rt.SystemContext()
+ self.assertTrue(ctx.is_dynamic)
+ ctx.add_module(create_simple_mul_module())
+ self.assertEqual(ctx.modules.arithmetic.name, "arithmetic")
+ f = ctx.modules.arithmetic["simple_mul"]
+ f_repr = repr(f)
+ print(f_repr)
+ self.assertRegex(
+ f_repr,
+ re.escape(
+ "(Buffer<float32[4]>, Buffer<float32[4]>) -> (Buffer<float32[4]>)"))
+
+ def test_duplicate_module(self):
+ ctx = rt.SystemContext()
+ self.assertTrue(ctx.is_dynamic)
+ ctx.add_module(create_simple_mul_module())
+ with self.assertRaisesRegex(ValueError, "arithmetic"):
+ ctx.add_module(create_simple_mul_module())
+
+ def test_static_invoke(self):
+ ctx = rt.SystemContext()
+ self.assertTrue(ctx.is_dynamic)
+ ctx.add_module(create_simple_mul_module())
+ self.assertEqual(ctx.modules.arithmetic.name, "arithmetic")
+ f = ctx.modules.arithmetic["simple_mul"]
+ arg0 = np.array([1., 2., 3., 4.], dtype=np.float32)
+ arg1 = np.array([4., 5., 6., 7.], dtype=np.float32)
+ results = f(arg0, arg1)
+ np.testing.assert_allclose(results, [4., 10., 18., 28.])
+
+ def test_load_module(self):
+ arithmetic = rt.load_module(create_simple_mul_module())
+ arg0 = np.array([1., 2., 3., 4.], dtype=np.float32)
+ arg1 = np.array([4., 5., 6., 7.], dtype=np.float32)
+ results = arithmetic.simple_mul(arg0, arg1)
+ np.testing.assert_allclose(results, [4., 10., 18., 28.])
+
+
+if __name__ == "__main__":
+ absltest.main()
diff --git a/bindings/python/pyiree/rt/vm.cc b/bindings/python/pyiree/rt/vm.cc
new file mode 100644
index 0000000..ffdb76a
--- /dev/null
+++ b/bindings/python/pyiree/rt/vm.cc
@@ -0,0 +1,249 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "bindings/python/pyiree/rt/vm.h"
+
+#include "absl/strings/str_cat.h"
+#include "absl/types/optional.h"
+#include "bindings/python/pyiree/common/status_utils.h"
+#include "bindings/python/pyiree/rt/function_abi.h"
+#include "iree/base/api.h"
+#include "iree/modules/hal/hal_module.h"
+#include "iree/vm/invocation.h"
+#include "iree/vm/module.h"
+
+namespace iree {
+namespace python {
+
+namespace {
+
+VmModule CreateHalModule(HalDevice* device) {
+ iree_vm_module_t* module;
+ CheckApiStatus(
+ iree_hal_module_create(device->raw_ptr(), IREE_ALLOCATOR_SYSTEM, &module),
+ "Error creating hal module");
+ return VmModule::CreateRetained(module);
+}
+
+} // namespace
+
+//------------------------------------------------------------------------------
+// VmInstance
+//------------------------------------------------------------------------------
+
+VmInstance VmInstance::Create() {
+ iree_vm_instance_t* instance;
+ auto status = iree_vm_instance_create(IREE_ALLOCATOR_SYSTEM, &instance);
+ CheckApiStatus(status, "Error creating instance");
+ return VmInstance::CreateRetained(instance);
+}
+
+//------------------------------------------------------------------------------
+// VmContext
+//------------------------------------------------------------------------------
+
+VmContext VmContext::Create(VmInstance* instance,
+ absl::optional<std::vector<VmModule*>> modules) {
+ iree_vm_context_t* context;
+ if (!modules) {
+ // Simple create with open allowed modules.
+ auto status = iree_vm_context_create(instance->raw_ptr(),
+ IREE_ALLOCATOR_SYSTEM, &context);
+ CheckApiStatus(status, "Error creating vm context");
+ } else {
+ // Closed set of modules.
+ absl::InlinedVector<iree_vm_module_t*, 8> module_handles;
+ module_handles.resize(modules->size());
+ for (size_t i = 0, e = module_handles.size(); i < e; ++i) {
+ module_handles[i] = (*modules)[i]->raw_ptr();
+ }
+ auto status = iree_vm_context_create_with_modules(
+ instance->raw_ptr(), module_handles.data(), module_handles.size(),
+ IREE_ALLOCATOR_SYSTEM, &context);
+ CheckApiStatus(status, "Error creating vm context with modules");
+ }
+
+ CHECK(context);
+ return VmContext::CreateRetained(context);
+}
+
+void VmContext::RegisterModules(std::vector<VmModule*> modules) {
+ absl::InlinedVector<iree_vm_module_t*, 8> module_handles;
+ module_handles.resize(modules.size());
+ for (size_t i = 0, e = module_handles.size(); i < e; ++i) {
+ module_handles[i] = modules[i]->raw_ptr();
+ }
+ auto status = iree_vm_context_register_modules(raw_ptr(), &module_handles[0],
+ module_handles.size());
+ CheckApiStatus(status, "Error registering modules");
+}
+
+std::unique_ptr<FunctionAbi> VmContext::CreateFunctionAbi(
+ HalDevice& device, std::shared_ptr<HostTypeFactory> host_type_factory,
+ iree_vm_function_t f) {
+ // Resolve attrs.
+ absl::InlinedVector<std::pair<iree_string_view_t, iree_string_view_t>, 4>
+ attrs;
+ for (int i = 0;; ++i) {
+ attrs.push_back({});
+ auto status = iree_vm_get_function_reflection_attr(
+ f, i, &attrs.back().first, &attrs.back().second);
+ if (status == IREE_STATUS_NOT_FOUND) {
+ attrs.pop_back();
+ break;
+ }
+ CheckApiStatus(status, "Error getting reflection attr");
+ }
+ auto attr_lookup =
+ [&attrs](absl::string_view key) -> absl::optional<absl::string_view> {
+ for (const auto& attr : attrs) {
+ absl::string_view found_key(attr.first.data, attr.first.size);
+ absl::string_view found_value(attr.second.data, attr.second.size);
+ if (found_key == key) return found_value;
+ }
+ return absl::nullopt;
+ };
+
+ return FunctionAbi::Create(device, std::move(host_type_factory), attr_lookup);
+}
+
+void VmContext::Invoke(iree_vm_function_t f, VmVariantList& inputs,
+ VmVariantList& outputs) {
+ CheckApiStatus(iree_vm_invoke(raw_ptr(), f, nullptr, inputs.raw_ptr(),
+ outputs.raw_ptr(), IREE_ALLOCATOR_SYSTEM),
+ "Error invoking function");
+}
+
+//------------------------------------------------------------------------------
+// VmModule
+//------------------------------------------------------------------------------
+
+VmModule VmModule::FromFlatbufferBlob(py::buffer flatbuffer_blob) {
+ auto buffer_info = flatbuffer_blob.request();
+ iree_vm_module_t* module;
+
+ // Bridge to the C-based deallocator API.
+ auto* raw_ptr = flatbuffer_blob.ptr();
+ auto free_fn = +([](void* self, void*) -> iree_status_t {
+ PyObject* self_ptr = static_cast<PyObject*>(self);
+ Py_XDECREF(self_ptr);
+ return IREE_STATUS_OK;
+ });
+ flatbuffer_blob.inc_ref();
+ iree_allocator_t deallocator{raw_ptr /* self */, nullptr /* alloc */,
+ free_fn /* dealloc */};
+
+ auto status = iree_vm_bytecode_module_create(
+ {static_cast<const uint8_t*>(buffer_info.ptr),
+ static_cast<iree_host_size_t>(buffer_info.size)},
+ deallocator, IREE_ALLOCATOR_SYSTEM, &module);
+ if (status != IREE_STATUS_OK) {
+ deallocator.free(raw_ptr, nullptr);
+ }
+
+ CheckApiStatus(status, "Error creating vm module from flatbuffer");
+ return VmModule::CreateRetained(module);
+}
+
+absl::optional<iree_vm_function_t> VmModule::LookupFunction(
+ const std::string& name, iree_vm_function_linkage_t linkage) {
+ iree_vm_function_t f;
+ auto status = iree_vm_module_lookup_function_by_name(
+ raw_ptr(), linkage, {name.data(), name.size()}, &f);
+ if (status == IREE_STATUS_NOT_FOUND) {
+ return absl::nullopt;
+ }
+ CheckApiStatus(status, "Error looking up function");
+ return f;
+}
+
+//------------------------------------------------------------------------------
+// VmVariantList
+//------------------------------------------------------------------------------
+
+std::string VmVariantList::DebugString() const {
+ // The variant list API requires mutability, so we const cast to it internally
+ // so we can maintain a const DebugString() for callers.
+ auto mutable_this = const_cast<VmVariantList*>(this);
+ std::string s;
+ absl::StrAppend(&s, "<VmVariantList(", size(), "): [");
+
+ for (iree_host_size_t i = 0, e = size(); i < e; ++i) {
+ iree_vm_variant_t* variant =
+ iree_vm_variant_list_get(mutable_this->raw_ptr(), i);
+ if (i > 0) absl::StrAppend(&s, ", ");
+
+ if (IREE_VM_VARIANT_IS_VALUE(variant)) {
+ absl::StrAppend(&s, variant->i32);
+ } else if (IREE_VM_VARIANT_IS_REF(variant)) {
+ // Pretty print a subset of ABI impacting known types.
+ if (iree_hal_buffer_isa(&variant->ref)) {
+ auto* hal_buffer = iree_hal_buffer_deref(&variant->ref);
+ assert(hal_buffer);
+ absl::StrAppend(&s, "HalBuffer(",
+ iree_hal_buffer_byte_length(hal_buffer), ")");
+ } else {
+ absl::StrAppend(&s, "Unknown(", variant->ref_type, ")");
+ }
+ } else {
+ absl::StrAppend(&s, "None");
+ }
+ }
+ absl::StrAppend(&s, "]>");
+ return s;
+}
+
+void SetupVmBindings(pybind11::module m) {
+ CHECK_EQ(IREE_STATUS_OK, iree_vm_register_builtin_types());
+ CHECK_EQ(IREE_STATUS_OK, iree_hal_module_register_types());
+
+ // Built-in module creation.
+ m.def("create_hal_module", &CreateHalModule);
+
+ py::enum_<iree_vm_function_linkage_t>(m, "Linkage")
+ .value("INTERNAL", IREE_VM_FUNCTION_LINKAGE_INTERNAL)
+ .value("IMPORT", IREE_VM_FUNCTION_LINKAGE_IMPORT)
+ .value("EXPORT", IREE_VM_FUNCTION_LINKAGE_EXPORT)
+ .export_values();
+
+ // Mutation and inspection of the variant list is mostly opaque to python.
+ py::class_<VmVariantList>(m, "VmVariantList")
+ .def(py::init(&VmVariantList::Create))
+ .def_property_readonly("size", &VmVariantList::size)
+ .def("__repr__", &VmVariantList::DebugString);
+
+ py::class_<iree_vm_function_t>(m, "VmFunction")
+ .def_readonly("linkage", &iree_vm_function_t::linkage)
+ .def_readonly("ordinal", &iree_vm_function_t::ordinal);
+
+ py::class_<VmInstance>(m, "VmInstance").def(py::init(&VmInstance::Create));
+
+ py::class_<VmContext>(m, "VmContext")
+ .def(py::init(&VmContext::Create), py::arg("instance"),
+ py::arg("modules") = absl::optional<std::vector<VmModule*>>())
+ .def("register_modules", &VmContext::RegisterModules)
+ .def_property_readonly("context_id", &VmContext::context_id)
+ .def("create_function_abi", &VmContext::CreateFunctionAbi,
+ py::arg("device"), py::arg("host_type_factory"), py::arg("f"))
+ .def("invoke", &VmContext::Invoke);
+
+ py::class_<VmModule>(m, "VmModule")
+ .def_static("from_flatbuffer", &VmModule::FromFlatbufferBlob)
+ .def_property_readonly("name", &VmModule::name)
+ .def("lookup_function", &VmModule::LookupFunction, py::arg("name"),
+ py::arg("linkage") = IREE_VM_FUNCTION_LINKAGE_EXPORT);
+}
+
+} // namespace python
+} // namespace iree
diff --git a/bindings/python/pyiree/rt/vm.h b/bindings/python/pyiree/rt/vm.h
new file mode 100644
index 0000000..4688d82
--- /dev/null
+++ b/bindings/python/pyiree/rt/vm.h
@@ -0,0 +1,162 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_BINDINGS_PYTHON_PYIREE_RT_VM_H_
+#define IREE_BINDINGS_PYTHON_PYIREE_RT_VM_H_
+
+#include "absl/types/optional.h"
+#include "bindings/python/pyiree/common/binding.h"
+#include "bindings/python/pyiree/rt/host_types.h"
+#include "iree/base/api.h"
+#include "iree/vm/api.h"
+#include "iree/vm/bytecode_module.h"
+#include "iree/vm/variant_list.h"
+
+namespace iree {
+namespace python {
+
+class FunctionAbi;
+
+//------------------------------------------------------------------------------
+// Retain/release bindings
+//------------------------------------------------------------------------------
+
+template <>
+struct ApiPtrAdapter<iree_vm_instance_t> {
+ static void Retain(iree_vm_instance_t* b) { iree_vm_instance_retain(b); }
+ static void Release(iree_vm_instance_t* b) { iree_vm_instance_release(b); }
+};
+
+template <>
+struct ApiPtrAdapter<iree_vm_context_t> {
+ static void Retain(iree_vm_context_t* b) { iree_vm_context_retain(b); }
+ static void Release(iree_vm_context_t* b) { iree_vm_context_release(b); }
+};
+
+template <>
+struct ApiPtrAdapter<iree_vm_module_t> {
+ static void Retain(iree_vm_module_t* b) { iree_vm_module_retain(b); }
+ static void Release(iree_vm_module_t* b) { iree_vm_module_release(b); }
+};
+
+template <>
+struct ApiPtrAdapter<iree_vm_invocation_t> {
+ static void Retain(iree_vm_invocation_t* b) { iree_vm_invocation_retain(b); }
+ static void Release(iree_vm_invocation_t* b) {
+ iree_vm_invocation_release(b);
+ }
+};
+
+//------------------------------------------------------------------------------
+// VmVariantList
+//------------------------------------------------------------------------------
+
+class VmVariantList {
+ public:
+ VmVariantList() : list_(nullptr) {}
+ ~VmVariantList() {
+ if (list_) {
+ CheckApiStatus(iree_vm_variant_list_free(list_), "Error freeing list");
+ }
+ }
+
+ VmVariantList(VmVariantList&& other) {
+ list_ = other.list_;
+ other.list_ = nullptr;
+ }
+
+ VmVariantList& operator=(const VmVariantList&) = delete;
+ VmVariantList(const VmVariantList&) = delete;
+
+ static VmVariantList Create(iree_host_size_t capacity) {
+ iree_vm_variant_list_t* list;
+ CheckApiStatus(
+ iree_vm_variant_list_alloc(capacity, IREE_ALLOCATOR_SYSTEM, &list),
+ "Error allocating variant list");
+ return VmVariantList(list);
+ }
+
+ iree_host_size_t size() const { return iree_vm_variant_list_size(list_); }
+
+ iree_vm_variant_list_t* raw_ptr() { return list_; }
+ const iree_vm_variant_list_t* raw_ptr() const { return list_; }
+
+ void AppendNullRef() {
+ CheckApiStatus(iree_vm_variant_list_append_null_ref(raw_ptr()),
+ "Error appending to list");
+ }
+
+ std::string DebugString() const;
+
+ private:
+ VmVariantList(iree_vm_variant_list_t* list) : list_(list) {}
+ iree_vm_variant_list_t* list_;
+};
+
+//------------------------------------------------------------------------------
+// ApiRefCounted types
+//------------------------------------------------------------------------------
+
+class VmInstance : public ApiRefCounted<VmInstance, iree_vm_instance_t> {
+ public:
+ static VmInstance Create();
+};
+
+class VmModule : public ApiRefCounted<VmModule, iree_vm_module_t> {
+ public:
+ static VmModule FromFlatbufferBlob(py::buffer flatbuffer_blob);
+
+ absl::optional<iree_vm_function_t> LookupFunction(
+ const std::string& name, iree_vm_function_linkage_t linkage);
+
+ std::string name() const {
+ auto name_sv = iree_vm_module_name(raw_ptr());
+ return std::string(name_sv.data, name_sv.size);
+ }
+};
+
+class VmContext : public ApiRefCounted<VmContext, iree_vm_context_t> {
+ public:
+ // Creates a context, optionally with modules, which will make the context
+ // static, disallowing further module registration (and may be more
+ // efficient).
+ static VmContext Create(VmInstance* instance,
+ absl::optional<std::vector<VmModule*>> modules);
+
+ // Registers additional modules. Only valid for non static contexts (i.e.
+ // those created without modules.
+ void RegisterModules(std::vector<VmModule*> modules);
+
+ // Unique id for this context.
+ int context_id() const { return iree_vm_context_id(raw_ptr()); }
+
+ // Synchronously invokes the given function.
+ void Invoke(iree_vm_function_t f, VmVariantList& inputs,
+ VmVariantList& outputs);
+
+ // Creates a function ABI suitable for marshalling function inputs/results.
+ std::unique_ptr<FunctionAbi> CreateFunctionAbi(
+ HalDevice& device, std::shared_ptr<HostTypeFactory> host_type_factory,
+ iree_vm_function_t f);
+};
+
+class VmInvocation : public ApiRefCounted<VmInvocation, iree_vm_invocation_t> {
+};
+
+void SetupVmBindings(pybind11::module m);
+
+} // namespace python
+} // namespace iree
+
+#endif // IREE_BINDINGS_PYTHON_PYIREE_RT_VM_H_
diff --git a/bindings/python/pyiree/rt/vm_test.py b/bindings/python/pyiree/rt/vm_test.py
new file mode 100644
index 0000000..e947000
--- /dev/null
+++ b/bindings/python/pyiree/rt/vm_test.py
@@ -0,0 +1,104 @@
+# Lint as: python3
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# pylint: disable=unused-variable
+
+from absl.testing import absltest
+import numpy as np
+from pyiree import compiler
+from pyiree import rt
+
+
+def create_simple_mul_module():
+ ctx = compiler.Context()
+ input_module = ctx.parse_asm("""
+ func @simple_mul(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32>
+ attributes { iree.module.export } {
+ %0 = "xla_hlo.mul"(%arg0, %arg1) {name = "mul.1"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
+ return %0 : tensor<4xf32>
+ }
+ """)
+ binary = input_module.compile()
+ m = rt.VmModule.from_flatbuffer(binary)
+ return m
+
+
+class VmTest(absltest.TestCase):
+
+ @classmethod
+ def setUpClass(cls):
+ super().setUpClass()
+ driver_names = rt.HalDriver.query()
+ print("DRIVER_NAMES =", driver_names)
+ cls.driver = rt.HalDriver.create("vulkan")
+ cls.device = cls.driver.create_default_device()
+ cls.hal_module = rt.create_hal_module(cls.device)
+ cls.htf = rt.HostTypeFactory.get_numpy()
+
+ def test_variant_list(self):
+ l = rt.VmVariantList(5)
+ print(l)
+ self.assertEqual(l.size, 0)
+
+ def test_context_id(self):
+ instance = rt.VmInstance()
+ context1 = rt.VmContext(instance)
+ context2 = rt.VmContext(instance)
+ self.assertGreater(context2.context_id, context1.context_id)
+
+ def test_module_basics(self):
+ m = create_simple_mul_module()
+ f = m.lookup_function("simple_mul")
+ self.assertGreater(f.ordinal, 0)
+ notfound = m.lookup_function("notfound")
+ self.assertIs(notfound, None)
+
+ def test_dynamic_module_context(self):
+ instance = rt.VmInstance()
+ context = rt.VmContext(instance)
+ m = create_simple_mul_module()
+ context.register_modules([self.hal_module, m])
+
+ def test_static_module_context(self):
+ m = create_simple_mul_module()
+ print(m)
+ instance = rt.VmInstance()
+ print(instance)
+ context = rt.VmContext(instance, modules=[self.hal_module, m])
+ print(context)
+
+ def test_synchronous_invoke_function(self):
+ m = create_simple_mul_module()
+ instance = rt.VmInstance()
+ context = rt.VmContext(instance, modules=[self.hal_module, m])
+ f = m.lookup_function("simple_mul")
+ abi = context.create_function_abi(self.device, self.htf, f)
+ print("INVOKING:", abi)
+ arg0 = np.array([1., 2., 3., 4.], dtype=np.float32)
+ arg1 = np.array([4., 5., 6., 7.], dtype=np.float32)
+ inputs = abi.raw_pack_inputs((arg0, arg1))
+ print("INPUTS:", inputs)
+ allocated_results = abi.allocate_results(inputs, static_alloc=False)
+ print("ALLOCATED RESULTS:", allocated_results)
+ print("--- INVOKE:")
+ context.invoke(f, inputs, allocated_results)
+ print("--- DONE.")
+ results = abi.raw_unpack_results(allocated_results)
+ print("RESULTS:", results)
+ np.testing.assert_allclose(results[0], [4., 10., 18., 28.])
+
+
+if __name__ == "__main__":
+ absltest.main()