Scott Todd | b5b9026 | 2021-06-28 16:29:43 -0700 | [diff] [blame] | 1 | # Copyright 2021 The IREE Authors |
| 2 | # |
| 3 | # Licensed under the Apache License v2.0 with LLVM Exceptions. |
| 4 | # See https://llvm.org/LICENSE.txt for license information. |
| 5 | # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | |
| 7 | import glob |
Stella Laurenzo | 95b3150 | 2022-05-09 09:15:50 -0700 | [diff] [blame^] | 8 | import logging |
Scott Todd | b5b9026 | 2021-06-28 16:29:43 -0700 | [diff] [blame] | 9 | import os |
| 10 | import subprocess |
| 11 | import unittest |
| 12 | |
Scott Todd | b5b9026 | 2021-06-28 16:29:43 -0700 | [diff] [blame] | 13 | NOTEBOOKS_TO_SKIP = [] |
| 14 | |
Scott Todd | 1941216 | 2022-01-19 11:51:29 -0800 | [diff] [blame] | 15 | NOTEBOOKS_EXPECTED_TO_FAIL = [ |
Scott Todd | c0be21d | 2022-05-06 21:34:13 -0700 | [diff] [blame] | 16 | # None! |
Scott Todd | 1941216 | 2022-01-19 11:51:29 -0800 | [diff] [blame] | 17 | ] |
Scott Todd | b5b9026 | 2021-06-28 16:29:43 -0700 | [diff] [blame] | 18 | |
| 19 | |
Stella Laurenzo | 95b3150 | 2022-05-09 09:15:50 -0700 | [diff] [blame^] | 20 | class ColabNotebookTests(unittest.TestCase): |
Scott Todd | b5b9026 | 2021-06-28 16:29:43 -0700 | [diff] [blame] | 21 | """Tests running all Colab notebooks in this directory.""" |
| 22 | |
| 23 | @classmethod |
| 24 | def generateTests(cls): |
Scott Todd | c0be21d | 2022-05-06 21:34:13 -0700 | [diff] [blame] | 25 | repo_root = os.path.dirname( |
| 26 | os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
Scott Todd | b5b9026 | 2021-06-28 16:29:43 -0700 | [diff] [blame] | 27 | script_path = os.path.join(repo_root, |
| 28 | "build_tools/testing/run_python_notebook.sh") |
| 29 | |
| 30 | # Create a test case for each notebook in this folder. |
Scott Todd | 88df3b0 | 2022-04-26 08:21:02 -0700 | [diff] [blame] | 31 | notebooks_path = os.path.join(repo_root, "samples/colab/") |
Scott Todd | b5b9026 | 2021-06-28 16:29:43 -0700 | [diff] [blame] | 32 | for notebook_path in glob.glob(notebooks_path + "*.ipynb"): |
| 33 | notebook_name = os.path.basename(notebook_path) |
| 34 | |
| 35 | def unit_test(self, notebook_path=notebook_path): |
| 36 | |
| 37 | completed_process = subprocess.run([script_path, notebook_path]) |
| 38 | self.assertEqual(completed_process.returncode, 0) |
| 39 | |
| 40 | if notebook_name in NOTEBOOKS_TO_SKIP: |
| 41 | unit_test = unittest.skip("Skip requested")(unit_test) |
| 42 | elif notebook_name in NOTEBOOKS_EXPECTED_TO_FAIL: |
| 43 | unit_test = unittest.expectedFailure(unit_test) |
| 44 | |
| 45 | # Add 'unit_test' to this class, so the test runner runs it. |
| 46 | unit_test.__name__ = f"test_{notebook_name}" |
| 47 | setattr(cls, unit_test.__name__, unit_test) |
| 48 | |
| 49 | |
| 50 | if __name__ == "__main__": |
| 51 | ColabNotebookTests.generateTests() |
Stella Laurenzo | 95b3150 | 2022-05-09 09:15:50 -0700 | [diff] [blame^] | 52 | logging.basicConfig(level=logging.DEBUG) |
| 53 | unittest.main() |