Merge google -> main (#8027)
* 580800efa Synchronize submodules with LLVM at llvm/llvm-project@c5965a411c63
* c554bc99e Merge pull request #8025 from not-jenni:main-to-google
* 631e8e3cc Synchronize submodules with LLVM at llvm/llvm-project@c5965a411c63
* 28cd81b73 Integrate LLVM at llvm/llvm-project@c5965a411c63
* 01fe2ffdf Integrate LLVM at llvm/llvm-project@564bcf9d0243
* 723fe7351 Synchronize submodules with LLVM at llvm/llvm-project@b5149f4e66a4
* 82f2f2107 Merge pull request #8016 from not-jenni:main-to-google
diff --git a/.github/workflows/build_package.yml b/.github/workflows/build_package.yml
index 724c4c3..8c096ba 100644
--- a/.github/workflows/build_package.yml
+++ b/.github/workflows/build_package.yml
@@ -211,18 +211,6 @@
export CIBW_BEFORE_BUILD="python -m pip install --upgrade pip==21.3dev0 -f https://github.com/stellaraccident/pip/releases/tag/21.3dev20210925"
python -m cibuildwheel --output-dir bindist ./main_checkout/llvm-external-projects/iree-compiler-api
- # Pure python packages only need to do a minimal CMake configure and no
- # actual building, aside from setting up sources. If there are multiples,
- # it is fairly cheap to build them serially. Using cibuildwheel is a bit
- # overkill for this, but we keep it the same as the others for maintenance
- # value.
- - name: Build pure python wheels
- if: "matrix.build_package == 'py-pure-pkgs'"
- shell: bash
- run: |
- python ./main_checkout/build_tools/github_actions/build_dist.py py-pure-pkgs
- python -m pip wheel -w bindist --no-deps ./iree-install/python_packages/iree_jax
-
# Compiler tools wheels are not python version specific, so just build
# for one examplar python version.
- name: Build XLA Compiler Tools wheels
diff --git a/.github/workflows/validate_and_publish_release.yml b/.github/workflows/validate_and_publish_release.yml
index 1236c48..a51386b 100644
--- a/.github/workflows/validate_and_publish_release.yml
+++ b/.github/workflows/validate_and_publish_release.yml
@@ -41,7 +41,7 @@
- name: Install python packages
id: install_python_packages
run: |
- python -m pip install -f file://$PWD/artifact/ iree-compiler-snapshot iree-jax-snapshot iree-runtime-snapshot iree-tools-tflite-snapshot iree-tools-tf-snapshot iree-tools-xla-snapshot
+ python -m pip install -f file://$PWD/artifact/ iree-compiler-snapshot iree-runtime-snapshot iree-tools-tflite-snapshot iree-tools-tf-snapshot iree-tools-xla-snapshot
- name: Run iree-benchmark-module
id: run_iree_benchmark_module
run: ./bin/iree-benchmark-module --help
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 3e67822..ac03a37 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -42,7 +42,6 @@
option(IREE_BUILD_SAMPLES "Builds IREE sample projects." ON)
option(IREE_BUILD_TRACY "Builds tracy server tools." OFF)
-option(IREE_BUILD_LEGACY_JAX "Builds the legacy JAX Python API" ON)
option(IREE_BUILD_TENSORFLOW_ALL "Builds all TensorFlow compiler frontends." OFF)
option(IREE_BUILD_TENSORFLOW_COMPILER "Builds TensorFlow compiler frontend." "${IREE_BUILD_TENSORFLOW_ALL}")
option(IREE_BUILD_TFLITE_COMPILER "Builds the TFLite compiler frontend." "${IREE_BUILD_TENSORFLOW_ALL}")
diff --git a/bindings/python/CMakeLists.txt b/bindings/python/CMakeLists.txt
index 77bf1ce..545a800 100644
--- a/bindings/python/CMakeLists.txt
+++ b/bindings/python/CMakeLists.txt
@@ -11,8 +11,3 @@
# Namespace packages.
add_subdirectory(iree/runtime)
-
-if(IREE_BUILD_LEGACY_JAX)
- message(STATUS "Building legacy JAX API")
- add_subdirectory(iree/jax)
-endif()
diff --git a/bindings/python/iree/jax/CMakeLists.txt b/bindings/python/iree/jax/CMakeLists.txt
deleted file mode 100644
index 08bd99a..0000000
--- a/bindings/python/iree/jax/CMakeLists.txt
+++ /dev/null
@@ -1,31 +0,0 @@
-# Copyright 2021 The IREE Authors
-#
-# Licensed under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-iree_py_library(
- NAME
- jax
- SRCS
- "__init__.py"
- "frontend.py"
-)
-
-# Only enable the tests if the XLA compiler is built.
-if(${IREE_BUILD_XLA_COMPILER})
-iree_py_test(
- NAME
- frontend_test
- SRCS
- "frontend_test.py"
-)
-endif()
-
-iree_py_install_package(
- COMPONENT IreePythonPackage-jax
- PACKAGE_NAME iree_jax
- MODULE_PATH iree/jax
- ADDL_PACKAGE_FILES
- ${CMAKE_CURRENT_SOURCE_DIR}/README.md
-)
diff --git a/bindings/python/iree/jax/README.md b/bindings/python/iree/jax/README.md
deleted file mode 100644
index cba2485..0000000
--- a/bindings/python/iree/jax/README.md
+++ /dev/null
@@ -1,106 +0,0 @@
-# IREE–JAX Frontend
-
-## Requirements
-
-A local JAX installation is necessary in addition to IREE's Python requirements:
-
-```shell
-python -m pip install jax jaxlib
-```
-
-## Just In Time Compilation with Runtime Bindings
-
-A just-in-time compilation decorator similar to `jax.jit` is provided by
-`iree.jax.jit`:
-
-```python
-import iree.jax
-
-import jax
-import jax.numpy as jnp
-
-# 'backend' is one of 'vmvx', 'llvmaot' and 'vulkan' and defaults to 'llvmaot'.
-@iree.jax.jit(backend="llvmaot")
-def linear_relu_layer(params, x):
- w, b = params
- return jnp.max(jnp.matmul(x, w) + b, 0)
-
-w = jnp.zeros((784, 128))
-b = jnp.zeros(128)
-x = jnp.zeros((1, 784))
-
-linear_relu_layer([w, b], x)
-```
-
-## Ahead of Time Compilation
-
-An ahead-of-time compilation function provides a lower-level API for compiling a
-function with a specific input signature without creating the runtime bindings
-for execution within Python. This is primarily useful for targeting other
-runtime environments like Android.
-
-### Example: Compile a MLP and run it on Android
-
-Install the Android NDK according to the
-[Android Getting Started](https://google.github.io/iree/get-started/getting-started-android-cmake)
-doc, and then ensure the following environment variable is set:
-
-```shell
-export ANDROID_NDK=# NDK install location
-```
-
-The code below assumes that you have `flax` installed.
-
-```python
-import iree.jax
-
-import jax
-import jax.numpy as jnp
-import flax
-from flax import linen as nn
-
-
-class MLP(nn.Module):
-
- @nn.compact
- def __call__(self, x):
- x = x.reshape((x.shape[0], -1)) # Flatten.
- x = nn.Dense(128)(x)
- x = nn.relu(x)
- x = nn.Dense(10)(x)
- x = nn.log_softmax(x)
- return x
-
-
-image = jnp.zeros((1, 28, 28, 1))
-params = MLP().init(jax.random.PRNGKey(0), image)["params"]
-
-apply_args = [{"params": params}, image]
-options = dict(target_backends=["dylib-llvm-aot"],
- extra_args=["--iree-llvm-target-triple=aarch64-linux-android"])
-
-compiled_binary = iree.jax.aot(MLP().apply, *apply_args, **options)
-
-with open("/tmp/mlp_apply.vmfb", "wb") as f:
- f.write(compiled_binary)
-```
-
-IREE doesn't provide installable tools for Android at this time, so they'll need
-to be built according to the
-[Android Getting Started](https://google.github.io/iree/get-started/getting-started-android-cmake).
-Afterward, the compiled `.vmfb` can be pushed to an Android device and executed
-using `iree-run-module`:
-
-```shell
-adb push /tmp/mlp_apply.vmfb /data/local/tmp/
-adb push ../iree-build-android/iree/tools/iree-run-module /data/local/tmp/
-adb shell /data/local/tmp/iree-run-module \
- --driver=dylib \
- --module_file=/data/local/tmp/mlp_apply.vmfb \
- --entry_function=main \
- --function_input=128xf32 \
- --function_input=784x128xf32 \
- --function_input=10xf32 \
- --function_input=128x10xf32 \
- --function_input=1x28x28x1xf32
-```
diff --git a/bindings/python/iree/jax/__init__.py b/bindings/python/iree/jax/__init__.py
deleted file mode 100644
index c724179..0000000
--- a/bindings/python/iree/jax/__init__.py
+++ /dev/null
@@ -1,7 +0,0 @@
-# Copyright 2021 The IREE Authors
-#
-# Licensed under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-from .frontend import *
diff --git a/bindings/python/iree/jax/frontend.py b/bindings/python/iree/jax/frontend.py
deleted file mode 100644
index 8b225ee..0000000
--- a/bindings/python/iree/jax/frontend.py
+++ /dev/null
@@ -1,145 +0,0 @@
-# Copyright 2021 The IREE Authors
-#
-# Licensed under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-import functools
-
-import iree.compiler.xla
-import iree.runtime
-
-try:
- import jax
-except ModuleNotFoundError as e:
- raise ModuleNotFoundError("iree.jax requires 'jax' and 'jaxlib' to be "
- "installed in your python environment.") from e
-
-# pytype thinks iree.jax is jax.
-# pytype: disable=module-attr
-
-__all__ = [
- "aot",
- "is_available",
- "jit",
-]
-
-_BACKEND_TO_TARGETS = {
- "vmvx": "vmvx",
- "llvmaot": "dylib-llvm-aot",
- "vulkan": "vulkan-spirv",
-}
-_BACKENDS = tuple(_BACKEND_TO_TARGETS.keys())
-
-
-def is_available():
- """Determine if the IREE–XLA compiler are available for JAX."""
- return iree.compiler.xla.is_available()
-
-
-def aot(function, *args, **options):
- """Traces and compiles a function, flattening the input args.
-
- This is intended to be a lower-level interface for compiling a JAX function to
- IREE without setting up the runtime bindings to use it within Python. A common
- usecase for this is compiling to Android (and similar targets).
-
- Args:
- function: The function to compile.
- args: The inputs to trace and compile the function for.
- **kwargs: Keyword args corresponding to xla.ImportOptions or CompilerOptions
- """
- xla_comp = jax.xla_computation(function)(*args)
- hlo_proto = xla_comp.as_serialized_hlo_module_proto()
- return iree.compiler.xla.compile_str(hlo_proto, **options)
-
-
-# A more JAX-native approach to jitting would be desireable here, however
-# implementing that reasonably would require using JAX internals, particularly
-# jax.linear_util.WrappedFun and helpers. The following is sufficient for many
-# usecases for the time being.
-
-
-class _JittedFunction:
-
- def __init__(self, function, driver: str, **options):
- self._function = function
- self._driver_config = iree.runtime.Config(driver)
- self._options = options
- self._memoized_signatures = {}
-
- def _get_signature(self, args_flat, in_tree):
- args_flat = [iree.runtime.normalize_value(arg) for arg in args_flat]
- return tuple((arg.shape, arg.dtype) for arg in args_flat) + (in_tree,)
-
- def _wrap_and_compile(self, signature, args_flat, in_tree):
- """Compiles the function for the given signature."""
-
- def wrapped_function(*args_flat):
- args, kwargs = jax.tree_unflatten(in_tree, args_flat)
- return self._function(*args, **kwargs)
-
- # Compile the wrapped_function to IREE.
- vm_flatbuffer = aot(wrapped_function, *args_flat, **self._options)
- vm_module = iree.runtime.VmModule.from_flatbuffer(vm_flatbuffer)
- module = iree.runtime.load_vm_module(vm_module, config=self._driver_config)
-
- # Get the output tree so it can be reconstructed from the outputs of the
- # compiled module. Duplicating execution here isn't ideal, and could
- # probably be avoided using internal APIs.
- args, kwargs = jax.tree_unflatten(in_tree, args_flat)
- _, out_tree = jax.tree_flatten(self._function(*args, **kwargs))
-
- self._memoized_signatures[signature] = (module, out_tree)
-
- def _get_compiled_artifacts(self, args, kwargs):
- """Returns the binary, loaded runtime module and out_tree."""
- args_flat, in_tree = jax.tree_flatten((args, kwargs))
- signature = self._get_signature(args_flat, in_tree)
-
- if signature not in self._memoized_signatures:
- self._wrap_and_compile(signature, args_flat, in_tree)
- return self._memoized_signatures[signature]
-
- def __call__(self, *args, **kwargs):
- """Executes the function on the provided inputs, compiling if necessary."""
- args_flat, _ = jax.tree_flatten((args, kwargs))
- # Use the uncompiled function if the inputs are being traced.
- if any(issubclass(type(arg), jax.core.Tracer) for arg in args_flat):
- return self._function(*args, **kwargs)
-
- module, out_tree = self._get_compiled_artifacts(args, kwargs)
- results = module.main(*args_flat)
- if results is not None:
- if not isinstance(results, tuple):
- results = (results,)
- return jax.tree_unflatten(out_tree, results)
- else:
- # Address IREE returning None instead of empty sequences.
- if out_tree == jax.tree_flatten([])[-1]:
- return []
- elif out_tree == jax.tree_flatten(())[-1]:
- return ()
- else:
- return results
-
-
-def jit(function=None, *, backend: str = "llvmaot", **options):
- """Compiles a function to the specified IREE backend."""
- if function is None:
- # 'function' will be None if @jit() is called with parens (e.g. to specify a
- # backend or **options). We return a partial function capturing these
- # options, which python will then apply as a decorator, and execution will
- # continue below.
- return functools.partial(jit, backend=backend, **options)
-
- # Parse the backend to more concrete compiler and runtime settings.
- if backend not in _BACKENDS:
- raise ValueError(
- f"Expected backend to be one of {_BACKENDS}, but got '{backend}'")
- target_backend = _BACKEND_TO_TARGETS[backend]
- driver = iree.runtime.TARGET_BACKEND_TO_DRIVER[target_backend]
- if "target_backends" not in options:
- options["target_backends"] = (target_backend,)
-
- return functools.wraps(function)(_JittedFunction(function, driver, **options))
diff --git a/bindings/python/iree/jax/frontend_test.py b/bindings/python/iree/jax/frontend_test.py
deleted file mode 100644
index 3bff19d..0000000
--- a/bindings/python/iree/jax/frontend_test.py
+++ /dev/null
@@ -1,171 +0,0 @@
-# Copyright 2021 The IREE Authors
-#
-# Licensed under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-from absl.testing import absltest
-import iree.jax
-import iree.runtime
-import jax
-import jax.numpy as jnp
-import numpy as np
-
-# pytype thinks iree.jax is jax.
-# pytype: disable=module-attr
-
-TOLERANCE = {"rtol": 1e-6, "atol": 1e-6}
-
-
-def normal(shape):
- return np.random.normal(0, 1, shape).astype(np.float32)
-
-
-class SqrtNode:
-
- def __init__(self, x, y):
- self.x = x
- self.y = y
-
- def apply(self, z):
- return self.x * jnp.sqrt(self.y * z)
-
- def tree_flatten(self):
- return ((self.x, self.y), None)
-
- @classmethod
- def tree_unflatten(cls, aux_data, children):
- return cls(*children)
-
-
-class SquareNode:
-
- def __init__(self, x, y):
- self.x = x
- self.y = y
-
- def apply(self, z):
- return self.x * (self.y * z)**2
-
- def tree_flatten(self):
- return ((self.x, self.y), None)
-
- @classmethod
- def tree_unflatten(cls, aux_data, children):
- return cls(*children)
-
-
-class JAXFrontendTest(absltest.TestCase):
-
- def test_aot_pytree(self):
-
- def pytree_func(params, x):
- return jnp.max(jnp.matmul(x, params["w"]) + params["b"], 0)
-
- trace_args = [
- {
- "w": jnp.zeros((32, 8)),
- "b": jnp.zeros((8,))
- },
- jnp.zeros((1, 32)),
- ]
- binary = iree.jax.aot(pytree_func, *trace_args, target_backends=["vmvx"])
-
- def test_jit_pytree_return(self):
-
- @iree.jax.jit
- def apply_sqrt(pytree):
- return jax.tree_map(jnp.sqrt, pytree)
-
- np.random.seed(0)
- input_tree = {
- "a": [
- normal((2, 3)),
- {
- "b": normal(3)
- },
- ],
- "c": (
- {
- "d": [normal(2), normal(3)]
- },
- (normal(1), normal(4)),
- )
- }
-
- expected = jax.tree_map(jnp.sqrt, input_tree)
- expected_arrays, expected_tree = jax.tree_flatten(expected)
- result = apply_sqrt(input_tree)
- result_arrays, result_tree = jax.tree_flatten(result)
-
- self.assertEqual(expected_tree, result_tree)
- for expected_array, result_array in zip(expected_arrays, result_arrays):
- np.testing.assert_allclose(expected_array, result_array, **TOLERANCE)
-
- def test_iree_jit_of_iree_jit(self):
-
- @iree.jax.jit
- def add(a, b):
- return a + b
-
- @iree.jax.jit
- def mul_two(a):
- return add(a, a)
-
- self.assertEqual(mul_two(3), 6)
-
- def test_jax_jit_of_iree_jit(self):
-
- @iree.jax.jit
- def add(a, b):
- return a + b
-
- @jax.jit
- def mul_two(a):
- return add(a, a)
-
- self.assertEqual(mul_two(3), 6)
-
- def test_iree_jit_of_jax_jit(self):
-
- @jax.jit
- def add(a, b):
- return a + b
-
- @iree.jax.jit
- def mul_two(a):
- return add(a, a)
-
- self.assertEqual(mul_two(3), 6)
-
- def test_iree_jit_of_empty_iree_jit(self):
-
- @iree.jax.jit
- def sqrt_four():
- return jnp.sqrt(4)
-
- @iree.jax.jit
- def add_sqrt_four(a):
- return a + sqrt_four()
-
- self.assertEqual(add_sqrt_four(2), 4)
-
- def test_jit_pytree_method(self):
-
- @iree.jax.jit
- def apply_node(node, z):
- return node.apply(z)
-
- expected_sqrt = apply_node._function(SqrtNode(2, 3), 4)
- compiled_sqrt = apply_node(SqrtNode(2, 3), 4)
- np.testing.assert_allclose(compiled_sqrt, expected_sqrt, **TOLERANCE)
-
- expected_square = apply_node._function(SquareNode(2, 3), 4)
- compiled_square = apply_node(SquareNode(2, 3), 4)
- np.testing.assert_allclose(compiled_square, expected_square, **TOLERANCE)
-
-
-if __name__ == "__main__":
- jax.tree_util.register_pytree_node_class(SqrtNode)
- jax.tree_util.register_pytree_node_class(SquareNode)
- absltest.main()
diff --git a/bindings/python/iree/jax/setup.py.in b/bindings/python/iree/jax/setup.py.in
deleted file mode 100644
index db31229..0000000
--- a/bindings/python/iree/jax/setup.py.in
+++ /dev/null
@@ -1,45 +0,0 @@
-#!/usr/bin/python3
-
-# Copyright 2021 The IREE Authors
-#
-# Licensed under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-from distutils.command.install import install
-import os
-import platform
-from setuptools import setup, find_namespace_packages
-
-with open(os.path.join(os.path.dirname(__file__), "README.md"), "r") as f:
- README = f.read()
-
-exe_suffix = ".exe" if platform.system() == "Windows" else ""
-
-setup(
- name="iree-jax@IREE_RELEASE_PACKAGE_SUFFIX@",
- version="@IREE_RELEASE_VERSION@",
- author="The IREE Team",
- author_email="iree-discuss@googlegroups.com",
- license="Apache",
- description="IREE JAX API",
- long_description=README,
- long_description_content_type="text/markdown",
- url="https://github.com/google/iree",
- classifiers=[
- "Programming Language :: Python :: 3",
- "License :: OSI Approved :: Apache License",
- "Operating System :: OS Independent",
- "Development Status :: 3 - Alpha",
- ],
- python_requires=">=3.7",
- packages=find_namespace_packages(include=["iree.jax"]),
- zip_safe=True,
- install_requires = [
- "jax",
- "jaxlib",
- "iree-compiler@IREE_RELEASE_PACKAGE_SUFFIX@==@IREE_RELEASE_VERSION@",
- "iree-runtime@IREE_RELEASE_PACKAGE_SUFFIX@==@IREE_RELEASE_VERSION@",
- "iree-tools-xla@IREE_RELEASE_PACKAGE_SUFFIX@==@IREE_RELEASE_VERSION@",
- ],
-)
diff --git a/bindings/python/iree/jax/version.py.in b/bindings/python/iree/jax/version.py.in
deleted file mode 100644
index ed72ac3..0000000
--- a/bindings/python/iree/jax/version.py.in
+++ /dev/null
@@ -1,9 +0,0 @@
-# Copyright 2021 The IREE Authors
-#
-# Licensed under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-PACKAGE_SUFFIX = "@IREE_RELEASE_PACKAGE_SUFFIX@"
-VERSION = "@IREE_RELEASE_VERSION@"
-REVISION = "@IREE_RELEASE_REVISION@"
diff --git a/build_tools/github_actions/build_dist.py b/build_tools/github_actions/build_dist.py
index f9e9545..ca72085 100644
--- a/build_tools/github_actions/build_dist.py
+++ b/build_tools/github_actions/build_dist.py
@@ -39,7 +39,6 @@
python ./main_checkout/build_tools/github_actions/build_dist.py main-dist
python ./main_checkout/build_tools/github_actions/build_dist.py py-runtime-pkg
- python ./main_checkout/build_tools/github_actions/build_dist.py py-pure-pkgs
python ./main_checkout/build_tools/github_actions/build_dist.py py-xla-compiler-tools-pkg
python ./main_checkout/build_tools/github_actions/build_dist.py py-tflite-compiler-tools-pkg
python ./main_checkout/build_tools/github_actions/build_dist.py py-tf-compiler-tools-pkg
@@ -168,48 +167,6 @@
tf.add(os.path.join(INSTALL_DIR, entry), arcname=entry, recursive=True)
-def build_py_pure_pkgs():
- """Performs a minimal build sufficient to produce pure python packages.
-
- This installs the following packages:
- - iree-install/python_packages/iree_jax
-
- Since these are pure python packages, it is expected that they will be built
- on a single examplar (i.e. Linux) distribution.
- """
- install_python_requirements()
-
- # Clean up install and build trees.
- shutil.rmtree(INSTALL_DIR, ignore_errors=True)
- remove_cmake_cache()
-
- # CMake configure.
- print("*** Configuring ***")
- subprocess.run([
- sys.executable,
- CMAKE_CI_SCRIPT,
- f"-B{BUILD_DIR}",
- f"-DCMAKE_INSTALL_PREFIX={INSTALL_DIR}",
- f"-DCMAKE_BUILD_TYPE=Release",
- f"-DIREE_BUILD_COMPILER=OFF",
- f"-DIREE_BUILD_PYTHON_BINDINGS=ON",
- f"-DIREE_BUILD_SAMPLES=OFF",
- f"-DIREE_BUILD_TESTS=OFF",
- ],
- check=True)
-
- print("*** Building ***")
- subprocess.run([
- sys.executable,
- CMAKE_CI_SCRIPT,
- "--build",
- BUILD_DIR,
- "--target",
- "install-IreePythonPackage-jax",
- ],
- check=True)
-
-
def build_py_runtime_pkg(instrumented: bool = False):
"""Builds the iree-install/python_packages/iree_runtime package.
@@ -420,8 +377,6 @@
build_py_runtime_pkg()
elif command == "instrumented-py-runtime-pkg":
build_py_runtime_pkg(instrumented=True)
-elif command == "py-pure-pkgs":
- build_py_pure_pkgs()
elif command == "py-xla-compiler-tools-pkg":
build_py_xla_compiler_tools_pkg()
elif command == "py-tflite-compiler-tools-pkg":
diff --git a/experimental/rocm/cts/CMakeLists.txt b/experimental/rocm/cts/CMakeLists.txt
index e5d7d07..e1b6e46 100644
--- a/experimental/rocm/cts/CMakeLists.txt
+++ b/experimental/rocm/cts/CMakeLists.txt
@@ -17,4 +17,13 @@
"\"PTXE\""
DEPS
experimental::rocm::registration
+ EXCLUDED_TESTS
+ # This test depends on iree_hal_rocm_direct_command_buffer_update_buffer
+ # via iree_hal_buffer_view_allocate_buffer, which is not implemented yet.
+ "command_buffer_dispatch"
+ # Non-push descriptor sets are not implemented in the ROCm backend yet.
+ "descriptor_set"
+ # Semaphores are not implemented in the ROCm backend yet.
+ "semaphore_submission"
+ "semaphore"
)
diff --git a/iree/hal/cts/CMakeLists.txt b/iree/hal/cts/CMakeLists.txt
index dea7cfd..0216567 100644
--- a/iree/hal/cts/CMakeLists.txt
+++ b/iree/hal/cts/CMakeLists.txt
@@ -8,14 +8,15 @@
"allocator"
"buffer_mapping"
"command_buffer"
- "descriptor_set_layout"
+ "command_buffer_dispatch"
"descriptor_set"
+ "descriptor_set_layout"
"driver"
"event"
"executable_cache"
"executable_layout"
- "semaphore_submission"
"semaphore"
+ "semaphore_submission"
PARENT_SCOPE
)
@@ -23,12 +24,14 @@
# If the compiler is disabled or a HAL driver implementation is not yet
# connected to a functional compiler target, these tests can be skipped.
set(IREE_EXECUTABLE_CTS_TESTS
+ "command_buffer_dispatch"
"executable_cache"
PARENT_SCOPE
)
# List of testdata/{name}.mlir source files.
set(IREE_ALL_CTS_EXECUTABLE_SOURCES
+ "command_buffer_dispatch_test"
"executable_cache_test"
PARENT_SCOPE
)
@@ -84,6 +87,18 @@
iree_cc_library(
NAME
+ command_buffer_dispatch_test_library
+ HDRS
+ "command_buffer_dispatch_test.h"
+ DEPS
+ ::cts_test_base
+ iree::base
+ iree::hal
+ iree::testing::gtest
+)
+
+iree_cc_library(
+ NAME
descriptor_set_test_library
HDRS
"descriptor_set_test.h"
diff --git a/iree/hal/cts/command_buffer_dispatch_test.h b/iree/hal/cts/command_buffer_dispatch_test.h
new file mode 100644
index 0000000..0ec2a85
--- /dev/null
+++ b/iree/hal/cts/command_buffer_dispatch_test.h
@@ -0,0 +1,143 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_HAL_CTS_COMMAND_BUFFER_DISPATCH_TEST_H_
+#define IREE_HAL_CTS_COMMAND_BUFFER_DISPATCH_TEST_H_
+
+#include "iree/base/api.h"
+#include "iree/base/string_view.h"
+#include "iree/hal/api.h"
+#include "iree/hal/cts/cts_test_base.h"
+#include "iree/testing/gtest.h"
+#include "iree/testing/status_matchers.h"
+
+namespace iree {
+namespace hal {
+namespace cts {
+
+class command_buffer_dispatch_test : public CtsTestBase {
+ protected:
+ void PrepareAbsExecutable() {
+ IREE_ASSERT_OK(iree_hal_executable_cache_create(
+ device_, iree_make_cstring_view("default"), &executable_cache_));
+
+ iree_hal_descriptor_set_layout_binding_t descriptor_set_layout_bindings[] =
+ {
+ {0, IREE_HAL_DESCRIPTOR_TYPE_STORAGE_BUFFER},
+ {1, IREE_HAL_DESCRIPTOR_TYPE_STORAGE_BUFFER},
+ };
+ IREE_ASSERT_OK(iree_hal_descriptor_set_layout_create(
+ device_, IREE_HAL_DESCRIPTOR_SET_LAYOUT_USAGE_TYPE_PUSH_ONLY,
+ IREE_ARRAYSIZE(descriptor_set_layout_bindings),
+ descriptor_set_layout_bindings, &descriptor_set_layout_));
+ IREE_ASSERT_OK(iree_hal_executable_layout_create(
+ device_, /*push_constants=*/0, /*set_layout_count=*/1,
+ &descriptor_set_layout_, &executable_layout_));
+
+ iree_hal_executable_spec_t executable_spec;
+ executable_spec.caching_mode =
+ IREE_HAL_EXECUTABLE_CACHING_MODE_ALIAS_PROVIDED_DATA;
+ executable_spec.executable_format =
+ iree_make_cstring_view(get_test_executable_format());
+ executable_spec.executable_data = get_test_executable_data(
+ iree_make_cstring_view("command_buffer_dispatch_test.bin"));
+ executable_spec.executable_layout_count = 1;
+ executable_spec.executable_layouts = &executable_layout_;
+
+ IREE_ASSERT_OK(iree_hal_executable_cache_prepare_executable(
+ executable_cache_, &executable_spec, &executable_));
+ }
+
+ void CleanupExecutable() {
+ iree_hal_executable_release(executable_);
+ iree_hal_executable_layout_release(executable_layout_);
+ iree_hal_descriptor_set_layout_release(descriptor_set_layout_);
+ iree_hal_executable_cache_release(executable_cache_);
+ }
+
+ iree_hal_executable_cache_t* executable_cache_ = NULL;
+ iree_hal_descriptor_set_layout_t* descriptor_set_layout_ = NULL;
+ iree_hal_executable_layout_t* executable_layout_ = NULL;
+ iree_hal_executable_t* executable_ = NULL;
+};
+
+TEST_P(command_buffer_dispatch_test, DispatchAbs) {
+ PrepareAbsExecutable();
+
+ iree_hal_command_buffer_t* command_buffer;
+ IREE_ASSERT_OK(iree_hal_command_buffer_create(
+ device_,
+ IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT |
+ IREE_HAL_COMMAND_BUFFER_MODE_ALLOW_INLINE_EXECUTION,
+ IREE_HAL_COMMAND_CATEGORY_DISPATCH, IREE_HAL_QUEUE_AFFINITY_ANY,
+ &command_buffer));
+
+ IREE_ASSERT_OK(iree_hal_command_buffer_begin(command_buffer));
+
+ // Create input and output buffers.
+ iree_hal_buffer_view_t* input_buffer_view = NULL;
+ float input_data[1] = {-2.5f};
+ IREE_ASSERT_OK(iree_hal_buffer_view_allocate_buffer(
+ device_allocator_, /*shape=*/NULL,
+ /*shape_rank=*/0, IREE_HAL_ELEMENT_TYPE_FLOAT_32,
+ IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR,
+ IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL | IREE_HAL_MEMORY_TYPE_HOST_VISIBLE,
+ IREE_HAL_BUFFER_USAGE_DISPATCH | IREE_HAL_BUFFER_USAGE_TRANSFER,
+ iree_make_const_byte_span((void*)input_data, sizeof(input_data)),
+ &input_buffer_view));
+ iree_hal_buffer_t* output_buffer = NULL;
+ IREE_ASSERT_OK(iree_hal_allocator_allocate_buffer(
+ device_allocator_,
+ IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL | IREE_HAL_MEMORY_TYPE_HOST_VISIBLE,
+ IREE_HAL_BUFFER_USAGE_DISPATCH | IREE_HAL_BUFFER_USAGE_MAPPING,
+ sizeof(float), iree_const_byte_span_empty(), &output_buffer));
+
+ iree_hal_descriptor_set_binding_t descriptor_set_bindings[] = {
+ {/*binding=*/0, iree_hal_buffer_view_buffer(input_buffer_view),
+ /*offset=*/0, iree_hal_buffer_view_byte_length(input_buffer_view)},
+ {/*binding=*/1, output_buffer, iree_hal_buffer_byte_offset(output_buffer),
+ iree_hal_buffer_byte_length(output_buffer)},
+ };
+
+ IREE_ASSERT_OK(iree_hal_command_buffer_push_descriptor_set(
+ command_buffer, executable_layout_, /*set=*/0,
+ IREE_ARRAYSIZE(descriptor_set_bindings), descriptor_set_bindings));
+
+ IREE_ASSERT_OK(iree_hal_command_buffer_dispatch(
+ command_buffer, executable_, /*entry_point=*/0,
+ /*workgroup_x=*/1, /*workgroup_y=*/1, /*workgroup_z=*/1));
+ IREE_ASSERT_OK(iree_hal_command_buffer_execution_barrier(
+ command_buffer,
+ /*source_stage_mask=*/IREE_HAL_EXECUTION_STAGE_DISPATCH |
+ IREE_HAL_EXECUTION_STAGE_TRANSFER |
+ IREE_HAL_EXECUTION_STAGE_COMMAND_RETIRE,
+ /*target_stage_mask=*/IREE_HAL_EXECUTION_STAGE_COMMAND_ISSUE |
+ IREE_HAL_EXECUTION_STAGE_DISPATCH | IREE_HAL_EXECUTION_STAGE_TRANSFER,
+ IREE_HAL_EXECUTION_BARRIER_FLAG_NONE, /*memory_barrier_count=*/0,
+ /*memory_barriers=*/NULL,
+ /*buffer_barrier_count=*/0, /*buffer_barriers=*/NULL));
+
+ IREE_ASSERT_OK(iree_hal_command_buffer_end(command_buffer));
+
+ IREE_ASSERT_OK(SubmitCommandBufferAndWait(IREE_HAL_COMMAND_CATEGORY_DISPATCH,
+ command_buffer));
+
+ float out_value;
+ IREE_ASSERT_OK(iree_hal_buffer_read_data(output_buffer, /*source_offset=*/0,
+ &out_value, sizeof(out_value)));
+ EXPECT_EQ(2.5f, out_value);
+
+ iree_hal_command_buffer_release(command_buffer);
+ iree_hal_buffer_release(output_buffer);
+ iree_hal_buffer_view_release(input_buffer_view);
+ CleanupExecutable();
+}
+
+} // namespace cts
+} // namespace hal
+} // namespace iree
+
+#endif // IREE_HAL_CTS_COMMAND_BUFFER_DISPATCH_TEST_H_
diff --git a/iree/hal/cts/command_buffer_test.h b/iree/hal/cts/command_buffer_test.h
index 4bd0d7b..1dc9f49 100644
--- a/iree/hal/cts/command_buffer_test.h
+++ b/iree/hal/cts/command_buffer_test.h
@@ -16,12 +16,6 @@
#include "iree/testing/gtest.h"
#include "iree/testing/status_matchers.h"
-// TODO(scotttodd): split into several tests, for example:
-// command_buffer_recording_test (recording/lifetime)
-// command_buffer_dispatch_test
-// command_buffer_fill_test (filling buffers)
-// command_buffer_e2e_test (barriers, dispatches)
-
namespace iree {
namespace hal {
namespace cts {
diff --git a/iree/hal/cts/cts_test_template.cc.in b/iree/hal/cts/cts_test_template.cc.in
index be02a15..7783f5e 100644
--- a/iree/hal/cts/cts_test_template.cc.in
+++ b/iree/hal/cts/cts_test_template.cc.in
@@ -51,7 +51,7 @@
}
// TODO(scotttodd): error handling / reporting? This a sharp edge.
#endif
- return {NULL, 0};
+ return iree_const_byte_span_empty();
}
INSTANTIATE_TEST_SUITE_P(CTS, IREE_CTS_TEST_CLASS_NAME,
diff --git a/iree/hal/cts/descriptor_set_test.h b/iree/hal/cts/descriptor_set_test.h
index 4cb811a..323f533 100644
--- a/iree/hal/cts/descriptor_set_test.h
+++ b/iree/hal/cts/descriptor_set_test.h
@@ -19,10 +19,7 @@
class descriptor_set_test : public CtsTestBase {};
-// TODO(scotttodd): enable once any driver implements non-push descriptor sets
-// * also test with buffers in the bindings
-// * also test usage in iree_hal_command_buffer_bind_descriptor_set
-TEST_P(descriptor_set_test, DISABLED_CreateWithTwoBindings) {
+TEST_P(descriptor_set_test, CreateWithTwoBindings) {
iree_hal_descriptor_set_layout_t* descriptor_set_layout;
iree_hal_descriptor_set_layout_binding_t descriptor_set_layout_bindings[] = {
{/*binding=*/0, /*type=*/IREE_HAL_DESCRIPTOR_TYPE_STORAGE_BUFFER},
diff --git a/iree/hal/cts/testdata/command_buffer_dispatch_test.mlir b/iree/hal/cts/testdata/command_buffer_dispatch_test.mlir
new file mode 100644
index 0000000..0e0e73e
--- /dev/null
+++ b/iree/hal/cts/testdata/command_buffer_dispatch_test.mlir
@@ -0,0 +1,37 @@
+// Bootstrapped from this source IR:
+//
+// func @abs(%input : tensor<f32>) -> (tensor<f32>) {
+// %result = math.abs %input : tensor<f32>
+// return %result : tensor<f32>
+// }
+
+#executable_layout = #hal.executable.layout<push_constants = 0, sets = [
+ #hal.descriptor_set.layout<0, bindings = [
+ #hal.descriptor_set.binding<0, storage_buffer>,
+ #hal.descriptor_set.binding<1, storage_buffer>
+ ]>
+]>
+
+hal.executable.source public @executable {
+ hal.executable.entry_point public @abs layout(#executable_layout)
+
+ builtin.module {
+ func @abs() {
+ %c0 = arith.constant 0 : index
+
+ %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(32) : !flow.dispatch.tensor<readonly:f32>
+ %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(32) : !flow.dispatch.tensor<writeonly:f32>
+
+ %2 = flow.dispatch.tensor.load %0, offsets = [], sizes = [], strides = [] : !flow.dispatch.tensor<readonly:f32> -> tensor<f32>
+ %3 = linalg.init_tensor [] : tensor<f32>
+ %4 = linalg.generic {indexing_maps = [affine_map<() -> ()>, affine_map<() -> ()>], iterator_types = []} ins(%2 : tensor<f32>) outs(%3 : tensor<f32>) {
+ ^bb0(%arg0: f32, %arg1: f32):
+ %5 = math.abs %arg0 : f32
+ linalg.yield %5 : f32
+ } -> tensor<f32>
+ flow.dispatch.tensor.store %4, %1, offsets = [], sizes = [], strides = [] : tensor<f32> -> !flow.dispatch.tensor<writeonly:f32>
+
+ return
+ }
+ }
+}
diff --git a/iree/hal/cuda/cts/CMakeLists.txt b/iree/hal/cuda/cts/CMakeLists.txt
index 520cd96..e2d6e72 100644
--- a/iree/hal/cuda/cts/CMakeLists.txt
+++ b/iree/hal/cuda/cts/CMakeLists.txt
@@ -18,6 +18,11 @@
DEPS
iree::hal::cuda::registration
EXCLUDED_TESTS
+ # This test depends on iree_hal_cuda_stream_command_buffer_update_buffer
+ # via iree_hal_buffer_view_allocate_buffer, which is not implemented yet.
+ "command_buffer_dispatch"
+ # Non-push descriptor sets are not implemented in the CUDA backend yet.
+ "descriptor_set"
# Semaphores are not implemented in the CUDA backend yet.
"semaphore_submission"
"semaphore"
diff --git a/iree/hal/vulkan/cts/CMakeLists.txt b/iree/hal/vulkan/cts/CMakeLists.txt
index 0ca8b65..67ef4df 100644
--- a/iree/hal/vulkan/cts/CMakeLists.txt
+++ b/iree/hal/vulkan/cts/CMakeLists.txt
@@ -17,4 +17,7 @@
"\"SPVE\""
DEPS
iree::hal::vulkan::registration
+ EXCLUDED_TESTS
+ # Non-push descriptor sets are not implemented in the Vulkan backend yet.
+ "descriptor_set"
)