blob: 66b21351d8cc87c60a29f84a43e7de407f779a04 [file] [log] [blame]
#!/bin/bash
# Copyright 2024 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
set -xeo pipefail
pjrt_platform=$1
if [ -z "${pjrt_platform}" ]; then
set +x
echo "Usage: run_jax_tests.sh <pjrt_platform>"
echo " <pjrt_platform> can be 'cpu', 'cuda', 'rocm' or 'vulkan'"
exit 1
fi
# cd into the PJRT plugin dir
ROOT_DIR="${ROOT_DIR:-$(git rev-parse --show-toplevel)}"
cd "${ROOT_DIR}/integrations/pjrt"
# perform some differential testing
actual_jax_platform=iree_${pjrt_platform}
expected_jax_platform=cpu
# this function will execute the test python script in
# both cpu mode and the IREE PJRT mode,
# and then compare the difference in the output
diff_jax_test() {
local test_py_file=$1
echo "executing ${test_py_file} in ${expected_jax_platform}.."
local expected_tmp_out=$(mktemp /tmp/jax_test_result_expected.XXXXXX)
JAX_PLATFORMS=$expected_jax_platform python $test_py_file > $expected_tmp_out
echo "executing ${test_py_file} in ${actual_jax_platform}.."
local actual_tmp_out=$(mktemp /tmp/jax_test_result_actual.XXXXXX)
JAX_PLATFORMS=$actual_jax_platform python $test_py_file > $actual_tmp_out
echo "comparing ${expected_tmp_out} and ${actual_tmp_out}.."
diff --unified $expected_tmp_out $actual_tmp_out
echo "no difference found"
}
diff_jax_test test/test_add.py
diff_jax_test test/test_degenerate.py
diff_jax_test test/test_simple.py
# FIXME: we can also utilize the native test cases from JAX,
# e.g. `tests/nn_test.py` from the JAX repo, as below,
# but currently some test cases in this file will fail.
# NOTE that `absl-py` is required to run these tests.
# local jax_nn_test_file=$(mktemp /tmp/jax_nn_test.XXXXXX.py)
# wget https://github.com/jax-ml/jax/blob/jax-v0.4.20/tests/nn_test.py -O $jax_nn_test_file
# JAX_PLATFORMS=$actual_jax_platform python $jax_nn_test_file