| # IREE PJRT Plugin | 
 |  | 
 | This directory contains an experimental PJRT plugin library which can bridge | 
 | Jax (and TensorFlow in the future) to IREE. | 
 |  | 
 | # Developing | 
 |  | 
 | Support for dynamically loaded PJRT plugins is brand new as of 12/21/2022 and | 
 | there are sharp edges still. The following procedure is being used to develop. | 
 |  | 
 | There are multiple development workflows, ranked from easiest to hardest (but | 
 | most powerful). | 
 |  | 
 | ## Install a compatible version of Jax and the IREE compiler | 
 |  | 
 | ```shell | 
 | pip install -r requirements.txt | 
 |  | 
 | # a higher version of jax is highly recommended, e.g. 0.4.36 | 
 | pip install jax==0.4.36 | 
 | ``` | 
 |  | 
 | Verify that your Jax install is functional like: | 
 |  | 
 | ```shell | 
 | python -c "import jax; a = jax.numpy.asarray([1, 2, 3, 4, 5, 6, 7, 8, 9]); print(a + a);" | 
 | ``` | 
 |  | 
 | ## Install the plugin of your choice (in this example 'cpu') | 
 |  | 
 | pip install -v --no-deps -e python_packages/iree_cpu_plugin | 
 |  | 
 | ## Verify basic functionality | 
 |  | 
 | ```shell | 
 | JAX_PLATFORMS=iree_cpu python -c "import jax; a = jax.numpy.asarray([1, 2, 3, 4, 5, 6, 7, 8, 9]); print(a + a);" | 
 | ``` | 
 |  | 
 | ## Advanced settings | 
 |  | 
 | To pass additional compile options to IREE during JIT compilation, you can use | 
 | the `IREE_PJRT_IREE_COMPILER_OPTIONS` environment variable. This variable can | 
 | be set to a space-delimited list of flags that would be passed to the | 
 | `iree-compile` command-line tool. | 
 |  | 
 | For example: | 
 | ```shell | 
 | export IREE_PJRT_IREE_COMPILER_OPTIONS=--iree-scheduling-dump-statistics-format=csv | 
 | JAX_PLATFORMS=iree_cpu python -c "import jax; a = jax.numpy.asarray([1, 2, 3, 4, 5, 6, 7, 8, 9]); print(a + a);" | 
 | ``` | 
 |  | 
 | ## Incrementally developing | 
 |  | 
 | If you did an editable install (`-e`) above, then you should be able to incrementally | 
 | make changes and build the native component with no further interaction needed. | 
 |  | 
 | ```shell | 
 | cd python_packages/iree_cpu_plugin/build/cmake | 
 | ninja | 
 | ``` | 
 |  | 
 | ## Running the Jax test suite | 
 |  | 
 | The JAX test suite can be run with pytest. We recommend using `pytest-xdist` | 
 | as it spawns tests in workers which can be restarted in the event of individual | 
 | test case crashes. | 
 |  | 
 | Setup: | 
 |  | 
 | ``` | 
 | # Install pytest | 
 | pip install pytest pytest-xdist | 
 |  | 
 | # Install the ctstools package from this repo (`-e` makes it editable). | 
 | pip install -e ctstools | 
 | ``` | 
 |  | 
 | Example of running tests: | 
 |  | 
 | ``` | 
 | JAX_PLATFORMS=iree_cuda pytest -n4 --max-worker-restart=9999 \ | 
 |   -p openxla_pjrt_artifacts --openxla-pjrt-artifact-dir=/tmp/foobar \ | 
 |   ~/src/jax/tests/nn_test.py | 
 | ``` | 
 |  | 
 | Note that you will typically want a small number of workers (`-n4` above) for | 
 | CUDA and a larger number can be tolerated for cpu. | 
 |  | 
 | The plugin `openxla_pjrt_artifacts` is in the `ctstools` directory and | 
 | performs additional manipulation of the environment in order to save | 
 | compilation artifacts, reproducers, etc. | 
 |  | 
 | ## Communication channels | 
 |  | 
 | * Please submit feature requests and bug reports about the plugin in [GitHub Issues](https://github.com/iree-org/iree/issues). | 
 | * Discuss the development of the plugin at `#jax` or `#pjrt-plugin` channel of [IREE Discord server](https://discord.gg/wEWh6Z9nMU). | 
 | * Check the [OpenXLA/XLA](https://github.com/openxla/xla) repo and [its communication channels](https://github.com/openxla/community?tab=readme-ov-file#communication-channels) for PJRT APIs and clients. | 
 |  | 
 | ## License | 
 |  | 
 | IREE PJRT plugin is licensed under the terms of the Apache 2.0 License with | 
 | LLVM Exceptions. See [LICENSE](../../LICENSE) for more information. | 
 |  | 
 | [PJRT C API](./third_party/pjrt_c_api) comes from | 
 | [OpenXLA/XLA](https://github.com/openxla/xla) and is licensed under | 
 | the Apache 2.0 License. See its own [LICENSE](./third_party/pjrt_c_api/LICENSE) for more information. |