blob: 4e7d307ded37f023710548fb4b9528fa204dab56 [file] [log] [blame]
Scott Toddb5b90262021-06-28 16:29:43 -07001# Copyright 2021 The IREE Authors
2#
3# Licensed under the Apache License v2.0 with LLVM Exceptions.
4# See https://llvm.org/LICENSE.txt for license information.
5# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
7import glob
Stella Laurenzo95b31502022-05-09 09:15:50 -07008import logging
Scott Toddb5b90262021-06-28 16:29:43 -07009import os
10import subprocess
11import unittest
12
Scott Toddb5b90262021-06-28 16:29:43 -070013NOTEBOOKS_TO_SKIP = []
14
Scott Todd19412162022-01-19 11:51:29 -080015NOTEBOOKS_EXPECTED_TO_FAIL = [
Scott Toddc0be21d2022-05-06 21:34:13 -070016 # None!
Scott Todd19412162022-01-19 11:51:29 -080017]
Scott Toddb5b90262021-06-28 16:29:43 -070018
19
Stella Laurenzo95b31502022-05-09 09:15:50 -070020class ColabNotebookTests(unittest.TestCase):
Scott Toddb5b90262021-06-28 16:29:43 -070021 """Tests running all Colab notebooks in this directory."""
22
23 @classmethod
24 def generateTests(cls):
Scott Toddc0be21d2022-05-06 21:34:13 -070025 repo_root = os.path.dirname(
26 os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
Scott Toddb5b90262021-06-28 16:29:43 -070027 script_path = os.path.join(repo_root,
28 "build_tools/testing/run_python_notebook.sh")
29
30 # Create a test case for each notebook in this folder.
Scott Todd88df3b02022-04-26 08:21:02 -070031 notebooks_path = os.path.join(repo_root, "samples/colab/")
Scott Toddb5b90262021-06-28 16:29:43 -070032 for notebook_path in glob.glob(notebooks_path + "*.ipynb"):
33 notebook_name = os.path.basename(notebook_path)
34
35 def unit_test(self, notebook_path=notebook_path):
36
37 completed_process = subprocess.run([script_path, notebook_path])
38 self.assertEqual(completed_process.returncode, 0)
39
40 if notebook_name in NOTEBOOKS_TO_SKIP:
41 unit_test = unittest.skip("Skip requested")(unit_test)
42 elif notebook_name in NOTEBOOKS_EXPECTED_TO_FAIL:
43 unit_test = unittest.expectedFailure(unit_test)
44
45 # Add 'unit_test' to this class, so the test runner runs it.
46 unit_test.__name__ = f"test_{notebook_name}"
47 setattr(cls, unit_test.__name__, unit_test)
48
49
50if __name__ == "__main__":
51 ColabNotebookTests.generateTests()
Stella Laurenzo95b31502022-05-09 09:15:50 -070052 logging.basicConfig(level=logging.DEBUG)
53 unittest.main()