blob: 04f8aaaa4c2ac2343094c0be9ae74d23c2fe6ca3 [file] [log] [blame] [view]
# IREE PJRT Plugin
This directory contains an experimental PJRT plugin library which can bridge
Jax (and TensorFlow in the future) to IREE.
# Developing
Support for dynamically loaded PJRT plugins is brand new as of 12/21/2022 and
there are sharp edges still. The following procedure is being used to develop.
There are multiple development workflows, ranked from easiest to hardest (but
most powerful).
## Install a compatible version of Jax and the IREE compiler
```
pip install -r requirements.txt
# Assume that you have the Jax repo checked out at JAX_REPO from
# https://github.com/google/jax (must be paired with nightly jaxlib).
pip install -e $JAX_REPO
```
Verify that your Jax install is functional like:
```shell
python -c "import jax; a = jax.numpy.asarray([1, 2, 3, 4, 5, 6, 7, 8, 9]); print(a + a);"
```
## Install the plugin of your choice (in this example 'cpu')
pip install -v --no-deps -e python_packages/iree_cpu_plugin
## Verify basic functionality
```shell
JAX_PLATFORMS=iree_cpu python -c "import jax; a = jax.numpy.asarray([1, 2, 3, 4, 5, 6, 7, 8, 9]); print(a + a);"
```
## Incrementally developing
If you did an editable install (`-e`) above, then you should be able to incrementally
make changes and build the native component with no further interaction needed.
```shell
cd python_packages/iree_cpu_plugin/build/cmake
ninja
```
## Running the Jax test suite
The JAX test suite can be run with pytest. We recommend using `pytest-xdist`
as it spawns tests in workers which can be restarted in the event of individual
test case crashes.
Setup:
```
# Install pytest
pip install pytest pytest-xdist
# Install the ctstools package from this repo (`-e` makes it editable).
pip install -e ctstools
```
Example of running tests:
```
JAX_PLATFORMS=iree_cuda pytest -n4 --max-worker-restart=9999 \
-p openxla_pjrt_artifacts --openxla-pjrt-artifact-dir=/tmp/foobar \
~/src/jax/tests/nn_test.py
```
Note that you will typically want a small number of workers (`-n4` above) for
CUDA and a larger number can be tolerated for cpu.
The plugin `openxla_pjrt_artifacts` is in the `ctstools` directory and
performs additional manipulation of the environment in order to save
compilation artifacts, reproducers, etc.
## Contacts
* [GitHub issues](https://github.com/openxla/openxla-pjrt-plugin/issues):
Feature requests, bugs, and other work tracking
* [OpenXLA discord](https://discord.gg/pvuUmVQa): Daily development discussions
with the core team and collaborators
## License
OpenXLA PJRT plugin is licensed under the terms of the Apache 2.0 License with
LLVM Exceptions. See [LICENSE](LICENSE) for more information.