| #!/usr/bin/env python3 |
| # Copyright 2021 The IREE Authors |
| # |
| # Licensed under the Apache License v2.0 with LLVM Exceptions. |
| # See https://llvm.org/LICENSE.txt for license information. |
| # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| |
| import glob |
| import logging |
| import os |
| import subprocess |
| import unittest |
| |
| NOTEBOOKS_TO_SKIP = [ |
| # matplotlib error when testing: |
| # FileNotFoundError: [Errno 2] No such file or directory: 'seaborn-whitegrid' |
| # support level for TF-code and samples is also low |
| "tensorflow_mnist_training.ipynb", |
| # This needs 'transformers' (and possibly other packages) preinstalled. |
| # Add to run_python_notebook.sh, colab/requirements.txt, or samples.yml? |
| "pytorch_huggingface_whisper.ipynb", |
| ] |
| |
| NOTEBOOKS_EXPECTED_TO_FAIL = [ |
| # None! |
| ] |
| |
| |
| class ColabNotebookTests(unittest.TestCase): |
| """Tests running all Colab notebooks in this directory.""" |
| |
| @classmethod |
| def generateTests(cls): |
| repo_root = os.path.dirname( |
| os.path.dirname(os.path.dirname(os.path.abspath(__file__))) |
| ) |
| script_path = os.path.join( |
| repo_root, "build_tools/testing/run_python_notebook.sh" |
| ) |
| |
| # Create a test case for each notebook in this folder. |
| notebooks_path = os.path.join(repo_root, "samples/colab/") |
| for notebook_path in glob.glob(notebooks_path + "*.ipynb"): |
| notebook_name = os.path.basename(notebook_path) |
| |
| def unit_test(self, notebook_path=notebook_path): |
| completed_process = subprocess.run([script_path, notebook_path]) |
| self.assertEqual(completed_process.returncode, 0) |
| |
| if notebook_name in NOTEBOOKS_TO_SKIP: |
| unit_test = unittest.skip("Skip requested")(unit_test) |
| elif notebook_name in NOTEBOOKS_EXPECTED_TO_FAIL: |
| unit_test = unittest.expectedFailure(unit_test) |
| |
| # Add 'unit_test' to this class, so the test runner runs it. |
| unit_test.__name__ = f"test_{notebook_name}" |
| setattr(cls, unit_test.__name__, unit_test) |
| |
| |
| if __name__ == "__main__": |
| ColabNotebookTests.generateTests() |
| logging.basicConfig(level=logging.DEBUG) |
| unittest.main() |