| # Installs prebuilt binaries and runs jax testsuite |
| |
| name: Run CPU Jax Testsuite |
| |
| on: |
| workflow_dispatch: |
| schedule: |
| # Do the nightly dep roll at 2:30 PDT. |
| - cron: '30 21 * * *' |
| |
| concurrency: |
| # A PR number if a pull request and otherwise the commit hash. This cancels |
| # queued and in-progress runs for the same PR (presubmit) or commit |
| # (postsubmit). |
| group: run_jax_testsuite_${{ github.event.number || github.sha }} |
| cancel-in-progress: true |
| |
| # Jobs are organized into groups and topologically sorted by dependencies |
| jobs: |
| build: |
| runs-on: ubuntu-20.04-64core |
| steps: |
| - name: "Checking out repository" |
| uses: actions/checkout@e2f20e631ae6d7dd3b768f56a5d2af784dd54791 # v2.5.0 |
| |
| - name: "Setting up Python" |
| uses: actions/setup-python@75f3110429a8c05be0e1bf360334e4cced2b63fa # v2.3.3 |
| with: |
| python-version: "3.10" |
| |
| - name: Download last PJRT build |
| id: download-pjrt-build |
| uses: dawidd6/action-download-artifact@v2 |
| with: |
| workflow: build_packages.yml |
| branch: main |
| name: artifact |
| |
| - name: Download Last Golden result |
| id: download-last-golden |
| uses: dawidd6/action-download-artifact@v2 |
| with: |
| workflow: run_jaxtests_cpu.yml |
| branch: main |
| name: artifact |
| if_no_artifact_found: warn |
| |
| - name: Sync and install versions |
| run: | |
| # TODO: https://github.com/openxla/openxla-pjrt-plugin/issues/30 |
| sudo apt install -y lld |
| # Since only building the runtime, exclude compiler deps (expensive). |
| python ./sync_deps.py --depth 1 --submodules-depth 1 --exclude-dep xla --exclude-submodule "iree:third_party/(.*)" |
| pip install -r requirements.txt |
| python -m pip install absl-py pytest |
| python -m pip install -e ctstools |
| |
| - name: "Configure" |
| run: | |
| python ./configure.py --cc=clang --cxx=clang++ |
| |
| - name: "Run JAX Testsuite" |
| id: jax-testsuite |
| env: |
| PASSING_ARTIFACT: jaxsuite_passing.txt |
| FAILING_ARTIFACT: jaxsuite_failing.txt |
| GOLDEN_ARTIFACT: jaxsuite_golden.txt |
| run: | |
| source .env.sh |
| echo "testsuite-passing-artifact=${PASSING_ARTIFACT}" >> "${GITHUB_OUTPUT}" |
| echo "testsuite-failing-artifact=${FAILING_ARTIFACT}" >> "${GITHUB_OUTPUT}" |
| echo "testsuite-golden-artifact=${GOLDEN_ARTIFACT}" >> "${GITHUB_OUTPUT}" |
| |
| # Point to the downloaded PJRT library. |
| export PJRT_NAMES_AND_LIBRARY_PATHS="iree_cpu:/home/runner/work/openxla-pjrt-plugin/openxla-pjrt-plugin/pjrt_plugins/pjrt_plugin_iree_cpu.so" |
| JAX_PLATFORMS=iree_cpu python test/test_simple.py |
| |
| JAX_PLATFORMS=iree_cpu python test/test_jax.py \ |
| /home/runner/work/openxla-pjrt-plugin/jax/tests/lax_numpy_*_test.py \ |
| /home/runner/work/openxla-pjrt-plugin/jax/tests/nn_test.py \ |
| --passing ${PASSING_ARTIFACT} \ |
| --failing ${FAILING_ARTIFACT} \ |
| --expected ${GOLDEN_ARTIFACT} \ |
| --logdir jax_testsuite |
| |
| # If we passed we can update the golden. |
| if [ $? -eq 0 ]; then |
| cp ${PASSING_ARTIFACT} ${GOLDEN_ARTIFACT} |
| fi |
| |
| - name: "Run JAX Testsuite Triage" |
| run: | |
| python test/triage_jaxtest.py --logdir jax_testsuite |
| |
| - uses: actions/upload-artifact@0b7f8abb1508181956e8e162db84b466c27e18ce # v3.1.2 |
| if: always() |
| with: |
| path: | |
| ${{ steps.jax-testsuite.outputs.testsuite-passing-artifact }} |
| ${{ steps.jax-testsuite.outputs.testsuite-failing-artifact }} |
| ${{ steps.jax-testsuite.outputs.testsuite-golden-artifact }} |
| retention-days: 5 |