|  | # Copyright 2020 The IREE Authors | 
|  | # | 
|  | # Licensed under the Apache License v2.0 with LLVM Exceptions. | 
|  | # See https://llvm.org/LICENSE.txt for license information. | 
|  | # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | 
|  |  | 
|  | import logging | 
|  | import os | 
|  | import sys | 
|  | import tempfile | 
|  | import unittest | 
|  |  | 
|  | # TODO: No idea why pytype cannot find names from this module. | 
|  | # pytype: disable=name-error | 
|  | import iree.compiler.tf | 
|  |  | 
|  | if not iree.compiler.tf.is_available(): | 
|  | print(f"Skipping test {__file__} because the IREE TensorFlow compiler " | 
|  | f"is not installed") | 
|  | sys.exit(0) | 
|  |  | 
|  | import tensorflow as tf | 
|  |  | 
|  |  | 
|  | class SimpleArithmeticModule(tf.Module): | 
|  |  | 
|  | @tf.function(input_signature=[ | 
|  | tf.TensorSpec([4], tf.float32), | 
|  | tf.TensorSpec([4], tf.float32) | 
|  | ]) | 
|  | def simple_mul(self, a, b): | 
|  | return a * b | 
|  |  | 
|  | @tf.function(input_signature=[ | 
|  | tf.TensorSpec([128, 3072], tf.float32), | 
|  | tf.TensorSpec([3072, 256], tf.float32), | 
|  | ]) | 
|  | def simple_matmul(self, a, b): | 
|  | return tf.matmul(a, b) | 
|  |  | 
|  |  | 
|  | # TODO(laurenzo): More test cases needed (may need additional files). | 
|  | # Specifically, figure out how to test v1 models. | 
|  | class TfCompilerTest(tf.test.TestCase): | 
|  |  | 
|  | def testImportSavedModel(self): | 
|  | import_mlir = iree.compiler.tf.compile_saved_model( | 
|  | self.smdir, import_only=True, output_generic_mlir=True).decode("utf-8") | 
|  | self.assertIn("sym_name = \"simple_matmul\"", import_mlir) | 
|  |  | 
|  | def testCompileSavedModel(self): | 
|  | binary = iree.compiler.tf.compile_saved_model( | 
|  | self.smdir, target_backends=iree.compiler.tf.DEFAULT_TESTING_BACKENDS) | 
|  | logging.info("Compiled len: %d", len(binary)) | 
|  | self.assertIn(b"simple_matmul", binary) | 
|  | self.assertIn(b"simple_mul", binary) | 
|  |  | 
|  | def testCompileModule(self): | 
|  | binary = iree.compiler.tf.compile_module( | 
|  | self.m, target_backends=iree.compiler.tf.DEFAULT_TESTING_BACKENDS) | 
|  | logging.info("Compiled len: %d", len(binary)) | 
|  | self.assertIn(b"simple_matmul", binary) | 
|  | self.assertIn(b"simple_mul", binary) | 
|  |  | 
|  | @classmethod | 
|  | def setUpClass(cls): | 
|  | cls.m = SimpleArithmeticModule() | 
|  | cls.tempdir = tempfile.TemporaryDirectory() | 
|  | cls.smdir = os.path.join(cls.tempdir.name, "arith.sm") | 
|  | tf.saved_model.save( | 
|  | cls.m, | 
|  | cls.smdir, | 
|  | options=tf.saved_model.SaveOptions(save_debug_info=True)) | 
|  |  | 
|  | @classmethod | 
|  | def tearDownClass(cls): | 
|  | cls.tempdir.cleanup() | 
|  |  | 
|  |  | 
|  | if __name__ == "__main__": | 
|  | logging.basicConfig(level=logging.DEBUG) | 
|  | tf.test.main() |