| # Lint as: python3 |
| # Copyright 2019 Google LLC |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # https://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| """Top-level python system API. |
| |
| This facility layers on top of the underlying binding native facilities and |
| exposes them in a way that allows general operation against contexts, modules |
| and functions. |
| """ |
| |
| # pylint: disable=protected-access |
| # pylint: disable=unused-argument |
| # pylint: disable=g-explicit-length-test |
| |
| # TODO(#4131) python>=3.7: Use postponed type annotations. |
| |
| __all__ = ["load_module", "load_modules", "Config", "SystemContext"] |
| |
| import os |
| import sys |
| |
| from typing import Optional, Sequence, Tuple |
| |
| from . import binding as _binding |
| |
| # Typing aliases (largely used for documentation). |
| AnyModule = _binding.VmModule |
| |
| # Environment key for a comma-delimitted list of drivers to try to load. |
| PREFERRED_DRIVER_ENV_KEY = "IREE_DEFAULT_DRIVER" |
| |
| # Default value for IREE_DRIVER |
| DEFAULT_IREE_DRIVER_VALUE = "vulkan,vmla" |
| |
| |
| def _create_default_iree_driver( |
| driver_names: Optional[Sequence[str]] = None) -> _binding.HalDriver: |
| """Returns a default driver based on environment settings.""" |
| # TODO(laurenzo): Ideally this should take a module and join any explicitly |
| # provided driver list with environmental constraints and what the module |
| # was compiled for. |
| if driver_names is None: |
| # Read from environment. |
| driver_names = os.environ.get(PREFERRED_DRIVER_ENV_KEY) |
| if driver_names is None: |
| driver_names = DEFAULT_IREE_DRIVER_VALUE |
| driver_names = driver_names.split(",") |
| available_driver_names = _binding.HalDriver.query() |
| driver_exceptions = {} |
| for driver_name in driver_names: |
| if driver_name not in available_driver_names: |
| print(f"Could not create driver {driver_name} (not registered)", |
| file=sys.stderr) |
| continue |
| try: |
| driver = _binding.HalDriver.create(driver_name) |
| # TODO(laurenzo): Remove these prints to stderr (for now, more information |
| # is better and there is no better way to report it yet). |
| except Exception as ex: # pylint: disable=broad-except |
| print(f"Could not create default driver {driver_name}: {ex:!r}", |
| file=sys.stderr) |
| driver_exceptions[driver_name] = ex |
| continue |
| |
| # Sanity check creation of the default device and skip the driver if |
| # this fails (this works around issues where the driver is present |
| # but there are no devices). This default initialization scheme needs |
| # to be improved. |
| try: |
| device = driver.create_default_device() |
| except Exception as ex: |
| print(f"Could not create default driver device {driver_name}: {ex:!r}", |
| file=sys.stderr) |
| driver_exceptions[driver_name] = ex |
| continue |
| |
| print("Created IREE driver {driver_name}: {driver:!r}", file=sys.stderr) |
| return driver |
| |
| # All failed. |
| raise RuntimeError( |
| f"Could not create any requested driver {repr(driver_names)} (available=" |
| f"{repr(available_driver_names)}) : {repr(driver_exceptions)}") |
| |
| |
| class Config: |
| """System configuration.""" |
| |
| driver: _binding.HalDriver |
| device: _binding.HalDevice |
| vm_instance: _binding.VmInstance |
| host_type_factory: _binding.HostTypeFactory |
| default_modules: Tuple[AnyModule] |
| |
| def __init__(self, driver_name: Optional[str] = None): |
| self.vm_instance = _binding.VmInstance() |
| self.driver = _create_default_iree_driver( |
| driver_name.split(",") if driver_name is not None else None) |
| self.device = self.driver.create_default_device() |
| hal_module = _binding.create_hal_module(self.device) |
| strings_module = _binding.create_strings_module() |
| tensorlist_module = _binding.create_tensorlist_module() |
| self.host_type_factory = _binding.HostTypeFactory.get_numpy() |
| self.default_modules = (hal_module, strings_module, tensorlist_module) |
| |
| |
| _global_config = None |
| |
| |
| def _get_global_config(): |
| global _global_config |
| if _global_config is None: |
| _global_config = Config() |
| return _global_config |
| |
| |
| class BoundFunction: |
| """Wraps a VmFunction, VmContext and ABI into a pythonic function.""" |
| |
| def __init__(self, context: "SystemContext", |
| vm_function: _binding.VmFunction): |
| self._context = context |
| self._vm_function = vm_function |
| self._abi = context.create_function_abi(vm_function) |
| self._serialized_inputs = None |
| self._serialized_outputs = None |
| |
| def __call__(self, *args, **kwargs): |
| # NOTE: This is just doing sync dispatch right now. In the future, |
| # this should default to async and potentially have some kind of policy |
| # flag that can allow it to be overridden. |
| inputs = self._abi.pack_inputs(*args, **kwargs) |
| self._serialized_inputs = tuple(self._abi.serialize_vm_list(inputs)) |
| results = self._abi.allocate_results(inputs, static_alloc=False) |
| self._context._vm_context.invoke(self._vm_function, inputs, results) |
| self._serialized_outputs = tuple(self._abi.serialize_vm_list(results)) |
| unpacked_results = self._abi.unpack_results(results) |
| return unpacked_results |
| |
| def __repr__(self): |
| return f"<BoundFunction {repr(self._abi)} ({repr(self._vm_function)})>" |
| |
| def get_serialized_values(self): |
| if self._serialized_inputs is None: |
| raise RuntimeError("Attempted to call get_serialized_values() before " |
| "any values were passed.") |
| return self._serialized_inputs, self._serialized_outputs |
| |
| |
| class BoundModule: |
| """Wraps a VmModule with its context and provides nice python accessors. |
| |
| Resolves item access (["foo"]) as function resolution. |
| """ |
| |
| def __init__(self, context: "SystemContext", vm_module: AnyModule): |
| self._context = context |
| self._vm_module = vm_module |
| self._lazy_functions = dict() |
| |
| @property |
| def name(self): |
| return self._vm_module.name |
| |
| def __getattr__(self, name): |
| try: |
| return self[name] |
| except KeyError: |
| raise AttributeError(name) |
| |
| def __getitem__(self, name): |
| vm_function = self._lazy_functions.get(name) |
| if vm_function is not None: |
| return vm_function |
| |
| vm_function = self._vm_module.lookup_function(name) |
| if vm_function is None: |
| raise KeyError(f"Function '{name}' not found in module '{self.name}'") |
| bound_function = BoundFunction(self._context, vm_function) |
| self._lazy_functions[name] = bound_function |
| return bound_function |
| |
| def __repr__(self): |
| return f"<BoundModule {self._vm_module:!r}>" |
| |
| |
| class Modules(dict): |
| """Provides nice python accessors for a dict of modules.""" |
| |
| def __getattr__(self, name): |
| try: |
| return self[name] |
| except KeyError: |
| raise AttributeError(name) |
| |
| |
| class SystemContext: |
| """Global system.""" |
| |
| def __init__(self, modules=None, config: Optional[Config] = None): |
| self._config = config if config is not None else _get_global_config() |
| # :!r does not work with the _binding.HalDriver class. |
| print(f"SystemContext driver={repr(self._config.driver)}", file=sys.stderr) |
| self._is_dynamic = modules is None |
| if not self._is_dynamic: |
| init_modules = self._config.default_modules + tuple(modules) |
| else: |
| init_modules = None |
| |
| self._vm_context = _binding.VmContext(instance=self._config.vm_instance, |
| modules=init_modules) |
| |
| if self._is_dynamic: |
| self._vm_context.register_modules(self._config.default_modules) |
| self._modules = Modules([ |
| (m.name, BoundModule(self, m)) for m in self._config.default_modules |
| ]) |
| else: |
| self._modules = Modules([ |
| (m.name, BoundModule(self, m)) for m in init_modules |
| ]) |
| |
| @property |
| def is_dynamic(self) -> bool: |
| return self._is_dynamic |
| |
| @property |
| def config(self) -> Config: |
| return self._config |
| |
| @property |
| def instance(self) -> _binding.VmInstance: |
| return self._instance |
| |
| @property |
| def modules(self) -> Modules: |
| return self._modules |
| |
| def create_function_abi(self, f: _binding.VmFunction) -> _binding.FunctionAbi: |
| return self._vm_context.create_function_abi(self._config.device, |
| self._config.host_type_factory, |
| f) |
| |
| def add_modules(self, modules): |
| assert self._is_dynamic, "Cannot 'add_module' on a static context" |
| for m in modules: |
| name = m.name |
| if name in self._modules: |
| raise ValueError(f"Attempt to register duplicate module: '{name}'") |
| self._modules[m.name] = BoundModule(self, m) |
| self._vm_context.register_modules(modules) |
| |
| def add_module(self, module): |
| self.add_modules((module,)) |
| |
| |
| def load_modules(*modules, config: Optional[Config] = None): |
| """Loads modules into a new or shared context and returns them.""" |
| context = SystemContext(modules=modules, config=config) |
| bound_modules = [context.modules[m.name] for m in modules] |
| return bound_modules |
| |
| |
| def load_module(module, **kwargs): |
| """Loads a module into a new or shared context and returns them.""" |
| return load_modules(module, **kwargs)[0] |