blob: 0429f369c271fc7affa55d4a5d9e1b2e5d6ff1e8 [file] [log] [blame]
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import pyiree.compiler2 as compiler
import pyiree.compiler2.xla
from pyiree import rt
try:
import jax
except ModuleNotFoundError as e:
raise ModuleNotFoundError("pyiree.jax requires 'jax' and 'jaxlib' to be "
"installed in your python environment.") from e
# pytype thinks pyiree.jax is jax.
# pytype: disable=module-attr
__all__ = [
"aot",
"is_available",
"jit",
]
_BACKENDS = ["vmla", "llvmaot", "vulkan"]
_BACKEND_TO_DRIVER = {
"vmla": "vmla",
"llvmaot": "dylib",
"vulkan": "vulkan",
}
_BACKEND_TO_TARGETS = {
"vmla": ("vmla",),
"llvmaot": ("dylib-llvm-aot",),
"vulkan": ("vulkan-*",),
}
def is_available():
"""Determine if the IREE–XLA compiler are available for JAX."""
return compiler.xla.is_available()
def aot(function, *args, **options):
"""Traces and compiles a function, flattening the input args.
This is intended to be a lower-level interface for compiling a JAX function to
IREE without setting up the runtime bindings to use it within Python. A common
usecase for this is compiling to Android (and similar targets).
Args:
function: The function to compile.
args: The inputs to trace and compile the function for.
**kwargs: Keyword args corresponding to xla.ImportOptions or CompilerOptions
"""
xla_comp = jax.xla_computation(function)(*args)
hlo_proto = xla_comp.as_serialized_hlo_module_proto()
binary = compiler.xla.compile_str(hlo_proto, **options)
return binary
# A more JAX-native approach to jitting would be desireable here, however
# implementing that reasonably would require using JAX internals, particularly
# jax.linear_util.WrappedFun and helpers. The following is sufficient for many
# usecases for the time being.
class _JittedFunction:
def __init__(self, function, driver: str, **options):
self._function = function
self._driver_config = rt.Config(driver)
self._options = options
self._memoized_signatures = {}
def _get_signature(self, args_flat, in_tree):
args_flat = [rt.normalize_value(arg) for arg in args_flat]
return tuple((arg.shape, arg.dtype) for arg in args_flat) + (in_tree,)
def _wrap_and_compile(self, signature, args_flat, in_tree):
"""Compiles the function for the given signature."""
def wrapped_function(*args_flat):
args, kwargs = jax.tree_unflatten(in_tree, args_flat)
return self._function(*args, **kwargs)
# Compile the wrapped_function to IREE.
binary = aot(wrapped_function, *args_flat, **self._options)
cpp_vm_module = rt.VmModule.from_flatbuffer(binary)
module = rt.load_module(cpp_vm_module, config=self._driver_config)
# Get the output tree so it can be reconstructed from the outputs of the
# compiled module. Duplicating execution here isn't ideal, and could
# probably be avoided using internal APIs.
args, kwargs = jax.tree_unflatten(in_tree, args_flat)
_, out_tree = jax.tree_flatten(self._function(*args, **kwargs))
self._memoized_signatures[signature] = (binary, module, out_tree)
def _get_compiled_artifacts(self, args, kwargs):
"""Returns the binary, loaded rt module and out_tree."""
args_flat, in_tree = jax.tree_flatten((args, kwargs))
signature = self._get_signature(args_flat, in_tree)
if signature not in self._memoized_signatures:
self._wrap_and_compile(signature, args_flat, in_tree)
return self._memoized_signatures[signature]
def __call__(self, *args, **kwargs):
"""Executes the function on the provided inputs, compiling if necessary."""
args_flat, _ = jax.tree_flatten((args, kwargs))
# Use the uncompiled function if the inputs are being traced.
if any(issubclass(type(arg), jax.core.Tracer) for arg in args_flat):
return self._function(*args, **kwargs)
_, module, out_tree = self._get_compiled_artifacts(args, kwargs)
results = module.main(*args_flat)
if not isinstance(results, tuple):
results = (results,)
return jax.tree_unflatten(out_tree, results)
def get_binary(self, *args, **kwargs):
"""Gets the IREE-compiled binary for the given inputs."""
binary, _, _ = self._get_compiled_artifacts(args, kwargs)
return binary
def jit(function=None, *, backend: str = "llvmaot", **options):
"""Compiles a function to the specified IREE backend."""
if function is None:
# 'function' will be None if @jit() is called with parens (e.g. to specify a
# backend or **options). We return a partial function capturing these
# options, which python will then apply as a decorator, and execution will
# continue below.
return functools.partial(jit, backend=backend, **options)
# Parse the backend to more concrete compiler and runtime settings.
if backend not in _BACKENDS:
raise ValueError(
f"Expected backend to be one of {_BACKENDS}, but got '{backend}'")
driver = _BACKEND_TO_DRIVER[backend]
if "target_backends" not in options:
options["target_backends"] = _BACKEND_TO_TARGETS[backend]
return functools.wraps(function)(_JittedFunction(function, driver, **options))