First step of splitting python bindings.
* As discussed on mailing list, introduces pyiree.{compiler,rt,tf.support}
* Separate module shared libraries for compiler and rt
* Does not yet fully separate the TensorFlow compiler from the generic compiler (follow-on)
* Fixes namespace fallout
* This will almost certainly require patching into various platforms and verifying. Please pre-review.
PiperOrigin-RevId: 288361204
diff --git a/bindings/python/pyiree/BUILD b/bindings/python/pyiree/BUILD
index 9ee6bbc..1cddab5 100644
--- a/bindings/python/pyiree/BUILD
+++ b/bindings/python/pyiree/BUILD
@@ -12,252 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-load(
- "//iree:build_defs.bzl",
- "NUMPY_DEPS",
- "PLATFORM_VULKAN_DEPS",
- "PYTHON_HEADERS_DEPS",
- "iree_py_extension",
-)
-
package(
default_visibility = ["//visibility:public"],
licenses = ["notice"], # Apache 2.0
)
-
-DEFAULT_COPTS = [
- "-fexceptions",
-]
-
-DEFAULT_FEATURES = [
- "-use_header_modules", # Incompatible with exceptions builds.
-]
-
-COMPILER_DEPS = [
- # Transforms.
- "//iree/compiler/Dialect/Flow/Transforms",
- "//iree/compiler/Dialect/HAL/Transforms",
- "//iree/compiler/Dialect/HAL/Target:ExecutableTarget",
- "//iree/compiler/Dialect/VM/Transforms",
-
- # Targets.
- "//iree/compiler/Dialect/HAL/Target/LegacyInterpreter",
- "//iree/compiler/Dialect/HAL/Target/VMLA",
- "//iree/compiler/Dialect/HAL/Target/VulkanSPIRV",
- "//iree/compiler/Dialect/VM/Target/Bytecode",
-]
-
-DRIVER_DEPS = PLATFORM_VULKAN_DEPS + [
- "//iree/hal/interpreter:interpreter_driver_module",
- "//iree/hal/vulkan:vulkan_driver_module",
-]
-
-py_binary(
- name = "everything_for_colab",
- srcs = ["dummy.py"],
- main = "dummy.py",
- python_version = "PY3",
- # TODO(b/145815906) Get this running in OSS CI.
- tags = ["nokokoro"],
- deps = [
- ":pyiree", # build_cleaner: keep
- "//bindings/python:pathsetup", # build_cleaner: keep
- ],
-)
-
-py_library(
- name = "pyiree",
- srcs = [
- "__init__.py",
- ],
- srcs_version = "PY3",
- # TODO(b/145815906) Get this running in OSS CI.
- tags = ["nokokoro"],
- deps = [
- ":binding",
- ":compiler",
- ":system_api",
- "//bindings/python:pathsetup", # build_cleaner: keep
- ] + select({
- "//iree:enable_tensorflow": [
- "//bindings/python/pyiree/tf_interop:test_utils",
- "//bindings/python/pyiree/tf_interop:tf_test_driver",
- ],
- "//conditions:default": [
- ],
- }),
-)
-
-py_library(
- name = "compiler",
- srcs = ["compiler.py"],
- srcs_version = "PY3",
- # TODO(b/145815906) Get this running in OSS CI.
- tags = ["nokokoro"],
- deps = [
- ":binding",
- "//bindings/python:pathsetup", # build_cleaner: keep
- ],
-)
-
-py_library(
- name = "system_api",
- srcs = ["system_api.py"],
- srcs_version = "PY3",
- # TODO(b/145815906) Get this running in OSS CI.
- tags = ["nokokoro"],
- deps = [
- ":binding",
- "//bindings/python:pathsetup", # build_cleaner: keep
- ],
-)
-
-cc_library(
- name = "base",
- srcs = [
- "compiler.cc",
- "function_abi.cc",
- "hal.cc",
- "host_types.cc",
- "status_utils.cc",
- "vm.cc",
- ],
- hdrs = [
- "binding.h",
- "compiler.h",
- "function_abi.h",
- "hal.h",
- "host_types.h",
- "status_utils.h",
- "vm.h",
- ],
- copts = DEFAULT_COPTS,
- features = DEFAULT_FEATURES,
- # TODO(b/145815906) Get this running in OSS CI.
- tags = ["nokokoro"],
- deps = [
- "@com_google_absl//absl/container:inlined_vector",
- "@com_google_absl//absl/memory",
- "@com_google_absl//absl/strings",
- "@com_google_absl//absl/types:optional",
- "@com_google_absl//absl/types:span",
- "@com_google_absl//absl/types:variant",
- "//iree/compiler/Utils",
- "//iree/modules/hal",
- "//iree/vm",
- "//iree/vm:bytecode_module",
- "//iree/vm:invocation",
- "//iree/vm:module",
- "//iree/vm:ref",
- "//iree/vm:variant_list",
- "@llvm-project//mlir:IR",
- "//iree/base:api",
- "//iree/base:status",
- "//iree/base:signature_mangle",
- "//iree/hal:api",
- "@llvm-project//llvm:support",
- "@llvm-project//mlir:Parser",
- "@llvm-project//mlir:Pass",
- "@iree_pybind11//:pybind11",
- ] + COMPILER_DEPS + DRIVER_DEPS + PYTHON_HEADERS_DEPS,
-)
-
-iree_py_extension(
- name = "binding",
- srcs = [
- "initialize_module.cc",
- ],
- copts = DEFAULT_COPTS,
- features = DEFAULT_FEATURES,
- linkstatic = 1,
- # TODO(b/145815906) Get this running in OSS CI.
- tags = ["nokokoro"],
- win_def_file = "export.def",
- deps = [
- ":base",
- "//bindings/python/pyiree/tf_interop",
- "//iree/base:initializer",
- "//iree/base:tracing",
- "@com_google_tracing_framework_cpp//:tracing_framework_bindings_cpp",
- ],
-)
-
-py_test(
- name = "compiler_test",
- srcs = ["compiler_test.py"],
- python_version = "PY3",
- # TODO(b/145815906) Get this running in OSS CI.
- tags = ["nokokoro"],
- deps = [
- "//bindings/python:pathsetup", # build_cleaner: keep
- "//bindings/python/pyiree",
- "@absl_py//absl/testing:absltest",
- ],
-)
-
-py_test(
- name = "function_abi_test",
- srcs = ["function_abi_test.py"],
- python_version = "PY3",
- # TODO(laurenzo): Enable once test does not depend on a real vulkan device.
- tags = [
- "noga",
- "nokokoro",
- "notap",
- ],
- deps = NUMPY_DEPS + [
- "//bindings/python:pathsetup", # build_cleaner: keep
- "@absl_py//absl/testing:absltest",
- "//bindings/python/pyiree",
- ],
- # TODO(b/145815906) Get this running in OSS CI.
-)
-
-py_test(
- name = "hal_test",
- srcs = ["hal_test.py"],
- python_version = "PY3",
- # TODO(b/145815906) Get this running in OSS CI.
- tags = ["nokokoro"],
- deps = NUMPY_DEPS + [
- "//bindings/python:pathsetup", # build_cleaner: keep
- "@absl_py//absl/testing:absltest",
- "//bindings/python/pyiree",
- ],
-)
-
-py_test(
- name = "system_api_test",
- srcs = ["system_api_test.py"],
- python_version = "PY3",
- tags = [
- # TODO(laurenzo): Enable once test does not depend on a real vulkan device.
- "notap",
- # TODO(b/145815906) Get this running in OSS CI.
- "noga",
- "nokokoro",
- ],
- deps = NUMPY_DEPS + [
- ":system_api",
- "//bindings/python:pathsetup", # build_cleaner: keep
- "@absl_py//absl/testing:absltest",
- "//bindings/python/pyiree",
- ],
-)
-
-py_test(
- name = "vm_test",
- srcs = ["vm_test.py"],
- python_version = "PY3",
- # TODO(laurenzo): Enable once test does not depend on a real vulkan device.
- tags = [
- "noga",
- "nokokoro",
- "notap",
- ],
- deps = NUMPY_DEPS + [
- "//bindings/python:pathsetup", # build_cleaner: keep
- "@absl_py//absl/testing:absltest",
- "//bindings/python/pyiree",
- ],
-)
diff --git a/bindings/python/pyiree/binding.h b/bindings/python/pyiree/binding.h
deleted file mode 100644
index 4346d82..0000000
--- a/bindings/python/pyiree/binding.h
+++ /dev/null
@@ -1,193 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_BINDINGS_PYTHON_PYIREE_BINDING_H_
-#define IREE_BINDINGS_PYTHON_PYIREE_BINDING_H_
-
-#include <vector>
-
-#include "absl/types/optional.h"
-#include "iree/base/api.h"
-#include "pybind11/pybind11.h"
-#include "pybind11/stl.h"
-
-namespace pybind11 {
-namespace detail {
-#if !defined(ABSL_HAVE_STD_OPTIONAL)
-// Make absl::optional act like the future C++17 optional for pybind11.
-// If ABSL_HAVE_STD_OPTIONAL is defined then absl::optional == std::optional
-// and the default type caster is sufficient.
-template <typename T>
-struct type_caster<absl::optional<T>> : optional_caster<absl::optional<T>> {};
-#endif
-} // namespace detail
-} // namespace pybind11
-
-namespace iree {
-namespace python {
-
-namespace py = pybind11;
-
-// Wrapper around a blob of memory.
-// Used to transport blobs back and forth between C++ and Python.
-class OpaqueBlob {
- public:
- OpaqueBlob() : data_(nullptr), size_(0) {}
- OpaqueBlob(void* data, size_t size) : data_(data), size_(size) {}
- virtual ~OpaqueBlob() = default;
-
- void* data() { return data_; }
- const void* data() const { return data_; }
- size_t size() const { return size_; }
-
- // Create a free function from the OpaqueBlob shared pointer.
- using BufferFreeFn = void (*)(void* self, iree_byte_span_t);
- static std::pair<BufferFreeFn, void*> CreateFreeFn(
- std::shared_ptr<OpaqueBlob> blob) {
- // Note that there are more efficient ways to write this which
- // don't bounce through an extra heap alloc, but this is not
- // intended to be a high impact code path.
- struct Holder {
- std::shared_ptr<OpaqueBlob> blob;
- };
- Holder* holder = new Holder{std::move(blob)};
- auto free_fn = +([](void* self, iree_byte_span_t) {
- Holder* self_holder = static_cast<Holder*>(self);
- delete self_holder;
- });
- return {free_fn, holder};
- }
-
- static iree_allocator_t CreateDeallocator(std::shared_ptr<OpaqueBlob> blob) {
- // Note that there are more efficient ways to write this which
- // don't bounce through an extra heap alloc, but this is not
- // intended to be a high impact code path.
- struct Holder {
- std::shared_ptr<OpaqueBlob> blob;
- };
- Holder* holder = new Holder{std::move(blob)};
- auto free_fn = +([](void* self, void*) -> iree_status_t {
- Holder* self_holder = static_cast<Holder*>(self);
- delete self_holder;
- return IREE_STATUS_OK;
- });
- return {holder /* self */, nullptr /* alloc */, free_fn /* free */};
- }
-
- protected:
- void* data_;
- size_t size_;
-};
-
-// Opaque blob that owns a vector.
-class OpaqueByteVectorBlob : public OpaqueBlob {
- public:
- OpaqueByteVectorBlob(std::vector<uint8_t> v)
- : OpaqueBlob(), v_(std::move(v)) {
- data_ = v_.data();
- size_ = v_.size();
- }
-
- private:
- std::vector<uint8_t> v_;
-};
-
-class OpaqueStringBlob : public OpaqueBlob {
- public:
- OpaqueStringBlob(std::string s) : OpaqueBlob(), s_(std::move(s)) {
- data_ = &s_[0];
- size_ = s_.size();
- }
-
- private:
- std::string s_;
-};
-
-template <typename T>
-struct ApiPtrAdapter {};
-
-template <typename Self, typename T>
-class ApiRefCounted {
- public:
- ApiRefCounted() : instance_(nullptr) {}
- ApiRefCounted(ApiRefCounted&& other) : instance_(other.instance_) {
- other.instance_ = nullptr;
- }
- ApiRefCounted& operator=(ApiRefCounted&& other) {
- instance_ = other.instance_;
- other.instance_ = nullptr;
- return *this;
- }
- void operator=(const ApiRefCounted&) = delete;
-
- ~ApiRefCounted() { Release(); }
-
- // Creates an instance of the ref counted wrapper based on an instance
- // that has already been retained. Ownership is transferred to the
- // wrapper.
- static Self CreateRetained(T* retained_inst) {
- auto self = Self();
- self.instance_ = retained_inst;
- return self;
- }
-
- // Creates a new instance, retaining the underlying object.
- static Self RetainAndCreate(T* non_retained_inst) {
- auto self = Self();
- self.instance_ = non_retained_inst;
- if (non_retained_inst) {
- ApiPtrAdapter<T>::Retain(non_retained_inst);
- }
- return self;
- }
-
- // Whether it is nullptr.
- operator bool() const { return instance_; }
-
- T* steal_raw_ptr() {
- T* ret = instance_;
- instance_ = nullptr;
- return ret;
- }
-
- T* raw_ptr() {
- if (!instance_) {
- throw std::invalid_argument("API object is null");
- }
- return instance_;
- }
-
- const T* raw_ptr() const {
- return const_cast<ApiRefCounted*>(this)->raw_ptr();
- }
-
- void Retain() {
- if (instance_) {
- ApiPtrAdapter<T>::Retain(instance_);
- }
- }
- void Release() {
- if (instance_) {
- ApiPtrAdapter<T>::Release(instance_);
- }
- }
-
- private:
- T* instance_;
-};
-
-} // namespace python
-} // namespace iree
-
-#endif // IREE_BINDINGS_PYTHON_PYIREE_BINDING_H_
diff --git a/bindings/python/pyiree/build_defs.bzl b/bindings/python/pyiree/build_defs.bzl
new file mode 100644
index 0000000..3a496e0
--- /dev/null
+++ b/bindings/python/pyiree/build_defs.bzl
@@ -0,0 +1,63 @@
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Macros for building IREE python extensions."""
+
+# Re-export various top-level things.
+# We do this to keep the python bindings fairly self-contained and
+# able to move to a dedicated repo more easily.
+
+load(
+ "//iree:build_defs.bzl",
+ "cc_library",
+ _NUMPY_DEPS = "NUMPY_DEPS",
+ _PLATFORM_VULKAN_DEPS = "PLATFORM_VULKAN_DEPS",
+ _PYTHON_HEADER_DEPS = "PYTHON_HEADERS_DEPS",
+ _iree_py_extension = "iree_py_extension",
+)
+
+NUMPY_DEPS = _NUMPY_DEPS
+PLATFORM_VULKAN_DEPS = _PLATFORM_VULKAN_DEPS
+PYTHON_HEADER_DEPS = _PYTHON_HEADER_DEPS
+iree_py_extension = _iree_py_extension
+
+PYBIND_COPTS = [
+ "-fexceptions",
+]
+
+PYBIND_FEATURES = [
+ "-use_header_modules", # Incompatible with exceptions builds.
+]
+
+PYBIND_EXTENSION_COPTS = [
+ "-fvisibility=hidden",
+]
+
+def pybind_cc_library(
+ name,
+ copts = [],
+ features = [],
+ deps = [
+ ],
+ **kwargs):
+ """Wrapper cc_library for deps that are part of the python bindings."""
+ cc_library(
+ name = name,
+ copts = PYBIND_COPTS,
+ features = PYBIND_FEATURES,
+ deps = [
+ "@iree_pybind11//:pybind11",
+ ] + deps + PYTHON_HEADER_DEPS,
+ **kwargs
+ )
diff --git a/bindings/python/pyiree/common/BUILD b/bindings/python/pyiree/common/BUILD
new file mode 100644
index 0000000..c261390
--- /dev/null
+++ b/bindings/python/pyiree/common/BUILD
@@ -0,0 +1,40 @@
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+load(
+ "//bindings/python/pyiree:build_defs.bzl",
+ "pybind_cc_library",
+)
+
+package(
+ default_visibility = ["//visibility:public"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+pybind_cc_library(
+ name = "common",
+ srcs = [
+ "status_utils.cc",
+ ],
+ hdrs = [
+ "binding.h",
+ "status_utils.h",
+ ],
+ deps = [
+ "//iree/base:api",
+ "//iree/base:status",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:optional",
+ ],
+)
diff --git a/bindings/python/pyiree/common/binding.h b/bindings/python/pyiree/common/binding.h
new file mode 100644
index 0000000..2a81b2c
--- /dev/null
+++ b/bindings/python/pyiree/common/binding.h
@@ -0,0 +1,118 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_BINDINGS_PYTHON_PYIREE_BINDING_H_
+#define IREE_BINDINGS_PYTHON_PYIREE_BINDING_H_
+
+#include <vector>
+
+#include "absl/types/optional.h"
+#include "iree/base/api.h"
+#include "pybind11/pybind11.h"
+#include "pybind11/stl.h"
+
+namespace pybind11 {
+namespace detail {
+#if !defined(ABSL_HAVE_STD_OPTIONAL)
+// Make absl::optional act like the future C++17 optional for pybind11.
+// If ABSL_HAVE_STD_OPTIONAL is defined then absl::optional == std::optional
+// and the default type caster is sufficient.
+template <typename T>
+struct type_caster<absl::optional<T>> : optional_caster<absl::optional<T>> {};
+#endif
+} // namespace detail
+} // namespace pybind11
+
+namespace iree {
+namespace python {
+
+namespace py = pybind11;
+
+template <typename T>
+struct ApiPtrAdapter {};
+
+template <typename Self, typename T>
+class ApiRefCounted {
+ public:
+ ApiRefCounted() : instance_(nullptr) {}
+ ApiRefCounted(ApiRefCounted&& other) : instance_(other.instance_) {
+ other.instance_ = nullptr;
+ }
+ ApiRefCounted& operator=(ApiRefCounted&& other) {
+ instance_ = other.instance_;
+ other.instance_ = nullptr;
+ return *this;
+ }
+ void operator=(const ApiRefCounted&) = delete;
+
+ ~ApiRefCounted() { Release(); }
+
+ // Creates an instance of the ref counted wrapper based on an instance
+ // that has already been retained. Ownership is transferred to the
+ // wrapper.
+ static Self CreateRetained(T* retained_inst) {
+ auto self = Self();
+ self.instance_ = retained_inst;
+ return self;
+ }
+
+ // Creates a new instance, retaining the underlying object.
+ static Self RetainAndCreate(T* non_retained_inst) {
+ auto self = Self();
+ self.instance_ = non_retained_inst;
+ if (non_retained_inst) {
+ ApiPtrAdapter<T>::Retain(non_retained_inst);
+ }
+ return self;
+ }
+
+ // Whether it is nullptr.
+ operator bool() const { return instance_; }
+
+ T* steal_raw_ptr() {
+ T* ret = instance_;
+ instance_ = nullptr;
+ return ret;
+ }
+
+ T* raw_ptr() {
+ if (!instance_) {
+ throw std::invalid_argument("API object is null");
+ }
+ return instance_;
+ }
+
+ const T* raw_ptr() const {
+ return const_cast<ApiRefCounted*>(this)->raw_ptr();
+ }
+
+ void Retain() {
+ if (instance_) {
+ ApiPtrAdapter<T>::Retain(instance_);
+ }
+ }
+ void Release() {
+ if (instance_) {
+ ApiPtrAdapter<T>::Release(instance_);
+ }
+ }
+
+ private:
+ T* instance_;
+};
+
+} // namespace python
+} // namespace iree
+
+#endif // IREE_BINDINGS_PYTHON_PYIREE_BINDING_H_
diff --git a/bindings/python/pyiree/status_utils.cc b/bindings/python/pyiree/common/status_utils.cc
similarity index 82%
rename from bindings/python/pyiree/status_utils.cc
rename to bindings/python/pyiree/common/status_utils.cc
index 444b4df..e1b7a93 100644
--- a/bindings/python/pyiree/status_utils.cc
+++ b/bindings/python/pyiree/common/status_utils.cc
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "bindings/python/pyiree/status_utils.h"
+#include "bindings/python/pyiree/common/status_utils.h"
#include "absl/strings/str_cat.h"
@@ -34,8 +34,8 @@
}
}
-PyObject* ApiStatusToPyExcClass(iree_status_code_t status_code) {
- switch (status_code) {
+PyObject* ApiStatusToPyExcClass(iree_status_t status) {
+ switch (status) {
case IREE_STATUS_INVALID_ARGUMENT:
return PyExc_ValueError;
case IREE_STATUS_OUT_OF_RANGE:
@@ -57,11 +57,9 @@
pybind11::error_already_set ApiStatusToPyExc(iree_status_t status,
const char* message) {
- assert(!iree_status_is_ok(status));
- auto full_message = absl::StrCat(
- message, ": ", iree_status_code_string(iree_status_code(status)));
- PyErr_SetString(ApiStatusToPyExcClass(iree_status_code(status)),
- full_message.c_str());
+ assert(status != IREE_STATUS_OK);
+ auto full_message = absl::StrCat(message, ": ", static_cast<int>(status));
+ PyErr_SetString(ApiStatusToPyExcClass(status), full_message.c_str());
return pybind11::error_already_set();
}
diff --git a/bindings/python/pyiree/status_utils.h b/bindings/python/pyiree/common/status_utils.h
similarity index 89%
rename from bindings/python/pyiree/status_utils.h
rename to bindings/python/pyiree/common/status_utils.h
index 23559b5..3f4956d 100644
--- a/bindings/python/pyiree/status_utils.h
+++ b/bindings/python/pyiree/common/status_utils.h
@@ -12,12 +12,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#ifndef IREE_BINDINGS_PYTHON_PYIREE_STATUS_UTILS_H_
-#define IREE_BINDINGS_PYTHON_PYIREE_STATUS_UTILS_H_
+#ifndef IREE_BINDINGS_PYTHON_PYIREE_COMMON_STATUS_UTILS_H_
+#define IREE_BINDINGS_PYTHON_PYIREE_COMMON_STATUS_UTILS_H_
#include "iree/base/api.h"
#include "iree/base/status.h"
-#include "pybind11/pytypes.h"
+#include "pybind11/pybind11.h"
namespace iree {
namespace python {
@@ -57,7 +57,7 @@
const char* message);
inline void CheckApiStatus(iree_status_t status, const char* message) {
- if (iree_status_is_ok(status)) {
+ if (status == IREE_STATUS_OK) {
return;
}
throw ApiStatusToPyExc(status, message);
@@ -72,4 +72,4 @@
} // namespace python
} // namespace iree
-#endif // IREE_BINDINGS_PYTHON_PYIREE_STATUS_UTILS_H_
+#endif // IREE_BINDINGS_PYTHON_PYIREE_COMMON_STATUS_UTILS_H_
diff --git a/bindings/python/pyiree/compiler.py b/bindings/python/pyiree/compiler.py
deleted file mode 100644
index 80a9ad5..0000000
--- a/bindings/python/pyiree/compiler.py
+++ /dev/null
@@ -1,117 +0,0 @@
-# Lint as: python3
-# Copyright 2019 Google LLC
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# https://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""High level compiler API.
-
-This imports parts of the native bindings as appropriate.
-"""
-
-from typing import Collection, Optional, Sequence
-
-from . import binding as _binding
-
-# Native aliases.
-Context = _binding.compiler.CompilerContext
-Module = _binding.compiler.CompilerModule
-CompileOptions = _binding.compiler.CompileOptions
-OutputFormat = _binding.compiler.OutputFormat
-
-# Conditionally import TensorFlow interop aliases.
-HAS_TENSORFLOW = hasattr(_binding, "tf_interop")
-
-if HAS_TENSORFLOW:
- # Pass pipeline that should run to lower a TF saved_model to a form suitable
- # for input to the IREE compiler.
- TF_IMPORT_PASS_PIPELINE = (
- # Clean up tf_executor and extraneous unused functions.
- "tf-saved-model-delete-unused-funcs",
- "tf-executor-graph-pruning",
- "tf-standard-pipeline",
- "canonicalize",
-
- # Clean up control flow
- "tf-functional-control-flow-to-cfg",
- "inline",
- "tf-saved-model-delete-unused-funcs",
- "canonicalize",
-
- # Legalize to XLA
- "xla-legalize-tf{allow-partial-conversion=true}",
- "canonicalize",
-
- # Now that the IR is starting to look nice, optimize global tensors.
- "tf-saved-model-optimize-global-tensors",
-
- # Adopt saved_model exports into IREE.
- "iree-tf-saved-model-adopt-exports",
- )
-
- def tf_load_saved_model(
- saved_model_dir: str,
- compiler_context: Optional[Context] = None,
- exported_names: Collection[str] = (),
- pass_pipeline: Sequence[str] = TF_IMPORT_PASS_PIPELINE) -> Module:
- """Loads a TensorFlow saved model from its persistent representation.
-
- See also tf_compile_saved_model() for a one-shot API to load and compile.
-
- Args:
- saved_model_dir: Directory of the saved model.
- compiler_context: The pyiree.compiler.Context() backing the module.
- exported_names: Optional tuple of strings representing the exported names
- to keep.
- pass_pipeline: Passes to run on the imported module prior to returning.
- Defaults to TF_IMPORT_PASS_PIPELINE.
-
- Returns:
- An MLIR Module suitable for compilation by the IREE compiler.
- This can be further compiled to an IREE blob by calling
- .compile_to_sequencer_blob.
- """
- if not compiler_context:
- compiler_context = Context()
- input_module = _binding.tf_interop.load_saved_model(
- compiler_context, saved_model_dir, exported_names=exported_names)
- if pass_pipeline:
- input_module.run_pass_pipeline(pass_pipeline)
- return input_module
-
- def tf_compile_saved_model(
- saved_model_dir: str,
- compiler_context: Optional[Context] = None,
- exported_names: Collection[str] = (),
- pass_pipeline: Sequence[str] = TF_IMPORT_PASS_PIPELINE,
- print_mlir: bool = False,
- target_backends: Collection[str] = ()
- ) -> _binding.OpaqueBlob:
- """Loads and compiles a TensorFlow saved model in one shot.
-
- Args:
- saved_model_dir: Directory of the saved model.
- compiler_context: The pyiree.compiler.Context() backing the module.
- exported_names: Optional tuple of strings representing the exported names
- to keep.
- pass_pipeline: Passes to run on the imported module prior to returning.
- Defaults to TF_IMPORT_PASS_PIPELINE.
- print_mlir: Whether to print intermediate MLIR after each pass.
- target_backends: The specific target backends to compile for (defaults to
- all compiled in targets).
-
- Returns:
- An OpaqueBlob representing the compiled module.
- """
- input_module = tf_load_saved_model(saved_model_dir, compiler_context,
- exported_names, pass_pipeline)
- return input_module.compile_to_sequencer_blob(
- print_mlir=print_mlir, target_backends=target_backends)
diff --git a/bindings/python/pyiree/compiler/BUILD b/bindings/python/pyiree/compiler/BUILD
new file mode 100644
index 0000000..5c2765d
--- /dev/null
+++ b/bindings/python/pyiree/compiler/BUILD
@@ -0,0 +1,99 @@
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+load(
+ "//bindings/python/pyiree:build_defs.bzl",
+ "NUMPY_DEPS",
+ "PYBIND_COPTS",
+ "PYBIND_EXTENSION_COPTS",
+ "PYBIND_FEATURES",
+ "iree_py_extension",
+ "pybind_cc_library",
+)
+
+package(
+ default_visibility = ["//visibility:public"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+COMPILER_DEPS = [
+ # Transforms.
+ "//iree/compiler/Dialect/Flow/Transforms",
+ "//iree/compiler/Dialect/HAL/Transforms",
+ "//iree/compiler/Dialect/HAL/Target:ExecutableTarget",
+ "//iree/compiler/Dialect/VM/Transforms",
+
+ # Targets.
+ "//iree/compiler/Dialect/HAL/Target/LegacyInterpreter",
+ "//iree/compiler/Dialect/HAL/Target/VMLA",
+ "//iree/compiler/Dialect/HAL/Target/VulkanSPIRV",
+ "//iree/compiler/Dialect/VM/Target/Bytecode",
+]
+
+py_library(
+ name = "compiler",
+ srcs = [
+ "__init__.py",
+ "conditional_tensorflow.py",
+ ],
+ srcs_version = "PY3",
+ deps = [
+ ":binding",
+ "//bindings/python:pathsetup", # build_cleaner: keep
+ ],
+)
+
+iree_py_extension(
+ name = "binding",
+ srcs = [
+ "initialize_module.cc",
+ ],
+ copts = PYBIND_COPTS + PYBIND_EXTENSION_COPTS,
+ features = PYBIND_FEATURES,
+ linkstatic = 1,
+ win_def_file = "export.def",
+ deps = [
+ ":compiler_library",
+ "//bindings/python/pyiree/common",
+ "//bindings/python/pyiree/compiler/tf_interop",
+ ],
+)
+
+pybind_cc_library(
+ name = "compiler_library",
+ srcs = [
+ "compiler.cc",
+ ],
+ hdrs = [
+ "compiler.h",
+ ],
+ deps = COMPILER_DEPS + [
+ "//bindings/python/pyiree/common",
+ "@llvm-project//llvm:support",
+ "@llvm-project//mlir:IR",
+ "@llvm-project//mlir:Parser",
+ "@llvm-project//mlir:Pass",
+ ],
+)
+
+py_test(
+ name = "compiler_test",
+ srcs = ["compiler_test.py"],
+ python_version = "PY3",
+ deps = NUMPY_DEPS + [
+ "//bindings/python:pathsetup", # build_cleaner: keep
+ "@absl_py//absl/testing:absltest",
+ "//bindings/python/pyiree/compiler",
+ ],
+)
diff --git a/bindings/python/pyiree/compiler/__init__.py b/bindings/python/pyiree/compiler/__init__.py
new file mode 100644
index 0000000..b1c5fa0
--- /dev/null
+++ b/bindings/python/pyiree/compiler/__init__.py
@@ -0,0 +1,33 @@
+"""Module init for the python bindings."""
+
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# pylint: disable=g-multiple-import
+# pylint: disable=g-bad-import-order
+# pylint: disable=g-import-not-at-top
+# pylint: disable=wildcard-import
+
+from . import binding as binding
+
+# Native aliases.
+Context = binding.CompilerContext
+Module = binding.CompilerModule
+CompileOptions = binding.CompileOptions
+OutputFormat = binding.OutputFormat
+
+# Conditionally import TensorFlow interop aliases.
+HAS_TENSORFLOW = hasattr(binding, "tf_interop")
+if HAS_TENSORFLOW:
+ from .conditional_tensorflow import *
diff --git a/bindings/python/pyiree/compiler.cc b/bindings/python/pyiree/compiler/compiler.cc
similarity index 98%
rename from bindings/python/pyiree/compiler.cc
rename to bindings/python/pyiree/compiler/compiler.cc
index a695194..768574b 100644
--- a/bindings/python/pyiree/compiler.cc
+++ b/bindings/python/pyiree/compiler/compiler.cc
@@ -12,13 +12,13 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "bindings/python/pyiree/compiler.h"
+#include "bindings/python/pyiree/compiler/compiler.h"
#include <stdexcept>
#include <string>
-#include "bindings/python/pyiree/binding.h"
-#include "bindings/python/pyiree/status_utils.h"
+#include "bindings/python/pyiree/common/binding.h"
+#include "bindings/python/pyiree/common/status_utils.h"
#include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
#include "iree/compiler/Dialect/HAL/Target/ExecutableTarget.h"
#include "iree/compiler/Dialect/HAL/Transforms/Passes.h"
diff --git a/bindings/python/pyiree/compiler.h b/bindings/python/pyiree/compiler/compiler.h
similarity index 67%
rename from bindings/python/pyiree/compiler.h
rename to bindings/python/pyiree/compiler/compiler.h
index 781cdb4..c08f94d 100644
--- a/bindings/python/pyiree/compiler.h
+++ b/bindings/python/pyiree/compiler/compiler.h
@@ -18,7 +18,7 @@
#include <mutex> // NOLINT
#include <string>
-#include "bindings/python/pyiree/binding.h"
+#include "bindings/python/pyiree/common/binding.h"
#include "iree/compiler/Dialect/VM/Target/Bytecode/BytecodeModuleTarget.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/MLIRContext.h"
@@ -27,6 +27,81 @@
namespace iree {
namespace python {
+// Wrapper around a blob of memory.
+// Used to transport blobs back and forth between C++ and Python.
+class OpaqueBlob {
+ public:
+ OpaqueBlob() : data_(nullptr), size_(0) {}
+ OpaqueBlob(void* data, size_t size) : data_(data), size_(size) {}
+ virtual ~OpaqueBlob() = default;
+
+ void* data() { return data_; }
+ const void* data() const { return data_; }
+ size_t size() const { return size_; }
+
+ // Create a free function from the OpaqueBlob shared pointer.
+ using BufferFreeFn = void (*)(void* self, iree_byte_span_t);
+ static std::pair<BufferFreeFn, void*> CreateFreeFn(
+ std::shared_ptr<OpaqueBlob> blob) {
+ // Note that there are more efficient ways to write this which
+ // don't bounce through an extra heap alloc, but this is not
+ // intended to be a high impact code path.
+ struct Holder {
+ std::shared_ptr<OpaqueBlob> blob;
+ };
+ Holder* holder = new Holder{std::move(blob)};
+ auto free_fn = +([](void* self, iree_byte_span_t) {
+ Holder* self_holder = static_cast<Holder*>(self);
+ delete self_holder;
+ });
+ return {free_fn, holder};
+ }
+
+ static iree_allocator_t CreateDeallocator(std::shared_ptr<OpaqueBlob> blob) {
+ // Note that there are more efficient ways to write this which
+ // don't bounce through an extra heap alloc, but this is not
+ // intended to be a high impact code path.
+ struct Holder {
+ std::shared_ptr<OpaqueBlob> blob;
+ };
+ Holder* holder = new Holder{std::move(blob)};
+ auto free_fn = +([](void* self, void*) -> iree_status_t {
+ Holder* self_holder = static_cast<Holder*>(self);
+ delete self_holder;
+ return IREE_STATUS_OK;
+ });
+ return {holder /* self */, nullptr /* alloc */, free_fn /* free */};
+ }
+
+ protected:
+ void* data_;
+ size_t size_;
+};
+
+// Opaque blob that owns a vector.
+class OpaqueByteVectorBlob : public OpaqueBlob {
+ public:
+ OpaqueByteVectorBlob(std::vector<uint8_t> v)
+ : OpaqueBlob(), v_(std::move(v)) {
+ data_ = v_.data();
+ size_ = v_.size();
+ }
+
+ private:
+ std::vector<uint8_t> v_;
+};
+
+class OpaqueStringBlob : public OpaqueBlob {
+ public:
+ OpaqueStringBlob(std::string s) : OpaqueBlob(), s_(std::move(s)) {
+ data_ = &s_[0];
+ size_ = s_.size();
+ }
+
+ private:
+ std::string s_;
+};
+
class CompilerContextBundle;
class CompilerModuleBundle;
diff --git a/bindings/python/pyiree/compiler_test.py b/bindings/python/pyiree/compiler/compiler_test.py
similarity index 81%
rename from bindings/python/pyiree/compiler_test.py
rename to bindings/python/pyiree/compiler/compiler_test.py
index 2bc02ea..2b68f0c 100644
--- a/bindings/python/pyiree/compiler_test.py
+++ b/bindings/python/pyiree/compiler/compiler_test.py
@@ -14,7 +14,7 @@
# limitations under the License.
from absl.testing import absltest
-import pyiree
+from pyiree import compiler
SIMPLE_MUL_ASM = """
func @simple_mul(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32>
@@ -28,12 +28,12 @@
class CompilerTest(absltest.TestCase):
def testParseError(self):
- ctx = pyiree.compiler.Context()
+ ctx = compiler.Context()
with self.assertRaisesRegex(ValueError, "custom op 'FOOBAR' is unknown"):
ctx.parse_asm("""FOOBAR: I SHOULD NOT PARSE""")
def testParseAndCompileToFlatbuffer(self):
- ctx = pyiree.compiler.Context()
+ ctx = compiler.Context()
input_module = ctx.parse_asm(SIMPLE_MUL_ASM)
binary = input_module.compile()
b = binary.bytes
@@ -41,19 +41,19 @@
self.assertTrue(binary.bytes)
def testParseAndCompileToFlatbufferText(self):
- ctx = pyiree.compiler.Context()
+ ctx = compiler.Context()
input_module = ctx.parse_asm(SIMPLE_MUL_ASM)
- options = pyiree.compiler.CompileOptions()
- options.output_format = pyiree.compiler.OutputFormat.FLATBUFFER_TEXT
+ options = compiler.CompileOptions()
+ options.output_format = compiler.OutputFormat.FLATBUFFER_TEXT
blob = input_module.compile(options=options)
text = blob.text
self.assertTrue(text)
def testParseAndCompileToMlirText(self):
- ctx = pyiree.compiler.Context()
+ ctx = compiler.Context()
input_module = ctx.parse_asm(SIMPLE_MUL_ASM)
- options = pyiree.compiler.CompileOptions()
- options.output_format = pyiree.compiler.OutputFormat.MLIR_TEXT
+ options = compiler.CompileOptions()
+ options.output_format = compiler.OutputFormat.MLIR_TEXT
blob = input_module.compile(options=options)
text = blob.text
self.assertTrue(text)
diff --git a/bindings/python/pyiree/compiler/conditional_tensorflow.py b/bindings/python/pyiree/compiler/conditional_tensorflow.py
new file mode 100644
index 0000000..5e6034c
--- /dev/null
+++ b/bindings/python/pyiree/compiler/conditional_tensorflow.py
@@ -0,0 +1,114 @@
+# Lint as: python3
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""High level compiler API.
+
+This imports parts of the native bindings as appropriate.
+"""
+
+__all__ = [
+ "TF_IMPORT_PASS_PIPELINE",
+ "tf_load_saved_model",
+ "tf_compile_saved_model",
+]
+
+from typing import Collection, Optional, Sequence
+
+from . import binding as _binding
+from .binding import CompilerContext as Context
+from .binding import CompilerModule as Module
+
+# Pass pipeline that should run to lower a TF saved_model to a form suitable
+# for input to the IREE compiler.
+TF_IMPORT_PASS_PIPELINE = (
+ # Clean up tf_executor and extraneous unused functions.
+ "tf-saved-model-delete-unused-funcs",
+ "tf-executor-graph-pruning",
+ "tf-standard-pipeline",
+ "canonicalize",
+
+ # Clean up control flow
+ "tf-functional-control-flow-to-cfg",
+ "inline",
+ "tf-saved-model-delete-unused-funcs",
+ "canonicalize",
+
+ # Legalize to XLA
+ "xla-legalize-tf{allow-partial-conversion=true}",
+ "canonicalize",
+
+ # Now that the IR is starting to look nice, optimize global tensors.
+ "tf-saved-model-optimize-global-tensors",
+
+ # Adopt saved_model exports into IREE.
+ "iree-tf-saved-model-adopt-exports",
+)
+
+
+def tf_load_saved_model(
+ saved_model_dir: str,
+ compiler_context: Optional[Context] = None,
+ exported_names: Collection[str] = (),
+ pass_pipeline: Sequence[str] = TF_IMPORT_PASS_PIPELINE) -> Module:
+ """Loads a TensorFlow saved model from its persistent representation.
+
+ See also tf_compile_saved_model() for a one-shot API to load and compile.
+
+ Args:
+ saved_model_dir: Directory of the saved model.
+ compiler_context: The pyiree.compiler.Context() backing the module.
+ exported_names: Optional tuple of strings representing the exported names to
+ keep.
+ pass_pipeline: Passes to run on the imported module prior to returning.
+ Defaults to TF_IMPORT_PASS_PIPELINE.
+
+ Returns:
+ An MLIR Module suitable for compilation by the IREE compiler.
+ This can be further compiled to an IREE blob by calling
+ .compile_to_sequencer_blob.
+ """
+ if not compiler_context:
+ compiler_context = Context()
+ input_module = _binding.tf_interop.load_saved_model(
+ compiler_context, saved_model_dir, exported_names=exported_names)
+ if pass_pipeline:
+ input_module.run_pass_pipeline(pass_pipeline)
+ return input_module
+
+
+def tf_compile_saved_model(
+ saved_model_dir: str,
+ compiler_context: Optional[Context] = None,
+ exported_names: Collection[str] = (),
+ pass_pipeline: Sequence[str] = TF_IMPORT_PASS_PIPELINE,
+ target_backends: Collection[str] = ()
+) -> _binding.OpaqueBlob:
+ """Loads and compiles a TensorFlow saved model in one shot.
+
+ Args:
+ saved_model_dir: Directory of the saved model.
+ compiler_context: The pyiree.compiler.Context() backing the module.
+ exported_names: Optional tuple of strings representing the exported names to
+ keep.
+ pass_pipeline: Passes to run on the imported module prior to returning.
+ Defaults to TF_IMPORT_PASS_PIPELINE.
+ target_backends: The specific target backends to compile for (defaults to
+ all compiled in targets).
+
+ Returns:
+ An OpaqueBlob representing the compiled module.
+ """
+ input_module = tf_load_saved_model(saved_model_dir, compiler_context,
+ exported_names, pass_pipeline)
+ return input_module.compile(target_backends=target_backends)
diff --git a/bindings/python/pyiree/compiler/initialize_module.cc b/bindings/python/pyiree/compiler/initialize_module.cc
new file mode 100644
index 0000000..274c3a8
--- /dev/null
+++ b/bindings/python/pyiree/compiler/initialize_module.cc
@@ -0,0 +1,59 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <mutex> // NOLINT
+
+#include "bindings/python/pyiree/common/binding.h"
+#include "bindings/python/pyiree/compiler/compiler.h"
+#include "bindings/python/pyiree/compiler/tf_interop/register_tensorflow.h"
+
+namespace iree {
+namespace python {
+
+PYBIND11_MODULE(binding, m) {
+ m.doc() = "IREE Compiler Interface";
+ py::class_<OpaqueBlob, std::shared_ptr<OpaqueBlob>>(m, "OpaqueBlob",
+ py::buffer_protocol())
+ .def_buffer([](OpaqueBlob* self) -> py::buffer_info {
+ return py::buffer_info(
+ self->data(), // Pointer to buffer
+ sizeof(uint8_t), // Size of one scalar
+ py::format_descriptor<uint8_t>::value, // Python struct-style
+ // format
+ 1, // Number of dimensions
+ {self->size()}, // Buffer dimensions
+ {self->size()} // Strides
+ );
+ })
+ .def_property_readonly("bytes",
+ [](OpaqueBlob* self) -> py::bytes {
+ return py::bytes(
+ static_cast<const char*>(self->data()),
+ self->size());
+ })
+ .def_property_readonly("text", [](OpaqueBlob* self) -> py::str {
+ return py::str(static_cast<const char*>(self->data()), self->size());
+ });
+
+ SetupCompilerBindings(m);
+
+// TensorFlow.
+#if defined(IREE_TENSORFLOW_ENABLED)
+ auto tf_m = m.def_submodule("tf_interop", "IREE TensorFlow interop");
+ SetupTensorFlowBindings(tf_m);
+#endif
+}
+
+} // namespace python
+} // namespace iree
diff --git a/bindings/python/pyiree/tf_interop/BUILD b/bindings/python/pyiree/compiler/tf_interop/BUILD
similarity index 76%
rename from bindings/python/pyiree/tf_interop/BUILD
rename to bindings/python/pyiree/compiler/tf_interop/BUILD
index c97e92d..aef5b7c 100644
--- a/bindings/python/pyiree/tf_interop/BUILD
+++ b/bindings/python/pyiree/compiler/tf_interop/BUILD
@@ -13,6 +13,10 @@
# limitations under the License.
load(
+ "//bindings/python/pyiree:build_defs.bzl",
+ "pybind_cc_library",
+)
+load(
"//iree:build_defs.bzl",
"INTREE_TENSORFLOW_PY_DEPS",
)
@@ -30,19 +34,17 @@
"-use_header_modules", # Incompatible with exceptions builds.
]
-cc_library(
+pybind_cc_library(
name = "tf_interop",
hdrs = [
"register_tensorflow.h",
],
- copts = DEFAULT_COPTS,
defines = select({
"//iree:enable_tensorflow": [
"IREE_TENSORFLOW_ENABLED",
],
"//conditions:default": [],
}),
- features = DEFAULT_FEATURES,
# TODO(b/145815906) Get this running in OSS CI.
tags = ["nokokoro"],
deps = select({
@@ -53,7 +55,7 @@
":tensorflow_disabled",
],
}) + [
- "//bindings/python/pyiree:base",
+ "//bindings/python/pyiree/common",
],
)
@@ -67,7 +69,7 @@
"@org_tensorflow//tensorflow/compiler/mlir/xla:xla_legalize_tf",
]
-cc_library(
+pybind_cc_library(
name = "tensorflow_impl",
srcs = [
"register_tensorflow.cc",
@@ -75,8 +77,6 @@
hdrs = [
"register_tensorflow.h",
],
- copts = DEFAULT_COPTS,
- features = DEFAULT_FEATURES,
# TODO(b/145815906) Get this running in OSS CI.
tags = ["nokokoro"],
visibility = ["//visibility:private"],
@@ -86,7 +86,8 @@
"@org_tensorflow//tensorflow/cc/saved_model:loader_lite",
"@org_tensorflow//tensorflow/compiler/mlir/tensorflow",
"@org_tensorflow//tensorflow/compiler/mlir/tensorflow:convert_graphdef",
- "//bindings/python/pyiree:base",
+ "//bindings/python/pyiree/common",
+ "//bindings/python/pyiree/compiler:compiler_library",
] + SAVED_MODEL_TF_RUNTIME_DEPS + TF_XLA_PASS_DEPS,
)
@@ -103,29 +104,7 @@
# TODO(b/145815906) Get this running in OSS CI.
tags = ["nokokoro"],
deps = [
- "//bindings/python/pyiree:base",
- ],
-)
-
-py_library(
- name = "tf_test_driver",
- srcs = ["tf_test_driver.py"],
- # TODO(b/145815906) Get this running in OSS CI.
- tags = ["nokokoro"],
- deps = INTREE_TENSORFLOW_PY_DEPS + [
- "//bindings/python/pyiree:binding",
- ],
-)
-
-py_library(
- name = "test_utils",
- srcs = ["test_utils.py"],
- # TODO(b/145815906) Get this running in OSS CI.
- tags = ["nokokoro"],
- deps = INTREE_TENSORFLOW_PY_DEPS + [
- "//bindings/python/pyiree:compiler",
- "//bindings/python/pyiree:binding",
- "//bindings/python/pyiree:system_api",
+ "//bindings/python/pyiree/common",
],
)
@@ -136,7 +115,7 @@
# TODO(b/145815906) Get this running in OSS CI.
tags = ["nokokoro"],
deps = [
- "//bindings/python/pyiree",
+ "//bindings/python/pyiree/compiler",
"//bindings/python:pathsetup", # build_cleaner: keep
] + INTREE_TENSORFLOW_PY_DEPS,
)
diff --git a/bindings/python/pyiree/tf_interop/register_tensorflow.cc b/bindings/python/pyiree/compiler/tf_interop/register_tensorflow.cc
similarity index 92%
rename from bindings/python/pyiree/tf_interop/register_tensorflow.cc
rename to bindings/python/pyiree/compiler/tf_interop/register_tensorflow.cc
index 2a1622f..9cc691c 100644
--- a/bindings/python/pyiree/tf_interop/register_tensorflow.cc
+++ b/bindings/python/pyiree/compiler/tf_interop/register_tensorflow.cc
@@ -12,13 +12,13 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "bindings/python/pyiree/tf_interop/register_tensorflow.h"
+#include "bindings/python/pyiree/compiler/tf_interop/register_tensorflow.h"
#include <string>
#include <vector>
-#include "bindings/python/pyiree/compiler.h"
-#include "bindings/python/pyiree/status_utils.h"
+#include "bindings/python/pyiree/common/status_utils.h"
+#include "bindings/python/pyiree/compiler/compiler.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Module.h"
diff --git a/bindings/python/pyiree/tf_interop/register_tensorflow.h b/bindings/python/pyiree/compiler/tf_interop/register_tensorflow.h
similarity index 94%
rename from bindings/python/pyiree/tf_interop/register_tensorflow.h
rename to bindings/python/pyiree/compiler/tf_interop/register_tensorflow.h
index 7a54404..6df0a8c 100644
--- a/bindings/python/pyiree/tf_interop/register_tensorflow.h
+++ b/bindings/python/pyiree/compiler/tf_interop/register_tensorflow.h
@@ -17,7 +17,7 @@
#include <string>
-#include "bindings/python/pyiree/binding.h"
+#include "bindings/python/pyiree/common/binding.h"
namespace iree {
namespace python {
diff --git a/bindings/python/pyiree/tf_interop/register_tensorflow_noop.cc b/bindings/python/pyiree/compiler/tf_interop/register_tensorflow_noop.cc
similarity index 90%
rename from bindings/python/pyiree/tf_interop/register_tensorflow_noop.cc
rename to bindings/python/pyiree/compiler/tf_interop/register_tensorflow_noop.cc
index e5f83aa..9422395 100644
--- a/bindings/python/pyiree/tf_interop/register_tensorflow_noop.cc
+++ b/bindings/python/pyiree/compiler/tf_interop/register_tensorflow_noop.cc
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "bindings/python/pyiree/tf_interop/register_tensorflow.h"
+#include "bindings/python/pyiree/compiler/tf_interop/register_tensorflow.h"
namespace iree {
namespace python {
diff --git a/bindings/python/pyiree/tf_interop/saved_model_test.py b/bindings/python/pyiree/compiler/tf_interop/saved_model_test.py
similarity index 94%
rename from bindings/python/pyiree/tf_interop/saved_model_test.py
rename to bindings/python/pyiree/compiler/tf_interop/saved_model_test.py
index 0f48c74..af317e5 100644
--- a/bindings/python/pyiree/tf_interop/saved_model_test.py
+++ b/bindings/python/pyiree/compiler/tf_interop/saved_model_test.py
@@ -21,10 +21,10 @@
import sys
import tempfile
-import pyiree
+from pyiree import compiler
# Determine if compiled with tf_interop support.
-if not hasattr(pyiree, "tf_interop"):
+if not hasattr(compiler, "tf_interop"):
print("Not running tests because tf_interop support not compiled")
sys.exit(0)
@@ -71,7 +71,7 @@
tf.saved_model.save(my_module, sm_dir, options=options)
# Load it up.
- input_module = pyiree.compiler.tf_load_saved_model(sm_dir)
+ input_module = compiler.tf_load_saved_model(sm_dir)
xla_asm = input_module.to_asm()
print("XLA ASM:", xla_asm)
self.assertRegex(xla_asm, "xla_hlo.tanh")
diff --git a/bindings/python/pyiree/rt/BUILD b/bindings/python/pyiree/rt/BUILD
new file mode 100644
index 0000000..56d2c1d
--- /dev/null
+++ b/bindings/python/pyiree/rt/BUILD
@@ -0,0 +1,291 @@
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+load(
+ "//bindings/python/pyiree:build_defs.bzl",
+ "NUMPY_DEPS",
+ "PLATFORM_VULKAN_DEPS",
+ "PYBIND_COPTS",
+ "PYBIND_EXTENSION_COPTS",
+ "PYBIND_FEATURES",
+ "iree_py_extension",
+ "pybind_cc_library",
+)
+
+package(
+ default_visibility = ["//visibility:public"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+DRIVER_DEPS = PLATFORM_VULKAN_DEPS + [
+ "//iree/hal/interpreter:interpreter_driver_module",
+ "//iree/hal/vulkan:vulkan_driver_module",
+]
+
+py_library(
+ name = "rt",
+ srcs = [
+ "__init__.py",
+ "system_api.py",
+ ],
+ srcs_version = "PY3",
+ deps = [
+ ":binding",
+ "//bindings/python:pathsetup", # build_cleaner: keep
+ ],
+)
+
+iree_py_extension(
+ name = "binding",
+ srcs = [
+ "initialize_module.cc",
+ ],
+ copts = PYBIND_COPTS + PYBIND_EXTENSION_COPTS,
+ features = PYBIND_FEATURES,
+ linkstatic = 1,
+ # TODO(b/145815906) Get this running in OSS CI.
+ tags = ["nokokoro"],
+ win_def_file = "export.def",
+ deps = DRIVER_DEPS + [
+ ":rt_library",
+ "//bindings/python/pyiree/common",
+ "//iree/base:initializer",
+ "//iree/base:tracing",
+ "@com_google_tracing_framework_cpp//:tracing_framework_bindings_cpp",
+ ],
+)
+
+pybind_cc_library(
+ name = "rt_library",
+ srcs = [
+ "function_abi.cc",
+ "hal.cc",
+ "host_types.cc",
+ "vm.cc",
+ ],
+ hdrs = [
+ "function_abi.h",
+ "hal.h",
+ "host_types.h",
+ "vm.h",
+ ],
+ deps = [
+ "//bindings/python/pyiree/common",
+ "//iree/base:api",
+ "//iree/base:signature_mangle",
+ "//iree/hal:api",
+ "//iree/modules/hal",
+ "//iree/vm",
+ "//iree/vm:bytecode_module",
+ "//iree/vm:invocation",
+ "//iree/vm:module",
+ "//iree/vm:ref",
+ "//iree/vm:variant_list",
+ "@com_google_absl//absl/container:inlined_vector",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:optional",
+ "@com_google_absl//absl/types:span",
+ ],
+)
+
+py_library(
+ name = "system_api",
+ srcs = ["system_api.py"],
+ srcs_version = "PY3",
+ # TODO(b/145815906) Get this running in OSS CI.
+ tags = ["nokokoro"],
+ deps = [
+ ":binding",
+ "//bindings/python:pathsetup", # build_cleaner: keep
+ ],
+)
+
+py_test(
+ name = "function_abi_test",
+ srcs = ["function_abi_test.py"],
+ python_version = "PY3",
+ # TODO(laurenzo): Enable once test does not depend on a real vulkan device.
+ tags = [
+ "noga",
+ "nokokoro",
+ "notap",
+ ],
+ deps = NUMPY_DEPS + [
+ "//bindings/python:pathsetup", # build_cleaner: keep
+ "@absl_py//absl/testing:absltest",
+ "//bindings/python/pyiree/rt",
+ ],
+ # TODO(b/145815906) Get this running in OSS CI.
+)
+
+py_test(
+ name = "hal_test",
+ srcs = ["hal_test.py"],
+ python_version = "PY3",
+ # TODO(b/145815906) Get this running in OSS CI.
+ tags = ["nokokoro"],
+ deps = NUMPY_DEPS + [
+ "//bindings/python:pathsetup", # build_cleaner: keep
+ "@absl_py//absl/testing:absltest",
+ "//bindings/python/pyiree/rt",
+ ],
+)
+
+py_test(
+ name = "system_api_test",
+ srcs = ["system_api_test.py"],
+ python_version = "PY3",
+ tags = [
+ # TODO(laurenzo): Enable once test does not depend on a real vulkan device.
+ "notap",
+ # TODO(b/145815906) Get this running in OSS CI.
+ "noga",
+ "nokokoro",
+ ],
+ deps = NUMPY_DEPS + [
+ ":system_api",
+ "//bindings/python:pathsetup", # build_cleaner: keep
+ "@absl_py//absl/testing:absltest",
+ "//bindings/python/pyiree/compiler",
+ "//bindings/python/pyiree/rt",
+ ],
+)
+
+py_test(
+ name = "vm_test",
+ srcs = ["vm_test.py"],
+ python_version = "PY3",
+ # TODO(laurenzo): Enable once test does not depend on a real vulkan device.
+ tags = [
+ "noga",
+ "nokokoro",
+ "notap",
+ ],
+ deps = NUMPY_DEPS + [
+ "//bindings/python:pathsetup", # build_cleaner: keep
+ "@absl_py//absl/testing:absltest",
+ "//bindings/python/pyiree/compiler",
+ "//bindings/python/pyiree/rt",
+ ],
+)
+
+# py_library(
+# name = "pyiree",
+# srcs = [
+# "__init__.py",
+# ],
+# srcs_version = "PY3",
+# # TODO(b/145815906) Get this running in OSS CI.
+# tags = ["nokokoro"],
+# deps = [
+# ":binding",
+# ":compiler",
+# ":system_api",
+# "//bindings/python:pathsetup", # build_cleaner: keep
+# ] + select({
+# "//iree:enable_tensorflow": [
+# "//bindings/python/pyiree/tf_interop:test_utils",
+# "//bindings/python/pyiree/tf_interop:tf_test_driver",
+# ],
+# "//conditions:default": [
+# ],
+# }),
+# )
+#
+#
+# cc_library(
+# name = "base",
+# srcs = [
+# "compiler.cc",
+# "function_abi.cc",
+# "hal.cc",
+# "host_types.cc",
+# "status_utils.cc",
+# "vm.cc",
+# ],
+# hdrs = [
+# "binding.h",
+# "compiler.h",
+# "function_abi.h",
+# "hal.h",
+# "host_types.h",
+# "status_utils.h",
+# "vm.h",
+# ],
+# copts = DEFAULT_COPTS,
+# features = DEFAULT_FEATURES,
+# # TODO(b/145815906) Get this running in OSS CI.
+# tags = ["nokokoro"],
+# deps = [
+# "@com_google_absl//absl/container:inlined_vector",
+# "@com_google_absl//absl/memory",
+# "@com_google_absl//absl/strings",
+# "@com_google_absl//absl/types:optional",
+# "@com_google_absl//absl/types:span",
+# "@com_google_absl//absl/types:variant",
+# "//iree/compiler/Utils",
+# "//iree/modules/hal",
+# "//iree/vm",
+# "//iree/vm:bytecode_module",
+# "//iree/vm:invocation",
+# "//iree/vm:module",
+# "//iree/vm:ref",
+# "//iree/vm:variant_list",
+# "//third_party/llvm/llvm/projects/google_mlir:IR",
+# "//iree/base:api",
+# "//iree/base:status",
+# "//iree/base:signature_mangle",
+# "//iree/hal:api",
+# "//iree/vm:api",
+# "@llvm-project//llvm:support",
+# "//third_party/llvm/llvm/projects/google_mlir:Parser",
+# "//third_party/llvm/llvm/projects/google_mlir:Pass",
+# "@iree_pybind11//:pybind11",
+# ] + COMPILER_DEPS + DRIVER_DEPS + PYTHON_HEADERS_DEPS,
+# )
+#
+# iree_py_extension(
+# name = "binding",
+# srcs = [
+# "initialize_module.cc",
+# ],
+# copts = DEFAULT_COPTS,
+# features = DEFAULT_FEATURES,
+# linkstatic = 1,
+# # TODO(b/145815906) Get this running in OSS CI.
+# tags = ["nokokoro"],
+# win_def_file = "export.def",
+# deps = [
+# ":base",
+# "//bindings/python/pyiree/tf_interop",
+# "//iree/base:initializer",
+# "//iree/base:tracing",
+# "@com_google_tracing_framework_cpp//:tracing_framework_bindings_cpp",
+# ],
+# )
+#
+# py_test(
+# name = "compiler_test",
+# srcs = ["compiler_test.py"],
+# python_version = "PY3",
+# # TODO(b/145815906) Get this running in OSS CI.
+# tags = ["nokokoro"],
+# deps = [
+# "//bindings/python:pathsetup", # build_cleaner: keep
+# "@absl_py//absl/testing:absltest",
+# "//bindings/python/pyiree",
+# ],
+# )
+#
diff --git a/bindings/python/pyiree/rt/__init__.py b/bindings/python/pyiree/rt/__init__.py
new file mode 100644
index 0000000..b3c5533
--- /dev/null
+++ b/bindings/python/pyiree/rt/__init__.py
@@ -0,0 +1,33 @@
+"""Module init for the python bindings."""
+
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# pylint: disable=g-multiple-import
+# pylint: disable=g-bad-import-order
+# pylint: disable=wildcard-import
+
+from . import binding
+
+# Pull some of the native symbols into the public API.
+# FunctionAbi imports
+from .binding import FunctionAbi
+# Hal imports
+from .binding import BufferUsage, HalBuffer, HalDevice, HalDriver, MemoryAccess, MemoryType, Shape
+# HostTypeFactory imports
+from .binding import HostTypeFactory
+# Vm imports
+from .binding import create_hal_module, Linkage, VmVariantList, VmFunction, VmInstance, VmContext, VmModule
+# SystemApi
+from .system_api import *
diff --git a/bindings/python/pyiree/export.def b/bindings/python/pyiree/rt/export.def
similarity index 100%
rename from bindings/python/pyiree/export.def
rename to bindings/python/pyiree/rt/export.def
diff --git a/bindings/python/pyiree/function_abi.cc b/bindings/python/pyiree/rt/function_abi.cc
similarity index 98%
rename from bindings/python/pyiree/function_abi.cc
rename to bindings/python/pyiree/rt/function_abi.cc
index 1fad6e6..1a3efd0 100644
--- a/bindings/python/pyiree/function_abi.cc
+++ b/bindings/python/pyiree/rt/function_abi.cc
@@ -12,14 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "bindings/python/pyiree/function_abi.h"
+#include "bindings/python/pyiree/rt/function_abi.h"
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
#include "absl/types/span.h"
-#include "bindings/python/pyiree/hal.h"
-#include "bindings/python/pyiree/status_utils.h"
-#include "bindings/python/pyiree/vm.h"
+#include "bindings/python/pyiree/common/status_utils.h"
+#include "bindings/python/pyiree/rt/hal.h"
+#include "bindings/python/pyiree/rt/vm.h"
#include "iree/base/api.h"
#include "iree/base/signature_mangle.h"
#include "iree/hal/api.h"
@@ -398,8 +398,8 @@
}
void SetupFunctionAbiBindings(pybind11::module m) {
- m.def("create", &PyCreateAbi);
py::class_<FunctionAbi, std::unique_ptr<FunctionAbi>>(m, "FunctionAbi")
+ .def(py::init(&PyCreateAbi))
.def("__repr__", &FunctionAbi::DebugString)
.def_property_readonly("raw_input_arity", &FunctionAbi::raw_input_arity)
.def_property_readonly("raw_result_arity", &FunctionAbi::raw_result_arity)
diff --git a/bindings/python/pyiree/function_abi.h b/bindings/python/pyiree/rt/function_abi.h
similarity index 92%
rename from bindings/python/pyiree/function_abi.h
rename to bindings/python/pyiree/rt/function_abi.h
index 7e20183..62b3d0a 100644
--- a/bindings/python/pyiree/function_abi.h
+++ b/bindings/python/pyiree/rt/function_abi.h
@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#ifndef IREE_BINDINGS_PYTHON_PYIREE_FUNCTION_ABI_H_
-#define IREE_BINDINGS_PYTHON_PYIREE_FUNCTION_ABI_H_
+#ifndef IREE_BINDINGS_PYTHON_PYIREE_RT_FUNCTION_ABI_H_
+#define IREE_BINDINGS_PYTHON_PYIREE_RT_FUNCTION_ABI_H_
#include <utility>
#include <vector>
@@ -22,11 +22,10 @@
#include "absl/strings/string_view.h"
#include "absl/types/optional.h"
#include "absl/types/span.h"
-#include "absl/types/variant.h"
-#include "bindings/python/pyiree/binding.h"
-#include "bindings/python/pyiree/hal.h"
-#include "bindings/python/pyiree/host_types.h"
-#include "bindings/python/pyiree/vm.h"
+#include "bindings/python/pyiree/common/binding.h"
+#include "bindings/python/pyiree/rt/hal.h"
+#include "bindings/python/pyiree/rt/host_types.h"
+#include "bindings/python/pyiree/rt/vm.h"
#include "iree/base/signature_mangle.h"
namespace iree {
@@ -115,4 +114,4 @@
} // namespace python
} // namespace iree
-#endif // IREE_BINDINGS_PYTHON_PYIREE_FUNCTION_ABI_H_
+#endif // IREE_BINDINGS_PYTHON_PYIREE_RT_FUNCTION_ABI_H_
diff --git a/bindings/python/pyiree/function_abi_test.py b/bindings/python/pyiree/rt/function_abi_test.py
similarity index 76%
rename from bindings/python/pyiree/function_abi_test.py
rename to bindings/python/pyiree/rt/function_abi_test.py
index 9221a62..204c281 100644
--- a/bindings/python/pyiree/function_abi_test.py
+++ b/bindings/python/pyiree/rt/function_abi_test.py
@@ -19,7 +19,7 @@
from absl.testing import absltest
import numpy as np
-import pyiree
+from pyiree import rt
ATTRS_1ARG_FLOAT32_10X128X64_TO_SINT32_32X8X64_V1 = (
("fv", "1"),
@@ -39,7 +39,7 @@
class HostTypeFactory(absltest.TestCase):
def test_baseclass(self):
- htf = pyiree.binding.host_types.HostTypeFactory()
+ htf = rt.HostTypeFactory()
print(htf)
@@ -48,19 +48,18 @@
@classmethod
def setUpClass(cls):
super().setUpClass()
- driver_names = pyiree.binding.hal.HalDriver.query()
+ driver_names = rt.HalDriver.query()
print("DRIVER_NAMES =", driver_names)
- cls.driver = pyiree.binding.hal.HalDriver.create("vulkan")
+ cls.driver = rt.HalDriver.create("vulkan")
cls.device = cls.driver.create_default_device()
def setUp(self):
super().setUp()
- self.htf = pyiree.binding.host_types.HostTypeFactory.get_numpy()
+ self.htf = rt.HostTypeFactory.get_numpy()
def test_static_arg_success(self):
- fabi = pyiree.binding.function_abi.create(
- self.device, self.htf,
- ATTRS_1ARG_FLOAT32_10X128X64_TO_SINT32_32X8X64_V1)
+ fabi = rt.FunctionAbi(self.device, self.htf,
+ ATTRS_1ARG_FLOAT32_10X128X64_TO_SINT32_32X8X64_V1)
print(fabi)
self.assertEqual(
"<FunctionAbi (Buffer<float32[10x128x64]>) -> "
@@ -74,9 +73,8 @@
self.assertEqual("<VmVariantList(1): [HalBuffer(327680)]>", repr(packed))
def test_static_result_success(self):
- fabi = pyiree.binding.function_abi.create(
- self.device, self.htf,
- ATTRS_1ARG_FLOAT32_10X128X64_TO_SINT32_32X8X64_V1)
+ fabi = rt.FunctionAbi(self.device, self.htf,
+ ATTRS_1ARG_FLOAT32_10X128X64_TO_SINT32_32X8X64_V1)
arg = np.zeros((10, 128, 64), dtype=np.float32)
f_args = fabi.raw_pack_inputs([arg])
f_results = fabi.allocate_results(f_args)
@@ -87,9 +85,8 @@
self.assertEqual((32, 8, 64), py_result.shape)
def test_dynamic_alloc_result_success(self):
- fabi = pyiree.binding.function_abi.create(
- self.device, self.htf,
- ATTRS_1ARG_FLOAT32_10X128X64_TO_SINT32_32X8X64_V1)
+ fabi = rt.FunctionAbi(self.device, self.htf,
+ ATTRS_1ARG_FLOAT32_10X128X64_TO_SINT32_32X8X64_V1)
arg = np.zeros((10, 128, 64), dtype=np.float32)
f_args = fabi.raw_pack_inputs([arg])
f_results = fabi.allocate_results(f_args, static_alloc=False)
@@ -97,9 +94,8 @@
self.assertEqual("<VmVariantList(0): []>", repr(f_results))
def test_dynamic_arg_success(self):
- fabi = pyiree.binding.function_abi.create(
- self.device, self.htf,
- ATTRS_1ARG_FLOAT32_DYNX128X64_TO_SINT32_DYNX8X64_V1)
+ fabi = rt.FunctionAbi(self.device, self.htf,
+ ATTRS_1ARG_FLOAT32_DYNX128X64_TO_SINT32_DYNX8X64_V1)
print(fabi)
self.assertEqual(
"<FunctionAbi (Buffer<float32[?x128x64]>) -> "
@@ -118,9 +114,8 @@
# repr(packed))
def test_static_arg_rank_mismatch(self):
- fabi = pyiree.binding.function_abi.create(
- self.device, self.htf,
- ATTRS_1ARG_FLOAT32_10X128X64_TO_SINT32_32X8X64_V1)
+ fabi = rt.FunctionAbi(self.device, self.htf,
+ ATTRS_1ARG_FLOAT32_10X128X64_TO_SINT32_32X8X64_V1)
print(fabi)
arg = np.zeros((10,), dtype=np.float32)
with self.assertRaisesRegex(
@@ -129,9 +124,8 @@
fabi.raw_pack_inputs([arg])
def test_static_arg_eltsize_mismatch(self):
- fabi = pyiree.binding.function_abi.create(
- self.device, self.htf,
- ATTRS_1ARG_FLOAT32_10X128X64_TO_SINT32_32X8X64_V1)
+ fabi = rt.FunctionAbi(self.device, self.htf,
+ ATTRS_1ARG_FLOAT32_10X128X64_TO_SINT32_32X8X64_V1)
print(fabi)
arg = np.zeros((10, 128, 64), dtype=np.float64)
with self.assertRaisesRegex(
@@ -140,9 +134,8 @@
fabi.raw_pack_inputs([arg])
def test_static_arg_dtype_mismatch(self):
- fabi = pyiree.binding.function_abi.create(
- self.device, self.htf,
- ATTRS_1ARG_FLOAT32_10X128X64_TO_SINT32_32X8X64_V1)
+ fabi = rt.FunctionAbi(self.device, self.htf,
+ ATTRS_1ARG_FLOAT32_10X128X64_TO_SINT32_32X8X64_V1)
print(fabi)
arg = np.zeros((10, 128, 64), dtype=np.int32)
with self.assertRaisesRegex(
@@ -151,9 +144,8 @@
fabi.raw_pack_inputs([arg])
def test_static_arg_static_dim_mismatch(self):
- fabi = pyiree.binding.function_abi.create(
- self.device, self.htf,
- ATTRS_1ARG_FLOAT32_10X128X64_TO_SINT32_32X8X64_V1)
+ fabi = rt.FunctionAbi(self.device, self.htf,
+ ATTRS_1ARG_FLOAT32_10X128X64_TO_SINT32_32X8X64_V1)
print(fabi)
arg = np.zeros((10, 32, 64), dtype=np.float32)
with self.assertRaisesRegex(
diff --git a/bindings/python/pyiree/hal.cc b/bindings/python/pyiree/rt/hal.cc
similarity index 98%
rename from bindings/python/pyiree/hal.cc
rename to bindings/python/pyiree/rt/hal.cc
index a78f46b..49d5f58 100644
--- a/bindings/python/pyiree/hal.cc
+++ b/bindings/python/pyiree/rt/hal.cc
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "bindings/python/pyiree/hal.h"
+#include "bindings/python/pyiree/rt/hal.h"
#include "absl/container/inlined_vector.h"
#include "iree/hal/api.h"
@@ -163,7 +163,7 @@
.def("map", HalMappedMemory::Create);
py::class_<HalMappedMemory>(m, "MappedMemory", py::buffer_protocol())
.def_buffer(&HalMappedMemory::ToBufferInfo);
- py::class_<HalBuffer>(m, "Buffer")
+ py::class_<HalBuffer>(m, "HalBuffer")
.def_static("allocate_heap", &HalBuffer::AllocateHeapBuffer,
py::arg("memory_type"), py::arg("usage"),
py::arg("allocation_size"))
diff --git a/bindings/python/pyiree/hal.h b/bindings/python/pyiree/rt/hal.h
similarity index 94%
rename from bindings/python/pyiree/hal.h
rename to bindings/python/pyiree/rt/hal.h
index 347f061..918193a 100644
--- a/bindings/python/pyiree/hal.h
+++ b/bindings/python/pyiree/rt/hal.h
@@ -12,11 +12,11 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#ifndef IREE_BINDINGS_PYTHON_PYIREE_HAL_H_
-#define IREE_BINDINGS_PYTHON_PYIREE_HAL_H_
+#ifndef IREE_BINDINGS_PYTHON_PYIREE_RT_HAL_H_
+#define IREE_BINDINGS_PYTHON_PYIREE_RT_HAL_H_
-#include "bindings/python/pyiree/binding.h"
-#include "bindings/python/pyiree/status_utils.h"
+#include "bindings/python/pyiree/common/binding.h"
+#include "bindings/python/pyiree/common/status_utils.h"
#include "iree/hal/api.h"
namespace iree {
@@ -133,4 +133,4 @@
} // namespace python
} // namespace iree
-#endif // IREE_BINDINGS_PYTHON_PYIREE_HAL_H_
+#endif // IREE_BINDINGS_PYTHON_PYIREE_RT_HAL_H_
diff --git a/bindings/python/pyiree/hal_test.py b/bindings/python/pyiree/rt/hal_test.py
similarity index 68%
rename from bindings/python/pyiree/hal_test.py
rename to bindings/python/pyiree/rt/hal_test.py
index 4188e0a..b7ab59b 100644
--- a/bindings/python/pyiree/hal_test.py
+++ b/bindings/python/pyiree/rt/hal_test.py
@@ -19,33 +19,33 @@
from absl.testing import absltest
import numpy as np
-import pyiree
+from pyiree import rt
class HalTest(absltest.TestCase):
def testEnums(self):
- print("MemoryType =", pyiree.binding.hal.MemoryType)
- print("HOST_VISIBLE =", int(pyiree.binding.hal.MemoryType.HOST_VISIBLE))
+ print("MemoryType =", rt.MemoryType)
+ print("HOST_VISIBLE =", int(rt.MemoryType.HOST_VISIBLE))
def testAllocateHeap(self):
- b = pyiree.binding.hal.Buffer.allocate_heap(
- memory_type=int(pyiree.binding.hal.MemoryType.HOST_LOCAL),
- usage=int(pyiree.binding.hal.BufferUsage.ALL),
+ b = rt.HalBuffer.allocate_heap(
+ memory_type=int(rt.MemoryType.HOST_LOCAL),
+ usage=int(rt.BufferUsage.ALL),
allocation_size=4096)
self.assertIsNot(b, None)
b.fill_zero(0, 4096)
- shape = pyiree.binding.hal.Shape([1, 1024])
+ shape = rt.Shape([1, 1024])
unused_bv = b.create_view(shape, 4)
def testStrideCalculation(self):
- b = pyiree.binding.hal.Buffer.allocate_heap(
- memory_type=int(pyiree.binding.hal.MemoryType.HOST_LOCAL),
- usage=int(pyiree.binding.hal.BufferUsage.ALL),
+ b = rt.HalBuffer.allocate_heap(
+ memory_type=int(rt.MemoryType.HOST_LOCAL),
+ usage=int(rt.BufferUsage.ALL),
allocation_size=4096)
self.assertIsNot(b, None)
b.fill_zero(0, 4096)
- shape = pyiree.binding.hal.Shape([16, 1, 8, 4, 2])
+ shape = rt.Shape([16, 1, 8, 4, 2])
bv = b.create_view(shape, 4)
self.assertEqual(
np.array(bv.map()).strides,
diff --git a/bindings/python/pyiree/host_types.cc b/bindings/python/pyiree/rt/host_types.cc
similarity index 97%
rename from bindings/python/pyiree/host_types.cc
rename to bindings/python/pyiree/rt/host_types.cc
index 4666381..8be929c 100644
--- a/bindings/python/pyiree/host_types.cc
+++ b/bindings/python/pyiree/rt/host_types.cc
@@ -12,12 +12,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "bindings/python/pyiree/host_types.h"
+#include "bindings/python/pyiree/rt/host_types.h"
#include "absl/container/inlined_vector.h"
#include "absl/memory/memory.h"
-#include "bindings/python/pyiree/hal.h"
-#include "bindings/python/pyiree/status_utils.h"
+#include "bindings/python/pyiree/common/status_utils.h"
+#include "bindings/python/pyiree/rt/hal.h"
#include "iree/base/signature_mangle.h"
#include "pybind11/numpy.h"
diff --git a/bindings/python/pyiree/host_types.h b/bindings/python/pyiree/rt/host_types.h
similarity index 86%
rename from bindings/python/pyiree/host_types.h
rename to bindings/python/pyiree/rt/host_types.h
index 7dd2964..00a9c42 100644
--- a/bindings/python/pyiree/host_types.h
+++ b/bindings/python/pyiree/rt/host_types.h
@@ -12,14 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#ifndef IREE_BINDINGS_PYTHON_PYIREE_HOST_TYPES_H_
-#define IREE_BINDINGS_PYTHON_PYIREE_HOST_TYPES_H_
+#ifndef IREE_BINDINGS_PYTHON_PYIREE_RT_HOST_TYPES_H_
+#define IREE_BINDINGS_PYTHON_PYIREE_RT_HOST_TYPES_H_
#include <array>
#include "absl/types/span.h"
-#include "bindings/python/pyiree/binding.h"
-#include "bindings/python/pyiree/hal.h"
+#include "bindings/python/pyiree/common/binding.h"
+#include "bindings/python/pyiree/rt/hal.h"
#include "iree/base/signature_mangle.h"
namespace iree {
@@ -53,4 +53,4 @@
} // namespace python
} // namespace iree
-#endif // IREE_BINDINGS_PYTHON_PYIREE_HOST_TYPES_H_
+#endif // IREE_BINDINGS_PYTHON_PYIREE_RT_HOST_TYPES_H_
diff --git a/bindings/python/pyiree/initialize_module.cc b/bindings/python/pyiree/rt/initialize_module.cc
similarity index 66%
rename from bindings/python/pyiree/initialize_module.cc
rename to bindings/python/pyiree/rt/initialize_module.cc
index 40dd884..a5dea72 100644
--- a/bindings/python/pyiree/initialize_module.cc
+++ b/bindings/python/pyiree/rt/initialize_module.cc
@@ -14,14 +14,12 @@
#include <mutex> // NOLINT
-#include "bindings/python/pyiree/binding.h"
-#include "bindings/python/pyiree/compiler.h"
-#include "bindings/python/pyiree/function_abi.h"
-#include "bindings/python/pyiree/hal.h"
-#include "bindings/python/pyiree/host_types.h"
-#include "bindings/python/pyiree/status_utils.h"
-#include "bindings/python/pyiree/tf_interop/register_tensorflow.h"
-#include "bindings/python/pyiree/vm.h"
+#include "bindings/python/pyiree/common/binding.h"
+#include "bindings/python/pyiree/common/status_utils.h"
+#include "bindings/python/pyiree/rt/function_abi.h"
+#include "bindings/python/pyiree/rt/hal.h"
+#include "bindings/python/pyiree/rt/host_types.h"
+#include "bindings/python/pyiree/rt/vm.h"
#include "iree/base/initializer.h"
#include "iree/base/tracing.h"
#include "wtf/event.h"
@@ -107,40 +105,13 @@
IREE_RUN_MODULE_INITIALIZERS();
m.doc() = "IREE Binding Backend Helpers";
- py::class_<OpaqueBlob, std::shared_ptr<OpaqueBlob>>(m, "OpaqueBlob")
- .def_property_readonly("bytes",
- [](OpaqueBlob* self) -> py::bytes {
- return py::bytes(
- static_cast<const char*>(self->data()),
- self->size());
- })
- .def_property_readonly("text", [](OpaqueBlob* self) -> py::str {
- return py::str(static_cast<const char*>(self->data()), self->size());
- });
- auto compiler_m = m.def_submodule("compiler", "IREE compiler support");
- SetupCompilerBindings(compiler_m);
-
- auto function_abi = m.def_submodule("function_abi", "Function ABI support");
- SetupFunctionAbiBindings(function_abi);
-
- auto host_types =
- m.def_submodule("host_types", "Utilities for manipulating host types");
- SetupHostTypesBindings(host_types);
-
- auto hal_m = m.def_submodule("hal", "IREE HAL support");
- SetupHalBindings(hal_m);
-
- auto vm_m = m.def_submodule("vm", "IREE VM api");
- SetupVmBindings(vm_m);
+ SetupFunctionAbiBindings(m);
+ SetupHostTypesBindings(m);
+ SetupHalBindings(m);
+ SetupVmBindings(m);
auto tracing_m = m.def_submodule("tracing", "IREE tracing api");
SetupTracingBindings(tracing_m);
-
-// TensorFlow.
-#if defined(IREE_TENSORFLOW_ENABLED)
- auto tf_m = m.def_submodule("tf_interop", "IREE TensorFlow interop");
- SetupTensorFlowBindings(tf_m);
-#endif
}
} // namespace python
diff --git a/bindings/python/pyiree/system_api.py b/bindings/python/pyiree/rt/system_api.py
similarity index 89%
rename from bindings/python/pyiree/system_api.py
rename to bindings/python/pyiree/rt/system_api.py
index 7d3edd0..e4089a9 100644
--- a/bindings/python/pyiree/system_api.py
+++ b/bindings/python/pyiree/rt/system_api.py
@@ -33,7 +33,7 @@
from . import binding as _binding
# Typing aliases (largely used for documentation).
-AnyModule = _binding.vm.VmModule
+AnyModule = _binding.VmModule
# Environment key for a comma-delimitted list of drivers to try to load.
PREFERRED_DRIVER_ENV_KEY = "IREE_DEFAULT_DRIVER"
@@ -43,7 +43,7 @@
def _create_default_iree_driver(
- driver_names: Optional[Sequence[str]] = None) -> _binding.hal.HalDriver:
+ driver_names: Optional[Sequence[str]] = None) -> _binding.HalDriver:
"""Returns a default driver based on environment settings."""
# TODO(laurenzo): Ideally this should take a module and join any explicitly
# provided driver list with environmental constraints and what the module
@@ -54,7 +54,7 @@
if driver_names is None:
driver_names = DEFAULT_IREE_DRIVER_VALUE
driver_names = driver_names.split(",")
- available_driver_names = _binding.hal.HalDriver.query()
+ available_driver_names = _binding.HalDriver.query()
driver_exceptions = {}
for driver_name in driver_names:
if driver_name not in available_driver_names:
@@ -63,7 +63,7 @@
file=sys.stderr)
continue
try:
- driver = _binding.hal.HalDriver.create(driver_name)
+ driver = _binding.HalDriver.create(driver_name)
# TODO(laurenzo): Remove these prints to stderr (for now, more information
# is better and there is no better way to report it yet).
except Exception as ex: # pylint: disable=broad-except
@@ -71,10 +71,8 @@
"Could not create default driver %s: %r" % (driver_name, ex),
file=sys.stderr)
driver_exceptions[driver_name] = ex
- else:
- print(
- "Created IREE driver %s: %r" % (driver_name, driver), file=sys.stderr)
- return driver
+ print("Created IREE driver %s: %r" % (driver_name, driver), file=sys.stderr)
+ return driver
# All failed.
raise RuntimeError("Could not create any requested driver "
@@ -85,19 +83,19 @@
class Config:
"""System configuration."""
- driver: _binding.hal.HalDriver
- device: _binding.hal.HalDevice
- vm_instance: _binding.vm.VmInstance
- host_type_factory: _binding.host_types.HostTypeFactory
+ driver: _binding.HalDriver
+ device: _binding.HalDevice
+ vm_instance: _binding.VmInstance
+ host_type_factory: _binding.HostTypeFactory
default_modules: Tuple[AnyModule]
def __init__(self, driver_name: Optional[str] = None):
- self.vm_instance = _binding.vm.VmInstance()
+ self.vm_instance = _binding.VmInstance()
self.driver = _create_default_iree_driver(
driver_name.split(",") if driver_name is not None else None)
self.device = self.driver.create_default_device()
- hal_module = _binding.vm.create_hal_module(self.device)
- self.host_type_factory = _binding.host_types.HostTypeFactory.get_numpy()
+ hal_module = _binding.create_hal_module(self.device)
+ self.host_type_factory = _binding.HostTypeFactory.get_numpy()
self.default_modules = (hal_module,)
@@ -115,7 +113,7 @@
"""Wraps a VmFunction, VmContext and ABI into a pythonic function."""
def __init__(self, context: "SystemContext",
- vm_function: _binding.vm.VmFunction):
+ vm_function: _binding.VmFunction):
self._context = context
self._vm_function = vm_function
self._abi = context.create_function_abi(vm_function)
@@ -204,7 +202,7 @@
else:
init_modules = None
- self._vm_context = _binding.vm.VmContext(
+ self._vm_context = _binding.VmContext(
instance=self._config.vm_instance, modules=init_modules)
if self._is_dynamic:
@@ -226,15 +224,14 @@
return self._config
@property
- def instance(self) -> _binding.vm.VmInstance:
+ def instance(self) -> _binding.VmInstance:
return self._instance
@property
def modules(self) -> Modules:
return self._modules
- def create_function_abi(
- self, f: _binding.vm.VmFunction) -> _binding.function_abi.FunctionAbi:
+ def create_function_abi(self, f: _binding.VmFunction) -> _binding.FunctionAbi:
return self._vm_context.create_function_abi(self._config.device,
self._config.host_type_factory,
f)
diff --git a/bindings/python/pyiree/system_api_test.py b/bindings/python/pyiree/rt/system_api_test.py
similarity index 87%
rename from bindings/python/pyiree/system_api_test.py
rename to bindings/python/pyiree/rt/system_api_test.py
index ddaffc4..adb70bb 100644
--- a/bindings/python/pyiree/system_api_test.py
+++ b/bindings/python/pyiree/rt/system_api_test.py
@@ -19,11 +19,12 @@
from absl.testing import absltest
import numpy as np
-import pyiree
+from pyiree import compiler
+from pyiree import rt
def create_simple_mul_module():
- ctx = pyiree.CompilerContext()
+ ctx = compiler.Context()
input_module = ctx.parse_asm("""
module @arithmetic {
func @simple_mul(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32>
@@ -34,7 +35,7 @@
}
""")
binary = input_module.compile()
- m = pyiree.binding.vm.VmModule.from_flatbuffer(binary)
+ m = rt.VmModule.from_flatbuffer(binary)
return m
@@ -43,25 +44,25 @@
def test_non_existing_driver(self):
with self.assertRaisesRegex(RuntimeError,
"Could not create any requested driver"):
- config = pyiree.Config("nothere1,nothere2")
+ config = rt.Config("nothere1,nothere2")
def test_subsequent_driver(self):
- config = pyiree.Config("nothere1,interpreter")
+ config = rt.Config("nothere1,interpreter")
def test_empty_dynamic(self):
- ctx = pyiree.SystemContext()
+ ctx = rt.SystemContext()
self.assertTrue(ctx.is_dynamic)
self.assertIn("hal", ctx.modules)
self.assertEqual(ctx.modules.hal.name, "hal")
def test_empty_static(self):
- ctx = pyiree.SystemContext(modules=())
+ ctx = rt.SystemContext(modules=())
self.assertFalse(ctx.is_dynamic)
self.assertIn("hal", ctx.modules)
self.assertEqual(ctx.modules.hal.name, "hal")
def test_custom_dynamic(self):
- ctx = pyiree.SystemContext()
+ ctx = rt.SystemContext()
self.assertTrue(ctx.is_dynamic)
ctx.add_module(create_simple_mul_module())
self.assertEqual(ctx.modules.arithmetic.name, "arithmetic")
@@ -74,14 +75,14 @@
"(Buffer<float32[4]>, Buffer<float32[4]>) -> (Buffer<float32[4]>)"))
def test_duplicate_module(self):
- ctx = pyiree.SystemContext()
+ ctx = rt.SystemContext()
self.assertTrue(ctx.is_dynamic)
ctx.add_module(create_simple_mul_module())
with self.assertRaisesRegex(ValueError, "arithmetic"):
ctx.add_module(create_simple_mul_module())
def test_static_invoke(self):
- ctx = pyiree.SystemContext()
+ ctx = rt.SystemContext()
self.assertTrue(ctx.is_dynamic)
ctx.add_module(create_simple_mul_module())
self.assertEqual(ctx.modules.arithmetic.name, "arithmetic")
@@ -92,7 +93,7 @@
np.testing.assert_allclose(results, [4., 10., 18., 28.])
def test_load_module(self):
- arithmetic = pyiree.load_module(create_simple_mul_module())
+ arithmetic = rt.load_module(create_simple_mul_module())
arg0 = np.array([1., 2., 3., 4.], dtype=np.float32)
arg1 = np.array([4., 5., 6., 7.], dtype=np.float32)
results = arithmetic.simple_mul(arg0, arg1)
diff --git a/bindings/python/pyiree/vm.cc b/bindings/python/pyiree/rt/vm.cc
similarity index 88%
rename from bindings/python/pyiree/vm.cc
rename to bindings/python/pyiree/rt/vm.cc
index a969159..ffdb76a 100644
--- a/bindings/python/pyiree/vm.cc
+++ b/bindings/python/pyiree/rt/vm.cc
@@ -12,12 +12,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "bindings/python/pyiree/vm.h"
+#include "bindings/python/pyiree/rt/vm.h"
#include "absl/strings/str_cat.h"
#include "absl/types/optional.h"
-#include "bindings/python/pyiree/function_abi.h"
-#include "bindings/python/pyiree/status_utils.h"
+#include "bindings/python/pyiree/common/status_utils.h"
+#include "bindings/python/pyiree/rt/function_abi.h"
#include "iree/base/api.h"
#include "iree/modules/hal/hal_module.h"
#include "iree/vm/invocation.h"
@@ -99,7 +99,7 @@
attrs.push_back({});
auto status = iree_vm_get_function_reflection_attr(
f, i, &attrs.back().first, &attrs.back().second);
- if (iree_status_is_not_found(status)) {
+ if (status == IREE_STATUS_NOT_FOUND) {
attrs.pop_back();
break;
}
@@ -129,16 +129,27 @@
// VmModule
//------------------------------------------------------------------------------
-VmModule VmModule::FromFlatbufferBlob(
- std::shared_ptr<OpaqueBlob> flatbuffer_blob) {
+VmModule VmModule::FromFlatbufferBlob(py::buffer flatbuffer_blob) {
+ auto buffer_info = flatbuffer_blob.request();
iree_vm_module_t* module;
- auto deallocator = OpaqueBlob::CreateDeallocator(flatbuffer_blob);
+
+ // Bridge to the C-based deallocator API.
+ auto* raw_ptr = flatbuffer_blob.ptr();
+ auto free_fn = +([](void* self, void*) -> iree_status_t {
+ PyObject* self_ptr = static_cast<PyObject*>(self);
+ Py_XDECREF(self_ptr);
+ return IREE_STATUS_OK;
+ });
+ flatbuffer_blob.inc_ref();
+ iree_allocator_t deallocator{raw_ptr /* self */, nullptr /* alloc */,
+ free_fn /* dealloc */};
+
auto status = iree_vm_bytecode_module_create(
- {static_cast<const uint8_t*>(flatbuffer_blob->data()),
- flatbuffer_blob->size()},
+ {static_cast<const uint8_t*>(buffer_info.ptr),
+ static_cast<iree_host_size_t>(buffer_info.size)},
deallocator, IREE_ALLOCATOR_SYSTEM, &module);
- if (!iree_status_is_ok(status)) {
- deallocator.free(deallocator.self, nullptr);
+ if (status != IREE_STATUS_OK) {
+ deallocator.free(raw_ptr, nullptr);
}
CheckApiStatus(status, "Error creating vm module from flatbuffer");
@@ -150,7 +161,7 @@
iree_vm_function_t f;
auto status = iree_vm_module_lookup_function_by_name(
raw_ptr(), linkage, {name.data(), name.size()}, &f);
- if (iree_status_is_not_found(status)) {
+ if (status == IREE_STATUS_NOT_FOUND) {
return absl::nullopt;
}
CheckApiStatus(status, "Error looking up function");
@@ -194,8 +205,8 @@
}
void SetupVmBindings(pybind11::module m) {
- IREE_CHECK_OK(iree_vm_register_builtin_types());
- IREE_CHECK_OK(iree_hal_module_register_types());
+ CHECK_EQ(IREE_STATUS_OK, iree_vm_register_builtin_types());
+ CHECK_EQ(IREE_STATUS_OK, iree_hal_module_register_types());
// Built-in module creation.
m.def("create_hal_module", &CreateHalModule);
diff --git a/bindings/python/pyiree/vm.h b/bindings/python/pyiree/rt/vm.h
similarity index 93%
rename from bindings/python/pyiree/vm.h
rename to bindings/python/pyiree/rt/vm.h
index 0ecbcce..4688d82 100644
--- a/bindings/python/pyiree/vm.h
+++ b/bindings/python/pyiree/rt/vm.h
@@ -12,12 +12,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#ifndef IREE_BINDINGS_PYTHON_PYIREE_VM_H_
-#define IREE_BINDINGS_PYTHON_PYIREE_VM_H_
+#ifndef IREE_BINDINGS_PYTHON_PYIREE_RT_VM_H_
+#define IREE_BINDINGS_PYTHON_PYIREE_RT_VM_H_
#include "absl/types/optional.h"
-#include "bindings/python/pyiree/binding.h"
-#include "bindings/python/pyiree/host_types.h"
+#include "bindings/python/pyiree/common/binding.h"
+#include "bindings/python/pyiree/rt/host_types.h"
#include "iree/base/api.h"
#include "iree/vm/api.h"
#include "iree/vm/bytecode_module.h"
@@ -115,8 +115,7 @@
class VmModule : public ApiRefCounted<VmModule, iree_vm_module_t> {
public:
- static VmModule FromFlatbufferBlob(
- std::shared_ptr<OpaqueBlob> flatbuffer_blob);
+ static VmModule FromFlatbufferBlob(py::buffer flatbuffer_blob);
absl::optional<iree_vm_function_t> LookupFunction(
const std::string& name, iree_vm_function_linkage_t linkage);
@@ -160,4 +159,4 @@
} // namespace python
} // namespace iree
-#endif // IREE_BINDINGS_PYTHON_PYIREE_VM_H_
+#endif // IREE_BINDINGS_PYTHON_PYIREE_RT_VM_H_
diff --git a/bindings/python/pyiree/vm_test.py b/bindings/python/pyiree/rt/vm_test.py
similarity index 74%
rename from bindings/python/pyiree/vm_test.py
rename to bindings/python/pyiree/rt/vm_test.py
index f0f3a90..e947000 100644
--- a/bindings/python/pyiree/vm_test.py
+++ b/bindings/python/pyiree/rt/vm_test.py
@@ -17,11 +17,12 @@
from absl.testing import absltest
import numpy as np
-import pyiree
+from pyiree import compiler
+from pyiree import rt
def create_simple_mul_module():
- ctx = pyiree.CompilerContext()
+ ctx = compiler.Context()
input_module = ctx.parse_asm("""
func @simple_mul(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32>
attributes { iree.module.export } {
@@ -30,7 +31,7 @@
}
""")
binary = input_module.compile()
- m = pyiree.binding.vm.VmModule.from_flatbuffer(binary)
+ m = rt.VmModule.from_flatbuffer(binary)
return m
@@ -39,22 +40,22 @@
@classmethod
def setUpClass(cls):
super().setUpClass()
- driver_names = pyiree.binding.hal.HalDriver.query()
+ driver_names = rt.HalDriver.query()
print("DRIVER_NAMES =", driver_names)
- cls.driver = pyiree.binding.hal.HalDriver.create("vulkan")
+ cls.driver = rt.HalDriver.create("vulkan")
cls.device = cls.driver.create_default_device()
- cls.hal_module = pyiree.binding.vm.create_hal_module(cls.device)
- cls.htf = pyiree.binding.host_types.HostTypeFactory.get_numpy()
+ cls.hal_module = rt.create_hal_module(cls.device)
+ cls.htf = rt.HostTypeFactory.get_numpy()
def test_variant_list(self):
- l = pyiree.binding.vm.VmVariantList(5)
+ l = rt.VmVariantList(5)
print(l)
self.assertEqual(l.size, 0)
def test_context_id(self):
- instance = pyiree.binding.vm.VmInstance()
- context1 = pyiree.binding.vm.VmContext(instance)
- context2 = pyiree.binding.vm.VmContext(instance)
+ instance = rt.VmInstance()
+ context1 = rt.VmContext(instance)
+ context2 = rt.VmContext(instance)
self.assertGreater(context2.context_id, context1.context_id)
def test_module_basics(self):
@@ -65,25 +66,23 @@
self.assertIs(notfound, None)
def test_dynamic_module_context(self):
- instance = pyiree.binding.vm.VmInstance()
- context = pyiree.binding.vm.VmContext(instance)
+ instance = rt.VmInstance()
+ context = rt.VmContext(instance)
m = create_simple_mul_module()
context.register_modules([self.hal_module, m])
def test_static_module_context(self):
m = create_simple_mul_module()
print(m)
- instance = pyiree.binding.vm.VmInstance()
+ instance = rt.VmInstance()
print(instance)
- context = pyiree.binding.vm.VmContext(
- instance, modules=[self.hal_module, m])
+ context = rt.VmContext(instance, modules=[self.hal_module, m])
print(context)
def test_synchronous_invoke_function(self):
m = create_simple_mul_module()
- instance = pyiree.binding.vm.VmInstance()
- context = pyiree.binding.vm.VmContext(
- instance, modules=[self.hal_module, m])
+ instance = rt.VmInstance()
+ context = rt.VmContext(instance, modules=[self.hal_module, m])
f = m.lookup_function("simple_mul")
abi = context.create_function_abi(self.device, self.htf, f)
print("INVOKING:", abi)
diff --git a/bindings/python/pyiree/tf/support/BUILD b/bindings/python/pyiree/tf/support/BUILD
new file mode 100644
index 0000000..beec482
--- /dev/null
+++ b/bindings/python/pyiree/tf/support/BUILD
@@ -0,0 +1,35 @@
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+load(
+ "//iree:build_defs.bzl",
+ "INTREE_TENSORFLOW_PY_DEPS",
+)
+
+package(
+ default_visibility = ["//visibility:public"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+py_library(
+ name = "support",
+ srcs = [
+ "tf_test_driver.py",
+ "tf_test_utils.py",
+ ],
+ deps = INTREE_TENSORFLOW_PY_DEPS + [
+ "//bindings/python/pyiree/compiler",
+ "//bindings/python/pyiree/rt",
+ ],
+)
diff --git a/bindings/python/pyiree/tf_interop/tf_test_driver.py b/bindings/python/pyiree/tf/support/tf_test_driver.py
similarity index 95%
rename from bindings/python/pyiree/tf_interop/tf_test_driver.py
rename to bindings/python/pyiree/tf/support/tf_test_driver.py
index e7656ea..9c3c336 100644
--- a/bindings/python/pyiree/tf_interop/tf_test_driver.py
+++ b/bindings/python/pyiree/tf/support/tf_test_driver.py
@@ -1,3 +1,4 @@
+# Lint as: python3
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -11,8 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
-# Lint as: python3
"""Utilities for running tests from TensorFlow models."""
import contextlib
@@ -25,10 +24,9 @@
from absl import app
from absl import flags
+from pyiree import compiler
import tensorflow.compat.v2 as tf
-from .. import binding
-
flags.DEFINE_string("filecheck_binary", "filecheck",
"Location of the filecheck binary.")
flags.DEFINE_bool("disable_filecheck", False,
@@ -47,11 +45,11 @@
"""Runs an individual test dict."""
tf_module_builder_lambda = test_dict["tf_module_builder"]
tf_module = tf_module_builder_lambda()
- ctx = binding.compiler.CompilerContext()
+ ctx = compiler.Context()
with tempfile.TemporaryDirectory() as sm_path:
options = tf.saved_model.SaveOptions(save_debug_info=True)
tf.saved_model.save(tf_module, sm_path, options=options)
- input_module = binding.tf_interop.load_saved_model(ctx, sm_path)
+ input_module = compiler.binding.tf_interop.load_saved_model(ctx, sm_path)
passes = test_dict.get("passes")
expect_pass_failure = test_dict.get("expect_pass_failure")
@@ -90,6 +88,7 @@
def _find_filecheck():
+ """Finds the filecheck binary."""
filecheck_binary = FLAGS.filecheck_binary
if os.path.isabs(filecheck_binary):
return filecheck_binary
diff --git a/bindings/python/pyiree/tf_interop/test_utils.py b/bindings/python/pyiree/tf/support/tf_test_utils.py
similarity index 97%
rename from bindings/python/pyiree/tf_interop/test_utils.py
rename to bindings/python/pyiree/tf/support/tf_test_utils.py
index d5f28c5..c28f5ca 100644
--- a/bindings/python/pyiree/tf_interop/test_utils.py
+++ b/bindings/python/pyiree/tf/support/tf_test_utils.py
@@ -23,12 +23,11 @@
import re
import tempfile
-from .. import binding
-from .. import compiler
-from .. import system_api
from absl import flags
from absl import logging
import numpy as np
+from pyiree import compiler
+from pyiree import rt
import tensorflow.compat.v2 as tf
flags.DEFINE_string(
@@ -180,8 +179,7 @@
ctor(),
exported_names=exported_names,
target_backends=backend.iree_compiler_targets)
- self._iree_module = binding.vm.VmModule.from_flatbuffer(
- self._iree_module_blob)
+ self._iree_module = rt.VmModule.from_flatbuffer(self._iree_module_blob)
def instantiate(self):
return _IreeModuleInstance(self._backend, self._iree_module_blob,
@@ -197,8 +195,8 @@
self._iree_module = iree_module
self._iree_module_name = self._iree_module.name
- self._system_config = system_api.Config(driver_name=backend.iree_driver)
- self._context = system_api.SystemContext(
+ self._system_config = rt.Config(driver_name=backend.iree_driver)
+ self._context = rt.SystemContext(
modules=[self._iree_module], config=self._system_config)
def __getattr__(self, attr):
@@ -536,7 +534,7 @@
trace_file_name = cls.__name__ + ".wtf-trace"
trace_file = os.path.join(tempfile.gettempdir(), trace_file_name)
print("Flushing trace file to:", trace_file)
- binding.tracing.flush(trace_file)
+ rt.binding.tracing.flush(trace_file)
print("Flush complete")
super().tearDownClass()
diff --git a/build_tools/scripts/start_colab_kernel.py b/build_tools/scripts/start_colab_kernel.py
index 79103b2..bdc8bf9 100644
--- a/build_tools/scripts/start_colab_kernel.py
+++ b/build_tools/scripts/start_colab_kernel.py
@@ -81,10 +81,9 @@
def build():
"""Builds the python bundle."""
print("Building python bindings...")
- subprocess.check_call(
- [bazel_exe, "build", "//bindings/python/pyiree:everything_for_colab"],
- cwd=repo_root,
- env=bazel_env)
+ subprocess.check_call([bazel_exe, "build", "//colab:everything_for_colab"],
+ cwd=repo_root,
+ env=bazel_env)
def run():
@@ -93,7 +92,7 @@
if os.path.sep == "\\":
runfiles_suffix = ".exe.runfiles" # Windows uses a special name
- runfiles_dir = os.path.join(bazel_bin, "bindings", "python", "pyiree",
+ runfiles_dir = os.path.join(bazel_bin, "colab",
"everything_for_colab" + runfiles_suffix)
# Top level directories under the runfiles get added to the sys path.
extra_python_path = []
diff --git a/colab/BUILD.bazel b/colab/BUILD.bazel
new file mode 100644
index 0000000..14c88ee
--- /dev/null
+++ b/colab/BUILD.bazel
@@ -0,0 +1,28 @@
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+py_binary(
+ name = "everything_for_colab",
+ srcs = ["dummy.py"],
+ main = "dummy.py",
+ python_version = "PY3",
+ # TODO(b/145815906) Get this running in OSS CI.
+ tags = ["nokokoro"],
+ deps = [
+ "//bindings/python:pathsetup", # build_cleaner: keep
+ "//bindings/python/pyiree/compiler", # build_cleaner: keep
+ "//bindings/python/pyiree/rt", # build_cleaner: keep
+ "//bindings/python/pyiree/tf/support", # build_cleaner: keep
+ ],
+)
diff --git a/bindings/python/pyiree/dummy.py b/colab/dummy.py
similarity index 100%
rename from bindings/python/pyiree/dummy.py
rename to colab/dummy.py
diff --git a/integrations/tensorflow/compiler/test/BUILD b/integrations/tensorflow/compiler/test/BUILD
index 09bf9fb..c985e3f 100644
--- a/integrations/tensorflow/compiler/test/BUILD
+++ b/integrations/tensorflow/compiler/test/BUILD
@@ -48,6 +48,6 @@
# TODO(b/145815906) Get this running in OSS CI.
tags = ["nokokoro"],
deps = INTREE_TENSORFLOW_PY_DEPS + [
- "//bindings/python/pyiree",
+ "//bindings/python/pyiree/tf/support",
],
)
diff --git a/integrations/tensorflow/compiler/test/saved_model_adopt_exports.py b/integrations/tensorflow/compiler/test/saved_model_adopt_exports.py
index 8a44628..7112191 100644
--- a/integrations/tensorflow/compiler/test/saved_model_adopt_exports.py
+++ b/integrations/tensorflow/compiler/test/saved_model_adopt_exports.py
@@ -17,7 +17,7 @@
# pylint: disable=missing-docstring
# pylint: disable=line-too-long
-import pyiree
+from pyiree.tf.support import tf_test_driver
import tensorflow.compat.v2 as tf
SAVED_MODEL_IMPORT_PASSES = [
@@ -46,7 +46,7 @@
return a * b
-pyiree.tf_test_driver.add_test(
+tf_test_driver.add_test(
test_name="T0001_FlatArgsResultsNoBoundGlobals",
tf_module_builder=T0001_FlatArgsResultsNoBoundGlobals,
passes=SAVED_MODEL_IMPORT_PASSES,
@@ -154,33 +154,33 @@
self.v.assign_add(tf.constant([0., 1.]))
-pyiree.tf_test_driver.add_test(
+tf_test_driver.add_test(
test_name="T0002a_SimpleVarRead",
tf_module_builder=T0002a_SimpleVarRead,
passes=SAVED_MODEL_IMPORT_PASSES,
print_input_module=True)
-pyiree.tf_test_driver.add_test(
+tf_test_driver.add_test(
test_name="T0002b_SimpleVarWrite",
tf_module_builder=T0002b_SimpleVarWrite,
passes=SAVED_MODEL_IMPORT_PASSES,
print_input_module=True)
-pyiree.tf_test_driver.add_test(
+tf_test_driver.add_test(
test_name="T0002c_SimpleConst",
tf_module_builder=T0002c_SimpleConst,
passes=SAVED_MODEL_IMPORT_PASSES,
print_input_module=True)
-pyiree.tf_test_driver.add_test(
+tf_test_driver.add_test(
test_name="T0002d_VarCompatibleShapeChange",
tf_module_builder=T0002d_VarCompatibleShapeChange,
passes=SAVED_MODEL_IMPORT_PASSES,
print_input_module=True)
-pyiree.tf_test_driver.add_test(
+tf_test_driver.add_test(
test_name="T0002e_Error_VarMultipleExportedNames",
tf_module_builder=T0002e_Error_VarMultipleExportedNames,
passes=SAVED_MODEL_IMPORT_PASSES,
print_input_module=True,
expect_pass_failure=True)
-pyiree.tf_test_driver.add_test(
+tf_test_driver.add_test(
test_name="T0002f_Error_UnsupportedResourceOp",
tf_module_builder=T0002f_Error_UnsupportedResourceOp,
passes=SAVED_MODEL_IMPORT_PASSES,
@@ -206,7 +206,7 @@
return d["x"] * d["y"]
-pyiree.tf_test_driver.add_test(
+tf_test_driver.add_test(
test_name="T0003a_StructuredArgs",
tf_module_builder=T0003a_StructuredArgs,
passes=SAVED_MODEL_IMPORT_PASSES,
@@ -232,7 +232,7 @@
return {"x": product, "x_squared": product * product}
-pyiree.tf_test_driver.add_test(
+tf_test_driver.add_test(
test_name="T0003b_StructuredMultipleDictResult",
tf_module_builder=T0003b_StructuredMultipleDictResult,
passes=SAVED_MODEL_IMPORT_PASSES,
@@ -258,7 +258,7 @@
return {"x": product}
-pyiree.tf_test_driver.add_test(
+tf_test_driver.add_test(
test_name="T0003c_StructuredSingleDictResult",
tf_module_builder=T0003c_StructuredSingleDictResult,
passes=SAVED_MODEL_IMPORT_PASSES,
@@ -284,7 +284,7 @@
return product
-pyiree.tf_test_driver.add_test(
+tf_test_driver.add_test(
test_name="T0003d_StructuredSingleResult",
tf_module_builder=T0003d_StructuredSingleResult,
passes=SAVED_MODEL_IMPORT_PASSES,
@@ -310,7 +310,7 @@
return product, a, b
-pyiree.tf_test_driver.add_test(
+tf_test_driver.add_test(
test_name="T0003e_StructuredSequenceResult",
tf_module_builder=T0003e_StructuredSequenceResult,
passes=SAVED_MODEL_IMPORT_PASSES,
@@ -336,7 +336,7 @@
return product, {"a": a, "b": b}
-pyiree.tf_test_driver.add_test(
+tf_test_driver.add_test(
test_name="T0003f_StructuredNestedResult",
tf_module_builder=T0003f_StructuredNestedResult,
passes=SAVED_MODEL_IMPORT_PASSES,
@@ -363,7 +363,7 @@
T0005_MultipleExportedFuncNames.another_copy = (
T0005_MultipleExportedFuncNames.simple_mul)
-pyiree.tf_test_driver.add_test(
+tf_test_driver.add_test(
test_name="T0005_MultipleExportedFuncNames",
tf_module_builder=T0005_MultipleExportedFuncNames,
passes=SAVED_MODEL_IMPORT_PASSES,
@@ -371,4 +371,4 @@
expect_pass_failure=True)
if __name__ == "__main__":
- pyiree.tf_test_driver.run_tests(__file__, with_filecheck=True)
+ tf_test_driver.run_tests(__file__, with_filecheck=True)
diff --git a/integrations/tensorflow/e2e/BUILD b/integrations/tensorflow/e2e/BUILD
index b7db32e..65e623f 100644
--- a/integrations/tensorflow/e2e/BUILD
+++ b/integrations/tensorflow/e2e/BUILD
@@ -31,7 +31,7 @@
# TODO(b/145815906) Get this running in OSS CI.
tags = ["nokokoro"],
deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
- "//bindings/python/pyiree",
+ "//bindings/python/pyiree/tf/support",
],
)
for name in [
@@ -54,6 +54,6 @@
"nokokoro",
],
deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
- "//bindings/python/pyiree",
+ "//bindings/python/pyiree/tf/support",
],
)
diff --git a/integrations/tensorflow/e2e/control_flow_test.py b/integrations/tensorflow/e2e/control_flow_test.py
index 89ac5e0..6cbc24d 100644
--- a/integrations/tensorflow/e2e/control_flow_test.py
+++ b/integrations/tensorflow/e2e/control_flow_test.py
@@ -17,7 +17,7 @@
from __future__ import print_function
import numpy
-from pyiree import tf_test_utils
+from pyiree.tf.support import tf_test_utils
import tensorflow.compat.v2 as tf
diff --git a/integrations/tensorflow/e2e/exported_names_test.py b/integrations/tensorflow/e2e/exported_names_test.py
index 1a86ef5..2d7e447 100644
--- a/integrations/tensorflow/e2e/exported_names_test.py
+++ b/integrations/tensorflow/e2e/exported_names_test.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from pyiree import tf_test_utils
+from pyiree.tf.support import tf_test_utils
import tensorflow.compat.v2 as tf
diff --git a/integrations/tensorflow/e2e/keras_lstm_test.py b/integrations/tensorflow/e2e/keras_lstm_test.py
index d97ec67..3c98125 100644
--- a/integrations/tensorflow/e2e/keras_lstm_test.py
+++ b/integrations/tensorflow/e2e/keras_lstm_test.py
@@ -14,7 +14,7 @@
# limitations under the License.
import os
-from pyiree import tf_test_utils
+from pyiree.tf.support import tf_test_utils
import tensorflow.compat.v2 as tf
# TODO(silvasean): Get this test working on IREE.
diff --git a/integrations/tensorflow/e2e/mandelbrot_test.py b/integrations/tensorflow/e2e/mandelbrot_test.py
index 9bbe33b..4ff5252 100644
--- a/integrations/tensorflow/e2e/mandelbrot_test.py
+++ b/integrations/tensorflow/e2e/mandelbrot_test.py
@@ -17,7 +17,7 @@
from __future__ import print_function
import os
-from pyiree import tf_test_utils
+from pyiree.tf.support import tf_test_utils
import tensorflow.compat.v2 as tf
# TODO(silvasean): Get this working on IREE.
diff --git a/integrations/tensorflow/e2e/simple_arithmetic_test.py b/integrations/tensorflow/e2e/simple_arithmetic_test.py
index e9888ca..0b8930d 100644
--- a/integrations/tensorflow/e2e/simple_arithmetic_test.py
+++ b/integrations/tensorflow/e2e/simple_arithmetic_test.py
@@ -15,7 +15,7 @@
"""Several baseline e2e simple arithmetic tests."""
import numpy as np
-from pyiree import tf_test_utils
+from pyiree.tf.support import tf_test_utils
import tensorflow.compat.v2 as tf
diff --git a/integrations/tensorflow/e2e/simple_stateful_test.py b/integrations/tensorflow/e2e/simple_stateful_test.py
index 52b2307..cc8bfd5 100644
--- a/integrations/tensorflow/e2e/simple_stateful_test.py
+++ b/integrations/tensorflow/e2e/simple_stateful_test.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from pyiree import tf_test_utils
+from pyiree.tf.support import tf_test_utils
import tensorflow.compat.v2 as tf
diff --git a/integrations/tensorflow/e2e/vulkan_conv_test.py b/integrations/tensorflow/e2e/vulkan_conv_test.py
index e197972..cf7c064 100644
--- a/integrations/tensorflow/e2e/vulkan_conv_test.py
+++ b/integrations/tensorflow/e2e/vulkan_conv_test.py
@@ -14,7 +14,7 @@
# limitations under the License.
import numpy as np
-from pyiree import tf_test_utils
+from pyiree.tf.support import tf_test_utils
import tensorflow.compat.v2 as tf
diff --git a/iree/build_defs.bzl b/iree/build_defs.bzl
index 8971752..3812812 100644
--- a/iree/build_defs.bzl
+++ b/iree/build_defs.bzl
@@ -15,6 +15,7 @@
"""Common Bazel definitions for IREE."""
load("@com_github_google_flatbuffers//:build_defs.bzl", "flatbuffer_cc_library")
+load("@rules_cc//cc:defs.bzl", _cc_binary = "cc_binary", _cc_library = "cc_library")
load("@iree_native_python//:build_defs.bzl", "py_extension")
load("@iree_core//build_tools/third_party/glslang:build_defs.bzl", "glsl_vulkan")
load("@iree_core//iree:lit_test.bzl", _iree_glob_lit_tests = "iree_glob_lit_tests", _iree_setup_lit_package = "iree_setup_lit_package")
@@ -75,6 +76,10 @@
"//iree/testing:gtest_main",
]
+# Aliases to the Starlark cc rules.
+cc_library = _cc_library
+cc_binary = _cc_binary
+
def iree_py_library(**kwargs):
"""Compatibility py_library which has bazel compatible args."""