blob: 38b19b91fb80a9dc1830fb9cb05dfd08475b9c8e [file] [log] [blame]
Geoffrey Martin-Noblefb7f7d12022-12-12 16:10:36 -08001#!/usr/bin/env python3
Scott Toddb5b90262021-06-28 16:29:43 -07002# Copyright 2021 The IREE Authors
3#
4# Licensed under the Apache License v2.0 with LLVM Exceptions.
5# See https://llvm.org/LICENSE.txt for license information.
6# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7
8import glob
Stella Laurenzo95b31502022-05-09 09:15:50 -07009import logging
Scott Toddb5b90262021-06-28 16:29:43 -070010import os
11import subprocess
12import unittest
13
Scott Toddf0636f42023-05-17 10:49:30 -070014NOTEBOOKS_TO_SKIP = [
Scott Toddbd97cc52023-10-23 16:14:35 -070015 # matplotlib error when testing:
16 # FileNotFoundError: [Errno 2] No such file or directory: 'seaborn-whitegrid'
17 # support level for TF-code and samples is also low
18 "tensorflow_mnist_training.ipynb",
Scott Todd326aca22023-08-31 17:11:39 -070019 # tflite_runtime requires some deps ("version `GLIBC_2.29' not found") that
Scott Toddf0636f42023-05-17 10:49:30 -070020 # samples.Dockerfile does not currently include.
21 "tflite_text_classification.ipynb",
Scott Toddc3972582023-10-13 15:18:33 -070022 # PyTorch notebooks using SHARK-Turbine require Python 3.10+ in Docker.
Scott Todd2f47c082023-11-09 11:03:36 -080023 "pytorch_aot_advanced.ipynb",
Scott Toddc3972582023-10-13 15:18:33 -070024 "pytorch_aot_simple.ipynb",
Scott Toddb3cd60a2023-10-12 09:40:30 -070025 "pytorch_jit.ipynb",
Scott Toddf0636f42023-05-17 10:49:30 -070026]
Scott Toddb5b90262021-06-28 16:29:43 -070027
Scott Todd19412162022-01-19 11:51:29 -080028NOTEBOOKS_EXPECTED_TO_FAIL = [
Scott Todd326aca22023-08-31 17:11:39 -070029 # Error:
30 # ```
31 # module 'tensorflow.python.pywrap_mlir' has no attribute
32 # 'experimental_convert_saved_model_v1'
33 # ```
34 # convert_saved_model_v1 may be broken, but convert_saved_model works?
35 "tensorflow_hub_import.ipynb",
Scott Todd9d0a1792023-12-14 11:19:49 -080036 # error: 'stablehlo.pad' op attribute 'edge_padding_low' failed to satisfy
37 # constraint: 64-bit signless integer elements attribute
38 "tensorflow_resnet.ipynb",
Scott Todd19412162022-01-19 11:51:29 -080039]
Scott Toddb5b90262021-06-28 16:29:43 -070040
41
Stella Laurenzo95b31502022-05-09 09:15:50 -070042class ColabNotebookTests(unittest.TestCase):
Jakub Kuderskibe24f022023-06-21 14:44:18 -040043 """Tests running all Colab notebooks in this directory."""
Scott Toddb5b90262021-06-28 16:29:43 -070044
Jakub Kuderskibe24f022023-06-21 14:44:18 -040045 @classmethod
46 def generateTests(cls):
47 repo_root = os.path.dirname(
48 os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
49 )
50 script_path = os.path.join(
51 repo_root, "build_tools/testing/run_python_notebook.sh"
52 )
Scott Toddb5b90262021-06-28 16:29:43 -070053
Jakub Kuderskibe24f022023-06-21 14:44:18 -040054 # Create a test case for each notebook in this folder.
55 notebooks_path = os.path.join(repo_root, "samples/colab/")
56 for notebook_path in glob.glob(notebooks_path + "*.ipynb"):
57 notebook_name = os.path.basename(notebook_path)
Scott Toddb5b90262021-06-28 16:29:43 -070058
Jakub Kuderskibe24f022023-06-21 14:44:18 -040059 def unit_test(self, notebook_path=notebook_path):
60 completed_process = subprocess.run([script_path, notebook_path])
61 self.assertEqual(completed_process.returncode, 0)
Scott Toddb5b90262021-06-28 16:29:43 -070062
Jakub Kuderskibe24f022023-06-21 14:44:18 -040063 if notebook_name in NOTEBOOKS_TO_SKIP:
64 unit_test = unittest.skip("Skip requested")(unit_test)
65 elif notebook_name in NOTEBOOKS_EXPECTED_TO_FAIL:
66 unit_test = unittest.expectedFailure(unit_test)
Scott Toddb5b90262021-06-28 16:29:43 -070067
Jakub Kuderskibe24f022023-06-21 14:44:18 -040068 # Add 'unit_test' to this class, so the test runner runs it.
69 unit_test.__name__ = f"test_{notebook_name}"
70 setattr(cls, unit_test.__name__, unit_test)
Scott Toddb5b90262021-06-28 16:29:43 -070071
72
73if __name__ == "__main__":
Jakub Kuderskibe24f022023-06-21 14:44:18 -040074 ColabNotebookTests.generateTests()
75 logging.basicConfig(level=logging.DEBUG)
76 unittest.main()