tree: c1b7a73e2f3c17820980264bf32c71fe52148d5b [path history] [tgz]
  1. ctstools/
  2. python_packages/
  3. src/
  4. test/
  5. third_party/
  6. .clang-format
  7. .gitignore
  8. .style.yapf
  9. CMakeLists.txt
  10. README.md
  11. requirements.txt
integrations/pjrt/README.md

IREE PJRT Plugin

This directory contains an experimental PJRT plugin library which can bridge Jax (and TensorFlow in the future) to IREE.

Developing

Support for dynamically loaded PJRT plugins is brand new as of 12/21/2022 and there are sharp edges still. The following procedure is being used to develop.

There are multiple development workflows, ranked from easiest to hardest (but most powerful).

Install a compatible version of Jax and the IREE compiler

pip install -r requirements.txt

# Assume that you have the Jax repo checked out at JAX_REPO from
# https://github.com/google/jax (must be paired with nightly jaxlib).
pip install -e $JAX_REPO

Verify that your Jax install is functional like:

python -c "import jax; a = jax.numpy.asarray([1, 2, 3, 4, 5, 6, 7, 8, 9]); print(a + a);"

Install the plugin of your choice (in this example ‘cpu’)

pip install -v --no-deps -e python_packages/iree_cpu_plugin

Verify basic functionality

JAX_PLATFORMS=iree_cpu python -c "import jax; a = jax.numpy.asarray([1, 2, 3, 4, 5, 6, 7, 8, 9]); print(a + a);"

Incrementally developing

If you did an editable install (-e) above, then you should be able to incrementally make changes and build the native component with no further interaction needed.

cd python_packages/iree_cpu_plugin/build/cmake
ninja

Running the Jax test suite

The JAX test suite can be run with pytest. We recommend using pytest-xdist as it spawns tests in workers which can be restarted in the event of individual test case crashes.

Setup:

# Install pytest
pip install pytest pytest-xdist

# Install the ctstools package from this repo (`-e` makes it editable).
pip install -e ctstools

Example of running tests:

JAX_PLATFORMS=iree_cuda pytest -n4 --max-worker-restart=9999 \
  -p openxla_pjrt_artifacts --openxla-pjrt-artifact-dir=/tmp/foobar \
  ~/src/jax/tests/nn_test.py

Note that you will typically want a small number of workers (-n4 above) for CUDA and a larger number can be tolerated for cpu.

The plugin openxla_pjrt_artifacts is in the ctstools directory and performs additional manipulation of the environment in order to save compilation artifacts, reproducers, etc.

Contacts

  • GitHub issues: Feature requests, bugs, and other work tracking
  • OpenXLA discord: Daily development discussions with the core team and collaborators

License

OpenXLA PJRT plugin is licensed under the terms of the Apache 2.0 License with LLVM Exceptions. See LICENSE for more information.