tree: 491bd454607f2606dd00ee26ff8c6a58af5599d4 [path history] [tgz]
  1. CMakeLists.txt
  2. datasets.py
  3. generate_test_data.py
  4. generate_test_data_requirements.txt
  5. mnist_train_test.py
  6. README.md
tests/e2e/models/mnist_train_test/README.md

Test that IREE has the correct model and optimizer state after doing one train step and after initialization of parameters. The ground truth is extracted from a JAX model. The MLIR model is generated with IREE JAX.

To regenerate the model together with the test data use

python -m venv generate_mnist.venv
source generate_mnist.venv/bin/activate
# Add IREE Python to your PYTHONPATH, following
# https://openxla.github.io/iree/building-from-source/getting-started/#python-bindings
pip install -r generate_test_data_requirements.txt
python ./generate_test_data.py

Upload to gcs

tar --remove-files -v -c -f mnist_train.tar *.npz *.mlirbc
DIGEST="$(sha256sum mnist_train.tar | awk '{print $1}')"
gcloud storage mv mnist_train.tar "gs://iree-model-artifacts/mnist_train.${DIGEST}.tar"
sed -i \
  "s|MODEL_ARTIFACTS_URL =.*|MODEL_ARTIFACTS_URL = \"https://storage.googleapis.com/iree-model-artifacts/mnist_train.${DIGEST}.tar\"|" \
  mnist_train_test.py