Add pythonic high level module loading and invocation API.
* For simple things, load_module() and load_modules() should be all that is needed.
* More complicated cases can bring their own SystemContext and drill through modules and functions.
* I plan to add some kind of thread local, shared SystemContext thingy for the more dynamic eager-like case.
* Look at the test_load_module() test in system_api_test.py for an example of how easy this can be.
PiperOrigin-RevId: 286943328
diff --git a/bindings/python/pyiree/BUILD b/bindings/python/pyiree/BUILD
index 7816b0e..a8b48d0 100644
--- a/bindings/python/pyiree/BUILD
+++ b/bindings/python/pyiree/BUILD
@@ -75,6 +75,7 @@
deps = [
":binding",
":compiler",
+ ":system_api",
"//bindings/python:pathsetup", # build_cleaner: keep
] + select({
"//iree:enable_tensorflow": [
@@ -96,6 +97,16 @@
],
)
+py_library(
+ name = "system_api",
+ srcs = ["system_api.py"],
+ srcs_version = "PY3",
+ deps = [
+ ":binding",
+ "//bindings/python:pathsetup", # build_cleaner: keep
+ ],
+)
+
cc_library(
name = "base",
srcs = [
@@ -218,6 +229,20 @@
)
py_test(
+ name = "system_api_test",
+ srcs = ["system_api_test.py"],
+ python_version = "PY3",
+ # TODO(laurenzo): Enable once test does not depend on a real vulkan device.
+ tags = ["notap"],
+ deps = NUMPY_DEPS + [
+ ":system_api",
+ "//bindings/python:pathsetup", # build_cleaner: keep
+ "@absl_py//absl/testing:absltest",
+ "//bindings/python/pyiree",
+ ],
+)
+
+py_test(
name = "vm_test",
srcs = ["vm_test.py"],
python_version = "PY3",
diff --git a/bindings/python/pyiree/__init__.py b/bindings/python/pyiree/__init__.py
index 46739ee..3d0701d 100644
--- a/bindings/python/pyiree/__init__.py
+++ b/bindings/python/pyiree/__init__.py
@@ -17,6 +17,7 @@
# pylint: disable=invalid-name
# pylint: disable=g-import-not-at-top
# pylint: disable=g-bad-import-order
+# pylint: disable=wildcard-import
# Top-level modules that are imported verbatim.
from . import binding
@@ -31,6 +32,10 @@
# Alias specific native functions.
from .binding.vm import create_module_from_blob
+# system_api explicitly exports the things that should be in the global
+# scope.
+from .system_api import *
+
### Load non-native py_library deps here ###
### Order matters because these typically have a back-reference on this
### module (being constructed). Issues should fail-fast and be easy to see.
diff --git a/bindings/python/pyiree/system_api.py b/bindings/python/pyiree/system_api.py
new file mode 100644
index 0000000..126f3a3
--- /dev/null
+++ b/bindings/python/pyiree/system_api.py
@@ -0,0 +1,215 @@
+# Lint as: python3
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Top-level python system API.
+
+This facility layers on top of the underlying binding native facilities and
+exposes them in a way that allows general operation against contexts, modules
+and functions.
+"""
+
+# pylint: disable=protected-access
+# pylint: disable=unused-argument
+# pylint: disable=g-explicit-length-test
+
+__all__ = ["load_module", "load_modules", "Config", "SystemContext"]
+
+from typing import Tuple
+
+from . import binding as _binding
+
+# Typing aliases (largely used for documentation).
+AnyModule = _binding.vm.VmModule
+
+
+class Config:
+
+ vm_instance: _binding.vm.VmInstance
+ host_type_factory: _binding.host_types.HostTypeFactory
+ default_modules: Tuple[AnyModule]
+
+
+class _GlobalConfig(Config):
+ """Singleton of globally configured instances."""
+
+ _instance = None
+
+ def __new__(cls, *args, **kwargs):
+ if cls._instance is None:
+ cls._instance = super().__new__(cls)
+ cls._instance._static_init()
+ return cls._instance
+
+ def _static_init(self):
+ self.vm_instance = _binding.vm.VmInstance()
+ self.driver_names = _binding.hal.HalDriver.query()
+ # TODO(laurenzo): More flexible selection of driver and device.
+ self.driver = _binding.hal.HalDriver.create("vulkan")
+ self.device = self.driver.create_default_device()
+ self.hal_module = _binding.vm.create_hal_module(self.device)
+ self.host_type_factory = _binding.host_types.HostTypeFactory.get_numpy()
+ self.default_modules = (self.hal_module,)
+
+
+class BoundFunction:
+ """Wraps a VmFunction, VmContext and ABI into a pythonic function."""
+
+ def __init__(self, context: "SystemContext",
+ vm_function: _binding.vm.VmFunction):
+ self._context = context
+ self._vm_function = vm_function
+ self._abi = context.create_function_abi(vm_function)
+
+ def __call__(self, *args):
+ # NOTE: This is just doing sync dispatch right now. In the future,
+ # this should default to async and potentially have some kind of policy
+ # flag that can allow it to be overriden.
+ inputs = self._abi.raw_pack_inputs(args)
+ results = self._abi.allocate_results(inputs, static_alloc=False)
+ self._context._vm_context.invoke(self._vm_function, inputs, results)
+ unpacked_results = self._abi.raw_unpack_results(results)
+ # TODO(laurenzo): When switching from 'raw' to structured pack/unpack,
+ # the ABI should take care of this one-arg special case.
+ if len(unpacked_results) == 1:
+ return unpacked_results[0]
+ elif len(unpacked_results) == 0:
+ return None
+ else:
+ return unpacked_results
+
+ def __repr__(self):
+ return "<BoundFunction %r (%r)>" % (
+ self._abi,
+ self._vm_function,
+ )
+
+
+class BoundModule:
+ """Wraps a VmModule with its context and provides nice python accessors.
+
+ Resolves item access (["foo"]) as function resolution.
+ """
+
+ def __init__(self, context: "SystemContext", vm_module: AnyModule):
+ self._context = context
+ self._vm_module = vm_module
+ self._lazy_functions = dict()
+
+ @property
+ def name(self):
+ return self._vm_module.name
+
+ def __getattr__(self, name):
+ try:
+ return self[name]
+ except KeyError:
+ raise AttributeError(name)
+
+ def __getitem__(self, name):
+ vm_function = self._lazy_functions.get(name)
+ if vm_function is not None:
+ return vm_function
+
+ vm_function = self._vm_module.lookup_function(name)
+ if vm_function is None:
+ raise KeyError("Function '%s' not found in module '%s'" %
+ (name, self.name))
+ bound_function = BoundFunction(self._context, vm_function)
+ self._lazy_functions[name] = bound_function
+ return bound_function
+
+ def __repr__(self):
+ return "<BoundModule %r>" % (self._vm_module,)
+
+
+class Modules(dict):
+ """Provides nice python accessors for a dict of modules."""
+
+ def __getattr__(self, name):
+ try:
+ return self[name]
+ except KeyError:
+ raise AttributeError(name)
+
+
+class SystemContext:
+ """Global system."""
+
+ def __init__(self, modules=None, config: Config = None):
+ self._config = config if config is not None else _GlobalConfig()
+ self._is_dynamic = modules is None
+ if not self._is_dynamic:
+ init_modules = self._config.default_modules + tuple(modules)
+ else:
+ init_modules = None
+
+ self._vm_context = _binding.vm.VmContext(
+ instance=self._config.vm_instance, modules=init_modules)
+
+ if self._is_dynamic:
+ self._vm_context.register_modules(self._config.default_modules)
+ self._modules = Modules([
+ (m.name, BoundModule(self, m)) for m in self._config.default_modules
+ ])
+ else:
+ self._modules = Modules([
+ (m.name, BoundModule(self, m)) for m in init_modules
+ ])
+
+ @property
+ def is_dynamic(self) -> bool:
+ return self._is_dynamic
+
+ @property
+ def config(self) -> Config:
+ return self._config
+
+ @property
+ def instance(self) -> _binding.vm.VmInstance:
+ return self._instance
+
+ @property
+ def modules(self) -> Modules:
+ return self._modules
+
+ def create_function_abi(
+ self, f: _binding.vm.VmFunction) -> _binding.function_abi.FunctionAbi:
+ return self._vm_context.create_function_abi(self._config.device,
+ self._config.host_type_factory,
+ f)
+
+ def add_modules(self, modules):
+ assert self._is_dynamic, "Cannot 'add_module' on a static context"
+ for m in modules:
+ name = m.name
+ if name in self._modules:
+ raise ValueError("Attempt to register duplicate module: '%s'" % (name,))
+ self._modules[m.name] = BoundModule(self, m)
+ self._vm_context.register_modules(modules)
+
+ def add_module(self, module):
+ self.add_modules((module,))
+
+
+def load_modules(*modules):
+ """Loads modules into a new or shared context and returns them."""
+ context = SystemContext(modules=modules)
+ context_modules = context.modules
+ bound_modules = [context_modules[m.name] for m in modules]
+ return bound_modules
+
+
+def load_module(module):
+ """Loads a module into a new or shared context and returns them."""
+ return load_modules(module)[0]
diff --git a/bindings/python/pyiree/system_api_test.py b/bindings/python/pyiree/system_api_test.py
new file mode 100644
index 0000000..c5b7f16
--- /dev/null
+++ b/bindings/python/pyiree/system_api_test.py
@@ -0,0 +1,95 @@
+# Lint as: python3
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# pylint: disable=unused-variable
+
+import re
+
+from absl.testing import absltest
+import numpy as np
+import pyiree
+
+
+def create_simple_mul_module():
+ ctx = pyiree.CompilerContext()
+ input_module = ctx.parse_asm("""
+ module @arithmetic {
+ func @simple_mul(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32>
+ attributes { iree.module.export } {
+ %0 = "xla_hlo.mul"(%arg0, %arg1) {name = "mul.1"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
+ return %0 : tensor<4xf32>
+ }
+ }
+ """)
+ binary = input_module.compile()
+ m = pyiree.binding.vm.VmModule.from_flatbuffer(binary)
+ return m
+
+
+class SystemApiTest(absltest.TestCase):
+
+ def test_empty_dynamic(self):
+ ctx = pyiree.SystemContext()
+ self.assertTrue(ctx.is_dynamic)
+ self.assertIn("hal", ctx.modules)
+ self.assertEqual(ctx.modules.hal.name, "hal")
+
+ def test_empty_static(self):
+ ctx = pyiree.SystemContext(modules=())
+ self.assertFalse(ctx.is_dynamic)
+ self.assertIn("hal", ctx.modules)
+ self.assertEqual(ctx.modules.hal.name, "hal")
+
+ def test_custom_dynamic(self):
+ ctx = pyiree.SystemContext()
+ self.assertTrue(ctx.is_dynamic)
+ ctx.add_module(create_simple_mul_module())
+ self.assertEqual(ctx.modules.arithmetic.name, "arithmetic")
+ f = ctx.modules.arithmetic["simple_mul"]
+ f_repr = repr(f)
+ print(f_repr)
+ self.assertRegex(
+ f_repr,
+ re.escape(
+ "(Buffer<float32[4]>, Buffer<float32[4]>) -> (Buffer<float32[4]>)"))
+
+ def test_duplicate_module(self):
+ ctx = pyiree.SystemContext()
+ self.assertTrue(ctx.is_dynamic)
+ ctx.add_module(create_simple_mul_module())
+ with self.assertRaisesRegex(ValueError, "arithmetic"):
+ ctx.add_module(create_simple_mul_module())
+
+ def test_static_invoke(self):
+ ctx = pyiree.SystemContext()
+ self.assertTrue(ctx.is_dynamic)
+ ctx.add_module(create_simple_mul_module())
+ self.assertEqual(ctx.modules.arithmetic.name, "arithmetic")
+ f = ctx.modules.arithmetic["simple_mul"]
+ arg0 = np.array([1., 2., 3., 4.], dtype=np.float32)
+ arg1 = np.array([4., 5., 6., 7.], dtype=np.float32)
+ results = f(arg0, arg1)
+ np.testing.assert_allclose(results, [4., 10., 18., 28.])
+
+ def test_load_module(self):
+ arithmetic = pyiree.load_module(create_simple_mul_module())
+ arg0 = np.array([1., 2., 3., 4.], dtype=np.float32)
+ arg1 = np.array([4., 5., 6., 7.], dtype=np.float32)
+ results = arithmetic.simple_mul(arg0, arg1)
+ np.testing.assert_allclose(results, [4., 10., 18., 28.])
+
+
+if __name__ == "__main__":
+ absltest.main()
diff --git a/bindings/python/pyiree/vm.cc b/bindings/python/pyiree/vm.cc
index 9c7eedc..d8de629 100644
--- a/bindings/python/pyiree/vm.cc
+++ b/bindings/python/pyiree/vm.cc
@@ -241,6 +241,7 @@
py::class_<VmModule>(m, "VmModule")
.def_static("from_flatbuffer", &VmModule::FromFlatbufferBlob)
+ .def_property_readonly("name", &VmModule::name)
.def("lookup_function", &VmModule::LookupFunction, py::arg("name"),
py::arg("linkage") = IREE_VM_FUNCTION_LINKAGE_EXPORT);
}
diff --git a/bindings/python/pyiree/vm.h b/bindings/python/pyiree/vm.h
index 80c0867..ea2fb30 100644
--- a/bindings/python/pyiree/vm.h
+++ b/bindings/python/pyiree/vm.h
@@ -122,6 +122,11 @@
absl::optional<iree_vm_function_t> LookupFunction(
const std::string& name, iree_vm_function_linkage_t linkage);
+
+ std::string name() const {
+ auto name_sv = iree_vm_module_name(raw_ptr());
+ return std::string(name_sv.data, name_sv.size);
+ }
};
class VmContext : public ApiRefCounted<VmContext, iree_vm_context_t> {