Remove the legacy Jax API from the iree repo. (#8031) * Remove the legacy Jax API from the iree repo. This is being landed in the new iree-jax repo, alongside a more modern API: https://github.com/google/iree-jax/pull/2
diff --git a/.github/workflows/build_package.yml b/.github/workflows/build_package.yml index 724c4c3..8c096ba 100644 --- a/.github/workflows/build_package.yml +++ b/.github/workflows/build_package.yml
@@ -211,18 +211,6 @@ export CIBW_BEFORE_BUILD="python -m pip install --upgrade pip==21.3dev0 -f https://github.com/stellaraccident/pip/releases/tag/21.3dev20210925" python -m cibuildwheel --output-dir bindist ./main_checkout/llvm-external-projects/iree-compiler-api - # Pure python packages only need to do a minimal CMake configure and no - # actual building, aside from setting up sources. If there are multiples, - # it is fairly cheap to build them serially. Using cibuildwheel is a bit - # overkill for this, but we keep it the same as the others for maintenance - # value. - - name: Build pure python wheels - if: "matrix.build_package == 'py-pure-pkgs'" - shell: bash - run: | - python ./main_checkout/build_tools/github_actions/build_dist.py py-pure-pkgs - python -m pip wheel -w bindist --no-deps ./iree-install/python_packages/iree_jax - # Compiler tools wheels are not python version specific, so just build # for one examplar python version. - name: Build XLA Compiler Tools wheels
diff --git a/.github/workflows/validate_and_publish_release.yml b/.github/workflows/validate_and_publish_release.yml index 1236c48..a51386b 100644 --- a/.github/workflows/validate_and_publish_release.yml +++ b/.github/workflows/validate_and_publish_release.yml
@@ -41,7 +41,7 @@ - name: Install python packages id: install_python_packages run: | - python -m pip install -f file://$PWD/artifact/ iree-compiler-snapshot iree-jax-snapshot iree-runtime-snapshot iree-tools-tflite-snapshot iree-tools-tf-snapshot iree-tools-xla-snapshot + python -m pip install -f file://$PWD/artifact/ iree-compiler-snapshot iree-runtime-snapshot iree-tools-tflite-snapshot iree-tools-tf-snapshot iree-tools-xla-snapshot - name: Run iree-benchmark-module id: run_iree_benchmark_module run: ./bin/iree-benchmark-module --help
diff --git a/CMakeLists.txt b/CMakeLists.txt index 3e67822..ac03a37 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt
@@ -42,7 +42,6 @@ option(IREE_BUILD_SAMPLES "Builds IREE sample projects." ON) option(IREE_BUILD_TRACY "Builds tracy server tools." OFF) -option(IREE_BUILD_LEGACY_JAX "Builds the legacy JAX Python API" ON) option(IREE_BUILD_TENSORFLOW_ALL "Builds all TensorFlow compiler frontends." OFF) option(IREE_BUILD_TENSORFLOW_COMPILER "Builds TensorFlow compiler frontend." "${IREE_BUILD_TENSORFLOW_ALL}") option(IREE_BUILD_TFLITE_COMPILER "Builds the TFLite compiler frontend." "${IREE_BUILD_TENSORFLOW_ALL}")
diff --git a/bindings/python/CMakeLists.txt b/bindings/python/CMakeLists.txt index 77bf1ce..545a800 100644 --- a/bindings/python/CMakeLists.txt +++ b/bindings/python/CMakeLists.txt
@@ -11,8 +11,3 @@ # Namespace packages. add_subdirectory(iree/runtime) - -if(IREE_BUILD_LEGACY_JAX) - message(STATUS "Building legacy JAX API") - add_subdirectory(iree/jax) -endif()
diff --git a/bindings/python/iree/jax/CMakeLists.txt b/bindings/python/iree/jax/CMakeLists.txt deleted file mode 100644 index 08bd99a..0000000 --- a/bindings/python/iree/jax/CMakeLists.txt +++ /dev/null
@@ -1,31 +0,0 @@ -# Copyright 2021 The IREE Authors -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -iree_py_library( - NAME - jax - SRCS - "__init__.py" - "frontend.py" -) - -# Only enable the tests if the XLA compiler is built. -if(${IREE_BUILD_XLA_COMPILER}) -iree_py_test( - NAME - frontend_test - SRCS - "frontend_test.py" -) -endif() - -iree_py_install_package( - COMPONENT IreePythonPackage-jax - PACKAGE_NAME iree_jax - MODULE_PATH iree/jax - ADDL_PACKAGE_FILES - ${CMAKE_CURRENT_SOURCE_DIR}/README.md -)
diff --git a/bindings/python/iree/jax/README.md b/bindings/python/iree/jax/README.md deleted file mode 100644 index cba2485..0000000 --- a/bindings/python/iree/jax/README.md +++ /dev/null
@@ -1,106 +0,0 @@ -# IREE–JAX Frontend - -## Requirements - -A local JAX installation is necessary in addition to IREE's Python requirements: - -```shell -python -m pip install jax jaxlib -``` - -## Just In Time Compilation with Runtime Bindings - -A just-in-time compilation decorator similar to `jax.jit` is provided by -`iree.jax.jit`: - -```python -import iree.jax - -import jax -import jax.numpy as jnp - -# 'backend' is one of 'vmvx', 'llvmaot' and 'vulkan' and defaults to 'llvmaot'. -@iree.jax.jit(backend="llvmaot") -def linear_relu_layer(params, x): - w, b = params - return jnp.max(jnp.matmul(x, w) + b, 0) - -w = jnp.zeros((784, 128)) -b = jnp.zeros(128) -x = jnp.zeros((1, 784)) - -linear_relu_layer([w, b], x) -``` - -## Ahead of Time Compilation - -An ahead-of-time compilation function provides a lower-level API for compiling a -function with a specific input signature without creating the runtime bindings -for execution within Python. This is primarily useful for targeting other -runtime environments like Android. - -### Example: Compile a MLP and run it on Android - -Install the Android NDK according to the -[Android Getting Started](https://google.github.io/iree/get-started/getting-started-android-cmake) -doc, and then ensure the following environment variable is set: - -```shell -export ANDROID_NDK=# NDK install location -``` - -The code below assumes that you have `flax` installed. - -```python -import iree.jax - -import jax -import jax.numpy as jnp -import flax -from flax import linen as nn - - -class MLP(nn.Module): - - @nn.compact - def __call__(self, x): - x = x.reshape((x.shape[0], -1)) # Flatten. - x = nn.Dense(128)(x) - x = nn.relu(x) - x = nn.Dense(10)(x) - x = nn.log_softmax(x) - return x - - -image = jnp.zeros((1, 28, 28, 1)) -params = MLP().init(jax.random.PRNGKey(0), image)["params"] - -apply_args = [{"params": params}, image] -options = dict(target_backends=["dylib-llvm-aot"], - extra_args=["--iree-llvm-target-triple=aarch64-linux-android"]) - -compiled_binary = iree.jax.aot(MLP().apply, *apply_args, **options) - -with open("/tmp/mlp_apply.vmfb", "wb") as f: - f.write(compiled_binary) -``` - -IREE doesn't provide installable tools for Android at this time, so they'll need -to be built according to the -[Android Getting Started](https://google.github.io/iree/get-started/getting-started-android-cmake). -Afterward, the compiled `.vmfb` can be pushed to an Android device and executed -using `iree-run-module`: - -```shell -adb push /tmp/mlp_apply.vmfb /data/local/tmp/ -adb push ../iree-build-android/iree/tools/iree-run-module /data/local/tmp/ -adb shell /data/local/tmp/iree-run-module \ - --driver=dylib \ - --module_file=/data/local/tmp/mlp_apply.vmfb \ - --entry_function=main \ - --function_input=128xf32 \ - --function_input=784x128xf32 \ - --function_input=10xf32 \ - --function_input=128x10xf32 \ - --function_input=1x28x28x1xf32 -```
diff --git a/bindings/python/iree/jax/__init__.py b/bindings/python/iree/jax/__init__.py deleted file mode 100644 index c724179..0000000 --- a/bindings/python/iree/jax/__init__.py +++ /dev/null
@@ -1,7 +0,0 @@ -# Copyright 2021 The IREE Authors -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -from .frontend import *
diff --git a/bindings/python/iree/jax/frontend.py b/bindings/python/iree/jax/frontend.py deleted file mode 100644 index 8b225ee..0000000 --- a/bindings/python/iree/jax/frontend.py +++ /dev/null
@@ -1,145 +0,0 @@ -# Copyright 2021 The IREE Authors -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -import functools - -import iree.compiler.xla -import iree.runtime - -try: - import jax -except ModuleNotFoundError as e: - raise ModuleNotFoundError("iree.jax requires 'jax' and 'jaxlib' to be " - "installed in your python environment.") from e - -# pytype thinks iree.jax is jax. -# pytype: disable=module-attr - -__all__ = [ - "aot", - "is_available", - "jit", -] - -_BACKEND_TO_TARGETS = { - "vmvx": "vmvx", - "llvmaot": "dylib-llvm-aot", - "vulkan": "vulkan-spirv", -} -_BACKENDS = tuple(_BACKEND_TO_TARGETS.keys()) - - -def is_available(): - """Determine if the IREE–XLA compiler are available for JAX.""" - return iree.compiler.xla.is_available() - - -def aot(function, *args, **options): - """Traces and compiles a function, flattening the input args. - - This is intended to be a lower-level interface for compiling a JAX function to - IREE without setting up the runtime bindings to use it within Python. A common - usecase for this is compiling to Android (and similar targets). - - Args: - function: The function to compile. - args: The inputs to trace and compile the function for. - **kwargs: Keyword args corresponding to xla.ImportOptions or CompilerOptions - """ - xla_comp = jax.xla_computation(function)(*args) - hlo_proto = xla_comp.as_serialized_hlo_module_proto() - return iree.compiler.xla.compile_str(hlo_proto, **options) - - -# A more JAX-native approach to jitting would be desireable here, however -# implementing that reasonably would require using JAX internals, particularly -# jax.linear_util.WrappedFun and helpers. The following is sufficient for many -# usecases for the time being. - - -class _JittedFunction: - - def __init__(self, function, driver: str, **options): - self._function = function - self._driver_config = iree.runtime.Config(driver) - self._options = options - self._memoized_signatures = {} - - def _get_signature(self, args_flat, in_tree): - args_flat = [iree.runtime.normalize_value(arg) for arg in args_flat] - return tuple((arg.shape, arg.dtype) for arg in args_flat) + (in_tree,) - - def _wrap_and_compile(self, signature, args_flat, in_tree): - """Compiles the function for the given signature.""" - - def wrapped_function(*args_flat): - args, kwargs = jax.tree_unflatten(in_tree, args_flat) - return self._function(*args, **kwargs) - - # Compile the wrapped_function to IREE. - vm_flatbuffer = aot(wrapped_function, *args_flat, **self._options) - vm_module = iree.runtime.VmModule.from_flatbuffer(vm_flatbuffer) - module = iree.runtime.load_vm_module(vm_module, config=self._driver_config) - - # Get the output tree so it can be reconstructed from the outputs of the - # compiled module. Duplicating execution here isn't ideal, and could - # probably be avoided using internal APIs. - args, kwargs = jax.tree_unflatten(in_tree, args_flat) - _, out_tree = jax.tree_flatten(self._function(*args, **kwargs)) - - self._memoized_signatures[signature] = (module, out_tree) - - def _get_compiled_artifacts(self, args, kwargs): - """Returns the binary, loaded runtime module and out_tree.""" - args_flat, in_tree = jax.tree_flatten((args, kwargs)) - signature = self._get_signature(args_flat, in_tree) - - if signature not in self._memoized_signatures: - self._wrap_and_compile(signature, args_flat, in_tree) - return self._memoized_signatures[signature] - - def __call__(self, *args, **kwargs): - """Executes the function on the provided inputs, compiling if necessary.""" - args_flat, _ = jax.tree_flatten((args, kwargs)) - # Use the uncompiled function if the inputs are being traced. - if any(issubclass(type(arg), jax.core.Tracer) for arg in args_flat): - return self._function(*args, **kwargs) - - module, out_tree = self._get_compiled_artifacts(args, kwargs) - results = module.main(*args_flat) - if results is not None: - if not isinstance(results, tuple): - results = (results,) - return jax.tree_unflatten(out_tree, results) - else: - # Address IREE returning None instead of empty sequences. - if out_tree == jax.tree_flatten([])[-1]: - return [] - elif out_tree == jax.tree_flatten(())[-1]: - return () - else: - return results - - -def jit(function=None, *, backend: str = "llvmaot", **options): - """Compiles a function to the specified IREE backend.""" - if function is None: - # 'function' will be None if @jit() is called with parens (e.g. to specify a - # backend or **options). We return a partial function capturing these - # options, which python will then apply as a decorator, and execution will - # continue below. - return functools.partial(jit, backend=backend, **options) - - # Parse the backend to more concrete compiler and runtime settings. - if backend not in _BACKENDS: - raise ValueError( - f"Expected backend to be one of {_BACKENDS}, but got '{backend}'") - target_backend = _BACKEND_TO_TARGETS[backend] - driver = iree.runtime.TARGET_BACKEND_TO_DRIVER[target_backend] - if "target_backends" not in options: - options["target_backends"] = (target_backend,) - - return functools.wraps(function)(_JittedFunction(function, driver, **options))
diff --git a/bindings/python/iree/jax/frontend_test.py b/bindings/python/iree/jax/frontend_test.py deleted file mode 100644 index 3bff19d..0000000 --- a/bindings/python/iree/jax/frontend_test.py +++ /dev/null
@@ -1,171 +0,0 @@ -# Copyright 2021 The IREE Authors -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -from absl.testing import absltest -import iree.jax -import iree.runtime -import jax -import jax.numpy as jnp -import numpy as np - -# pytype thinks iree.jax is jax. -# pytype: disable=module-attr - -TOLERANCE = {"rtol": 1e-6, "atol": 1e-6} - - -def normal(shape): - return np.random.normal(0, 1, shape).astype(np.float32) - - -class SqrtNode: - - def __init__(self, x, y): - self.x = x - self.y = y - - def apply(self, z): - return self.x * jnp.sqrt(self.y * z) - - def tree_flatten(self): - return ((self.x, self.y), None) - - @classmethod - def tree_unflatten(cls, aux_data, children): - return cls(*children) - - -class SquareNode: - - def __init__(self, x, y): - self.x = x - self.y = y - - def apply(self, z): - return self.x * (self.y * z)**2 - - def tree_flatten(self): - return ((self.x, self.y), None) - - @classmethod - def tree_unflatten(cls, aux_data, children): - return cls(*children) - - -class JAXFrontendTest(absltest.TestCase): - - def test_aot_pytree(self): - - def pytree_func(params, x): - return jnp.max(jnp.matmul(x, params["w"]) + params["b"], 0) - - trace_args = [ - { - "w": jnp.zeros((32, 8)), - "b": jnp.zeros((8,)) - }, - jnp.zeros((1, 32)), - ] - binary = iree.jax.aot(pytree_func, *trace_args, target_backends=["vmvx"]) - - def test_jit_pytree_return(self): - - @iree.jax.jit - def apply_sqrt(pytree): - return jax.tree_map(jnp.sqrt, pytree) - - np.random.seed(0) - input_tree = { - "a": [ - normal((2, 3)), - { - "b": normal(3) - }, - ], - "c": ( - { - "d": [normal(2), normal(3)] - }, - (normal(1), normal(4)), - ) - } - - expected = jax.tree_map(jnp.sqrt, input_tree) - expected_arrays, expected_tree = jax.tree_flatten(expected) - result = apply_sqrt(input_tree) - result_arrays, result_tree = jax.tree_flatten(result) - - self.assertEqual(expected_tree, result_tree) - for expected_array, result_array in zip(expected_arrays, result_arrays): - np.testing.assert_allclose(expected_array, result_array, **TOLERANCE) - - def test_iree_jit_of_iree_jit(self): - - @iree.jax.jit - def add(a, b): - return a + b - - @iree.jax.jit - def mul_two(a): - return add(a, a) - - self.assertEqual(mul_two(3), 6) - - def test_jax_jit_of_iree_jit(self): - - @iree.jax.jit - def add(a, b): - return a + b - - @jax.jit - def mul_two(a): - return add(a, a) - - self.assertEqual(mul_two(3), 6) - - def test_iree_jit_of_jax_jit(self): - - @jax.jit - def add(a, b): - return a + b - - @iree.jax.jit - def mul_two(a): - return add(a, a) - - self.assertEqual(mul_two(3), 6) - - def test_iree_jit_of_empty_iree_jit(self): - - @iree.jax.jit - def sqrt_four(): - return jnp.sqrt(4) - - @iree.jax.jit - def add_sqrt_four(a): - return a + sqrt_four() - - self.assertEqual(add_sqrt_four(2), 4) - - def test_jit_pytree_method(self): - - @iree.jax.jit - def apply_node(node, z): - return node.apply(z) - - expected_sqrt = apply_node._function(SqrtNode(2, 3), 4) - compiled_sqrt = apply_node(SqrtNode(2, 3), 4) - np.testing.assert_allclose(compiled_sqrt, expected_sqrt, **TOLERANCE) - - expected_square = apply_node._function(SquareNode(2, 3), 4) - compiled_square = apply_node(SquareNode(2, 3), 4) - np.testing.assert_allclose(compiled_square, expected_square, **TOLERANCE) - - -if __name__ == "__main__": - jax.tree_util.register_pytree_node_class(SqrtNode) - jax.tree_util.register_pytree_node_class(SquareNode) - absltest.main()
diff --git a/bindings/python/iree/jax/setup.py.in b/bindings/python/iree/jax/setup.py.in deleted file mode 100644 index db31229..0000000 --- a/bindings/python/iree/jax/setup.py.in +++ /dev/null
@@ -1,45 +0,0 @@ -#!/usr/bin/python3 - -# Copyright 2021 The IREE Authors -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -from distutils.command.install import install -import os -import platform -from setuptools import setup, find_namespace_packages - -with open(os.path.join(os.path.dirname(__file__), "README.md"), "r") as f: - README = f.read() - -exe_suffix = ".exe" if platform.system() == "Windows" else "" - -setup( - name="iree-jax@IREE_RELEASE_PACKAGE_SUFFIX@", - version="@IREE_RELEASE_VERSION@", - author="The IREE Team", - author_email="iree-discuss@googlegroups.com", - license="Apache", - description="IREE JAX API", - long_description=README, - long_description_content_type="text/markdown", - url="https://github.com/google/iree", - classifiers=[ - "Programming Language :: Python :: 3", - "License :: OSI Approved :: Apache License", - "Operating System :: OS Independent", - "Development Status :: 3 - Alpha", - ], - python_requires=">=3.7", - packages=find_namespace_packages(include=["iree.jax"]), - zip_safe=True, - install_requires = [ - "jax", - "jaxlib", - "iree-compiler@IREE_RELEASE_PACKAGE_SUFFIX@==@IREE_RELEASE_VERSION@", - "iree-runtime@IREE_RELEASE_PACKAGE_SUFFIX@==@IREE_RELEASE_VERSION@", - "iree-tools-xla@IREE_RELEASE_PACKAGE_SUFFIX@==@IREE_RELEASE_VERSION@", - ], -)
diff --git a/bindings/python/iree/jax/version.py.in b/bindings/python/iree/jax/version.py.in deleted file mode 100644 index ed72ac3..0000000 --- a/bindings/python/iree/jax/version.py.in +++ /dev/null
@@ -1,9 +0,0 @@ -# Copyright 2021 The IREE Authors -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -PACKAGE_SUFFIX = "@IREE_RELEASE_PACKAGE_SUFFIX@" -VERSION = "@IREE_RELEASE_VERSION@" -REVISION = "@IREE_RELEASE_REVISION@"
diff --git a/build_tools/github_actions/build_dist.py b/build_tools/github_actions/build_dist.py index f9e9545..ca72085 100644 --- a/build_tools/github_actions/build_dist.py +++ b/build_tools/github_actions/build_dist.py
@@ -39,7 +39,6 @@ python ./main_checkout/build_tools/github_actions/build_dist.py main-dist python ./main_checkout/build_tools/github_actions/build_dist.py py-runtime-pkg - python ./main_checkout/build_tools/github_actions/build_dist.py py-pure-pkgs python ./main_checkout/build_tools/github_actions/build_dist.py py-xla-compiler-tools-pkg python ./main_checkout/build_tools/github_actions/build_dist.py py-tflite-compiler-tools-pkg python ./main_checkout/build_tools/github_actions/build_dist.py py-tf-compiler-tools-pkg @@ -168,48 +167,6 @@ tf.add(os.path.join(INSTALL_DIR, entry), arcname=entry, recursive=True) -def build_py_pure_pkgs(): - """Performs a minimal build sufficient to produce pure python packages. - - This installs the following packages: - - iree-install/python_packages/iree_jax - - Since these are pure python packages, it is expected that they will be built - on a single examplar (i.e. Linux) distribution. - """ - install_python_requirements() - - # Clean up install and build trees. - shutil.rmtree(INSTALL_DIR, ignore_errors=True) - remove_cmake_cache() - - # CMake configure. - print("*** Configuring ***") - subprocess.run([ - sys.executable, - CMAKE_CI_SCRIPT, - f"-B{BUILD_DIR}", - f"-DCMAKE_INSTALL_PREFIX={INSTALL_DIR}", - f"-DCMAKE_BUILD_TYPE=Release", - f"-DIREE_BUILD_COMPILER=OFF", - f"-DIREE_BUILD_PYTHON_BINDINGS=ON", - f"-DIREE_BUILD_SAMPLES=OFF", - f"-DIREE_BUILD_TESTS=OFF", - ], - check=True) - - print("*** Building ***") - subprocess.run([ - sys.executable, - CMAKE_CI_SCRIPT, - "--build", - BUILD_DIR, - "--target", - "install-IreePythonPackage-jax", - ], - check=True) - - def build_py_runtime_pkg(instrumented: bool = False): """Builds the iree-install/python_packages/iree_runtime package. @@ -420,8 +377,6 @@ build_py_runtime_pkg() elif command == "instrumented-py-runtime-pkg": build_py_runtime_pkg(instrumented=True) -elif command == "py-pure-pkgs": - build_py_pure_pkgs() elif command == "py-xla-compiler-tools-pkg": build_py_xla_compiler_tools_pkg() elif command == "py-tflite-compiler-tools-pkg":