|  | # Copyright 2021 The IREE Authors | 
|  | # | 
|  | # Licensed under the Apache License v2.0 with LLVM Exceptions. | 
|  | # See https://llvm.org/LICENSE.txt for license information. | 
|  | # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | 
|  |  | 
|  | from absl import app | 
|  | from iree.tf.support import tf_test_utils | 
|  | import tensorflow as tf | 
|  |  | 
|  |  | 
|  | # Empty lists and dicts are currently unsupported. IREE also currently cannot | 
|  | # represent multiple sequence types, so we turn all sequences into tuples. | 
|  | class PyTreeModule(tf_test_utils.TestModule): | 
|  |  | 
|  | @tf_test_utils.tf_function_unit_test(input_signature=[]) | 
|  | def output_tuple_len_1(self): | 
|  | return (0,) | 
|  |  | 
|  | @tf_test_utils.tf_function_unit_test(input_signature=[]) | 
|  | def output_tuple_len_2(self): | 
|  | return 0, 1 | 
|  |  | 
|  | @tf_test_utils.tf_function_unit_test(input_signature=[]) | 
|  | def output_tuple_len_3(self): | 
|  | return 0, 1, 2 | 
|  |  | 
|  | @tf_test_utils.tf_function_unit_test(input_signature=[]) | 
|  | def output_nested_pytree(self): | 
|  | return {"key_a": (0, 1, 2), "key_b": (0, 1, {"key_c": (0, 1)})} | 
|  |  | 
|  | @tf_test_utils.tf_function_unit_test(input_signature=[{ | 
|  | "key_a": (tf.TensorSpec([]), tf.TensorSpec([]), tf.TensorSpec([])), | 
|  | "key_b": (tf.TensorSpec([]), tf.TensorSpec([]), { | 
|  | "key_c": (tf.TensorSpec([]), tf.TensorSpec([])) | 
|  | }) | 
|  | }]) | 
|  | def input_nested_pytree(self, input_pytree): | 
|  | return input_pytree | 
|  |  | 
|  |  | 
|  | class PyTreeTest(tf_test_utils.TracedModuleTestCase): | 
|  |  | 
|  | def __init__(self, *args, **kwargs): | 
|  | super().__init__(*args, **kwargs) | 
|  | self._modules = tf_test_utils.compile_tf_module(PyTreeModule) | 
|  |  | 
|  |  | 
|  | def main(argv): | 
|  | del argv  # Unused | 
|  | PyTreeTest.generate_unit_tests(PyTreeModule) | 
|  | tf.test.main() | 
|  |  | 
|  |  | 
|  | if __name__ == '__main__': | 
|  | app.run(main) |