| # Copyright 2020 Google LLC | 
 | # | 
 | # Licensed under the Apache License, Version 2.0 (the "License"); | 
 | # you may not use this file except in compliance with the License. | 
 | # You may obtain a copy of the License at | 
 | # | 
 | #      https://www.apache.org/licenses/LICENSE-2.0 | 
 | # | 
 | # Unless required by applicable law or agreed to in writing, software | 
 | # distributed under the License is distributed on an "AS IS" BASIS, | 
 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
 | # See the License for the specific language governing permissions and | 
 | # limitations under the License. | 
 |  | 
 | import collections | 
 | import os | 
 | from typing import Any, Dict, Sequence, Type, Union | 
 |  | 
 | from absl import app | 
 | from absl import flags | 
 | import numpy as np | 
 | from pyiree.tf.support import tf_test_utils | 
 | from pyiree.tf.support import tf_utils | 
 | import tensorflow.compat.v2 as tf | 
 |  | 
 | FLAGS = flags.FLAGS | 
 |  | 
 | # As high as tf goes without breaking. | 
 | RANK_7_SHAPE = [2] * 7 | 
 | UNARY_SIGNATURE_SHAPES = [[RANK_7_SHAPE]] | 
 | BINARY_SIGNATURE_SHAPES = [[RANK_7_SHAPE] * 2] | 
 | TERNARY_SIGNATURE_SHAPES = [[RANK_7_SHAPE] * 3] | 
 |  | 
 | # Reused UnitTestSpecs. | 
 | SEGMENT_UNIT_TEST_SPECS = tf_test_utils.unit_test_specs_from_args( | 
 |     names_to_input_args={ | 
 |         "tf_doc_example": [ | 
 |             tf.constant([ | 
 |                 [1, 2, 3, 4], | 
 |                 [4, 3, 2, 1], | 
 |                 [5, 6, 7, 8], | 
 |             ], np.float32), | 
 |             np.array([0, 0, 1], np.int32), | 
 |         ] | 
 |     }) | 
 | UNSORTED_SEGMENT_UNIT_TEST_SPECS = tf_test_utils.unit_test_specs_from_args( | 
 |     names_to_input_args={ | 
 |         "tf_doc_example": [ | 
 |             tf.constant([ | 
 |                 [1, 2, 3, 4], | 
 |                 [4, 3, 2, 1], | 
 |                 [5, 6, 7, 8], | 
 |             ], np.float32), | 
 |             np.array([0, 0, 1], np.int32), | 
 |             2, | 
 |         ] | 
 |     }) | 
 |  | 
 | REDUCE_KWARGS_TO_VALUES = { | 
 |     "axis": [None, 1], | 
 |     "keepdims": [False, True], | 
 | } | 
 |  | 
 | # A dictionary mapping tf.math function names to lists of UnitTestSpecs. | 
 | # Each unit_test_name will have the tf.math function name prepended to it. | 
 | FUNCTIONS_TO_UNIT_TEST_SPECS = { | 
 |     "abs": | 
 |         tf_test_utils.unit_test_specs_from_signatures( | 
 |             signature_shapes=UNARY_SIGNATURE_SHAPES, | 
 |             signature_dtypes=[tf.float32, tf.int32, tf.complex64]), | 
 |     "accumulate_n": [ | 
 |         tf_test_utils.UnitTestSpec( | 
 |             unit_test_name='f32', | 
 |             input_signature=[[tf.TensorSpec(RANK_7_SHAPE, tf.float32)] * 5]), | 
 |         tf_test_utils.UnitTestSpec( | 
 |             unit_test_name='i32', | 
 |             input_signature=[[tf.TensorSpec(RANK_7_SHAPE, tf.int32)] * 5]), | 
 |     ], | 
 |     "acos": | 
 |         tf_test_utils.unit_test_specs_from_signatures( | 
 |             signature_shapes=UNARY_SIGNATURE_SHAPES, | 
 |             signature_dtypes=[tf.float32]), | 
 |     "acosh": | 
 |         tf_test_utils.unit_test_specs_from_signatures( | 
 |             signature_shapes=UNARY_SIGNATURE_SHAPES, | 
 |             signature_dtypes=[tf.float32], | 
 |             input_generators=[tf_utils.ndarange]), | 
 |     "add": | 
 |         tf_test_utils.unit_test_specs_from_signatures( | 
 |             signature_shapes=BINARY_SIGNATURE_SHAPES, | 
 |             signature_dtypes=[tf.float32, tf.int32, tf.complex64]), | 
 |     "add_n": [ | 
 |         tf_test_utils.UnitTestSpec( | 
 |             unit_test_name='f32', | 
 |             input_signature=[[tf.TensorSpec(RANK_7_SHAPE, tf.float32)] * 5]), | 
 |         tf_test_utils.UnitTestSpec( | 
 |             unit_test_name='i32', | 
 |             input_signature=[[tf.TensorSpec(RANK_7_SHAPE, tf.int32)] * 5]), | 
 |     ], | 
 |     "angle": | 
 |         tf_test_utils.unit_test_specs_from_signatures( | 
 |             signature_shapes=UNARY_SIGNATURE_SHAPES, | 
 |             signature_dtypes=[tf.float32, tf.complex64]), | 
 |     "argmax": | 
 |         tf_test_utils.unit_test_specs_from_signatures( | 
 |             signature_shapes=UNARY_SIGNATURE_SHAPES, | 
 |             signature_dtypes=[tf.float32, tf.int32]), | 
 |     "argmin": | 
 |         tf_test_utils.unit_test_specs_from_signatures( | 
 |             signature_shapes=UNARY_SIGNATURE_SHAPES, | 
 |             signature_dtypes=[tf.float32, tf.int32]), | 
 |     "asin": | 
 |         tf_test_utils.unit_test_specs_from_signatures( | 
 |             signature_shapes=UNARY_SIGNATURE_SHAPES, | 
 |             signature_dtypes=[tf.float32]), | 
 |     "asinh": | 
 |         tf_test_utils.unit_test_specs_from_signatures( | 
 |             signature_shapes=UNARY_SIGNATURE_SHAPES, | 
 |             signature_dtypes=[tf.float32, tf.complex64]), | 
 |     "atan": | 
 |         tf_test_utils.unit_test_specs_from_signatures( | 
 |             signature_shapes=UNARY_SIGNATURE_SHAPES, | 
 |             signature_dtypes=[tf.float32]), | 
 |     "atan2": | 
 |         tf_test_utils.unit_test_specs_from_signatures( | 
 |             signature_shapes=BINARY_SIGNATURE_SHAPES, | 
 |             signature_dtypes=[tf.float32]), | 
 |     "atanh": | 
 |         tf_test_utils.unit_test_specs_from_signatures( | 
 |             signature_shapes=UNARY_SIGNATURE_SHAPES, | 
 |             signature_dtypes=[tf.float32, tf.complex64]), | 
 |     "bessel_i0": | 
 |         tf_test_utils.unit_test_specs_from_signatures( | 
 |             signature_shapes=UNARY_SIGNATURE_SHAPES, | 
 |             signature_dtypes=[tf.float32]), | 
 |     "bessel_i0e": | 
 |         tf_test_utils.unit_test_specs_from_signatures( | 
 |             signature_shapes=UNARY_SIGNATURE_SHAPES, | 
 |             signature_dtypes=[tf.float32]), | 
 |     "bessel_i1": | 
 |         tf_test_utils.unit_test_specs_from_signatures( | 
 |             signature_shapes=UNARY_SIGNATURE_SHAPES, | 
 |             signature_dtypes=[tf.float32]), | 
 |     "bessel_i1e": | 
 |         tf_test_utils.unit_test_specs_from_signatures( | 
 |             signature_shapes=UNARY_SIGNATURE_SHAPES, | 
 |             signature_dtypes=[tf.float32]), | 
 |     "betainc": | 
 |         tf_test_utils.unit_test_specs_from_signatures( | 
 |             signature_shapes=TERNARY_SIGNATURE_SHAPES, | 
 |             signature_dtypes=[tf.float32]), | 
 |     "bincount": | 
 |         tf_test_utils.unit_test_specs_from_signatures( | 
 |             signature_shapes=UNARY_SIGNATURE_SHAPES, | 
 |             signature_dtypes=[tf.int32], | 
 |             input_generators=[tf_utils.ndarange]), | 
 |     "ceil": | 
 |         tf_test_utils.unit_test_specs_from_signatures( | 
 |             signature_shapes=UNARY_SIGNATURE_SHAPES, | 
 |             signature_dtypes=[tf.float32]), | 
 |     "confusion_matrix": | 
 |         tf_test_utils.unit_test_specs_from_args(names_to_input_args={ | 
 |             "five_classes": [tf.constant([1, 2, 4]), | 
 |                              tf.constant([2, 2, 4])] | 
 |         }), | 
 |     "conj": | 
 |         tf_test_utils.unit_test_specs_from_signatures( | 
 |             signature_shapes=UNARY_SIGNATURE_SHAPES, | 
 |             signature_dtypes=[tf.float32, tf.int32, tf.complex64]), | 
 |     "cos": | 
 |         tf_test_utils.unit_test_specs_from_signatures( | 
 |             signature_shapes=UNARY_SIGNATURE_SHAPES, | 
 |             signature_dtypes=[tf.float32, tf.complex64]), | 
 |     "cosh": | 
 |         tf_test_utils.unit_test_specs_from_signatures( | 
 |             signature_shapes=UNARY_SIGNATURE_SHAPES, | 
 |             signature_dtypes=[tf.float32, tf.complex64]), | 
 |     "count_nonzero": | 
 |         tf_test_utils.unit_test_specs_from_signatures( | 
 |             signature_shapes=UNARY_SIGNATURE_SHAPES, | 
 |             signature_dtypes=[tf.float32, tf.int32, tf.complex64], | 
 |             input_generators=[tf_utils.ndarange]), | 
 |     "cumprod": | 
 |         tf_test_utils.unit_test_specs_from_signatures( | 
 |             signature_shapes=UNARY_SIGNATURE_SHAPES, | 
 |             signature_dtypes=[tf.float32, tf.int32, tf.complex64]), | 
 |     "cumsum": | 
 |         tf_test_utils.unit_test_specs_from_signatures( | 
 |             signature_shapes=UNARY_SIGNATURE_SHAPES, | 
 |             signature_dtypes=[tf.float32, tf.int32, tf.complex64]), | 
 |     "cumulative_logsumexp": | 
 |         tf_test_utils.unit_test_specs_from_signatures( | 
 |             signature_shapes=UNARY_SIGNATURE_SHAPES, | 
 |             signature_dtypes=[tf.float32]), | 
 |     "digamma": | 
 |         tf_test_utils.unit_test_specs_from_signatures( | 
 |             signature_shapes=UNARY_SIGNATURE_SHAPES, | 
 |             signature_dtypes=[tf.float32]), | 
 |     "divide": | 
 |         tf_test_utils.unit_test_specs_from_signatures( | 
 |             signature_shapes=BINARY_SIGNATURE_SHAPES, | 
 |             signature_dtypes=[tf.float32, tf.int32, tf.complex64]), | 
 |     "divide_no_nan": | 
 |         tf_test_utils.unit_test_specs_from_signatures( | 
 |             signature_shapes=BINARY_SIGNATURE_SHAPES, | 
 |             signature_dtypes=[tf.float32, tf.complex64]), | 
 |     "equal": | 
 |         tf_test_utils.unit_test_specs_from_signatures( | 
 |             signature_shapes=BINARY_SIGNATURE_SHAPES, | 
 |             signature_dtypes=[tf.float32, tf.int32]), | 
 |     "erf": | 
 |         tf_test_utils.unit_test_specs_from_signatures( | 
 |             signature_shapes=UNARY_SIGNATURE_SHAPES, | 
 |             signature_dtypes=[tf.float32]), | 
 |     "erfc": | 
 |         tf_test_utils.unit_test_specs_from_signatures( | 
 |             signature_shapes=UNARY_SIGNATURE_SHAPES, | 
 |             signature_dtypes=[tf.float32]), | 
 |     "erfinv": | 
 |         tf_test_utils.unit_test_specs_from_signatures( | 
 |             signature_shapes=UNARY_SIGNATURE_SHAPES, | 
 |             signature_dtypes=[tf.float32]), | 
 |     "exp": | 
 |         tf_test_utils.unit_test_specs_from_signatures( | 
 |             signature_shapes=UNARY_SIGNATURE_SHAPES, | 
 |             signature_dtypes=[tf.float32, tf.complex64]), | 
 |     "expm1": | 
 |         tf_test_utils.unit_test_specs_from_signatures( | 
 |             signature_shapes=UNARY_SIGNATURE_SHAPES, | 
 |             signature_dtypes=[tf.float32, tf.complex64]), | 
 |     "floor": | 
 |         tf_test_utils.unit_test_specs_from_signatures( | 
 |             signature_shapes=UNARY_SIGNATURE_SHAPES, | 
 |             signature_dtypes=[tf.float32]), | 
 |     "floordiv": | 
 |         tf_test_utils.unit_test_specs_from_signatures( | 
 |             signature_shapes=BINARY_SIGNATURE_SHAPES, | 
 |             signature_dtypes=[tf.float32, tf.int32], | 
 |             # Avoid integer division by 0. | 
 |             input_generators={ | 
 |                 "uniform_1_3": | 
 |                     lambda *args: tf_utils.uniform(*args, low=1.0, high=3.0) | 
 |             }), | 
 |     "floormod": | 
 |         tf_test_utils.unit_test_specs_from_signatures( | 
 |             signature_shapes=BINARY_SIGNATURE_SHAPES, | 
 |             signature_dtypes=[tf.float32, tf.int32], | 
 |             # Avoid integer division by 0. | 
 |             input_generators={ | 
 |                 "uniform_1_3": | 
 |                     lambda *args: tf_utils.uniform(*args, low=1.0, high=3.0) | 
 |             }), | 
 |     "greater": | 
 |         tf_test_utils.unit_test_specs_from_signatures( | 
 |             signature_shapes=BINARY_SIGNATURE_SHAPES, | 
 |             signature_dtypes=[tf.float32, tf.int32]), | 
 |     "greater_equal": | 
 |         tf_test_utils.unit_test_specs_from_signatures( | 
 |             signature_shapes=BINARY_SIGNATURE_SHAPES, | 
 |             signature_dtypes=[tf.float32, tf.int32]), | 
 |     "igamma": | 
 |         tf_test_utils.unit_test_specs_from_signatures( | 
 |             signature_shapes=BINARY_SIGNATURE_SHAPES, | 
 |             signature_dtypes=[tf.float32]), | 
 |     "igammac": | 
 |         tf_test_utils.unit_test_specs_from_signatures( | 
 |             signature_shapes=BINARY_SIGNATURE_SHAPES, | 
 |             signature_dtypes=[tf.float32]), | 
 |     "imag": | 
 |         tf_test_utils.unit_test_specs_from_signatures( | 
 |             signature_shapes=UNARY_SIGNATURE_SHAPES, | 
 |             signature_dtypes=[tf.float32, tf.complex64]), | 
 |     "in_top_k": [ | 
 |         tf_test_utils.UnitTestSpec( | 
 |             unit_test_name="k_3", | 
 |             input_signature=[ | 
 |                 tf.TensorSpec([8], tf.int32), | 
 |                 tf.TensorSpec([8, 3]) | 
 |             ], | 
 |             input_generator=tf_utils.ndarange, | 
 |             kwargs=dict(k=3), | 
 |         ) | 
 |     ], | 
 |     "invert_permutation": [ | 
 |         tf_test_utils.UnitTestSpec( | 
 |             unit_test_name="random", | 
 |             input_signature=[tf.TensorSpec([8], tf.int32)], | 
 |             input_generator=tf_utils.random_permutation, | 
 |         ) | 
 |     ], | 
 |     "is_finite": | 
 |         tf_test_utils.unit_test_specs_from_args(names_to_input_args={ | 
 |             "nan_and_inf": [tf.constant([[1., np.nan], [np.inf, 2.]])] | 
 |         }), | 
 |     "is_inf": | 
 |         tf_test_utils.unit_test_specs_from_args(names_to_input_args={ | 
 |             "nan_and_inf": [tf.constant([[1., np.nan], [np.inf, 2.]])] | 
 |         }), | 
 |     "is_nan": | 
 |         tf_test_utils.unit_test_specs_from_args(names_to_input_args={ | 
 |             "nan_and_inf": [tf.constant([[1., np.nan], [np.inf, 2.]])] | 
 |         }), | 
 |     "is_non_decreasing": | 
 |         tf_test_utils.unit_test_specs_from_signatures( | 
 |             signature_shapes=UNARY_SIGNATURE_SHAPES, | 
 |             signature_dtypes=[tf.float32, tf.int32]), | 
 |     "is_strictly_increasing": | 
 |         tf_test_utils.unit_test_specs_from_signatures( | 
 |             signature_shapes=UNARY_SIGNATURE_SHAPES, | 
 |             signature_dtypes=[tf.float32, tf.int32]), | 
 |     "l2_normalize": | 
 |         tf_test_utils.unit_test_specs_from_signatures( | 
 |             signature_shapes=UNARY_SIGNATURE_SHAPES, | 
 |             signature_dtypes=[tf.float32, tf.complex64]), | 
 |     "lbeta": | 
 |         tf_test_utils.unit_test_specs_from_signatures( | 
 |             signature_shapes=UNARY_SIGNATURE_SHAPES, | 
 |             signature_dtypes=[tf.float32]), | 
 |     "less": | 
 |         tf_test_utils.unit_test_specs_from_signatures( | 
 |             signature_shapes=BINARY_SIGNATURE_SHAPES, | 
 |             signature_dtypes=[tf.float32, tf.int32]), | 
 |     "less_equal": | 
 |         tf_test_utils.unit_test_specs_from_signatures( | 
 |             signature_shapes=BINARY_SIGNATURE_SHAPES, | 
 |             signature_dtypes=[tf.float32, tf.int32]), | 
 |     "lgamma": | 
 |         tf_test_utils.unit_test_specs_from_signatures( | 
 |             signature_shapes=UNARY_SIGNATURE_SHAPES, | 
 |             signature_dtypes=[tf.float32]), | 
 |     "log": | 
 |         tf_test_utils.unit_test_specs_from_signatures( | 
 |             signature_shapes=UNARY_SIGNATURE_SHAPES, | 
 |             signature_dtypes=[tf.float32, tf.complex64]), | 
 |     "log1p": | 
 |         tf_test_utils.unit_test_specs_from_signatures( | 
 |             signature_shapes=UNARY_SIGNATURE_SHAPES, | 
 |             signature_dtypes=[tf.float32, tf.complex64]), | 
 |     "log_sigmoid": | 
 |         tf_test_utils.unit_test_specs_from_signatures( | 
 |             signature_shapes=UNARY_SIGNATURE_SHAPES, | 
 |             signature_dtypes=[tf.float32]), | 
 |     "log_softmax": | 
 |         tf_test_utils.unit_test_specs_from_signatures( | 
 |             signature_shapes=UNARY_SIGNATURE_SHAPES, | 
 |             signature_dtypes=[tf.float32]), | 
 |     "logical_and": | 
 |         tf_test_utils.unit_test_specs_from_signatures( | 
 |             signature_shapes=BINARY_SIGNATURE_SHAPES, | 
 |             signature_dtypes=[tf.bool]), | 
 |     "logical_not": | 
 |         tf_test_utils.unit_test_specs_from_signatures( | 
 |             signature_shapes=UNARY_SIGNATURE_SHAPES, | 
 |             signature_dtypes=[tf.bool]), | 
 |     "logical_or": | 
 |         tf_test_utils.unit_test_specs_from_signatures( | 
 |             signature_shapes=BINARY_SIGNATURE_SHAPES, | 
 |             signature_dtypes=[tf.bool]), | 
 |     "logical_xor": | 
 |         tf_test_utils.unit_test_specs_from_signatures( | 
 |             signature_shapes=BINARY_SIGNATURE_SHAPES, | 
 |             signature_dtypes=[tf.bool]), | 
 |     "maximum": | 
 |         tf_test_utils.unit_test_specs_from_signatures( | 
 |             signature_shapes=BINARY_SIGNATURE_SHAPES, | 
 |             signature_dtypes=[tf.float32, tf.int32]), | 
 |     "minimum": | 
 |         tf_test_utils.unit_test_specs_from_signatures( | 
 |             signature_shapes=BINARY_SIGNATURE_SHAPES, | 
 |             signature_dtypes=[tf.float32, tf.int32]), | 
 |     "mod": | 
 |         tf_test_utils.unit_test_specs_from_signatures( | 
 |             signature_shapes=BINARY_SIGNATURE_SHAPES, | 
 |             signature_dtypes=[tf.float32, tf.int32], | 
 |             input_generators={ | 
 |                 "positive_ndarange": lambda *args: tf_utils.ndarange(*args) + 1 | 
 |             }), | 
 |     "multiply": | 
 |         tf_test_utils.unit_test_specs_from_signatures( | 
 |             signature_shapes=BINARY_SIGNATURE_SHAPES, | 
 |             signature_dtypes=[tf.float32, tf.int32, tf.complex64]), | 
 |     "multiply_no_nan": | 
 |         tf_test_utils.unit_test_specs_from_signatures( | 
 |             signature_shapes=BINARY_SIGNATURE_SHAPES, | 
 |             signature_dtypes=[tf.float32, tf.complex64]), | 
 |     "ndtri": | 
 |         tf_test_utils.unit_test_specs_from_signatures( | 
 |             signature_shapes=UNARY_SIGNATURE_SHAPES, | 
 |             signature_dtypes=[tf.float32]), | 
 |     "negative": | 
 |         tf_test_utils.unit_test_specs_from_signatures( | 
 |             signature_shapes=UNARY_SIGNATURE_SHAPES, | 
 |             signature_dtypes=[tf.float32, tf.int32, tf.complex64]), | 
 |     "nextafter": | 
 |         tf_test_utils.unit_test_specs_from_signatures( | 
 |             signature_shapes=BINARY_SIGNATURE_SHAPES), | 
 |     "not_equal": | 
 |         tf_test_utils.unit_test_specs_from_signatures( | 
 |             signature_shapes=BINARY_SIGNATURE_SHAPES, | 
 |             signature_dtypes=[tf.float32, tf.int32]), | 
 |     "polygamma": | 
 |         tf_test_utils.unit_test_specs_from_args(names_to_input_args={ | 
 |             "nan_and_inf": [tf.ones(16), tf.linspace(0.5, 4, 16)] | 
 |         }), | 
 |     "polyval": [ | 
 |         tf_test_utils.UnitTestSpec( | 
 |             unit_test_name="three_coeffs", | 
 |             input_signature=[[tf.TensorSpec(RANK_7_SHAPE)] * 3, | 
 |                              tf.TensorSpec([])], | 
 |         ) | 
 |     ], | 
 |     "pow": | 
 |         tf_test_utils.unit_test_specs_from_signatures( | 
 |             signature_shapes=BINARY_SIGNATURE_SHAPES, | 
 |             signature_dtypes=[tf.float32, tf.int32, tf.complex64], | 
 |             input_generators={ | 
 |                 "positive_ndarange": lambda *args: tf_utils.ndarange(*args) + 1 | 
 |             }), | 
 |     "real": | 
 |         tf_test_utils.unit_test_specs_from_signatures( | 
 |             signature_shapes=UNARY_SIGNATURE_SHAPES, | 
 |             signature_dtypes=[tf.float32, tf.complex64]), | 
 |     "reciprocal": | 
 |         tf_test_utils.unit_test_specs_from_signatures( | 
 |             signature_shapes=UNARY_SIGNATURE_SHAPES, | 
 |             signature_dtypes=[tf.float32, tf.complex64]), | 
 |     "reciprocal_no_nan": | 
 |         tf_test_utils.unit_test_specs_from_signatures( | 
 |             signature_shapes=UNARY_SIGNATURE_SHAPES, | 
 |             signature_dtypes=[tf.float32, tf.complex64]), | 
 |     "reduce_all": [ | 
 |         # Explicitly test all True inputs to be absolutely sure that some | 
 |         # reduction axes return True. | 
 |         *tf_test_utils.unit_test_specs_from_args( | 
 |             names_to_input_args={ | 
 |                 "all_true": [np.ones(RANK_7_SHAPE, np.bool)], | 
 |             }, | 
 |             kwargs_to_values=REDUCE_KWARGS_TO_VALUES), | 
 |         *tf_test_utils.unit_test_specs_from_signatures( | 
 |             signature_shapes=UNARY_SIGNATURE_SHAPES, | 
 |             signature_dtypes=[tf.bool], | 
 |             kwargs_to_values=REDUCE_KWARGS_TO_VALUES), | 
 |     ], | 
 |     "reduce_any": [ | 
 |         # Explicitly test all False inputs to be absolutely sure that some | 
 |         # reduction axes return False. | 
 |         *tf_test_utils.unit_test_specs_from_args( | 
 |             names_to_input_args={ | 
 |                 "all_false": [np.zeros(RANK_7_SHAPE, np.bool)], | 
 |             }, | 
 |             kwargs_to_values=REDUCE_KWARGS_TO_VALUES), | 
 |         *tf_test_utils.unit_test_specs_from_signatures( | 
 |             signature_shapes=UNARY_SIGNATURE_SHAPES, | 
 |             signature_dtypes=[tf.bool], | 
 |             kwargs_to_values=REDUCE_KWARGS_TO_VALUES), | 
 |     ], | 
 |     "reduce_euclidean_norm": | 
 |         tf_test_utils.unit_test_specs_from_signatures( | 
 |             signature_shapes=UNARY_SIGNATURE_SHAPES, | 
 |             signature_dtypes=[tf.float32, tf.complex64], | 
 |             kwargs_to_values=REDUCE_KWARGS_TO_VALUES), | 
 |     "reduce_logsumexp": | 
 |         tf_test_utils.unit_test_specs_from_signatures( | 
 |             signature_shapes=UNARY_SIGNATURE_SHAPES, | 
 |             signature_dtypes=[tf.float32], | 
 |             kwargs_to_values=REDUCE_KWARGS_TO_VALUES), | 
 |     "reduce_max": | 
 |         tf_test_utils.unit_test_specs_from_signatures( | 
 |             signature_shapes=UNARY_SIGNATURE_SHAPES, | 
 |             signature_dtypes=[tf.float32, tf.int32], | 
 |             kwargs_to_values=REDUCE_KWARGS_TO_VALUES), | 
 |     "reduce_mean": | 
 |         tf_test_utils.unit_test_specs_from_signatures( | 
 |             signature_shapes=UNARY_SIGNATURE_SHAPES, | 
 |             signature_dtypes=[tf.float32, tf.int32], | 
 |             kwargs_to_values=REDUCE_KWARGS_TO_VALUES), | 
 |     "reduce_min": | 
 |         tf_test_utils.unit_test_specs_from_signatures( | 
 |             signature_shapes=UNARY_SIGNATURE_SHAPES, | 
 |             signature_dtypes=[tf.float32, tf.int32], | 
 |             kwargs_to_values=REDUCE_KWARGS_TO_VALUES), | 
 |     "reduce_prod": | 
 |         tf_test_utils.unit_test_specs_from_signatures( | 
 |             signature_shapes=UNARY_SIGNATURE_SHAPES, | 
 |             signature_dtypes=[tf.float32, tf.int32], | 
 |             kwargs_to_values=REDUCE_KWARGS_TO_VALUES), | 
 |     "reduce_std": | 
 |         tf_test_utils.unit_test_specs_from_signatures( | 
 |             signature_shapes=UNARY_SIGNATURE_SHAPES, | 
 |             signature_dtypes=[tf.float32, tf.complex64], | 
 |             kwargs_to_values=REDUCE_KWARGS_TO_VALUES), | 
 |     "reduce_sum": | 
 |         tf_test_utils.unit_test_specs_from_signatures( | 
 |             signature_shapes=UNARY_SIGNATURE_SHAPES, | 
 |             signature_dtypes=[tf.float32, tf.int32], | 
 |             kwargs_to_values=REDUCE_KWARGS_TO_VALUES), | 
 |     "reduce_variance": | 
 |         tf_test_utils.unit_test_specs_from_signatures( | 
 |             signature_shapes=UNARY_SIGNATURE_SHAPES, | 
 |             signature_dtypes=[tf.float32, tf.complex64], | 
 |             kwargs_to_values=REDUCE_KWARGS_TO_VALUES), | 
 |     "rint": | 
 |         tf_test_utils.unit_test_specs_from_signatures( | 
 |             signature_shapes=UNARY_SIGNATURE_SHAPES, | 
 |             signature_dtypes=[tf.float32]), | 
 |     "round": | 
 |         tf_test_utils.unit_test_specs_from_signatures( | 
 |             signature_shapes=UNARY_SIGNATURE_SHAPES, | 
 |             signature_dtypes=[tf.float32]), | 
 |     "rsqrt": | 
 |         tf_test_utils.unit_test_specs_from_signatures( | 
 |             signature_shapes=UNARY_SIGNATURE_SHAPES, | 
 |             signature_dtypes=[tf.float32, tf.complex64]), | 
 |     "scalar_mul": | 
 |         tf_test_utils.unit_test_specs_from_signatures( | 
 |             signature_shapes=[[[], [8]]]), | 
 |     "segment_max": | 
 |         SEGMENT_UNIT_TEST_SPECS, | 
 |     "segment_mean": | 
 |         SEGMENT_UNIT_TEST_SPECS, | 
 |     "segment_min": | 
 |         SEGMENT_UNIT_TEST_SPECS, | 
 |     "segment_prod": | 
 |         SEGMENT_UNIT_TEST_SPECS, | 
 |     "segment_sum": | 
 |         SEGMENT_UNIT_TEST_SPECS, | 
 |     "sigmoid": | 
 |         tf_test_utils.unit_test_specs_from_signatures( | 
 |             signature_shapes=UNARY_SIGNATURE_SHAPES, | 
 |             signature_dtypes=[tf.float32, tf.complex64]), | 
 |     "sign": | 
 |         tf_test_utils.unit_test_specs_from_signatures( | 
 |             signature_shapes=UNARY_SIGNATURE_SHAPES, | 
 |             signature_dtypes=[tf.float32, tf.int32, tf.complex64]), | 
 |     "sin": | 
 |         tf_test_utils.unit_test_specs_from_signatures( | 
 |             signature_shapes=UNARY_SIGNATURE_SHAPES, | 
 |             signature_dtypes=[tf.float32, tf.complex64]), | 
 |     "sinh": | 
 |         tf_test_utils.unit_test_specs_from_signatures( | 
 |             signature_shapes=UNARY_SIGNATURE_SHAPES, | 
 |             signature_dtypes=[tf.float32, tf.complex64]), | 
 |     "sobol_sample": | 
 |         tf_test_utils.unit_test_specs_from_args( | 
 |             names_to_input_args={"simple": [4, 3]}), | 
 |     "softmax": | 
 |         tf_test_utils.unit_test_specs_from_signatures( | 
 |             signature_shapes=UNARY_SIGNATURE_SHAPES, | 
 |             signature_dtypes=[tf.float32]), | 
 |     "softplus": | 
 |         tf_test_utils.unit_test_specs_from_signatures( | 
 |             signature_shapes=UNARY_SIGNATURE_SHAPES, | 
 |             signature_dtypes=[tf.float32]), | 
 |     "softsign": | 
 |         tf_test_utils.unit_test_specs_from_signatures( | 
 |             signature_shapes=UNARY_SIGNATURE_SHAPES, | 
 |             signature_dtypes=[tf.float32]), | 
 |     "sqrt": | 
 |         tf_test_utils.unit_test_specs_from_signatures( | 
 |             signature_shapes=UNARY_SIGNATURE_SHAPES, | 
 |             signature_dtypes=[tf.float32, tf.complex64]), | 
 |     "square": | 
 |         tf_test_utils.unit_test_specs_from_signatures( | 
 |             signature_shapes=UNARY_SIGNATURE_SHAPES, | 
 |             signature_dtypes=[tf.float32, tf.int32, tf.complex64]), | 
 |     "squared_difference": | 
 |         tf_test_utils.unit_test_specs_from_signatures( | 
 |             signature_shapes=BINARY_SIGNATURE_SHAPES, | 
 |             signature_dtypes=[tf.float32, tf.int32, tf.complex64]), | 
 |     "subtract": | 
 |         tf_test_utils.unit_test_specs_from_signatures( | 
 |             signature_shapes=BINARY_SIGNATURE_SHAPES, | 
 |             signature_dtypes=[tf.float32, tf.int32, tf.complex64]), | 
 |     "tan": | 
 |         tf_test_utils.unit_test_specs_from_signatures( | 
 |             signature_shapes=UNARY_SIGNATURE_SHAPES, | 
 |             signature_dtypes=[tf.float32, tf.complex64]), | 
 |     "tanh": | 
 |         tf_test_utils.unit_test_specs_from_signatures( | 
 |             signature_shapes=UNARY_SIGNATURE_SHAPES, | 
 |             signature_dtypes=[tf.float32, tf.complex64]), | 
 |     "top_k": | 
 |         tf_test_utils.unit_test_specs_from_signatures( | 
 |             signature_shapes=UNARY_SIGNATURE_SHAPES, | 
 |             signature_dtypes=[tf.float32], | 
 |             kwargs_to_values={"k": [1, 2]}), | 
 |     "truediv": | 
 |         tf_test_utils.unit_test_specs_from_signatures( | 
 |             signature_shapes=BINARY_SIGNATURE_SHAPES, | 
 |             signature_dtypes=[tf.float32, tf.complex64]), | 
 |     "unsorted_segment_max": | 
 |         UNSORTED_SEGMENT_UNIT_TEST_SPECS, | 
 |     "unsorted_segment_mean": | 
 |         UNSORTED_SEGMENT_UNIT_TEST_SPECS, | 
 |     "unsorted_segment_min": | 
 |         UNSORTED_SEGMENT_UNIT_TEST_SPECS, | 
 |     "unsorted_segment_prod": | 
 |         UNSORTED_SEGMENT_UNIT_TEST_SPECS, | 
 |     "unsorted_segment_sqrt_n": | 
 |         UNSORTED_SEGMENT_UNIT_TEST_SPECS, | 
 |     "unsorted_segment_sum": | 
 |         UNSORTED_SEGMENT_UNIT_TEST_SPECS, | 
 |     "xdivy": | 
 |         tf_test_utils.unit_test_specs_from_signatures( | 
 |             signature_shapes=BINARY_SIGNATURE_SHAPES, | 
 |             signature_dtypes=[tf.float32, tf.complex64]), | 
 |     "xlog1py": | 
 |         tf_test_utils.unit_test_specs_from_signatures( | 
 |             signature_shapes=BINARY_SIGNATURE_SHAPES, | 
 |             signature_dtypes=[tf.float32, tf.complex64]), | 
 |     "xlogy": | 
 |         tf_test_utils.unit_test_specs_from_signatures( | 
 |             signature_shapes=BINARY_SIGNATURE_SHAPES, | 
 |             signature_dtypes=[tf.float32, tf.complex64]), | 
 |     "zero_fraction": | 
 |         tf_test_utils.unit_test_specs_from_signatures( | 
 |             signature_shapes=UNARY_SIGNATURE_SHAPES, | 
 |             signature_dtypes=[tf.float32, tf.int32, tf.complex64]), | 
 |     "zeta": | 
 |         tf_test_utils.unit_test_specs_from_signatures( | 
 |             signature_shapes=BINARY_SIGNATURE_SHAPES, | 
 |             # The function is poorly behaved near zero, so we test this range | 
 |             # to avoid outputing all nans. | 
 |             input_generators={ | 
 |                 "uniform_3_4": | 
 |                     lambda *args: tf_utils.uniform(*args, low=3.0, high=4.0) | 
 |             }, | 
 |         ) | 
 | } | 
 |  | 
 | for function, specs in FUNCTIONS_TO_UNIT_TEST_SPECS.items(): | 
 |   # Update using 'with_name' to avoid updating shared UnitTestSpecs. | 
 |   specs = [ | 
 |       spec.with_name(f"{function}__{spec.unit_test_name}") for spec in specs | 
 |   ] | 
 |   FUNCTIONS_TO_UNIT_TEST_SPECS[function] = specs | 
 |  | 
 |   # Validate that there are not multiple UnitTestSpecs with the same name. | 
 |   seen_unit_test_names = set() | 
 |   for spec in specs: | 
 |     if spec.unit_test_name in seen_unit_test_names: | 
 |       raise ValueError( | 
 |           f"Found multiple UnitTestSpecs with the name '{spec.unit_test_name}'") | 
 |     seen_unit_test_names.add(spec.unit_test_name) | 
 |  | 
 | flags.DEFINE_list( | 
 |     "functions", None, | 
 |     f"Any of {list(FUNCTIONS_TO_UNIT_TEST_SPECS.keys())}. If more than one " | 
 |     "function is provided then len(--target_backends) must be one.") | 
 | flags.DEFINE_bool( | 
 |     "dynamic_dims", False, | 
 |     "Whether or not to compile the layer with dynamic dimensions.") | 
 | flags.DEFINE_bool( | 
 |     "test_complex", False, | 
 |     "Whether or not to test or ignore function signatures with complex types.") | 
 | flags.DEFINE_bool( | 
 |     'list_functions_with_complex_tests', False, | 
 |     'Whether or not to print out all functions with complex inputs ' | 
 |     '(and skip running the tests).') | 
 |  | 
 |  | 
 | def create_function_unit_test( | 
 |     function_name: str, | 
 |     unit_test_spec: tf_test_utils.UnitTestSpec) -> tf.function: | 
 |   """Creates a tf_function_unit_test from the provided UnitTestSpec.""" | 
 |   function = getattr(tf.math, function_name) | 
 |   signature = unit_test_spec.input_signature | 
 |  | 
 |   if tf_utils.is_complex(signature): | 
 |     function, signature = tf_utils.rewrite_complex_signature( | 
 |         function, signature) | 
 |   wrapped_function = lambda *args: function(*args, **unit_test_spec.kwargs) | 
 |  | 
 |   if FLAGS.dynamic_dims: | 
 |     signature = tf_utils.apply_function(signature, tf_utils.make_dims_dynamic) | 
 |  | 
 |   return tf_test_utils.tf_function_unit_test( | 
 |       input_signature=signature, | 
 |       input_generator=unit_test_spec.input_generator, | 
 |       input_args=unit_test_spec.input_args, | 
 |       name=unit_test_spec.unit_test_name, | 
 |       rtol=1e-5, | 
 |       atol=1e-5)(wrapped_function) | 
 |  | 
 |  | 
 | class TfMathModule(tf_test_utils.TestModule): | 
 |  | 
 |   def __init__(self): | 
 |     super().__init__() | 
 |     for function in FLAGS.functions: | 
 |       for unit_test_spec in FUNCTIONS_TO_UNIT_TEST_SPECS[function]: | 
 |         if not FLAGS.test_complex and tf_utils.is_complex( | 
 |             unit_test_spec.input_signature): | 
 |           continue | 
 |         function_unit_test = create_function_unit_test(function, unit_test_spec) | 
 |         setattr(self, unit_test_spec.unit_test_name, function_unit_test) | 
 |  | 
 |  | 
 | def get_relative_artifacts_dir() -> str: | 
 |   if len(FLAGS.functions) > 1: | 
 |     # We only allow testing multiple functions with a single target backend | 
 |     # so that we can store the artifacts under: | 
 |     #   'artifacts_dir/multiple_functions__backend/...' | 
 |     # We specialize the 'multiple_functions' dir by backend to avoid overwriting | 
 |     # tf_input.mlir and iree_input.mlir. These are typically identical across | 
 |     # backends, but are not when the functions to compile change per-backend. | 
 |     if len(FLAGS.target_backends) != 1: | 
 |       raise flags.IllegalFlagValueError( | 
 |           "Expected len(target_backends) == 1 when len(functions) > 1, but got " | 
 |           f"the following values for target_backends: {FLAGS.target_backends}.") | 
 |     function_str = f"multiple_functions__{FLAGS.target_backends[0]}" | 
 |   else: | 
 |     function_str = FLAGS.functions[0] | 
 |   dim_str = "dynamic_dims" if FLAGS.dynamic_dims else "static_dims" | 
 |   complex_str = "complex" if FLAGS.test_complex else "non_complex" | 
 |   return os.path.join("tf", "math", function_str, f"{dim_str}_{complex_str}") | 
 |  | 
 |  | 
 | class TfMathTest(tf_test_utils.TracedModuleTestCase): | 
 |  | 
 |   def __init__(self, *args, **kwargs): | 
 |     super().__init__(*args, **kwargs) | 
 |     self._modules = tf_test_utils.compile_tf_module( | 
 |         TfMathModule, | 
 |         exported_names=TfMathModule.get_tf_function_unit_tests(), | 
 |         relative_artifacts_dir=get_relative_artifacts_dir()) | 
 |  | 
 |  | 
 | def main(argv): | 
 |   del argv  # Unused. | 
 |   if hasattr(tf, "enable_v2_behavior"): | 
 |     tf.enable_v2_behavior() | 
 |  | 
 |   if FLAGS.list_functions_with_complex_tests: | 
 |     for function_name, unit_test_specs in FUNCTIONS_TO_UNIT_TEST_SPECS.items(): | 
 |       for spec in unit_test_specs: | 
 |         if tf_utils.is_complex(spec.input_signature): | 
 |           print(f'    "{function_name}",') | 
 |     return | 
 |  | 
 |   if FLAGS.functions is None: | 
 |     raise flags.IllegalFlagValueError( | 
 |         "'--functions' must be specified if " | 
 |         "'--list_functions_with_complex_tests' isn't") | 
 |  | 
 |   TfMathTest.generate_unit_tests(TfMathModule) | 
 |   tf.test.main() | 
 |  | 
 |  | 
 | if __name__ == "__main__": | 
 |   app.run(main) |