| # OpenXLA PJRT Plugin |
| |
| This repository contains an experimental PJRT plugin library which can bridge |
| Jax (and TensorFlow in the future) to OpenXLA/IREE. |
| |
| # Developing |
| |
| Support for dynamically loaded PJRT plugins is brand new as of 12/21/2022 and |
| there are sharp edges still. The following procedure is being used to develop. |
| |
| There are multiple development workflows, ranked from easiest to hardest (but |
| most powerful). |
| |
| ## Setup options |
| |
| The below presumes that you have a compatible Jax/Jaxlib installed. Since |
| PJRT plugin support is moving fast, it is rare that released versions are |
| appropriate. **See ["Building Jax from Source"](#building-jax-from-source) |
| below.** |
| |
| If you are building without CUDA, you may still need to install IREE's CUDA deps |
| for the `bazel` build below: |
| |
| ```shell |
| export IREE_CUDA_DEPS_DIR=${HOME?}/.iree_cuda_deps |
| ../iree/build_tools/docker/context/fetch_cuda_deps.sh ${IREE_CUDA_DEPS_DIR?} |
| ``` |
| |
| ### Option 0: Pip install (non-dev) |
| |
| ```shell |
| pip install jax openxla_pjrt_plugin_cpu \ |
| -f https://openxla.github.io/openxla-pjrt-plugin/pip-release-links.html \ |
| -f https://openxla.github.io/iree/pip-release-links.html |
| ``` |
| |
| Then one can verify & use simply with |
| |
| ```shell |
| $ python -c "import jax; a = jax.numpy.asarray([1, 2, 3, 4, 5, 6, 7, 8, 9]); print(a + a);" |
| ``` |
| ``` |
| Platform 'iree_cpu' is experimental and not all JAX functionality may be correctly supported! |
| [IREE-PJRT] DEBUG: Using IREE compiler binary: /tmp/.venv/lib/python3.11/site-packages/iree/compiler/_mlir_libs/libIREECompiler.so |
| [IREE-PJRT] DEBUG: Compiler Version: 20230813.612 @ b56ac23bd85f0b9f4a9939c9e87fe83e629f8566 (API version 1.4) |
| [IREE-PJRT] DEBUG: Partitioner was not enabled. The partitioner can be enabled by setting the 'PARTITIONER_LIB_PATH' config var ('IREE_PJRT_PARTITIONER_LIB_PATH' env var) |
| [IREE-PJRT] DEBUG: CPU driver created |
| [ 2 4 6 8 10 12 14 16 18] |
| ``` |
| |
| ### Option 1: Synchronize to a nightly IREE release |
| |
| ```shell |
| python ./sync_deps.py |
| python -m pip install -U -r requirements.txt |
| python ./configure.py --cc=clang --cxx=clang++ --cuda-sdk-dir=$CUDA_SDK_DIR |
| |
| # Source environment variables to run interactively. |
| # The above generates a .env and .env.sh file with key setup vars. |
| source .env.sh |
| |
| # Build. |
| bazel build iree/integrations/pjrt/... |
| |
| # Run a sample. |
| JAX_PLATFORMS=iree_cpu python test/test_simple.py |
| JAX_PLATFORMS=iree_cuda python test/test_simple.py |
| # When multiple CUDA devices are installed, pick one by setting CUDA_VISIBLE_DEVICES=<n>. |
| CUDA_VISIBLE_DEVICES=0 JAX_PLATFORMS=iree_cuda python test/test_simple.py |
| |
| ``` |
| |
| ### Option 2: Set up for a full at-head dev rig |
| |
| ``` |
| mkdir openxla |
| cd openxla |
| python -m venv .env |
| source .env/bin/activate || die "Could not activate venv" |
| |
| pip install git+https://github.com/openxla/openxla-devtools.git |
| openxla-workspace init |
| openxla-workspace checkout --sync openxla-pjrt-plugin |
| |
| cd jax |
| pip install build numpy wheel |
| python build/build.py \ |
| --bazel_options=--override_repository=xla=$PWD/../xla \ |
| && pip3 install dist/*.whl --force-reinstall |
| pip install -e . |
| |
| cd ../iree |
| cmake -GNinja -B ../iree-build/ -S . \ |
| -DCMAKE_BUILD_TYPE=RelWithDebInfo \ |
| -DIREE_ENABLE_ASSERTIONS=ON \ |
| -DCMAKE_C_COMPILER=clang \ |
| -DCMAKE_CXX_COMPILER=clang++ \ |
| -DIREE_ENABLE_LLD=ON -DIREE_ENABLE_CCACHE=ON |
| cd ../iree-build |
| ninja libIREECompiler.so |
| export DYLIB_PATH=$PWD |
| |
| cd ../openxla-pjrt-plugin |
| python ./configure.py --cc=clang --cxx=clang++ --iree-compiler-dylib=$DYLIB_PATH/lib/libIREECompiler.so |
| source .env.sh |
| bazel build iree/integrations/pjrt/cpu/... |
| |
| # Do simple smoke test. |
| JAX_PLATFORMS=iree_cpu python test/test_simple.py |
| ``` |
| |
| ## Building Jax from Source |
| |
| Install Jax with Python sources: |
| |
| ```shell |
| # Starting in the openxla-pjrt-plugin repo, download JAX and sync to a |
| # compatible commit. |
| python ./sync_deps.py |
| python -m pip install -e ../jax |
| ``` |
| |
| Build a compatible jaxlib: |
| |
| ```shell |
| cd ../jax |
| # NOTE: Try running `bazel clean --expunge` if you run into undeclared inclusion |
| # error(s). |
| python build/build.py \ |
| --bazel_options=--override_repository=xla=$PWD/../xla |
| # Install the version of jaxlib you just built. |
| python -m pip install dist/*.whl --force-reinstall |
| ``` |
| |
| ## Generating runtime traces |
| |
| The plugins can be build with tracing enabled by adding the bazel build flag |
| `--iree_enable_runtime_tracing`. With this flag, if a profiler is running, |
| instrumentation will be sent to it. It can be useful to set the environment |
| variable `TRACY_NO_EXIT=1` in order to block termination of one-shot programs |
| that exit too quickly to stream all events. |
| |
| ## Generating compile_commands.json |
| |
| `compile_commands.json` can be generated by the following command. |
| |
| ``` |
| bazel run @hedron_compile_commands//:refresh_all |
| ``` |
| |
| |
| ## ASAN |
| |
| Developing with ASAN is recommended but requires some special steps because |
| we need to arrange for the plugin to be able to link with undefined |
| symbols and load the ASAN runtime library. |
| |
| * Edit out the `"-Wl,--no-undefined"` from `build_defs.bzl` |
| * Set env var `LD_PRELOAD=$(clang-12 -print-file-name=libclang_rt.asan-x86_64.so)` |
| (assuming compiling with `clang-12`. See configured.bazelrc in the IREE repo). |
| * Set env var `ASAN_OPTIONS=detect_leaks=0` (Python allocates a bunch of stuff |
| that it never frees. TODO: Make this more fine-grained so we can detect leaks in |
| plugin code). |
| * `--config=asan` |
| |
| This can be improved and made more systematic but should work. |
| |
| ## Running the Jax test suite |
| |
| The JAX test suite can be run with pytest. We recommend using `pytest-xdist` |
| as it spawns tests in workers which can be restarted in the event of individual |
| test case crashes. |
| |
| Setup: |
| |
| ``` |
| # Install pytest |
| pip install pytest pytest-xdist |
| |
| # Install the ctstools package from this repo (`-e` makes it editable). |
| pip install -e ctstools |
| ``` |
| |
| Example of running tests: |
| |
| ``` |
| JAX_PLATFORMS=iree_cuda pytest -n4 --max-worker-restart=9999 \ |
| -p openxla_pjrt_artifacts --openxla-pjrt-artifact-dir=/tmp/foobar \ |
| ~/src/jax/tests/nn_test.py |
| ``` |
| |
| Note that you will typically want a small number of workers (`-n4` above) for |
| CUDA and a larger number can be tolerated for cpu. |
| |
| The plugin `openxla_pjrt_artifacts` is in the `ctstools` directory and |
| performs additional manipulation of the environment in order to save |
| compilation artifacts, reproducers, etc. |
| |
| ## Project Maintenance |
| |
| This section is a work in progress describing various project maintenance |
| tasks. |
| |
| ### Pre-requisite: Install openxla-devtools |
| |
| ``` |
| pip install git+https://github.com/openxla/openxla-devtools.git |
| ``` |
| |
| ### Sync all deps to pinned versions |
| |
| This updates the git repositories and upgrades Python packages. |
| |
| ``` |
| openxla-workspace sync |
| python -m pip install -U -r requirements.txt |
| ``` |
| |
| ### Update to latest nightlies |
| |
| This updates the pinned revisions to track upstream nightlies. |
| Note that the roll action will upgrade Python packages implicitly. |
| |
| ``` |
| # Updates the sync_deps.py metadata. |
| openxla-workspace roll nightly |
| # Brings all dependencies to pinned versions. |
| openxla-workspace sync |
| ``` |
| |
| ### Update just IREE to its latest nightly. |
| |
| This just updates the IREE compiler and source pins to IREE's latest |
| nightly. It is useful for when there is some issue blocking a jax/xla |
| upgrade but progress is desired. |
| Note that the roll action will upgrade Python packages implicitly. |
| |
| ``` |
| # Updates the sync_deps.py metadata. |
| openxla-workspace roll iree_nightly |
| # Brings all dependencies to pinned versions. |
| openxla-workspace sync |
| ``` |
| |
| Alternatively, just the IREE source dep (runtime and APIs) can be pinned |
| to head: |
| |
| ``` |
| # Updates the sync_deps.py metadata. |
| openxla-workspace roll iree |
| # Brings all dependencies to pinned versions. |
| openxla-workspace sync |
| ``` |
| |
| ### Pin current versions of all deps |
| |
| This can be done if local, cross project changes have been made and landed. |
| It snapshots the state of all deps as actually checked out and updates |
| the metadata. |
| |
| ``` |
| openxla-workspace pin |
| ``` |
| |
| ## Contacts |
| |
| * [GitHub issues](https://github.com/openxla/openxla-pjrt-plugin/issues): |
| Feature requests, bugs, and other work tracking |
| * [OpenXLA discord](https://discord.gg/pvuUmVQa): Daily development discussions |
| with the core team and collaborators |
| |
| ## License |
| |
| OpenXLA PJRT plugin is licensed under the terms of the Apache 2.0 License with |
| LLVM Exceptions. See [LICENSE](LICENSE) for more information. |