blob: 6ec57b4d442604c8512ab23eda037c260bcc6039 [file] [log] [blame]
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import os
from typing import Any, Dict, Sequence, Type, Union
from absl import app
from absl import flags
import numpy as np
from pyiree.tf.support import tf_test_utils
from pyiree.tf.support import tf_utils
import tensorflow.compat.v2 as tf
FLAGS = flags.FLAGS
# As high as tf goes without breaking.
RANK_7_SHAPE = [2] * 7
UNARY_SIGNATURE_SHAPES = [[RANK_7_SHAPE]]
BINARY_SIGNATURE_SHAPES = [[RANK_7_SHAPE] * 2]
TERNARY_SIGNATURE_SHAPES = [[RANK_7_SHAPE] * 3]
# Reused UnitTestSpecs.
SEGMENT_UNIT_TEST_SPECS = tf_test_utils.unit_test_specs_from_args(
names_to_input_args={
"tf_doc_example": [
tf.constant([
[1, 2, 3, 4],
[4, 3, 2, 1],
[5, 6, 7, 8],
], np.float32),
np.array([0, 0, 1], np.int32),
]
})
UNSORTED_SEGMENT_UNIT_TEST_SPECS = tf_test_utils.unit_test_specs_from_args(
names_to_input_args={
"tf_doc_example": [
tf.constant([
[1, 2, 3, 4],
[4, 3, 2, 1],
[5, 6, 7, 8],
], np.float32),
np.array([0, 0, 1], np.int32),
2,
]
})
REDUCE_KWARGS_TO_VALUES = {
"axis": [None, 1],
"keepdims": [False, True],
}
# A dictionary mapping tf.math function names to lists of UnitTestSpecs.
# Each unit_test_name will have the tf.math function name prepended to it.
FUNCTIONS_TO_UNIT_TEST_SPECS = {
"abs":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=UNARY_SIGNATURE_SHAPES,
signature_dtypes=[tf.float32, tf.int32, tf.complex64]),
"accumulate_n": [
tf_test_utils.UnitTestSpec(
unit_test_name='f32',
input_signature=[[tf.TensorSpec(RANK_7_SHAPE, tf.float32)] * 5]),
tf_test_utils.UnitTestSpec(
unit_test_name='i32',
input_signature=[[tf.TensorSpec(RANK_7_SHAPE, tf.int32)] * 5]),
],
"acos":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=UNARY_SIGNATURE_SHAPES,
signature_dtypes=[tf.float32]),
"acosh":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=UNARY_SIGNATURE_SHAPES,
signature_dtypes=[tf.float32],
input_generators=[tf_utils.ndarange]),
"add":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=BINARY_SIGNATURE_SHAPES,
signature_dtypes=[tf.float32, tf.int32, tf.complex64]),
"add_n": [
tf_test_utils.UnitTestSpec(
unit_test_name='f32',
input_signature=[[tf.TensorSpec(RANK_7_SHAPE, tf.float32)] * 5]),
tf_test_utils.UnitTestSpec(
unit_test_name='i32',
input_signature=[[tf.TensorSpec(RANK_7_SHAPE, tf.int32)] * 5]),
],
"angle":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=UNARY_SIGNATURE_SHAPES,
signature_dtypes=[tf.float32, tf.complex64]),
"argmax":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=UNARY_SIGNATURE_SHAPES,
signature_dtypes=[tf.float32, tf.int32]),
"argmin":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=UNARY_SIGNATURE_SHAPES,
signature_dtypes=[tf.float32, tf.int32]),
"asin":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=UNARY_SIGNATURE_SHAPES,
signature_dtypes=[tf.float32]),
"asinh":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=UNARY_SIGNATURE_SHAPES,
signature_dtypes=[tf.float32, tf.complex64]),
"atan":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=UNARY_SIGNATURE_SHAPES,
signature_dtypes=[tf.float32]),
"atan2":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=BINARY_SIGNATURE_SHAPES,
signature_dtypes=[tf.float32]),
"atanh":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=UNARY_SIGNATURE_SHAPES,
signature_dtypes=[tf.float32, tf.complex64]),
"bessel_i0":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=UNARY_SIGNATURE_SHAPES,
signature_dtypes=[tf.float32]),
"bessel_i0e":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=UNARY_SIGNATURE_SHAPES,
signature_dtypes=[tf.float32]),
"bessel_i1":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=UNARY_SIGNATURE_SHAPES,
signature_dtypes=[tf.float32]),
"bessel_i1e":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=UNARY_SIGNATURE_SHAPES,
signature_dtypes=[tf.float32]),
"betainc":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=TERNARY_SIGNATURE_SHAPES,
signature_dtypes=[tf.float32]),
"bincount":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=UNARY_SIGNATURE_SHAPES,
signature_dtypes=[tf.int32],
input_generators=[tf_utils.ndarange]),
"ceil":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=UNARY_SIGNATURE_SHAPES,
signature_dtypes=[tf.float32]),
"confusion_matrix":
tf_test_utils.unit_test_specs_from_args(names_to_input_args={
"five_classes": [tf.constant([1, 2, 4]),
tf.constant([2, 2, 4])]
}),
"conj":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=UNARY_SIGNATURE_SHAPES,
signature_dtypes=[tf.float32, tf.int32, tf.complex64]),
"cos":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=UNARY_SIGNATURE_SHAPES,
signature_dtypes=[tf.float32, tf.complex64]),
"cosh":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=UNARY_SIGNATURE_SHAPES,
signature_dtypes=[tf.float32, tf.complex64]),
"count_nonzero":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=UNARY_SIGNATURE_SHAPES,
signature_dtypes=[tf.float32, tf.int32, tf.complex64],
input_generators=[tf_utils.ndarange]),
"cumprod":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=UNARY_SIGNATURE_SHAPES,
signature_dtypes=[tf.float32, tf.int32, tf.complex64]),
"cumsum":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=UNARY_SIGNATURE_SHAPES,
signature_dtypes=[tf.float32, tf.int32, tf.complex64]),
"cumulative_logsumexp":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=UNARY_SIGNATURE_SHAPES,
signature_dtypes=[tf.float32]),
"digamma":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=UNARY_SIGNATURE_SHAPES,
signature_dtypes=[tf.float32]),
"divide":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=BINARY_SIGNATURE_SHAPES,
signature_dtypes=[tf.float32, tf.int32, tf.complex64]),
"divide_no_nan":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=BINARY_SIGNATURE_SHAPES,
signature_dtypes=[tf.float32, tf.complex64]),
"equal":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=BINARY_SIGNATURE_SHAPES,
signature_dtypes=[tf.float32, tf.int32]),
"erf":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=UNARY_SIGNATURE_SHAPES,
signature_dtypes=[tf.float32]),
"erfc":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=UNARY_SIGNATURE_SHAPES,
signature_dtypes=[tf.float32]),
"erfinv":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=UNARY_SIGNATURE_SHAPES,
signature_dtypes=[tf.float32]),
"exp":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=UNARY_SIGNATURE_SHAPES,
signature_dtypes=[tf.float32, tf.complex64]),
"expm1":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=UNARY_SIGNATURE_SHAPES,
signature_dtypes=[tf.float32, tf.complex64]),
"floor":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=UNARY_SIGNATURE_SHAPES,
signature_dtypes=[tf.float32]),
"floordiv":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=BINARY_SIGNATURE_SHAPES,
signature_dtypes=[tf.float32, tf.int32],
# Avoid integer division by 0.
input_generators={
"uniform_1_3":
lambda *args: tf_utils.uniform(*args, low=1.0, high=3.0)
}),
"floormod":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=BINARY_SIGNATURE_SHAPES,
signature_dtypes=[tf.float32, tf.int32],
# Avoid integer division by 0.
input_generators={
"uniform_1_3":
lambda *args: tf_utils.uniform(*args, low=1.0, high=3.0)
}),
"greater":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=BINARY_SIGNATURE_SHAPES,
signature_dtypes=[tf.float32, tf.int32]),
"greater_equal":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=BINARY_SIGNATURE_SHAPES,
signature_dtypes=[tf.float32, tf.int32]),
"igamma":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=BINARY_SIGNATURE_SHAPES,
signature_dtypes=[tf.float32]),
"igammac":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=BINARY_SIGNATURE_SHAPES,
signature_dtypes=[tf.float32]),
"imag":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=UNARY_SIGNATURE_SHAPES,
signature_dtypes=[tf.float32, tf.complex64]),
"in_top_k": [
tf_test_utils.UnitTestSpec(
unit_test_name="k_3",
input_signature=[
tf.TensorSpec([8], tf.int32),
tf.TensorSpec([8, 3])
],
input_generator=tf_utils.ndarange,
kwargs=dict(k=3),
)
],
"invert_permutation": [
tf_test_utils.UnitTestSpec(
unit_test_name="random",
input_signature=[tf.TensorSpec([8], tf.int32)],
input_generator=tf_utils.random_permutation,
)
],
"is_finite":
tf_test_utils.unit_test_specs_from_args(names_to_input_args={
"nan_and_inf": [tf.constant([[1., np.nan], [np.inf, 2.]])]
}),
"is_inf":
tf_test_utils.unit_test_specs_from_args(names_to_input_args={
"nan_and_inf": [tf.constant([[1., np.nan], [np.inf, 2.]])]
}),
"is_nan":
tf_test_utils.unit_test_specs_from_args(names_to_input_args={
"nan_and_inf": [tf.constant([[1., np.nan], [np.inf, 2.]])]
}),
"is_non_decreasing":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=UNARY_SIGNATURE_SHAPES,
signature_dtypes=[tf.float32, tf.int32]),
"is_strictly_increasing":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=UNARY_SIGNATURE_SHAPES,
signature_dtypes=[tf.float32, tf.int32]),
"l2_normalize":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=UNARY_SIGNATURE_SHAPES,
signature_dtypes=[tf.float32, tf.complex64]),
"lbeta":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=UNARY_SIGNATURE_SHAPES,
signature_dtypes=[tf.float32]),
"less":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=BINARY_SIGNATURE_SHAPES,
signature_dtypes=[tf.float32, tf.int32]),
"less_equal":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=BINARY_SIGNATURE_SHAPES,
signature_dtypes=[tf.float32, tf.int32]),
"lgamma":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=UNARY_SIGNATURE_SHAPES,
signature_dtypes=[tf.float32]),
"log":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=UNARY_SIGNATURE_SHAPES,
signature_dtypes=[tf.float32, tf.complex64]),
"log1p":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=UNARY_SIGNATURE_SHAPES,
signature_dtypes=[tf.float32, tf.complex64]),
"log_sigmoid":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=UNARY_SIGNATURE_SHAPES,
signature_dtypes=[tf.float32]),
"log_softmax":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=UNARY_SIGNATURE_SHAPES,
signature_dtypes=[tf.float32]),
"logical_and":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=BINARY_SIGNATURE_SHAPES,
signature_dtypes=[tf.bool]),
"logical_not":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=UNARY_SIGNATURE_SHAPES,
signature_dtypes=[tf.bool]),
"logical_or":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=BINARY_SIGNATURE_SHAPES,
signature_dtypes=[tf.bool]),
"logical_xor":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=BINARY_SIGNATURE_SHAPES,
signature_dtypes=[tf.bool]),
"maximum":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=BINARY_SIGNATURE_SHAPES,
signature_dtypes=[tf.float32, tf.int32]),
"minimum":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=BINARY_SIGNATURE_SHAPES,
signature_dtypes=[tf.float32, tf.int32]),
"mod":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=BINARY_SIGNATURE_SHAPES,
signature_dtypes=[tf.float32, tf.int32],
input_generators={
"positive_ndarange": lambda *args: tf_utils.ndarange(*args) + 1
}),
"multiply":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=BINARY_SIGNATURE_SHAPES,
signature_dtypes=[tf.float32, tf.int32, tf.complex64]),
"multiply_no_nan":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=BINARY_SIGNATURE_SHAPES,
signature_dtypes=[tf.float32, tf.complex64]),
"ndtri":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=UNARY_SIGNATURE_SHAPES,
signature_dtypes=[tf.float32]),
"negative":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=UNARY_SIGNATURE_SHAPES,
signature_dtypes=[tf.float32, tf.int32, tf.complex64]),
"nextafter":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=BINARY_SIGNATURE_SHAPES),
"not_equal":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=BINARY_SIGNATURE_SHAPES,
signature_dtypes=[tf.float32, tf.int32]),
"polygamma":
tf_test_utils.unit_test_specs_from_args(names_to_input_args={
"nan_and_inf": [tf.ones(16), tf.linspace(0.5, 4, 16)]
}),
"polyval": [
tf_test_utils.UnitTestSpec(
unit_test_name="three_coeffs",
input_signature=[[tf.TensorSpec(RANK_7_SHAPE)] * 3,
tf.TensorSpec([])],
)
],
"pow":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=BINARY_SIGNATURE_SHAPES,
signature_dtypes=[tf.float32, tf.int32, tf.complex64],
input_generators={
"positive_ndarange": lambda *args: tf_utils.ndarange(*args) + 1
}),
"real":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=UNARY_SIGNATURE_SHAPES,
signature_dtypes=[tf.float32, tf.complex64]),
"reciprocal":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=UNARY_SIGNATURE_SHAPES,
signature_dtypes=[tf.float32, tf.complex64]),
"reciprocal_no_nan":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=UNARY_SIGNATURE_SHAPES,
signature_dtypes=[tf.float32, tf.complex64]),
"reduce_all": [
# Explicitly test all True inputs to be absolutely sure that some
# reduction axes return True.
*tf_test_utils.unit_test_specs_from_args(
names_to_input_args={
"all_true": [np.ones(RANK_7_SHAPE, np.bool)],
},
kwargs_to_values=REDUCE_KWARGS_TO_VALUES),
*tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=UNARY_SIGNATURE_SHAPES,
signature_dtypes=[tf.bool],
kwargs_to_values=REDUCE_KWARGS_TO_VALUES),
],
"reduce_any": [
# Explicitly test all False inputs to be absolutely sure that some
# reduction axes return False.
*tf_test_utils.unit_test_specs_from_args(
names_to_input_args={
"all_false": [np.zeros(RANK_7_SHAPE, np.bool)],
},
kwargs_to_values=REDUCE_KWARGS_TO_VALUES),
*tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=UNARY_SIGNATURE_SHAPES,
signature_dtypes=[tf.bool],
kwargs_to_values=REDUCE_KWARGS_TO_VALUES),
],
"reduce_euclidean_norm":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=UNARY_SIGNATURE_SHAPES,
signature_dtypes=[tf.float32, tf.complex64],
kwargs_to_values=REDUCE_KWARGS_TO_VALUES),
"reduce_logsumexp":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=UNARY_SIGNATURE_SHAPES,
signature_dtypes=[tf.float32],
kwargs_to_values=REDUCE_KWARGS_TO_VALUES),
"reduce_max":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=UNARY_SIGNATURE_SHAPES,
signature_dtypes=[tf.float32, tf.int32],
kwargs_to_values=REDUCE_KWARGS_TO_VALUES),
"reduce_mean":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=UNARY_SIGNATURE_SHAPES,
signature_dtypes=[tf.float32, tf.int32],
kwargs_to_values=REDUCE_KWARGS_TO_VALUES),
"reduce_min":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=UNARY_SIGNATURE_SHAPES,
signature_dtypes=[tf.float32, tf.int32],
kwargs_to_values=REDUCE_KWARGS_TO_VALUES),
"reduce_prod":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=UNARY_SIGNATURE_SHAPES,
signature_dtypes=[tf.float32, tf.int32],
kwargs_to_values=REDUCE_KWARGS_TO_VALUES),
"reduce_std":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=UNARY_SIGNATURE_SHAPES,
signature_dtypes=[tf.float32, tf.complex64],
kwargs_to_values=REDUCE_KWARGS_TO_VALUES),
"reduce_sum":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=UNARY_SIGNATURE_SHAPES,
signature_dtypes=[tf.float32, tf.int32],
kwargs_to_values=REDUCE_KWARGS_TO_VALUES),
"reduce_variance":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=UNARY_SIGNATURE_SHAPES,
signature_dtypes=[tf.float32, tf.complex64],
kwargs_to_values=REDUCE_KWARGS_TO_VALUES),
"rint":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=UNARY_SIGNATURE_SHAPES,
signature_dtypes=[tf.float32]),
"round":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=UNARY_SIGNATURE_SHAPES,
signature_dtypes=[tf.float32]),
"rsqrt":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=UNARY_SIGNATURE_SHAPES,
signature_dtypes=[tf.float32, tf.complex64]),
"scalar_mul":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=[[[], [8]]]),
"segment_max":
SEGMENT_UNIT_TEST_SPECS,
"segment_mean":
SEGMENT_UNIT_TEST_SPECS,
"segment_min":
SEGMENT_UNIT_TEST_SPECS,
"segment_prod":
SEGMENT_UNIT_TEST_SPECS,
"segment_sum":
SEGMENT_UNIT_TEST_SPECS,
"sigmoid":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=UNARY_SIGNATURE_SHAPES,
signature_dtypes=[tf.float32, tf.complex64]),
"sign":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=UNARY_SIGNATURE_SHAPES,
signature_dtypes=[tf.float32, tf.int32, tf.complex64]),
"sin":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=UNARY_SIGNATURE_SHAPES,
signature_dtypes=[tf.float32, tf.complex64]),
"sinh":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=UNARY_SIGNATURE_SHAPES,
signature_dtypes=[tf.float32, tf.complex64]),
"sobol_sample":
tf_test_utils.unit_test_specs_from_args(
names_to_input_args={"simple": [4, 3]}),
"softmax":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=UNARY_SIGNATURE_SHAPES,
signature_dtypes=[tf.float32]),
"softplus":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=UNARY_SIGNATURE_SHAPES,
signature_dtypes=[tf.float32]),
"softsign":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=UNARY_SIGNATURE_SHAPES,
signature_dtypes=[tf.float32]),
"sqrt":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=UNARY_SIGNATURE_SHAPES,
signature_dtypes=[tf.float32, tf.complex64]),
"square":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=UNARY_SIGNATURE_SHAPES,
signature_dtypes=[tf.float32, tf.int32, tf.complex64]),
"squared_difference":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=BINARY_SIGNATURE_SHAPES,
signature_dtypes=[tf.float32, tf.int32, tf.complex64]),
"subtract":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=BINARY_SIGNATURE_SHAPES,
signature_dtypes=[tf.float32, tf.int32, tf.complex64]),
"tan":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=UNARY_SIGNATURE_SHAPES,
signature_dtypes=[tf.float32, tf.complex64]),
"tanh":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=UNARY_SIGNATURE_SHAPES,
signature_dtypes=[tf.float32, tf.complex64]),
"top_k":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=UNARY_SIGNATURE_SHAPES,
signature_dtypes=[tf.float32],
kwargs_to_values={"k": [1, 2]}),
"truediv":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=BINARY_SIGNATURE_SHAPES,
signature_dtypes=[tf.float32, tf.complex64]),
"unsorted_segment_max":
UNSORTED_SEGMENT_UNIT_TEST_SPECS,
"unsorted_segment_mean":
UNSORTED_SEGMENT_UNIT_TEST_SPECS,
"unsorted_segment_min":
UNSORTED_SEGMENT_UNIT_TEST_SPECS,
"unsorted_segment_prod":
UNSORTED_SEGMENT_UNIT_TEST_SPECS,
"unsorted_segment_sqrt_n":
UNSORTED_SEGMENT_UNIT_TEST_SPECS,
"unsorted_segment_sum":
UNSORTED_SEGMENT_UNIT_TEST_SPECS,
"xdivy":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=BINARY_SIGNATURE_SHAPES,
signature_dtypes=[tf.float32, tf.complex64]),
"xlog1py":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=BINARY_SIGNATURE_SHAPES,
signature_dtypes=[tf.float32, tf.complex64]),
"xlogy":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=BINARY_SIGNATURE_SHAPES,
signature_dtypes=[tf.float32, tf.complex64]),
"zero_fraction":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=UNARY_SIGNATURE_SHAPES,
signature_dtypes=[tf.float32, tf.int32, tf.complex64]),
"zeta":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=BINARY_SIGNATURE_SHAPES,
# The function is poorly behaved near zero, so we test this range
# to avoid outputing all nans.
input_generators={
"uniform_3_4":
lambda *args: tf_utils.uniform(*args, low=3.0, high=4.0)
},
)
}
for function, specs in FUNCTIONS_TO_UNIT_TEST_SPECS.items():
# Update using 'with_name' to avoid updating shared UnitTestSpecs.
specs = [
spec.with_name(f"{function}__{spec.unit_test_name}") for spec in specs
]
FUNCTIONS_TO_UNIT_TEST_SPECS[function] = specs
# Validate that there are not multiple UnitTestSpecs with the same name.
seen_unit_test_names = set()
for spec in specs:
if spec.unit_test_name in seen_unit_test_names:
raise ValueError(
f"Found multiple UnitTestSpecs with the name '{spec.unit_test_name}'")
seen_unit_test_names.add(spec.unit_test_name)
flags.DEFINE_list(
"functions", None,
f"Any of {list(FUNCTIONS_TO_UNIT_TEST_SPECS.keys())}. If more than one "
"function is provided then len(--target_backends) must be one.")
flags.DEFINE_bool(
"dynamic_dims", False,
"Whether or not to compile the layer with dynamic dimensions.")
flags.DEFINE_bool(
"test_complex", False,
"Whether or not to test or ignore function signatures with complex types.")
flags.DEFINE_bool(
'list_functions_with_complex_tests', False,
'Whether or not to print out all functions with complex inputs '
'(and skip running the tests).')
def create_function_unit_test(
function_name: str,
unit_test_spec: tf_test_utils.UnitTestSpec) -> tf.function:
"""Creates a tf_function_unit_test from the provided UnitTestSpec."""
function = getattr(tf.math, function_name)
signature = unit_test_spec.input_signature
if tf_utils.is_complex(signature):
function, signature = tf_utils.rewrite_complex_signature(
function, signature)
wrapped_function = lambda *args: function(*args, **unit_test_spec.kwargs)
if FLAGS.dynamic_dims:
signature = tf_utils.apply_function(signature, tf_utils.make_dims_dynamic)
return tf_test_utils.tf_function_unit_test(
input_signature=signature,
input_generator=unit_test_spec.input_generator,
input_args=unit_test_spec.input_args,
name=unit_test_spec.unit_test_name,
rtol=1e-5,
atol=1e-5)(wrapped_function)
class TfMathModule(tf_test_utils.TestModule):
def __init__(self):
super().__init__()
for function in FLAGS.functions:
for unit_test_spec in FUNCTIONS_TO_UNIT_TEST_SPECS[function]:
if not FLAGS.test_complex and tf_utils.is_complex(
unit_test_spec.input_signature):
continue
function_unit_test = create_function_unit_test(function, unit_test_spec)
setattr(self, unit_test_spec.unit_test_name, function_unit_test)
def get_relative_artifacts_dir() -> str:
if len(FLAGS.functions) > 1:
# We only allow testing multiple functions with a single target backend
# so that we can store the artifacts under:
# 'artifacts_dir/multiple_functions__backend/...'
# We specialize the 'multiple_functions' dir by backend to avoid overwriting
# tf_input.mlir and iree_input.mlir. These are typically identical across
# backends, but are not when the functions to compile change per-backend.
if len(FLAGS.target_backends) != 1:
raise flags.IllegalFlagValueError(
"Expected len(target_backends) == 1 when len(functions) > 1, but got "
f"the following values for target_backends: {FLAGS.target_backends}.")
function_str = f"multiple_functions__{FLAGS.target_backends[0]}"
else:
function_str = FLAGS.functions[0]
dim_str = "dynamic_dims" if FLAGS.dynamic_dims else "static_dims"
complex_str = "complex" if FLAGS.test_complex else "non_complex"
return os.path.join("tf", "math", function_str, f"{dim_str}_{complex_str}")
class TfMathTest(tf_test_utils.TracedModuleTestCase):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._modules = tf_test_utils.compile_tf_module(
TfMathModule,
exported_names=TfMathModule.get_tf_function_unit_tests(),
relative_artifacts_dir=get_relative_artifacts_dir())
def main(argv):
del argv # Unused.
if hasattr(tf, "enable_v2_behavior"):
tf.enable_v2_behavior()
if FLAGS.list_functions_with_complex_tests:
for function_name, unit_test_specs in FUNCTIONS_TO_UNIT_TEST_SPECS.items():
for spec in unit_test_specs:
if tf_utils.is_complex(spec.input_signature):
print(f' "{function_name}",')
return
if FLAGS.functions is None:
raise flags.IllegalFlagValueError(
"'--functions' must be specified if "
"'--list_functions_with_complex_tests' isn't")
TfMathTest.generate_unit_tests(TfMathModule)
tf.test.main()
if __name__ == "__main__":
app.run(main)