Centralize tf.math coverage and expand to all functions (#3826)
Adds tests for the functions in the `tf.math` API, with dynamic and complex variants.
To avoid adding another 1332 tests of `bazel` overhead to the CI while still generating coverage tables, I created two versions of the test suites: one that CI will run that detects any new errors, and one that creates a test for each function and backend. The tests that this change adds to the CI are:
```bazel
//integrations/tensorflow/e2e/math:math_complex_tests_multiple__iree_llvmjit PASSED in 26.9s
//integrations/tensorflow/e2e/math:math_complex_tests_multiple__iree_vmla PASSED in 15.3s
//integrations/tensorflow/e2e/math:math_complex_tests_multiple__iree_vulkan PASSED in 72.1s
//integrations/tensorflow/e2e/math:math_complex_tests_multiple__tf PASSED in 38.1s
//integrations/tensorflow/e2e/math:math_dynamic_dims_tests_multiple__iree_llvmjit PASSED in 7.9s
//integrations/tensorflow/e2e/math:math_dynamic_dims_tests_multiple__iree_vmla PASSED in 13.9s
//integrations/tensorflow/e2e/math:math_dynamic_dims_tests_multiple__iree_vulkan PASSED in 4.5s
//integrations/tensorflow/e2e/math:math_dynamic_dims_tests_multiple__tf PASSED in 38.0s
//integrations/tensorflow/e2e/math:math_tests_multiple__iree_llvmjit PASSED in 41.5s
//integrations/tensorflow/e2e/math:math_tests_multiple__iree_vmla PASSED in 19.7s
//integrations/tensorflow/e2e/math:math_tests_multiple__iree_vulkan PASSED in 106.6s
//integrations/tensorflow/e2e/math:math_tests_multiple__tf PASSED in 38.2s
//integrations/tensorflow/e2e/math:math_tests_multiple__tflite PASSED in 21.1s
```
The following tests are deleted:
- `bool_test.py`
- `complex_test.py`
- `finite_test.py`
- `logical_ops_test.py`
`math_test.py` and `math_dyn_test.py` are replaced by `quantization_test.py` and `quantization_dyn_test.py`, since a fake quant test was added to them after the PR was made.
Supporting changes:
- Allow `tf.Tensor`s to be passed to traces. (Numpy's unchangeable `float64` default is cumbersome).
- Add `set_minus` to `bazel`.
- Expand support for input generation (e.g. uniform bools).
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
index 49eb68a..e18b5ce 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
@@ -26,11 +26,13 @@
import copy
import glob
import inspect
+import itertools
import os
import pickle
+import re
import sys
import tempfile
-from typing import Any, Callable, Dict, Sequence, Set, Tuple, Type, Union
+from typing import Any, Callable, Dict, List, Sequence, Set, Tuple, Type, Union
from absl import flags
from absl import logging
@@ -40,8 +42,8 @@
flags.DEFINE_string("reference_backend", "tf",
"The backend to treat as a source of truth.")
-flags.DEFINE_string("target_backends", None,
- "Explicit comma-delimited list of target backends.")
+flags.DEFINE_list("target_backends", None,
+ "Explicit comma-delimited list of target backends.")
flags.DEFINE_string(
"artifacts_dir", None,
"Specifies a directory to dump compilation artifacts and traces to. "
@@ -56,6 +58,7 @@
"Creates and stores a SavedModel for the tf.Module class to be tested.")
FLAGS = flags.FLAGS
NUMPY_LINEWIDTH = 120
+DEFAULT_INPUT_GENERATOR = tf_utils.uniform
def _setup_artifacts_dir(module_name: str) -> str:
@@ -76,7 +79,7 @@
def _parse_target_backends() -> Tuple[Sequence[str], Sequence[str]]:
"""Decodes --target_backends and creates unique ids for them."""
- backend_names = FLAGS.target_backends.split(",")
+ backend_names = FLAGS.target_backends
backend_to_index = {k: 0 for k in backend_names if backend_names.count(k) > 1}
backend_ids = []
@@ -141,11 +144,6 @@
"""Records the details of a call to a CompiledModule."""
self.method = method
- for value in inputs:
- if isinstance(value, tf.Tensor):
- raise TypeError("Expected inputs to be native python types or numpy "
- f"arrays, but got {type(value)}")
-
# Deepcopy to safegard against mutation.
self.inputs = copy.deepcopy(inputs)
if outputs is not None:
@@ -574,6 +572,10 @@
# Only pass these to ModuleCall if they were specified by the user.
tolerances = {k: v for k, v in tolerances.items() if v is not None}
+ # Ensure the inputs are numpy inputs.
+ args = tf_utils.convert_to_numpy(args)
+ kwargs = tf_utils.convert_to_numpy(kwargs)
+
# Run the method and record the details of the call.
outputs = method(*args, **kwargs)
serialized_inputs, serialized_outputs = method.get_serialized_values()
@@ -602,7 +604,7 @@
# We have to use a global variable to store the compiled modules so that we can
# avoid recompilation. This is because the TestCase class resets it's entire
-# state and calls __init__ before each unittest. It also calls __init__ one
+# state and calls __init__ before each unit_test. It also calls __init__ one
# additional time before that for good measure, which means without storing the
# modules somewhere else we would have to compile each of them at least twice.
# We can't store the modules on the class itself via setUpClass because of #2900
@@ -693,13 +695,229 @@
return _global_modules
-def tf_function_unittest(input_generator: tf_utils.InputGeneratorType = None,
- input_args: Sequence[Any] = None,
- atol: float = None,
- rtol: float = None,
- name: str = None,
- **tf_function_kwargs):
- """Creates a tf.function that can be used to generate unittests.
+# We use global variables to store the configuration information for
+# tf_function_unit_tests because tensorflow.python.eager.def_function.Function
+# is not an API that we can subclass, and storing the information directly
+# that class results in it being deleted at tf.Module initialization.
+# _global_unit_test_configs is a dict mapping exported_names to dicts containing
+# a get-function for input args and the tolerance kwargs for the trace.
+global _global_unit_test_configs
+_global_unit_test_configs = dict()
+
+
+class UnitTestSpec:
+
+ def __init__(self,
+ unit_test_name: str,
+ input_signature: Sequence[tf.TensorSpec],
+ input_generator: tf_utils.InputGeneratorType = None,
+ input_args: Union[Sequence[Any], None] = None,
+ kwargs: Dict[str, Any] = None):
+ self.unit_test_name = tf_utils.remove_special_characters(unit_test_name)
+ self.input_signature = input_signature
+ self.input_args = input_args
+ self.kwargs = dict() if kwargs is None else kwargs
+ self.input_generator = input_generator
+
+ def with_name(self, new_name: str) -> "UnitTestSpec":
+ return UnitTestSpec(new_name, self.input_signature, self.input_generator,
+ self.input_args, self.kwargs)
+
+ def __str__(self):
+ return self.unit_test_name
+
+
+def _dictionary_product(dictionary: Dict[Any, Any]) -> List[Dict[Any, Any]]:
+ """Returns a named cartesian product of dictionary's values.
+
+ Converts {'a': [1, 2], 'b': [3, 4]} into
+ [{'a': 1, 'b': 3}, {'a': 1, 'b': 4}, {'a': 2, 'b': 3}, {'a': 2, 'b': 4}]
+ """
+ product = [[]]
+ for values in dictionary.values():
+ # Iteratively grow the elements of the product.
+ product = [element + [value] for element in product for value in values]
+ dicts = [{k: v for k, v in zip(dictionary, element)} for element in product]
+ return dicts
+
+
+def _named_kwargs_product(
+ kwargs_to_values: Dict[str, Sequence[Any]]) -> Dict[str, Dict[str, Any]]:
+ """Splits kwargs_to_values into a Cartesian product of its elements."""
+ # Validate 'kwargs_to_values'
+ if kwargs_to_values is None:
+ kwargs_to_values = dict() # Use only default kwargs.
+ for kwarg_key, kwarg_values in kwargs_to_values.items():
+ if not isinstance(kwarg_values, Sequence):
+ raise TypeError(f"Expected kwargs_to_values[{repr(kwarg_key)}] to be a "
+ f"sequence, but got '{type(kwarg_values)}'")
+
+ # Expand across a Cartesian product.
+ kwargs_product = _dictionary_product(kwargs_to_values)
+ # {'a': 1, 'b': 3} -> "a_1__b_3"
+ dict_to_str = lambda d: "__".join([f"{k}_{v}" for k, v in d.items()])
+ return {dict_to_str(kwargs): kwargs for kwargs in kwargs_product}
+
+
+def unit_test_specs_from_signatures(
+ signature_shapes: Sequence[Sequence[Sequence[int]]],
+ signature_dtypes: Sequence[tf.DType] = [tf.float32],
+ input_generators: Union[Sequence[tf_utils.InputGeneratorType],
+ Dict[str, tf_utils.InputGeneratorType]] = [
+ DEFAULT_INPUT_GENERATOR
+ ],
+ kwargs_to_values: Dict[str, Sequence[Any]] = None) -> List[UnitTestSpec]:
+ """Generates a Cartesian product of UnitTestSpecs from the given arguments.
+
+ Args:
+ signature_shapes:
+ A sequence (representing multiple signatures to test) of sequences
+ (representing the shapes of the args in those signatures) of ints
+ (representing the individual sizes of those shapes).
+ signature_dtypes:
+ A sequence of dtypes to test each signature with.
+ input_generators:
+ Either:
+ 1. a sequence of input generators to test each of the signature-dtype
+ pairs with
+ 2. a dictionary mapping input generator names to input generators to
+ test each of the signature-dtype pairs with. This format must be used
+ if any of the generators are lambda functions.
+ kwargs_to_values:
+ A dict mapping kwarg names to sequences of values that they can take.
+
+ Returns:
+ A list of 'UnitTestSpec's generated from the provided arguments.
+ """
+ # Validate 'signature_shapes'
+ for i, shapes in enumerate(signature_shapes):
+ if not isinstance(shapes, Sequence):
+ raise TypeError(f"Expected signature_shapes[{i}] to be a sequence, but "
+ f"got '{type(shapes)}'")
+ for j, shape in enumerate(shapes):
+ if not isinstance(shape, Sequence):
+ raise TypeError(f"Expected signature_shapes[{i}][{j}] to be a "
+ f"sequence, but got '{type(shape)}'")
+ for k, size in enumerate(shape):
+ if not isinstance(size, int):
+ raise TypeError(f"Expected signature_shapes[{i}][{j}][{k}] to be an "
+ f"int but got '{type(size)}")
+
+ # Parse 'signature_shapes'
+ names_to_shapes = dict()
+ for signature in signature_shapes:
+ # Converts [[1, 2, 3], [4, 5]] into 1x2x3_4x5.
+ signature_key = "_".join(
+ ["x".join(str(size) for size in shape) for shape in signature])
+ names_to_shapes[signature_key] = signature
+
+ # Validate 'signature_dtypes'
+ for i, dtype in enumerate(signature_dtypes):
+ if not isinstance(dtype, tf.DType):
+ raise TypeError(
+ f"Expected dtypes[{i}] to be a tf.DType, but got '{type(dtype)}'")
+
+ # Parse 'signature_dtypes'
+ # 'complex64' -> 'c64'
+ abbreviate = lambda dtype: re.sub(r"([a-z])[a-z]*([0-9]+)", r"\1\2", dtype)
+ names_to_dtypes = {
+ abbreviate(dtype.name): dtype for dtype in signature_dtypes
+ }
+
+ # Validate 'input_generators'
+ if not isinstance(input_generators, (Sequence, Dict)):
+ raise TypeError("Expected 'input_generators' to be a sequence or "
+ f"dictionary, but got '{type(input_generators)}'")
+ if isinstance(input_generators, Sequence):
+ for i, generator in enumerate(input_generators):
+ if generator.__name__ == "<lambda>":
+ raise TypeError(
+ f"'input_generators' was a sequence but input_generators[{i}] was "
+ "lambda function. 'input_generators' must be a dictionary if "
+ "lambda functions are used.")
+
+ # Parse 'input_generators'
+ if isinstance(input_generators, Sequence):
+ names_to_generators = {gen.__name__: gen for gen in input_generators}
+ else:
+ names_to_generators = input_generators
+
+ # Validate and parse 'kwargs_to_values'
+ names_to_kwargs = _named_kwargs_product(kwargs_to_values)
+
+ # Create a Cartesian product through all specifications and their names.
+ specs = [
+ names_to_shapes, names_to_dtypes, names_to_generators, names_to_kwargs
+ ]
+ # pytype: disable=attribute-error
+ key_product = itertools.product(*[list(spec.keys()) for spec in specs])
+ value_product = itertools.product(*[list(spec.values()) for spec in specs])
+ # pytype: enable=attribute-error
+
+ # Generate a UnitTestSpec for each element in the above product.
+ unit_tests = []
+ for keys, (shapes, dtype, generator, kwargs) in zip(key_product,
+ value_product):
+ unit_test_name = "__".join(key for key in keys if key)
+ input_signature = [tf.TensorSpec(shape, dtype) for shape in shapes]
+ unit_tests.append(
+ UnitTestSpec(
+ unit_test_name=unit_test_name,
+ input_signature=input_signature,
+ input_generator=generator,
+ input_args=None,
+ kwargs=kwargs,
+ ))
+ return unit_tests
+
+
+def unit_test_specs_from_args(
+ names_to_input_args: Dict[str, Sequence[Any]],
+ kwargs_to_values: Dict[str, Sequence[Any]] = None) -> List[UnitTestSpec]:
+ """Generates a Cartesian product of UnitTestSpecs from the given arguments.
+
+ Args:
+ signature_shapes:
+ A dict mapping names for input arguments to the arguments themselves.
+ kwargs_to_values:
+ A dict mapping kwarg names to sequences of values that they can take.
+
+ Returns:
+ A list of 'UnitTestSpec's generated from the provided arguments.
+ """
+ # Validate and parse 'kwargs_to_values'
+ names_to_kwargs = _named_kwargs_product(kwargs_to_values)
+
+ # Create a Cartesian product through all specifications and their names.
+ specs = [names_to_input_args, names_to_kwargs]
+ key_product = itertools.product(*[list(spec.keys()) for spec in specs])
+ value_product = itertools.product(*[list(spec.values()) for spec in specs])
+
+ # Generate a UnitTestSpec for each element in the above product.
+ unit_tests = []
+ for keys, (input_args, kwargs) in zip(key_product, value_product):
+ unit_test_name = "__".join(key for key in keys if key)
+ input_signature = tf_utils.apply_function(
+ input_args,
+ lambda x: tf.TensorSpec.from_tensor(tf.convert_to_tensor(x)))
+ unit_tests.append(
+ UnitTestSpec(
+ unit_test_name=unit_test_name,
+ input_signature=input_signature,
+ input_generator=None,
+ input_args=input_args,
+ kwargs=kwargs,
+ ))
+ return unit_tests
+
+
+def tf_function_unit_test(input_generator: tf_utils.InputGeneratorType = None,
+ input_args: Sequence[Any] = None,
+ atol: float = None,
+ rtol: float = None,
+ name: str = None,
+ **tf_function_kwargs):
+ """Creates a tf.function that can be used to generate unit_tests.
If 'input_generator' and 'input_args' are unspecified then the function will
be tested using random uniform data.
@@ -707,7 +925,7 @@
Args:
input_generator:
an optional callable taking a shape and dtype that returns input data for
- the unittest.
+ the unit_test.
input_args:
an optional sequence of values to pass as positional args to the function.
atol:
@@ -729,7 +947,7 @@
__name__ attribute if 'name' was specified.
"""
- def _store_unittest_info(function):
+ def _store_unit_test_info(function):
# Validate arguments.
if input_generator is not None and input_args is not None:
raise ValueError(
@@ -737,58 +955,54 @@
function = tf.function(**tf_function_kwargs)(function)
- # Used to identify that the tf.function was created by this decorator.
- function.is_tf_function_unittest = True
-
- # Set function.get_trace_args.
- if input_generator is not None:
- # Use the user-specificed input_generator.
- function.get_trace_args = lambda: tf_utils.generate_inputs(
- function.input_signature, input_generator)
- elif input_args is not None:
- # Use the user-specified input_args.
- function.get_trace_args = lambda: copy.deepcopy(input_args)
- else:
- # No user data specification – default to using random uniform data.
- function.get_trace_args = lambda: tf_utils.generate_inputs(
- function.input_signature, tf_utils.uniform)
-
- # Set function.trace_kwargs.
- function.trace_kwargs = dict(atol=atol, rtol=rtol)
-
- # Set function.__name__.
+ # Set function.__name__
if name is not None:
function.__name__ = name
elif function.__name__ == "<lambda>":
raise ValueError("The 'name' kwarg must be provided when decorating a "
"lambda function.")
+ global _global_unit_test_configs
+ if function.__name__ not in _global_unit_test_configs:
+
+ if input_generator is not None:
+ # Use the user-specificed input_generator.
+ get_trace_args = lambda: tf_utils.generate_inputs(
+ function.input_signature, input_generator)
+ elif input_args is not None:
+ # Use the user-specified input_args.
+ get_trace_args = lambda: copy.deepcopy(input_args)
+ else:
+ # No user data specification – default to using random uniform data.
+ get_trace_args = lambda: tf_utils.generate_inputs(
+ function.input_signature, DEFAULT_INPUT_GENERATOR)
+
+ _global_unit_test_configs[function.__name__] = dict(
+ get_trace_args=get_trace_args,
+ trace_kwargs=dict(atol=atol, rtol=rtol))
+
return function
- return _store_unittest_info
+ return _store_unit_test_info
class TestModule(tf.Module):
- """Thin wrapper of tf.Module with helper methods for tf_function_unittests."""
+ """Thin tf.Module wrapper with helper methods for tf_function_unit_tests."""
@classmethod
- def get_tf_function_unittests(cls):
- """Get all tf_function_unittest-created tf.functions on the class."""
- tf_function_unittests = []
- for name in dir(cls):
- value = getattr(cls, name)
- if hasattr(value, 'is_tf_function_unittest'):
- tf_function_unittests.append(value)
+ def get_tf_function_unit_tests(cls):
+ """Get all tf_function_unit_test-created tf.functions on the class."""
+ # Initialize the module to ensure that _global_unit_test_configs has the
+ # info for all of the unit_tests. (Only doing this if
+ # _global_unit_test_configs is empty wouldn't address the case where some
+ # unit_tests are defined on the class and some are generated by __init__).
+ cls()
- if not len(tf_function_unittests):
+ tf_function_unit_tests = list(_global_unit_test_configs.keys())
+ if not len(tf_function_unit_tests):
raise ValueError(
- "'get_tf_function_unittests' was called but no unittests were found.")
- return tf_function_unittests
-
- @classmethod
- def get_exported_names(cls):
- """Get the names of all tf_function_unittest-created tf.functions"""
- return [function.__name__ for function in cls.get_tf_function_unittests()]
+ "'get_tf_function_unit_tests' was called but no tests were found.")
+ return tf_function_unit_tests
class TracedModuleTestCase(tf.test.TestCase):
@@ -802,35 +1016,38 @@
module.reinitialize()
@classmethod
- def generate_unittests(cls, module_class: Type[TestModule]):
- """Generates unittests for each 'tf_function_unittest' on 'module_class'."""
- for function in module_class.get_tf_function_unittests():
- # We have to pass the closure argument 'funcion' to 'trace' via a kwarg
- # instead of using it directly in the body because 'function' is
- # overwritten in each iteration of this loop, and python will only use
- # the most recent version of 'function'. If we didn't do this, then we
- # would only test the last function in this loop. The same is true for
- # passing 'trace' to 'unittest'.
+ def generate_unit_tests(cls, module_class: Type[TestModule]):
+ """Generates tests for each 'tf_function_unit_test' on 'module_class'."""
+ for function_name in module_class.get_tf_function_unit_tests():
+ # We have to pass the closure arguments 'function_name', 'get_args' and
+ # 'kwargs' to 'trace' via a kwarg instead of using it directly in the body
+ # because 'function_name' and 'unit_test_config' are overwritten in each
+ # iteration of this loop, and python will only use the most recent version
+ # of each. If we didn't do this, then we would only test the last function
+ # in this loop. The same is true for passing 'trace' to 'unit_test'.
+ unit_test_config = _global_unit_test_configs[function_name]
# Runs the inputs through a (traced) module.
- def trace(module, function=function):
- getattr(module, function.__name__)(*function.get_trace_args(),
- **function.trace_kwargs)
+ def trace(module,
+ function_name=function_name,
+ get_args=unit_test_config["get_trace_args"],
+ kwargs=unit_test_config["trace_kwargs"]):
+ getattr(module, function_name)(*get_args(), **kwargs)
# Give the trace the name of the tf.function that it is testing.
- trace.__name__ = function.__name__
+ trace.__name__ = function_name
# Runs 'trace' on modules compiled to each backend and compares them.
- def unittest(self, trace=trace):
+ def unit_test(self, trace=trace):
self.compare_backends(trace, self._modules)
- # Make 'unittest' a function on the TracedModuleTestCase, which tells
+ # Make 'unit_test' a function on the TracedModuleTestCase, which tells
# the test runner to run it.
- unittest.__name__ = f"test_{function.__name__}"
- if hasattr(cls, unittest.__name__):
- raise ValueError("Tried to generate multiple instances of the unittest "
- f"'{unittest.__name__}'.")
- setattr(cls, unittest.__name__, unittest)
+ unit_test.__name__ = f"test_{function_name}"
+ if hasattr(cls, unit_test.__name__):
+ raise ValueError("Tried to generate multiple instances of the "
+ f"unit_test '{unit_test.__name__}'.")
+ setattr(cls, unit_test.__name__, unit_test)
def compare_backends(self, trace_function: Callable[[TracedModule], None],
modules: Modules) -> None:
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils_test.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils_test.py
index 63cfa2f..6069688 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils_test.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils_test.py
@@ -54,18 +54,18 @@
class TfFunctionUnittestModule(tf_test_utils.TestModule):
- @tf_test_utils.tf_function_unittest(input_signature=[])
+ @tf_test_utils.tf_function_unit_test(input_signature=[])
def no_args(self):
return np.array([True], dtype=np.bool)
- @tf_test_utils.tf_function_unittest(input_signature=[
+ @tf_test_utils.tf_function_unit_test(input_signature=[
tf.TensorSpec([4]),
tf.TensorSpec([4]),
])
def default_uniform_inputs(self, a, b):
return a + b
- @tf_test_utils.tf_function_unittest(
+ @tf_test_utils.tf_function_unit_test(
input_signature=[
tf.TensorSpec([4]),
tf.TensorSpec([4]),
@@ -75,7 +75,7 @@
def custom_input_generator(self, a, b):
return a + b
- @tf_test_utils.tf_function_unittest(
+ @tf_test_utils.tf_function_unit_test(
input_signature=[
tf.TensorSpec([4]),
tf.TensorSpec([4]),
@@ -89,7 +89,7 @@
return a + b
# This test will fail if atol is not successfully set.
- @tf_test_utils.tf_function_unittest(
+ @tf_test_utils.tf_function_unit_test(
input_signature=[
tf.TensorSpec([128, 3072], tf.float32),
tf.TensorSpec([3072, 256], tf.float32),
@@ -262,7 +262,7 @@
self._modules = tf_test_utils.compile_tf_module(
TfFunctionUnittestModule)
- TfFunctionUnittestTest.generate_unittests(TfFunctionUnittestModule)
+ TfFunctionUnittestTest.generate_unit_tests(TfFunctionUnittestModule)
test_case = TfFunctionUnittestTest()
self.assertTrue(hasattr(test_case, 'test_no_args'))
self.assertTrue(hasattr(test_case, 'test_default_uniform_inputs'))
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py
index 591df43..aa94744 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py
@@ -45,9 +45,15 @@
dtype: Union[tf.DType, np.dtype] = np.float32,
low: float = -1.0,
high: float = 1.0) -> np.ndarray:
- """np.random.uniform with simplified API and dtype control."""
+ """np.random.uniform with simplified API and dtype and bool support."""
dtype = dtype.as_numpy_dtype if isinstance(dtype, tf.DType) else dtype
- return np.random.uniform(size=shape, low=low, high=high).astype(dtype)
+ if dtype == np.bool:
+ return np.random.choice(2, shape).astype(np.bool)
+ else:
+ values = np.random.uniform(size=shape, low=low, high=high)
+ if np.issubdtype(dtype, np.integer):
+ values = np.round(values)
+ return values.astype(dtype)
def ndarange(shape: Sequence[int],
@@ -57,22 +63,35 @@
return np.arange(np.prod(shape), dtype=dtype).reshape(shape)
+def random_permutation(
+ shape: Sequence[int],
+ dtype: Union[tf.DType, np.dtype] = np.float32) -> np.ndarray:
+ """Returns a random permutation of [0, np.prod(shape))."""
+ values = ndarange(shape, dtype)
+ np.random.shuffle(values)
+ return values
+
+
+def apply_function(values, function):
+ """Applies 'function' recursively to the inputted values."""
+ if isinstance(values, list):
+ return [apply_function(v, function) for v in values]
+ elif isinstance(values, tuple):
+ return tuple(apply_function(v, function) for v in values)
+ elif isinstance(values, dict):
+ return {k: apply_function(v, function) for k, v in values.items()}
+ else:
+ return function(values)
+
+
def generate_inputs(
spec, # Union[Sequence[tf.TensorSpec], tf.TensorSpec]
input_generator: InputGeneratorType,
) -> Sequence[np.ndarray]:
"""Generates inputs for a given input signature using 'input_generator'."""
- if isinstance(spec, Sequence):
- # 'spec' is a sequence of 'tf.TensorSpec'.
- # Recursively generate inputs.
- return [generate_inputs(s, input_generator) for s in spec]
- elif isinstance(spec, tf.TensorSpec):
- # Handle dynamic shapes (e.g. batches) by substituting an int for None.
- shape = [size if size is not None else 2 for size in spec.shape]
- return input_generator(shape, spec.dtype)
- else:
- raise TypeError("Expected 'spec' to be a sequence of 'tf.TensorSpec' or "
- f"'tf.TensorSpec', but got '{type(spec)}'")
+ make_static = lambda shape: [dim if dim is not None else 2 for dim in shape]
+ generate = lambda spec: input_generator(make_static(spec.shape), spec.dtype)
+ return apply_function(spec, generate)
def to_mlir_type(dtype: np.dtype) -> str:
@@ -116,6 +135,63 @@
return result
+def remove_special_characters(value: str) -> str:
+ """Replaces special characters with '_' while keeping instances of '__'."""
+ normalized_parts = []
+ for part in value.split("__"):
+ part = re.sub(r"[^a-zA-Z0-9_]", "_", part) # Remove special characters.
+ part = re.sub(r"_+", "_", part) # Remove duplicate "_".
+ part = part.strip("_") # Don't end or start in "_".
+ normalized_parts.append(part)
+ return "__".join(normalized_parts)
+
+
+def is_complex(tensors: Union[Sequence[tf.TensorSpec], tf.TensorSpec]) -> bool:
+ if isinstance(tensors, Sequence):
+ for tensor in tensors:
+ if is_complex(tensor):
+ return True
+ return False
+ else:
+ return tensors.dtype.is_complex # pytype: disable=attribute-error
+
+
+def _complex_wrapper(function):
+ """Wraps a tf.function to allow compiling functions of complex numbers."""
+
+ def decorator(*args, **kwargs):
+ inputs = []
+ for real, imag in zip(args[::2], args[1::2]):
+ inputs.append(tf.complex(real, imag))
+ result = function(*inputs, **kwargs)
+ # TODO(meadowlark): Support returning complex numbers.
+ return tf.math.real(result) + tf.math.imag(result)
+
+ return decorator
+
+
+def rewrite_complex_signature(function, signature: Sequence[tf.TensorSpec]):
+ """Compatibility layer for testing complex numbers."""
+ if not all([spec.dtype.is_complex for spec in signature]):
+ raise NotImplementedError("Signatures with mixed complex and non-complex "
+ "tensor specs are not supported.")
+
+ # Rewrite the signature, replacing all complex tensors with pairs of real
+ # and imaginary tensors.
+ real_imag_signature = []
+ for spec in signature:
+ new_dtype = tf.float32 if spec.dtype.size == 8 else tf.float64
+ real_imag_signature.append(tf.TensorSpec(spec.shape, new_dtype))
+ real_imag_signature.append(tf.TensorSpec(spec.shape, new_dtype))
+
+ return _complex_wrapper(function), real_imag_signature
+
+
+def make_dims_dynamic(spec: tf.TensorSpec) -> tf.TensorSpec:
+ """Gives a tf.TensorSpec dynamic dims."""
+ return tf.TensorSpec([None] * len(spec.shape), spec.dtype)
+
+
def _setup_mlir_crash_reproducer(
function: Any, # pytype doesn't support arbitrary Callable[*args, **kwargs]
artifacts_dir: str,
@@ -557,28 +633,27 @@
return result
+def _convert_to_numpy(tensor: Any) -> Any:
+ if not isinstance(tensor, tf.Tensor):
+ return tensor
+ return _normalize_numpy(tensor.numpy())
+
+
+def convert_to_numpy(values: Any) -> Any:
+ """Converts any tf.Tensor in values to numpy."""
+ return apply_function(values, _convert_to_numpy)
+
+
class _TfFunctionWrapper(_FunctionWrapper):
"""Wraps a TF function, normalizing it to numpy."""
def __init__(self, f: Callable[..., Any]):
self._f = f
- def _convert_to_numpy(self, tensor: Any) -> Any:
- if not isinstance(tensor, tf.Tensor):
- return tensor
- return _normalize_numpy(tensor.numpy())
-
def __call__(self, *args, **kwargs):
# TensorFlow will auto-convert all inbound args.
results = self._f(*args, **kwargs)
- # Then unmarshal them to numpy in the same way that the other backends do.
- # Handle single result (technically ambiguous with return of a tuple,
- # which is sad).
- if not isinstance(results, tuple):
- results = (results,)
- return tf.nest.map_structure(self._convert_to_numpy,
- *results,
- check_types=False)
+ return convert_to_numpy(results)
def _convert_inputs_to_tensors(function):
@@ -738,9 +813,22 @@
tflite_modules = []
methods, method_names, instance = _get_concrete_functions(
module_class, exported_names)
- for method in methods:
- converter = tf.lite.TFLiteConverter.from_concrete_functions([method])
- tflite_modules.append(converter.convert())
+ failed_methods = []
+ for method, method_name in zip(methods, method_names):
+ logging.info("Attempting to convert '%s' to tflite...", method_name)
+ try:
+ converter = tf.lite.TFLiteConverter.from_concrete_functions([method])
+ logging.info("...converted '%s' to tflite.", method_name)
+ tflite_modules.append(converter.convert())
+ except Exception as e:
+ logging.error("Failed to convert '%s' to tflite.", method_name)
+ logging.error("TFLite excpetion: %s", e)
+ failed_methods.append(method_name)
+
+ if failed_methods:
+ raise RuntimeError(
+ f"Failed to convert the following methods to tflite: {failed_methods}")
+
# Keep variables alive until TFLite has done the conversion; ConcreteFunctions
# themselves only keep weak references to variables.
del instance
diff --git a/integrations/tensorflow/e2e/BUILD b/integrations/tensorflow/e2e/BUILD
index 9dfc5d4..2cdaaf2 100644
--- a/integrations/tensorflow/e2e/BUILD
+++ b/integrations/tensorflow/e2e/BUILD
@@ -59,7 +59,6 @@
# keep sorted
TFLITE_FAILING = [
"broadcasting_test.py",
- "complex_test.py",
"concat_test.py",
"dynamic_mlp_relu_test.py",
"dynamic_mlp_test.py",
@@ -67,12 +66,11 @@
"einsum_static_test.py",
"einsum_vector_test.py",
"fft_test.py",
- "finite_test.py",
"gather_test.py",
"image_resize_test.py",
"mandelbrot_test.py",
- "math_dyn_test.py",
"matrix_ops_dynamic_test.py",
+ "quantization_dyn_test.py",
"resource_ops_test.py",
"ring_buffer_test.py",
"scatter_update_test.py",
@@ -105,10 +103,9 @@
"fft_test.py", # TODO(natashaknk): Get this working after kernel is in.
"fill_test.py", # TODO(jennik): Get this test working on IREE.
"linspace_test.py", # TODO(https://github.com/google/iree/issues/1521)
- "logical_ops_test.py",
"mandelbrot_test.py", # TODO(silvasean): Get this working on IREE.
- "math_dyn_test.py",
"matrix_ops_dynamic_test.py",
+ "quantization_dyn_test.py",
"range_test.py",
"ring_buffer_test.py", # TODO(b/148747011)
"scatter_update_test.py",
@@ -117,7 +114,6 @@
# keep sorted
VULKAN_FAILING = [
- "bool_test.py",
"broadcast_to_test.py",
"broadcasting_test.py",
"conv_transpose_test.py",
@@ -129,10 +125,9 @@
"fft_test.py", # TODO(natashaknk): Get this working after kernel is in.
"fill_test.py", # TODO(jennik): Get this test working on IREE.
"linspace_test.py", # TODO(https://github.com/google/iree/issues/1521)
- "logical_ops_test.py",
"mandelbrot_test.py", # TODO(silvasean): Get this working on IREE.
- "math_dyn_test.py",
"matrix_ops_dynamic_test.py",
+ "quantization_dyn_test.py",
"range_test.py",
"ring_buffer_test.py", # TODO(b/148747011)
"scatter_update_test.py",
diff --git a/integrations/tensorflow/e2e/README.md b/integrations/tensorflow/e2e/README.md
index 4b457b5..e806496 100644
--- a/integrations/tensorflow/e2e/README.md
+++ b/integrations/tensorflow/e2e/README.md
@@ -52,17 +52,17 @@
preferred.
```shell
-# Run math_test on all backends.
-bazel run //integrations/tensorflow/e2e:math_test_manual
+# Run conv_test on all backends.
+bazel run //integrations/tensorflow/e2e:conv_test_manual
-# Run math_test comparing TensorFlow to itself (e.g. to debug randomization).
-bazel run //integrations/tensorflow/e2e:math_test_manual -- --target_backends=tf
+# Run conv_test comparing TensorFlow to itself (e.g. to debug randomization).
+bazel run //integrations/tensorflow/e2e:conv_test_manual -- --target_backends=tf
-# Run math_test comparing the VMLA backend and TensorFlow.
-bazel run //integrations/tensorflow/e2e:math_test_manual -- --target_backends=iree_vmla
+# Run conv_test comparing the VMLA backend and TensorFlow.
+bazel run //integrations/tensorflow/e2e:conv_test_manual -- --target_backends=iree_vmla
-# Run math_test comparing the VMLA backend to itself multiple times.
-bazel run //integrations/tensorflow/e2e:math_test_manual -- \
+# Run conv_test comparing the VMLA backend to itself multiple times.
+bazel run //integrations/tensorflow/e2e:conv_test_manual -- \
--reference_backend=iree_vmla --target_backends=iree_vmla,iree_vmla
```
@@ -72,10 +72,10 @@
## Writing Tests
-There are two ways to write tests – via `tf_test_utils.tf_function_unittest` and
+There are two ways to write tests – via `tf_test_utils.tf_function_unit_test` and
via test methods on a child of `tf_test_utils.TracedModuleTestCase`.
-### Via `tf_test_utils.tf_function_unittest`
+### Via `tf_test_utils.tf_function_unit_test`
This is preferred in the cases where
@@ -86,7 +86,7 @@
Tests are specified by writing modules that inherit from
`tf_test_utils.TestModule` (which is a thin wrapper around `tf.Module`) with
-methods decorated with `@tf_test_utils.tf_function_unittest` (with is a thin
+methods decorated with `@tf_test_utils.tf_function_unit_test` (with is a thin
wrapper around `tf.function`).
#### Basic example
@@ -101,14 +101,14 @@
# function. The 'input_signature' is required. If no other arguments are
# specified then uniform random data is generated from the input signature
# to numerically test the function.
- @tf_test_utils.tf_function_unittest(input_signature=[
+ @tf_test_utils.tf_function_unit_test(input_signature=[
tf.TensorSpec([1, 4, 5, 1], tf.float32),
tf.TensorSpec([1, 1, 1, 1], tf.float32),
])
def conv2d_1451x1111_valid(self, img, kernel):
return tf.nn.conv2d(img, kernel, [1, 1, 1, 1], "VALID", name="result")
- @tf_test_utils.tf_function_unittest(input_signature=[
+ @tf_test_utils.tf_function_unit_test(input_signature=[
tf.TensorSpec([2, 4, 5, 1], tf.float32),
tf.TensorSpec([1, 1, 1, 1], tf.float32),
])
@@ -130,7 +130,7 @@
```
Finally, in the `main` function, you need to call
-`.generate_unittests(module_class)` on your `TestCase` to actually generate
+`.generate_unit_tests(module_class)` on your `TestCase` to actually generate
the unittests that we specified:
```python
@@ -138,12 +138,12 @@
del argv # Unused
if hasattr(tf, 'enable_v2_behavior'):
tf.enable_v2_behavior()
- # Generates unittests for all @tf_test_utils.tf_function_unittest decorated
+ # Generates unittests for all @tf_test_utils.tf_function_unit_test decorated
# functions on the module class.
# Note: if you are automatically generating functions to test they need to be
# specified via a `classmethod` prior to this call _as well_ as via `__init__`
# to properly handle stateful `tf.function`s.
- ConvTest.generate_unittests(Conv2dModule)
+ ConvTest.generate_unit_tests(Conv2dModule)
tf.test.main()
@@ -154,9 +154,9 @@
This generates two unittests: `test_conv2d_1451x1111_valid` and
`test_conv2d_2451x1111_valid`.
-#### Configuring `@tf_test_utils.tf_function_unittest`
+#### Configuring `@tf_test_utils.tf_function_unit_test`
-By default `@tf_test_utils.tf_function_unittest` uses uniform random input data
+By default `@tf_test_utils.tf_function_unit_test` uses uniform random input data
to numerically test the function, but you can specify an `input_generator` or
`input_args` to test data-specific behaviors:
diff --git a/integrations/tensorflow/e2e/bool_test.py b/integrations/tensorflow/e2e/bool_test.py
deleted file mode 100644
index df09ecb..0000000
--- a/integrations/tensorflow/e2e/bool_test.py
+++ /dev/null
@@ -1,65 +0,0 @@
-# Lint as: python3
-# Copyright 2020 Google LLC
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# https://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""Tests for ops in the tf.math module."""
-
-from absl import app
-import numpy as np
-from pyiree.tf.support import tf_test_utils
-import tensorflow.compat.v2 as tf
-
-
-class BooleanModule(tf_test_utils.TestModule):
-
- @tf_test_utils.tf_function_unittest(input_signature=[])
- def constant(self):
- return np.array([True, False, True], dtype=np.bool)
-
- @tf_test_utils.tf_function_unittest(
- input_signature=[tf.TensorSpec([4], tf.float32)],
- input_args=[np.array([0.0, 1.2, 1.5, 3.75], dtype=np.float32)])
- def greater_than(self, x):
- return x > 1.0
-
- @tf_test_utils.tf_function_unittest(
- input_signature=[
- tf.TensorSpec([4], tf.bool),
- tf.TensorSpec([4], tf.bool)
- ],
- input_args=[
- np.array([True, True, False, False], dtype=np.bool),
- np.array([True, False, False, True], dtype=np.bool)
- ],
- )
- def logical_and(self, x, y):
- return tf.math.logical_and(x, y)
-
-
-class BooleanTest(tf_test_utils.TracedModuleTestCase):
-
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self._modules = tf_test_utils.compile_tf_module(BooleanModule)
-
-
-def main(argv):
- del argv # Unused
- if hasattr(tf, 'enable_v2_behavior'):
- tf.enable_v2_behavior()
- BooleanTest.generate_unittests(BooleanModule)
- tf.test.main()
-
-
-if __name__ == '__main__':
- app.run(main)
diff --git a/integrations/tensorflow/e2e/complex_test.py b/integrations/tensorflow/e2e/complex_test.py
deleted file mode 100644
index e60f532..0000000
--- a/integrations/tensorflow/e2e/complex_test.py
+++ /dev/null
@@ -1,60 +0,0 @@
-# Copyright 2020 Google LLC
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# https://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from absl import app
-import numpy as np
-from pyiree.tf.support import tf_test_utils
-import tensorflow.compat.v2 as tf
-
-
-class ComplexModule(tf.Module):
-
- def __init__(self):
- pass
-
- @tf.function(input_signature=[
- tf.TensorSpec([2], tf.float32),
- tf.TensorSpec([2], tf.float32)
- ])
- def complex_exp(self, real, imag):
- tensor = tf.complex(real, imag)
- exp = tf.exp(tensor)
- return tf.math.real(exp)
-
-
-class ComplexTest(tf_test_utils.TracedModuleTestCase):
-
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self._modules = tf_test_utils.compile_tf_module(ComplexModule)
-
- def test_complex(self):
-
- def complex_exp(module):
- real = np.array([2., 3.], dtype=np.float32)
- imag = np.array([-1., 0.4], dtype=np.float32)
- module.complex_exp(real, imag)
-
- self.compare_backends(complex_exp, self._modules)
-
-
-def main(argv):
- del argv # Unused
- if hasattr(tf, 'enable_v2_behavior'):
- tf.enable_v2_behavior()
- tf.test.main()
-
-
-if __name__ == '__main__':
- app.run(main)
diff --git a/integrations/tensorflow/e2e/conv_test.py b/integrations/tensorflow/e2e/conv_test.py
index 11ca8d6..ca0076e 100644
--- a/integrations/tensorflow/e2e/conv_test.py
+++ b/integrations/tensorflow/e2e/conv_test.py
@@ -22,14 +22,14 @@
class Conv2dModule(tf_test_utils.TestModule):
- @tf_test_utils.tf_function_unittest(input_signature=[
+ @tf_test_utils.tf_function_unit_test(input_signature=[
tf.TensorSpec([1, 4, 5, 1], tf.float32),
tf.TensorSpec([1, 1, 1, 1], tf.float32),
])
def conv2d_1451x1111_valid(self, img, kernel):
return tf.nn.conv2d(img, kernel, [1, 1, 1, 1], "VALID", name="result")
- @tf_test_utils.tf_function_unittest(input_signature=[
+ @tf_test_utils.tf_function_unit_test(input_signature=[
tf.TensorSpec([1, 4, 5, 1], tf.float32),
tf.TensorSpec([2, 2, 1, 1], tf.float32),
])
@@ -40,7 +40,7 @@
dilations=[1, 2, 1, 1],
name="result")
- @tf_test_utils.tf_function_unittest(input_signature=[
+ @tf_test_utils.tf_function_unit_test(input_signature=[
tf.TensorSpec([1, 4, 5, 2], tf.float32),
tf.TensorSpec([2, 2, 2, 3], tf.float32),
])
@@ -51,70 +51,70 @@
dilations=[1, 2, 1, 1],
name="result")
- @tf_test_utils.tf_function_unittest(input_signature=[
+ @tf_test_utils.tf_function_unit_test(input_signature=[
tf.TensorSpec([2, 4, 5, 1], tf.float32),
tf.TensorSpec([1, 1, 1, 1], tf.float32),
])
def conv2d_2451x1111_valid(self, img, kernel):
return tf.nn.conv2d(img, kernel, [1, 1, 1, 1], "VALID", name="result")
- @tf_test_utils.tf_function_unittest(input_signature=[
+ @tf_test_utils.tf_function_unit_test(input_signature=[
tf.TensorSpec([1, 4, 5, 1], tf.float32),
tf.TensorSpec([2, 3, 1, 1], tf.float32),
])
def conv2d_1451x2311_valid(self, img, kernel):
return tf.nn.conv2d(img, kernel, [1, 1, 1, 1], "VALID", name="result")
- @tf_test_utils.tf_function_unittest(input_signature=[
+ @tf_test_utils.tf_function_unit_test(input_signature=[
tf.TensorSpec([1, 4, 5, 1], tf.float32),
tf.TensorSpec([2, 3, 1, 1], tf.float32),
])
def conv2d_1451x2311_same(self, img, kernel):
return tf.nn.conv2d(img, kernel, [1, 1, 1, 1], "SAME", name="result")
- @tf_test_utils.tf_function_unittest(input_signature=[
+ @tf_test_utils.tf_function_unit_test(input_signature=[
tf.TensorSpec([2, 4, 5, 1], tf.float32),
tf.TensorSpec([2, 3, 1, 1], tf.float32),
])
def conv2d_2451x2311_same(self, img, kernel):
return tf.nn.conv2d(img, kernel, [1, 1, 1, 1], "SAME", name="result")
- @tf_test_utils.tf_function_unittest(input_signature=[
+ @tf_test_utils.tf_function_unit_test(input_signature=[
tf.TensorSpec([1, 4, 5, 2], tf.float32),
tf.TensorSpec([3, 2, 2, 1], tf.float32),
])
def conv2d_1452x3221_same(self, img, kernel):
return tf.nn.conv2d(img, kernel, [1, 1, 1, 1], "SAME", name="result")
- @tf_test_utils.tf_function_unittest(input_signature=[
+ @tf_test_utils.tf_function_unit_test(input_signature=[
tf.TensorSpec([1, 4, 5, 1], tf.float32),
tf.TensorSpec([1, 1, 1, 2], tf.float32),
])
def conv2d_1451x1112_same(self, img, kernel):
return tf.nn.conv2d(img, kernel, [1, 1, 1, 1], "SAME", name="result")
- @tf_test_utils.tf_function_unittest(input_signature=[
+ @tf_test_utils.tf_function_unit_test(input_signature=[
tf.TensorSpec([1, 4, 5, 2], tf.float32),
tf.TensorSpec([1, 1, 2, 2], tf.float32),
])
def conv2d_1452x1122_same(self, img, kernel):
return tf.nn.conv2d(img, kernel, [1, 1, 1, 1], "SAME", name="result")
- @tf_test_utils.tf_function_unittest(input_signature=[
+ @tf_test_utils.tf_function_unit_test(input_signature=[
tf.TensorSpec([1, 4, 5, 2], tf.float32),
tf.TensorSpec([2, 2, 2, 3], tf.float32),
])
def conv2d_1452x2223_same(self, img, kernel):
return tf.nn.conv2d(img, kernel, [1, 1, 1, 1], "SAME", name="result")
- @tf_test_utils.tf_function_unittest(input_signature=[
+ @tf_test_utils.tf_function_unit_test(input_signature=[
tf.TensorSpec([1, 4, 5, 2], tf.float32),
tf.TensorSpec([2, 2, 2, 3], tf.float32),
])
def conv2d_1452x2223_valid(self, img, kernel):
return tf.nn.conv2d(img, kernel, [1, 1, 1, 1], "VALID", name="result")
- @tf_test_utils.tf_function_unittest(input_signature=[
+ @tf_test_utils.tf_function_unit_test(input_signature=[
tf.TensorSpec([2, 4, 5, 2], tf.float32),
tf.TensorSpec([2, 2, 2, 3], tf.float32),
])
@@ -133,7 +133,7 @@
del argv # Unused
if hasattr(tf, 'enable_v2_behavior'):
tf.enable_v2_behavior()
- ConvTest.generate_unittests(Conv2dModule)
+ ConvTest.generate_unit_tests(Conv2dModule)
tf.test.main()
diff --git a/integrations/tensorflow/e2e/finite_test.py b/integrations/tensorflow/e2e/finite_test.py
deleted file mode 100644
index 761cae4..0000000
--- a/integrations/tensorflow/e2e/finite_test.py
+++ /dev/null
@@ -1,50 +0,0 @@
-# Copyright 2020 Google LLC
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# https://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from absl import app
-import numpy as np
-from pyiree.tf.support import tf_test_utils
-import tensorflow.compat.v2 as tf
-
-
-class FiniteModule(tf.Module):
-
- @tf.function(input_signature=[tf.TensorSpec([4], tf.float32)])
- def finite(self, x):
- return tf.math.is_finite(x)
-
-
-class FiniteTest(tf_test_utils.TracedModuleTestCase):
-
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self._modules = tf_test_utils.compile_tf_module(FiniteModule)
-
- def test_finite(self):
-
- def finite(module):
- module.finite(np.array([0.0, 1.2, -5.0, np.inf], dtype=np.float32))
-
- self.compare_backends(finite, self._modules)
-
-
-def main(argv):
- del argv # Unused
- if hasattr(tf, 'enable_v2_behavior'):
- tf.enable_v2_behavior()
- tf.test.main()
-
-
-if __name__ == '__main__':
- app.run(main)
diff --git a/integrations/tensorflow/e2e/iree_e2e_test_suite.bzl b/integrations/tensorflow/e2e/iree_e2e_test_suite.bzl
index 58be0f6..661cc30 100644
--- a/integrations/tensorflow/e2e/iree_e2e_test_suite.bzl
+++ b/integrations/tensorflow/e2e/iree_e2e_test_suite.bzl
@@ -25,6 +25,13 @@
driver = ""
return driver
+def set_difference(include, exclude):
+ return [
+ value
+ for value in include
+ if value not in exclude
+ ]
+
def iree_e2e_test_suite(
name,
backends_to_srcs,
diff --git a/integrations/tensorflow/e2e/keras/layers/BUILD b/integrations/tensorflow/e2e/keras/layers/BUILD
index 148c808..75dbeda 100644
--- a/integrations/tensorflow/e2e/keras/layers/BUILD
+++ b/integrations/tensorflow/e2e/keras/layers/BUILD
@@ -287,7 +287,7 @@
"LSTM",
"MaxPool1D",
"MaxPool3D",
- "SeparableConv1D",
+ "SeparableConv1D", # Failing on Kokoro.
"SimpleRNN",
],
"target_backends": "tflite",
diff --git a/integrations/tensorflow/e2e/keras/layers/layers_test.py b/integrations/tensorflow/e2e/keras/layers/layers_test.py
index 0ab5f32..4755450 100644
--- a/integrations/tensorflow/e2e/keras/layers/layers_test.py
+++ b/integrations/tensorflow/e2e/keras/layers/layers_test.py
@@ -567,8 +567,8 @@
return tf.keras.Model(inputs, outputs)
-def create_tf_function_unittest(config: Config, exported_name: str,
- model: tf.keras.Model) -> tf.function:
+def create_tf_function_unit_test(config: Config, exported_name: str,
+ model: tf.keras.Model) -> tf.function:
"""Wrap the model's __call__ function in a tf.function for testing."""
input_shapes = config.shapes
if FLAGS.dynamic_batch:
@@ -579,19 +579,19 @@
input_signature = [input_signature]
call = lambda *args: model(keras_arg_wrapper(*args), training=FLAGS.training)
- return tf_test_utils.tf_function_unittest(input_signature=input_signature,
- name=exported_name)(call)
+ return tf_test_utils.tf_function_unit_test(input_signature=input_signature,
+ name=exported_name)(call)
class KerasLayersModule(tf_test_utils.TestModule):
@classmethod
def configure_class(cls):
- """Configure each tf_function_unittest and define it on the cls."""
+ """Configure each tf_function_unit_test and define it on the cls."""
for i, (exported_name, config) in enumerate(get_configs().items()):
model = create_wrapped_keras_layer(config)
setattr(cls, exported_name,
- create_tf_function_unittest(config, exported_name, model))
+ create_tf_function_unit_test(config, exported_name, model))
def __init__(self):
super().__init__()
@@ -600,7 +600,7 @@
model = create_wrapped_keras_layer(config)
self.models.append(model)
setattr(self, exported_name,
- create_tf_function_unittest(config, exported_name, model))
+ create_tf_function_unit_test(config, exported_name, model))
class KerasLayersTest(tf_test_utils.TracedModuleTestCase):
@@ -609,7 +609,7 @@
super().__init__(*args, **kwargs)
self._modules = tf_test_utils.compile_tf_module(
KerasLayersModule,
- exported_names=KerasLayersModule.get_exported_names())
+ exported_names=KerasLayersModule.get_tf_function_unit_tests())
def main(argv):
@@ -638,7 +638,7 @@
# to test to the KerasLayersModule, and then generate unittests for each of
# them.
KerasLayersModule.configure_class()
- KerasLayersTest.generate_unittests(KerasLayersModule)
+ KerasLayersTest.generate_unit_tests(KerasLayersModule)
tf.test.main()
diff --git a/integrations/tensorflow/e2e/logical_ops_test.py b/integrations/tensorflow/e2e/logical_ops_test.py
deleted file mode 100644
index cb25db2..0000000
--- a/integrations/tensorflow/e2e/logical_ops_test.py
+++ /dev/null
@@ -1,93 +0,0 @@
-# Copyright 2020 Google LLC
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# https://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""Tests for ops in the tf.math module that specifically handle logical ops."""
-
-from absl import app
-import numpy as np
-from pyiree.tf.support import tf_test_utils
-import tensorflow.compat.v2 as tf
-
-
-class LogicalOpsModule(tf.Module):
-
- @tf.function(input_signature=[
- tf.TensorSpec([4], tf.bool),
- tf.TensorSpec([4], tf.bool)
- ])
- def logical_and(self, x, y):
- return tf.math.logical_and(x, y)
-
- @tf.function(input_signature=[
- tf.TensorSpec([4], tf.bool),
- tf.TensorSpec([4], tf.bool)
- ])
- def logical_or(self, x, y):
- return tf.math.logical_or(x, y)
-
- @tf.function(input_signature=[
- tf.TensorSpec([4], tf.bool),
- tf.TensorSpec([4], tf.bool)
- ])
- def logical_xor(self, x, y):
- return tf.math.logical_xor(x, y)
-
- @tf.function(input_signature=[tf.TensorSpec([4], tf.bool)])
- def logical_not(self, x):
- return tf.math.logical_not(x)
-
-
-class LogicalOpsTest(tf_test_utils.TracedModuleTestCase):
-
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self._modules = tf_test_utils.compile_tf_module(LogicalOpsModule)
-
- # yapf: disable
- def test_logical_and(self):
- def logical_and(module):
- module.logical_and(
- np.array([1, 1, 0, 0], dtype=np.bool),
- np.array([0, 1, 1, 0], dtype=np.bool))
- self.compare_backends(logical_and, self._modules)
-
- def test_logical_or(self):
- def logical_or(module):
- module.logical_or(
- np.array([1, 1, 0, 0], dtype=np.bool),
- np.array([0, 1, 1, 0], dtype=np.bool))
- self.compare_backends(logical_or, self._modules)
-
- def test_logical_xor(self):
- def logical_xor(module):
- module.logical_xor(
- np.array([1, 1, 0, 0], dtype=np.bool),
- np.array([0, 1, 1, 0], dtype=np.bool))
- self.compare_backends(logical_xor, self._modules)
-
- def test_logical_not(self):
- def logical_not(module):
- module.logical_not(np.array([0, 1, 1, 0], dtype=np.bool))
- self.compare_backends(logical_not, self._modules)
- # yapf: enable
-
-
-def main(argv):
- del argv # Unused
- if hasattr(tf, 'enable_v2_behavior'):
- tf.enable_v2_behavior()
- tf.test.main()
-
-
-if __name__ == '__main__':
- app.run(main)
diff --git a/integrations/tensorflow/e2e/math/BUILD b/integrations/tensorflow/e2e/math/BUILD
new file mode 100644
index 0000000..242a9c9
--- /dev/null
+++ b/integrations/tensorflow/e2e/math/BUILD
@@ -0,0 +1,971 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Test coverage across backends for e2e tests is defined directly in the BUILD
+# files. A coverage table generated from this file can be viewed here:
+# https://google.github.io/iree/tf-e2e-coverage
+# Updates made to test suite names should also be reflected here:
+# https://github.com/google/iree/blob/main/scripts/update_e2e_coverage.py
+
+load(
+ "//bindings/python:build_defs.oss.bzl",
+ "INTREE_TENSORFLOW_PY_DEPS",
+ "NUMPY_DEPS",
+ "iree_py_binary",
+ "iree_py_test",
+)
+load(
+ "//integrations/tensorflow/e2e:iree_e2e_test_suite.bzl",
+ "set_difference",
+)
+load(
+ "//integrations/tensorflow/e2e:iree_e2e_cartesian_product_test_suite.bzl",
+ "iree_e2e_cartesian_product_test_suite",
+)
+
+package(
+ default_visibility = ["//visibility:public"],
+ features = ["layering_check"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+[
+ iree_py_binary(
+ name = src.replace(".py", "_manual"),
+ srcs = [src],
+ main = src,
+ python_version = "PY3",
+ deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
+ "//integrations/tensorflow/bindings/python/pyiree/tf/support",
+ ],
+ )
+ for src in glob(
+ ["*_test.py"],
+ exclude = ["keyword_spotting_streaming_test.py"],
+ )
+]
+
+# These functions were selected using all of the funcions in the tf.math docs:
+# https://www.tensorflow.org/api_docs/python/tf/math
+TF_MATH_FUNCTIONS = [
+ "abs",
+ "accumulate_n",
+ "acos",
+ "acosh",
+ "add",
+ "add_n",
+ "angle",
+ "argmax",
+ "argmin",
+ "asin",
+ "asinh",
+ "atan",
+ "atan2",
+ "atanh",
+ "bessel_i0",
+ "bessel_i0e",
+ "bessel_i1",
+ "bessel_i1e",
+ "betainc",
+ "bincount",
+ "ceil",
+ "confusion_matrix",
+ "cos",
+ "cosh",
+ "count_nonzero",
+ "cumprod",
+ "cumsum",
+ "cumulative_logsumexp",
+ "digamma",
+ "divide",
+ "divide_no_nan",
+ "equal",
+ "erf",
+ "erfc",
+ "erfinv",
+ "exp",
+ "expm1",
+ "floor",
+ "floordiv",
+ "floormod",
+ "greater",
+ "greater_equal",
+ "igamma",
+ "igammac",
+ "imag",
+ "in_top_k",
+ "invert_permutation",
+ "is_finite",
+ "is_inf",
+ "is_nan",
+ "is_non_decreasing",
+ "is_strictly_increasing",
+ "lbeta",
+ "less",
+ "less_equal",
+ "lgamma",
+ "log",
+ "log1p",
+ "log_sigmoid",
+ "log_softmax",
+ "logical_and",
+ "logical_not",
+ "logical_or",
+ "logical_xor",
+ "maximum",
+ "minimum",
+ "mod",
+ "multiply",
+ "multiply_no_nan",
+ "ndtri",
+ "negative",
+ "nextafter",
+ "not_equal",
+ "polygamma",
+ "polyval",
+ "pow",
+ "real",
+ "reciprocal",
+ "reciprocal_no_nan",
+ "reduce_all",
+ "reduce_any",
+ "reduce_euclidean_norm",
+ "reduce_logsumexp",
+ "reduce_max",
+ "reduce_mean",
+ "reduce_min",
+ "reduce_prod",
+ "reduce_std",
+ "reduce_sum",
+ "reduce_variance",
+ "rint",
+ "round",
+ "rsqrt",
+ "scalar_mul",
+ "segment_max",
+ "segment_mean",
+ "segment_min",
+ "segment_prod",
+ "segment_sum",
+ "sigmoid",
+ "sign",
+ "sin",
+ "sinh",
+ "sobol_sample",
+ "softmax",
+ "softplus",
+ "softsign",
+ "sqrt",
+ "square",
+ "squared_difference",
+ "subtract",
+ "tan",
+ "tanh",
+ # "top_k", # TODO(meadowlark): Enable once list outputs are supported.
+ "truediv",
+ "unsorted_segment_max",
+ "unsorted_segment_mean",
+ "unsorted_segment_min",
+ "unsorted_segment_prod",
+ "unsorted_segment_sqrt_n",
+ "unsorted_segment_sum",
+ "xdivy",
+ "xlog1py",
+ "xlogy",
+ "zero_fraction",
+ "zeta",
+]
+
+# keep sorted
+TFLITE_FAILING = [
+ "abs", # Failing for integer inputs.
+ "acos",
+ "acosh",
+ "asin",
+ "asinh",
+ "atan",
+ "atan2",
+ "atanh",
+ "bessel_i0",
+ "bessel_i0e",
+ "bessel_i1",
+ "bessel_i1e",
+ "betainc",
+ "bincount",
+ "confusion_matrix",
+ "conj",
+ "cosh",
+ "cumprod",
+ "cumulative_logsumexp",
+ "digamma",
+ "divide", # Failing for integer inputs.
+ "erf",
+ "erfc",
+ "erfinv",
+ "expm1",
+ "igamma",
+ "igammac",
+ "in_top_k",
+ "invert_permutation",
+ "is_finite",
+ "is_non_decreasing",
+ "is_strictly_increasing",
+ "l2_normalize",
+ "lbeta",
+ "lgamma",
+ "log1p",
+ "log_sigmoid",
+ "ndtri",
+ "nextafter",
+ "polygamma",
+ "polyval",
+ "pow", # Failing for integer inputs.
+ "reduce_all",
+ "reduce_euclidean_norm",
+ "reduce_logsumexp",
+ "reduce_mean",
+ "reduce_std",
+ "reduce_variance",
+ "rint",
+ "segment_max",
+ "segment_mean",
+ "segment_min",
+ "segment_prod",
+ "sign",
+ "sinh",
+ "sobol_sample",
+ "softmax",
+ "softplus",
+ "softsign",
+ "tan",
+ "unsorted_segment_max",
+ "unsorted_segment_mean",
+ "unsorted_segment_min",
+ "unsorted_segment_prod",
+ "unsorted_segment_sqrt_n",
+ "unsorted_segment_sum",
+ "xdivy",
+ "xlog1py",
+ "xlogy",
+ "zeta",
+]
+
+# Note: The VMLA_FAILING_DYNAMIC specification extends this list. Newly-passing
+# functions removed from this list may need to be added to VMLA_FAILING_DYNAMIC.
+# keep sorted
+VMLA_FAILING = [
+ "acosh",
+ "argmax",
+ "argmin",
+ "asin",
+ "asinh",
+ "atan2",
+ "atanh",
+ "bessel_i0",
+ "bessel_i0e",
+ "bessel_i1",
+ "bessel_i1e",
+ "betainc",
+ "bincount",
+ "confusion_matrix",
+ "cosh",
+ "count_nonzero",
+ "cumprod",
+ "cumulative_logsumexp",
+ "digamma",
+ "divide", # Failing for integer inputs because iree doesn't output 'f64'.
+ "erf",
+ "erfc",
+ "erfinv",
+ "expm1",
+ "igamma",
+ "igammac",
+ "in_top_k",
+ "invert_permutation",
+ "is_nan",
+ "is_non_decreasing",
+ "is_strictly_increasing",
+ "ndtri",
+ "nextafter",
+ "polygamma",
+ "pow", # Failing for integer inputs.
+ "reduce_all",
+ "reduce_any",
+ "reduce_euclidean_norm",
+ "reduce_prod",
+ "rint",
+ "segment_max",
+ "segment_mean",
+ "segment_min",
+ "segment_prod",
+ "segment_sum",
+ "sign",
+ "sobol_sample",
+ "softsign",
+ "unsorted_segment_max",
+ "unsorted_segment_mean",
+ "unsorted_segment_min",
+ "unsorted_segment_prod",
+ "unsorted_segment_sqrt_n",
+ "unsorted_segment_sum",
+ "xdivy",
+ "xlog1py",
+ "xlogy",
+ "zeta",
+]
+
+# Note: The LLVM_FAILING_DYNAMIC specification extends this list. Newly-passing
+# functions removed from this list may need to be added to LLVM_FAILING_DYNAMIC.
+# keep sorted
+LLVM_FAILING = [
+ "acos",
+ "acosh",
+ "argmax",
+ "argmin",
+ "asin",
+ "asinh",
+ "atan",
+ "atan2",
+ "atanh",
+ "bessel_i0",
+ "bessel_i0e",
+ "bessel_i1",
+ "bessel_i1e",
+ "betainc",
+ "bincount",
+ "confusion_matrix",
+ "cosh",
+ "count_nonzero",
+ "cumprod",
+ "cumulative_logsumexp",
+ "digamma",
+ "divide", # Failing for integer inputs because iree doesn't output 'f64'.
+ "erf",
+ "erfc",
+ "erfinv",
+ "expm1",
+ "igamma",
+ "igammac",
+ "in_top_k",
+ "invert_permutation",
+ "is_nan",
+ "is_non_decreasing",
+ "is_strictly_increasing",
+ "l2_normalize",
+ "logical_or",
+ "logical_xor",
+ "ndtri",
+ "nextafter",
+ "polygamma",
+ "pow",
+ "reduce_all",
+ "reduce_any",
+ "reduce_euclidean_norm",
+ "reduce_logsumexp",
+ "reduce_max",
+ "reduce_mean",
+ "reduce_min",
+ "reduce_prod",
+ "reduce_std",
+ "reduce_sum",
+ "reduce_variance",
+ "rint",
+ "segment_max",
+ "segment_mean",
+ "segment_min",
+ "segment_prod",
+ "segment_sum",
+ "sign",
+ "sobol_sample",
+ "softsign",
+ "unsorted_segment_max",
+ "unsorted_segment_mean",
+ "unsorted_segment_min",
+ "unsorted_segment_prod",
+ "unsorted_segment_sqrt_n",
+ "unsorted_segment_sum",
+ "xdivy",
+ "xlog1py",
+ "xlogy",
+ "zeta",
+]
+
+# Note: The VULKAN_FAILING_DYNAMIC specification extends this list.
+# Newly-passing functions removed from this list may need to be added to
+# VULKAN_FAILING_DYNAMIC.
+# keep sorted
+VULKAN_FAILING = [
+ "acos",
+ "acosh",
+ "argmax",
+ "argmin",
+ "asin",
+ "asinh",
+ "atan",
+ "atan2",
+ "atanh",
+ "bessel_i0",
+ "bessel_i0e",
+ "bessel_i1",
+ "bessel_i1e",
+ "betainc",
+ "bincount",
+ "confusion_matrix",
+ "cosh",
+ "count_nonzero",
+ "cumprod",
+ "cumsum",
+ "cumulative_logsumexp",
+ "digamma",
+ "divide", # Failing for integer inputs because iree doesn't output 'f64'.
+ "erf",
+ "erfc",
+ "erfinv",
+ "expm1",
+ "igamma",
+ "igammac",
+ "in_top_k",
+ "invert_permutation",
+ "is_nan",
+ "is_non_decreasing",
+ "is_strictly_increasing",
+ "l2_normalize",
+ "logical_and",
+ "logical_not",
+ "logical_or",
+ "logical_xor",
+ "ndtri",
+ "nextafter",
+ "polygamma",
+ "pow",
+ "reduce_all",
+ "reduce_any",
+ "reduce_euclidean_norm",
+ "reduce_logsumexp",
+ "reduce_max",
+ "reduce_mean",
+ "reduce_min",
+ "reduce_prod",
+ "reduce_std",
+ "reduce_sum",
+ "reduce_variance",
+ "rint",
+ "segment_max",
+ "segment_mean",
+ "segment_min",
+ "segment_prod",
+ "segment_sum",
+ "sign",
+ "sobol_sample",
+ "softsign",
+ "unsorted_segment_max",
+ "unsorted_segment_mean",
+ "unsorted_segment_min",
+ "unsorted_segment_prod",
+ "unsorted_segment_sqrt_n",
+ "unsorted_segment_sum",
+ "xdivy",
+ "xlog1py",
+ "xlogy",
+ "zeta",
+]
+
+# ---- INDIVIDUAL STATIC TESTS ----------------------------------------------- #
+
+# These tests allow us to generate coverage tables and give a finer-grained view
+# of the coverage, but are very slow due to bazel overhead, so they are not
+# run on the internal or OSS CI.
+iree_e2e_cartesian_product_test_suite(
+ name = "math_tests",
+ srcs = ["math_test.py"],
+ failing_configurations = [
+ {
+ # Failing on TFLite.
+ "functions": TFLITE_FAILING,
+ "target_backends": "tflite",
+ },
+ {
+ # Failing on vmla.
+ "functions": VMLA_FAILING,
+ "target_backends": "iree_vmla",
+ },
+ {
+ # Failing on llvm.
+ "functions": LLVM_FAILING,
+ "target_backends": "iree_llvmjit",
+ },
+ {
+ # Failing on vulkan.
+ "functions": VULKAN_FAILING,
+ "target_backends": "iree_vulkan",
+ },
+ ],
+ flags_to_values = {
+ "reference_backend": "tf",
+ "functions": TF_MATH_FUNCTIONS,
+ "dynamic_dims": False,
+ "test_complex": False,
+ "target_backends": [
+ "tf",
+ "tflite",
+ "iree_vmla",
+ "iree_llvmjit",
+ "iree_vulkan",
+ ],
+ },
+ main = "math_test.py",
+ tags = [
+ "manual",
+ "nokokoro",
+ "notap",
+ ],
+ deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
+ "//integrations/tensorflow/bindings/python/pyiree/tf/support",
+ ],
+)
+
+# ---- MULTIPLE STATIC TESTS ------------------------------------------------ #
+
+# These tests compile all functions in tf.math at once for testing so that
+# we can run them on the CI with 5 additional targets instead of 640.
+
+# TODO(#3810) 'multiply' outputs all zeros when compiled with other functions.
+VMLA_FAILING_MULTIPLE = VMLA_FAILING + ["multiply"]
+
+# TODO(#3810) Including 'square' causes error: Recieved signal 11.
+LLVM_FAILING_MULTIPLE = LLVM_FAILING + ["square"]
+
+# TODO(#3810) Including 'square' causes error: Recieved signal 11.
+VULKAN_FAILING_MULTIPLE = VULKAN_FAILING + ["square"]
+
+[
+ iree_py_test(
+ name = "math_tests_multiple__{}".format(target_backend),
+ srcs = ["math_test.py"],
+ args = [
+ "--reference_backend=tf",
+ "--target_backends={}".format(target_backend),
+ "--functions={}".format(",".join(functions)),
+ "--dynamic_dims=False",
+ ],
+ main = "math_test.py",
+ deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
+ "//integrations/tensorflow/bindings/python/pyiree/tf/support",
+ ],
+ )
+ for target_backend, functions in dict(
+ iree_llvmjit = set_difference(TF_MATH_FUNCTIONS, LLVM_FAILING_MULTIPLE),
+ iree_vmla = set_difference(TF_MATH_FUNCTIONS, VMLA_FAILING_MULTIPLE),
+ iree_vulkan = set_difference(TF_MATH_FUNCTIONS, VULKAN_FAILING_MULTIPLE),
+ tf = TF_MATH_FUNCTIONS,
+ tflite = set_difference(TF_MATH_FUNCTIONS, TFLITE_FAILING),
+ ).items()
+]
+
+# ---- INDIVIDUAL DYNAMIC TESTS ---------------------------------------------- #
+
+# keep sorted
+VMLA_FAILING_DYNAMIC = VMLA_FAILING + [
+ "angle",
+ "cumsum",
+ "divide_no_nan",
+ "equal",
+ "floormod",
+ "imag",
+ "lbeta",
+ "lgamma",
+ "log_sigmoid",
+ "log1p",
+ "logical_and",
+ "logical_not",
+ "logical_or",
+ "logical_xor",
+ "mod",
+ "floordiv",
+ "multiply_no_nan",
+ "round",
+ "not_equal",
+ "reciprocal_no_nan",
+ "reduce_logsumexp",
+ "reduce_max",
+ "reduce_min",
+ "reduce_sum",
+ "reduce_mean",
+ "reduce_std",
+ "reduce_variance",
+ "softplus",
+ "zero_fraction",
+]
+
+# keep sorted
+LLVM_FAILING_DYNAMIC = LLVM_FAILING + [
+ "accumulate_n",
+ "add",
+ "add_n",
+ "angle",
+ "cumsum",
+ "divide",
+ "divide_no_nan",
+ "equal",
+ "floordiv",
+ "floormod",
+ "greater",
+ "greater_equal",
+ "is_finite",
+ "is_inf",
+ "lbeta",
+ "less",
+ "less_equal",
+ "lgamma",
+ "log_sigmoid",
+ "log_softmax",
+ "log1p",
+ "logical_and",
+ "logical_not",
+ "maximum",
+ "minimum",
+ "mod",
+ "multiply",
+ "multiply_no_nan",
+ "not_equal",
+ "polyval",
+ "reciprocal",
+ "reciprocal_no_nan",
+ "reduce_mean",
+ "scalar_mul",
+ "sigmoid",
+ "sinh",
+ "softmax",
+ "softplus",
+ "square",
+ "squared_difference",
+ "subtract",
+ "round",
+ "tan",
+ "truediv",
+ "zero_fraction",
+]
+
+# keep sorted
+VULKAN_FAILING_DYNAMIC = VULKAN_FAILING + [
+ "abs",
+ "accumulate_n",
+ "add",
+ "add_n",
+ "angle",
+ "ceil",
+ "cos",
+ "divide",
+ "divide_no_nan",
+ "equal",
+ "exp",
+ "floor",
+ "floordiv",
+ "floormod",
+ "greater",
+ "greater_equal",
+ "imag",
+ "is_finite",
+ "is_inf",
+ "lbeta",
+ "less",
+ "round",
+ "less_equal",
+ "lgamma",
+ "log",
+ "log_sigmoid",
+ "log_softmax",
+ "log1p",
+ "maximum",
+ "minimum",
+ "mod",
+ "multiply",
+ "multiply_no_nan",
+ "negative",
+ "not_equal",
+ "polyval",
+ "reciprocal",
+ "reciprocal_no_nan",
+ "reduce_max",
+ "reduce_mean",
+ "reduce_min",
+ "reduce_sum",
+ "rsqrt",
+ "scalar_mul",
+ "sigmoid",
+ "sin",
+ "sinh",
+ "softmax",
+ "softplus",
+ "sqrt",
+ "square",
+ "squared_difference",
+ "subtract",
+ "tan",
+ "tanh",
+ "truediv",
+ "zero_fraction",
+]
+
+# These tests allow us to generate coverage tables and give a finer-grained view
+# of the coverage, but are very slow due to bazel overhead, so they are not
+# run on the internal or OSS CI.
+iree_e2e_cartesian_product_test_suite(
+ name = "math_dynamic_dims_tests",
+ srcs = ["math_test.py"],
+ failing_configurations = [
+ {
+ # TFLite does not support dynamic shapes.
+ "target_backends": "tflite",
+ },
+ {
+ # Failing on vmla.
+ "functions": VMLA_FAILING_DYNAMIC,
+ "target_backends": "iree_vmla",
+ },
+ {
+ # Failing on llvm.
+ "functions": LLVM_FAILING_DYNAMIC,
+ "target_backends": "iree_llvmjit",
+ },
+ {
+ # Failing on vulkan.
+ "functions": VULKAN_FAILING_DYNAMIC,
+ "target_backends": "iree_vulkan",
+ },
+ ],
+ flags_to_values = {
+ "reference_backend": "tf",
+ "functions": TF_MATH_FUNCTIONS,
+ "dynamic_dims": True,
+ "test_complex": False,
+ "target_backends": [
+ "tf",
+ "tflite",
+ "iree_vmla",
+ "iree_llvmjit",
+ "iree_vulkan",
+ ],
+ },
+ main = "math_test.py",
+ tags = [
+ "manual",
+ "nokokoro",
+ "notap",
+ ],
+ deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
+ "//integrations/tensorflow/bindings/python/pyiree/tf/support",
+ ],
+)
+
+# ---- MULTIPLE DYNAMIC TESTS ----------------------------------------------- #
+
+# These tests compile all functions in tf.math at once for testing so that
+# we can run them on the CI with 4 additional targets instead of 512.
+
+# TODO(#3810) 'multiply' outputs all zeros when compiled with other functions.
+VMLA_FAILING_DYNAMIC_MULTIPLE = VMLA_FAILING_DYNAMIC + ["multiply"]
+
+[
+ iree_py_test(
+ name = "math_dynamic_dims_tests_multiple__{}".format(target_backend),
+ srcs = ["math_test.py"],
+ args = [
+ "--reference_backend=tf",
+ "--target_backends={}".format(target_backend),
+ "--functions={}".format(",".join(functions)),
+ "--dynamic_dims=False",
+ ],
+ main = "math_test.py",
+ deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
+ "//integrations/tensorflow/bindings/python/pyiree/tf/support",
+ ],
+ )
+ for target_backend, functions in dict(
+ iree_llvmjit = set_difference(TF_MATH_FUNCTIONS, LLVM_FAILING_DYNAMIC),
+ iree_vmla = set_difference(TF_MATH_FUNCTIONS, VMLA_FAILING_DYNAMIC_MULTIPLE),
+ iree_vulkan = set_difference(TF_MATH_FUNCTIONS, VULKAN_FAILING_DYNAMIC),
+ tf = TF_MATH_FUNCTIONS,
+ ).items()
+]
+
+# ---- INDIVIDUAL COMPLEX TESTS ---------------------------------------------- #
+
+# This list was generated by running:
+# bazel run integrations/tensorflow/e2e/math:math_test_manual -- --list_functions_with_complex_tests
+COMPLEX_FUNCTIONS = [
+ "abs",
+ "add",
+ "angle",
+ "asinh",
+ "atanh",
+ "conj",
+ "cos",
+ "cosh",
+ "count_nonzero",
+ "cumprod",
+ "cumsum",
+ "divide",
+ "divide_no_nan",
+ "exp",
+ "expm1",
+ "imag",
+ "l2_normalize",
+ "log",
+ "log1p",
+ "multiply",
+ "multiply_no_nan",
+ "negative",
+ "pow",
+ "real",
+ "reciprocal",
+ "reciprocal_no_nan",
+ "reduce_euclidean_norm",
+ "reduce_std",
+ "reduce_variance",
+ "rsqrt",
+ "sigmoid",
+ "sign",
+ "sin",
+ "sinh",
+ "sqrt",
+ "square",
+ "squared_difference",
+ "subtract",
+ "tan",
+ "tanh",
+ "truediv",
+ "xdivy",
+ "xlog1py",
+ "xlogy",
+ "zero_fraction",
+]
+
+# keep sorted
+FAILING_COMPLEX = [
+ "angle",
+ "cos",
+ "cumsum",
+ "divide_no_nan",
+ "log",
+ "log1p",
+ "multiply_no_nan",
+ "negative",
+ "reciprocal",
+ "reciprocal_no_nan",
+ "reduce_std",
+ "reduce_variance",
+ "rsqrt",
+ "sigmoid",
+ "sin",
+ "sinh",
+ "sqrt",
+ "tan",
+ "tanh",
+ "zero_fraction",
+]
+
+VMLA_FAILING_COMPLEX = VMLA_FAILING + FAILING_COMPLEX
+
+LLVM_FAILING_COMPLEX = LLVM_FAILING + FAILING_COMPLEX
+
+VULKAN_FAILING_COMPLEX = VULKAN_FAILING + FAILING_COMPLEX
+
+# These tests allow us to generate coverage tables and give a finer-grained view
+# of the coverage, but are very slow due to bazel overhead, so they are not
+# run on the internal or OSS CI.
+iree_e2e_cartesian_product_test_suite(
+ name = "math_complex_tests",
+ srcs = ["math_test.py"],
+ failing_configurations = [
+ {
+ # TFLite does not support complex numbers.
+ "target_backends": "tflite",
+ },
+ {
+ # Failing on vmla.
+ "functions": VMLA_FAILING_COMPLEX,
+ "target_backends": "iree_vmla",
+ },
+ {
+ # Failing on llvm.
+ "functions": LLVM_FAILING_COMPLEX,
+ "target_backends": "iree_llvmjit",
+ },
+ {
+ # Failing on vulkan.
+ "functions": VULKAN_FAILING_COMPLEX,
+ "target_backends": "iree_vulkan",
+ },
+ ],
+ flags_to_values = {
+ "reference_backend": "tf",
+ "functions": COMPLEX_FUNCTIONS,
+ "dynamic_dims": False,
+ "test_complex": True,
+ "target_backends": [
+ "tf",
+ "tflite",
+ "iree_vmla",
+ "iree_llvmjit",
+ "iree_vulkan",
+ ],
+ },
+ main = "math_test.py",
+ tags = [
+ "manual",
+ "nokokoro",
+ "notap",
+ ],
+ deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
+ "//integrations/tensorflow/bindings/python/pyiree/tf/support",
+ ],
+)
+
+# ---- MULTIPLE COMPLEX TESTS ----------------------------------------------- #
+
+# These tests compile all functions in tf.math at once for testing so that
+# we can run them on the CI with 4 additional targets instead of 512.
+
+# TODO(#3810) 'multiply' outputs all zeros when compiled with other functions.
+VMLA_FAILING_COMPLEX_MULTIPLE = VMLA_FAILING_COMPLEX + ["multiply"]
+
+# TODO(#3810) Including 'square' causes error: Recieved signal 11.
+LLVM_FAILING_COMPLEX_MULTIPLE = LLVM_FAILING_COMPLEX + ["square"]
+
+# TODO(#3810) Including 'square' causes error: Recieved signal 11.
+VULKAN_FAILING_COMPLEX_MULTIPLE = VULKAN_FAILING_COMPLEX + ["square"]
+
+[
+ iree_py_test(
+ name = "math_complex_tests_multiple__{}".format(target_backend),
+ srcs = ["math_test.py"],
+ args = [
+ "--reference_backend=tf",
+ "--target_backends={}".format(target_backend),
+ "--functions={}".format(",".join(functions)),
+ "--dynamic_dims=False",
+ ],
+ main = "math_test.py",
+ deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
+ "//integrations/tensorflow/bindings/python/pyiree/tf/support",
+ ],
+ )
+ for target_backend, functions in dict(
+ iree_llvmjit = set_difference(TF_MATH_FUNCTIONS, LLVM_FAILING_COMPLEX_MULTIPLE),
+ iree_vmla = set_difference(TF_MATH_FUNCTIONS, VMLA_FAILING_COMPLEX_MULTIPLE),
+ iree_vulkan = set_difference(TF_MATH_FUNCTIONS, VULKAN_FAILING_COMPLEX_MULTIPLE),
+ tf = TF_MATH_FUNCTIONS,
+ ).items()
+]
diff --git a/integrations/tensorflow/e2e/math/math_test.py b/integrations/tensorflow/e2e/math/math_test.py
new file mode 100644
index 0000000..7a1096f
--- /dev/null
+++ b/integrations/tensorflow/e2e/math/math_test.py
@@ -0,0 +1,735 @@
+import collections
+import os
+from typing import Any, Dict, Sequence, Type, Union
+
+from absl import app
+from absl import flags
+import numpy as np
+from pyiree.tf.support import tf_test_utils
+from pyiree.tf.support import tf_utils
+import tensorflow.compat.v2 as tf
+
+FLAGS = flags.FLAGS
+
+# As high as tf goes without breaking.
+RANK_7_SHAPE = [2] * 7
+UNARY_SIGNATURE_SHAPES = [[RANK_7_SHAPE]]
+BINARY_SIGNATURE_SHAPES = [[RANK_7_SHAPE] * 2]
+TERNARY_SIGNATURE_SHAPES = [[RANK_7_SHAPE] * 3]
+
+# Reused UnitTestSpecs.
+SEGMENT_UNIT_TEST_SPECS = tf_test_utils.unit_test_specs_from_args(
+ names_to_input_args={
+ "tf_doc_example": [
+ tf.constant([
+ [1, 2, 3, 4],
+ [4, 3, 2, 1],
+ [5, 6, 7, 8],
+ ], np.float32),
+ np.array([0, 0, 1], np.int32),
+ ]
+ })
+UNSORTED_SEGMENT_UNIT_TEST_SPECS = tf_test_utils.unit_test_specs_from_args(
+ names_to_input_args={
+ "tf_doc_example": [
+ tf.constant([
+ [1, 2, 3, 4],
+ [4, 3, 2, 1],
+ [5, 6, 7, 8],
+ ], np.float32),
+ np.array([0, 0, 1], np.int32),
+ 2,
+ ]
+ })
+
+REDUCE_KWARGS_TO_VALUES = {
+ "axis": [None, 1],
+ "keepdims": [False, True],
+}
+
+# A dictionary mapping tf.math function names to lists of UnitTestSpecs.
+# Each unit_test_name will have the tf.math function name prepended to it.
+FUNCTIONS_TO_UNIT_TEST_SPECS = {
+ "abs":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.int32, tf.complex64]),
+ "accumulate_n": [
+ tf_test_utils.UnitTestSpec(
+ unit_test_name='f32',
+ input_signature=[[tf.TensorSpec(RANK_7_SHAPE, tf.float32)] * 5]),
+ tf_test_utils.UnitTestSpec(
+ unit_test_name='i32',
+ input_signature=[[tf.TensorSpec(RANK_7_SHAPE, tf.int32)] * 5]),
+ ],
+ "acos":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32]),
+ "acosh":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32],
+ input_generators=[tf_utils.ndarange]),
+ "add":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=BINARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.int32, tf.complex64]),
+ "add_n": [
+ tf_test_utils.UnitTestSpec(
+ unit_test_name='f32',
+ input_signature=[[tf.TensorSpec(RANK_7_SHAPE, tf.float32)] * 5]),
+ tf_test_utils.UnitTestSpec(
+ unit_test_name='i32',
+ input_signature=[[tf.TensorSpec(RANK_7_SHAPE, tf.int32)] * 5]),
+ ],
+ "angle":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.complex64]),
+ "argmax":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.int32]),
+ "argmin":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.int32]),
+ "asin":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32]),
+ "asinh":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.complex64]),
+ "atan":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32]),
+ "atan2":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=BINARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32]),
+ "atanh":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.complex64]),
+ "bessel_i0":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32]),
+ "bessel_i0e":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32]),
+ "bessel_i1":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32]),
+ "bessel_i1e":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32]),
+ "betainc":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=TERNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32]),
+ "bincount":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.int32],
+ input_generators=[tf_utils.ndarange]),
+ "ceil":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32]),
+ "confusion_matrix":
+ tf_test_utils.unit_test_specs_from_args(names_to_input_args={
+ "five_classes": [tf.constant([1, 2, 4]),
+ tf.constant([2, 2, 4])]
+ }),
+ "conj":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.int32, tf.complex64]),
+ "cos":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.complex64]),
+ "cosh":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.complex64]),
+ "count_nonzero":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.int32, tf.complex64],
+ input_generators=[tf_utils.ndarange]),
+ "cumprod":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.int32, tf.complex64]),
+ "cumsum":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.int32, tf.complex64]),
+ "cumulative_logsumexp":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32]),
+ "digamma":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32]),
+ "divide":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=BINARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.int32, tf.complex64]),
+ "divide_no_nan":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=BINARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.complex64]),
+ "equal":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=BINARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.int32]),
+ "erf":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32]),
+ "erfc":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32]),
+ "erfinv":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32]),
+ "exp":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.complex64]),
+ "expm1":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.complex64]),
+ "floor":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32]),
+ "floordiv":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=BINARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.int32],
+ # Avoid integer division by 0.
+ input_generators={
+ "uniform_1_3":
+ lambda *args: tf_utils.uniform(*args, low=1.0, high=3.0)
+ }),
+ "floormod":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=BINARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.int32],
+ # Avoid integer division by 0.
+ input_generators={
+ "uniform_1_3":
+ lambda *args: tf_utils.uniform(*args, low=1.0, high=3.0)
+ }),
+ "greater":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=BINARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.int32]),
+ "greater_equal":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=BINARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.int32]),
+ "igamma":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=BINARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32]),
+ "igammac":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=BINARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32]),
+ "imag":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.complex64]),
+ "in_top_k": [
+ tf_test_utils.UnitTestSpec(
+ unit_test_name="k_3",
+ input_signature=[
+ tf.TensorSpec([8], tf.int32),
+ tf.TensorSpec([8, 3])
+ ],
+ input_generator=tf_utils.ndarange,
+ kwargs=dict(k=3),
+ )
+ ],
+ "invert_permutation": [
+ tf_test_utils.UnitTestSpec(
+ unit_test_name="random",
+ input_signature=[tf.TensorSpec([8], tf.int32)],
+ input_generator=tf_utils.random_permutation,
+ )
+ ],
+ "is_finite":
+ tf_test_utils.unit_test_specs_from_args(names_to_input_args={
+ "nan_and_inf": [tf.constant([[1., np.nan], [np.inf, 2.]])]
+ }),
+ "is_inf":
+ tf_test_utils.unit_test_specs_from_args(names_to_input_args={
+ "nan_and_inf": [tf.constant([[1., np.nan], [np.inf, 2.]])]
+ }),
+ "is_nan":
+ tf_test_utils.unit_test_specs_from_args(names_to_input_args={
+ "nan_and_inf": [tf.constant([[1., np.nan], [np.inf, 2.]])]
+ }),
+ "is_non_decreasing":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.int32]),
+ "is_strictly_increasing":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.int32]),
+ "l2_normalize":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.complex64]),
+ "lbeta":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32]),
+ "less":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=BINARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.int32]),
+ "less_equal":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=BINARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.int32]),
+ "lgamma":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32]),
+ "log":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.complex64]),
+ "log1p":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.complex64]),
+ "log_sigmoid":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32]),
+ "log_softmax":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32]),
+ "logical_and":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=BINARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.bool]),
+ "logical_not":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.bool]),
+ "logical_or":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=BINARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.bool]),
+ "logical_xor":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=BINARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.bool]),
+ "maximum":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=BINARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.int32]),
+ "minimum":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=BINARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.int32]),
+ "mod":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=BINARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.int32],
+ input_generators={
+ "positive_ndarange": lambda *args: tf_utils.ndarange(*args) + 1
+ }),
+ "multiply":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=BINARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.int32, tf.complex64]),
+ "multiply_no_nan":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=BINARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.complex64]),
+ "ndtri":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32]),
+ "negative":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.int32, tf.complex64]),
+ "nextafter":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=BINARY_SIGNATURE_SHAPES),
+ "not_equal":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=BINARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.int32]),
+ "polygamma":
+ tf_test_utils.unit_test_specs_from_args(names_to_input_args={
+ "nan_and_inf": [tf.ones(16), tf.linspace(0.5, 4, 16)]
+ }),
+ "polyval": [
+ tf_test_utils.UnitTestSpec(
+ unit_test_name="three_coeffs",
+ input_signature=[[tf.TensorSpec(RANK_7_SHAPE)] * 3,
+ tf.TensorSpec([])],
+ )
+ ],
+ "pow":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=BINARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.int32, tf.complex64],
+ input_generators={
+ "positive_ndarange": lambda *args: tf_utils.ndarange(*args) + 1
+ }),
+ "real":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.complex64]),
+ "reciprocal":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.complex64]),
+ "reciprocal_no_nan":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.complex64]),
+ "reduce_all": [
+ # Explicitly test all True inputs to be absolutely sure that some
+ # reduction axes return True.
+ *tf_test_utils.unit_test_specs_from_args(
+ names_to_input_args={
+ "all_true": [np.ones(RANK_7_SHAPE, np.bool)],
+ },
+ kwargs_to_values=REDUCE_KWARGS_TO_VALUES),
+ *tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.bool],
+ kwargs_to_values=REDUCE_KWARGS_TO_VALUES),
+ ],
+ "reduce_any": [
+ # Explicitly test all False inputs to be absolutely sure that some
+ # reduction axes return False.
+ *tf_test_utils.unit_test_specs_from_args(
+ names_to_input_args={
+ "all_false": [np.zeros(RANK_7_SHAPE, np.bool)],
+ },
+ kwargs_to_values=REDUCE_KWARGS_TO_VALUES),
+ *tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.bool],
+ kwargs_to_values=REDUCE_KWARGS_TO_VALUES),
+ ],
+ "reduce_euclidean_norm":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.complex64],
+ kwargs_to_values=REDUCE_KWARGS_TO_VALUES),
+ "reduce_logsumexp":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32],
+ kwargs_to_values=REDUCE_KWARGS_TO_VALUES),
+ "reduce_max":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.int32],
+ kwargs_to_values=REDUCE_KWARGS_TO_VALUES),
+ "reduce_mean":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.int32],
+ kwargs_to_values=REDUCE_KWARGS_TO_VALUES),
+ "reduce_min":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.int32],
+ kwargs_to_values=REDUCE_KWARGS_TO_VALUES),
+ "reduce_prod":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.int32],
+ kwargs_to_values=REDUCE_KWARGS_TO_VALUES),
+ "reduce_std":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.complex64],
+ kwargs_to_values=REDUCE_KWARGS_TO_VALUES),
+ "reduce_sum":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.int32],
+ kwargs_to_values=REDUCE_KWARGS_TO_VALUES),
+ "reduce_variance":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.complex64],
+ kwargs_to_values=REDUCE_KWARGS_TO_VALUES),
+ "rint":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32]),
+ "round":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32]),
+ "rsqrt":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.complex64]),
+ "scalar_mul":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=[[[], [8]]]),
+ "segment_max":
+ SEGMENT_UNIT_TEST_SPECS,
+ "segment_mean":
+ SEGMENT_UNIT_TEST_SPECS,
+ "segment_min":
+ SEGMENT_UNIT_TEST_SPECS,
+ "segment_prod":
+ SEGMENT_UNIT_TEST_SPECS,
+ "segment_sum":
+ SEGMENT_UNIT_TEST_SPECS,
+ "sigmoid":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.complex64]),
+ "sign":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.int32, tf.complex64]),
+ "sin":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.complex64]),
+ "sinh":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.complex64]),
+ "sobol_sample":
+ tf_test_utils.unit_test_specs_from_args(
+ names_to_input_args={"simple": [4, 3]}),
+ "softmax":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32]),
+ "softplus":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32]),
+ "softsign":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32]),
+ "sqrt":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.complex64]),
+ "square":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.int32, tf.complex64]),
+ "squared_difference":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=BINARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.int32, tf.complex64]),
+ "subtract":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=BINARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.int32, tf.complex64]),
+ "tan":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.complex64]),
+ "tanh":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.complex64]),
+ "top_k":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32],
+ kwargs_to_values={"k": [1, 2]}),
+ "truediv":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=BINARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.complex64]),
+ "unsorted_segment_max":
+ UNSORTED_SEGMENT_UNIT_TEST_SPECS,
+ "unsorted_segment_mean":
+ UNSORTED_SEGMENT_UNIT_TEST_SPECS,
+ "unsorted_segment_min":
+ UNSORTED_SEGMENT_UNIT_TEST_SPECS,
+ "unsorted_segment_prod":
+ UNSORTED_SEGMENT_UNIT_TEST_SPECS,
+ "unsorted_segment_sqrt_n":
+ UNSORTED_SEGMENT_UNIT_TEST_SPECS,
+ "unsorted_segment_sum":
+ UNSORTED_SEGMENT_UNIT_TEST_SPECS,
+ "xdivy":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=BINARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.complex64]),
+ "xlog1py":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=BINARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.complex64]),
+ "xlogy":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=BINARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.complex64]),
+ "zero_fraction":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.int32, tf.complex64]),
+ "zeta":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=BINARY_SIGNATURE_SHAPES,
+ # The function is poorly behaved near zero, so we test this range
+ # to avoid outputing all nans.
+ input_generators={
+ "uniform_3_4":
+ lambda *args: tf_utils.uniform(*args, low=3.0, high=4.0)
+ },
+ )
+}
+
+for function, specs in FUNCTIONS_TO_UNIT_TEST_SPECS.items():
+ # Update using 'with_name' to avoid updating shared UnitTestSpecs.
+ specs = [
+ spec.with_name(f"{function}__{spec.unit_test_name}") for spec in specs
+ ]
+ FUNCTIONS_TO_UNIT_TEST_SPECS[function] = specs
+
+ # Validate that there are not multiple UnitTestSpecs with the same name.
+ seen_unit_test_names = set()
+ for spec in specs:
+ if spec.unit_test_name in seen_unit_test_names:
+ raise ValueError(
+ f"Found multiple UnitTestSpecs with the name '{spec.unit_test_name}'")
+ seen_unit_test_names.add(spec.unit_test_name)
+
+flags.DEFINE_list(
+ "functions", None,
+ f"Any of {list(FUNCTIONS_TO_UNIT_TEST_SPECS.keys())}. If more than one "
+ "function is provided then len(--target_backends) must be one.")
+flags.DEFINE_bool(
+ "dynamic_dims", False,
+ "Whether or not to compile the layer with dynamic dimensions.")
+flags.DEFINE_bool(
+ "test_complex", False,
+ "Whether or not to test or ignore function signatures with complex types.")
+flags.DEFINE_bool(
+ 'list_functions_with_complex_tests', False,
+ 'Whether or not to print out all functions with complex inputs '
+ '(and skip running the tests).')
+
+
+def create_function_unit_test(
+ function_name: str,
+ unit_test_spec: tf_test_utils.UnitTestSpec) -> tf.function:
+ """Creates a tf_function_unit_test from the provided UnitTestSpec."""
+ function = getattr(tf.math, function_name)
+ signature = unit_test_spec.input_signature
+
+ if tf_utils.is_complex(signature):
+ function, signature = tf_utils.rewrite_complex_signature(
+ function, signature)
+ wrapped_function = lambda *args: function(*args, **unit_test_spec.kwargs)
+
+ if FLAGS.dynamic_dims:
+ signature = tf_utils.apply_function(signature, tf_utils.make_dims_dynamic)
+
+ return tf_test_utils.tf_function_unit_test(
+ input_signature=signature,
+ input_generator=unit_test_spec.input_generator,
+ input_args=unit_test_spec.input_args,
+ name=unit_test_spec.unit_test_name,
+ rtol=1e-5,
+ atol=1e-5)(wrapped_function)
+
+
+class TfMathModule(tf_test_utils.TestModule):
+
+ def __init__(self):
+ super().__init__()
+ for function in FLAGS.functions:
+ for unit_test_spec in FUNCTIONS_TO_UNIT_TEST_SPECS[function]:
+ if not FLAGS.test_complex and tf_utils.is_complex(
+ unit_test_spec.input_signature):
+ continue
+ function_unit_test = create_function_unit_test(function, unit_test_spec)
+ setattr(self, unit_test_spec.unit_test_name, function_unit_test)
+
+
+class TfMathTest(tf_test_utils.TracedModuleTestCase):
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self._modules = tf_test_utils.compile_tf_module(
+ TfMathModule, exported_names=TfMathModule.get_tf_function_unit_tests())
+
+
+def main(argv):
+ del argv # Unused.
+ if hasattr(tf, "enable_v2_behavior"):
+ tf.enable_v2_behavior()
+
+ if FLAGS.list_functions_with_complex_tests:
+ for function_name, unit_test_specs in FUNCTIONS_TO_UNIT_TEST_SPECS.items():
+ for spec in unit_test_specs:
+ if tf_utils.is_complex(spec.input_signature):
+ print(f' "{function_name}",')
+ return
+
+ if FLAGS.functions is None:
+ raise flags.IllegalFlagValueError(
+ "'--functions' must be specified if "
+ "'--list_functions_with_complex_tests' isn't")
+
+ if len(FLAGS.functions) > 1:
+ # We only allow testing multiple functions with a single target backend
+ # so that we can store the artifacts under:
+ # 'artifacts_dir/multiple_functions__backend/...'
+ # We specialize the 'multiple_functions' dir by backend to avoid overwriting
+ # tf_input.mlir and iree_input.mlir. These are typically identical across
+ # backends, but are not when the functions to compile change per-backend.
+ if len(FLAGS.target_backends) != 1:
+ raise flags.IllegalFlagValueError(
+ "Expected len(target_backends) == 1 when len(functions) > 1, but got "
+ f"the following values for target_backends: {FLAGS.target_backends}.")
+ function_str = f"multiple_functions__{FLAGS.target_backends[0]}"
+ else:
+ function_str = FLAGS.functions[0]
+ dim_str = "dynamic_dims" if FLAGS.dynamic_dims else "static_dims"
+ settings_str = os.path.join(function_str, dim_str)
+ # The relative artifacts directory path is calculated from the module name
+ # TODO(meadowlark): provide a better way of overridding this default.
+ TfMathModule.__name__ = os.path.join("tf", "math", settings_str)
+
+ TfMathTest.generate_unit_tests(TfMathModule)
+ tf.test.main()
+
+
+if __name__ == "__main__":
+ app.run(main)
diff --git a/integrations/tensorflow/e2e/math_dyn_test.py b/integrations/tensorflow/e2e/math_dyn_test.py
deleted file mode 100644
index cca0918..0000000
--- a/integrations/tensorflow/e2e/math_dyn_test.py
+++ /dev/null
@@ -1,102 +0,0 @@
-# Lint as: python3
-# Copyright 2020 Google LLC
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# https://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""Tests for ops in the tf.math module."""
-
-from absl import app
-import numpy as np
-from pyiree.tf.support import tf_test_utils
-import tensorflow.compat.v2 as tf
-
-
-class MathModule(tf.Module):
-
- @tf.function(input_signature=[tf.TensorSpec([None], tf.float32)])
- def abs(self, x):
- return tf.math.abs(x)
-
- @tf.function(input_signature=[tf.TensorSpec([None], tf.float32)])
- def ceil(self, x):
- return tf.math.ceil(x)
-
- @tf.function(input_signature=[tf.TensorSpec([None], tf.float32)])
- def cos(self, x):
- return tf.math.cos(x)
-
- @tf.function(input_signature=[tf.TensorSpec([None], tf.float32)])
- def log(self, x):
- return tf.math.log(x)
-
- @tf.function(input_signature=[tf.TensorSpec([None], tf.float32)])
- def mod(self, x):
- return tf.math.mod(x, 2.0)
-
- @tf.function(input_signature=[tf.TensorSpec([None], tf.float32)])
- def fake_quant(self, x):
- return tf.quantization.fake_quant_with_min_max_args(x,
- min=-6,
- max=6,
- num_bits=8,
- narrow_range=False,
- name=None)
-
-
-class MathTest(tf_test_utils.TracedModuleTestCase):
-
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self._modules = tf_test_utils.compile_tf_module(MathModule)
-
- # yapf: disable
- def test_abs(self):
- def abs(module):
- module.abs(np.array([-0.5, 0.0, 0.5, 1.0], dtype=np.float32))
- self.compare_backends(abs, self._modules)
-
- def test_ceil(self):
- def ceil(module):
- module.ceil(np.array([0.0, 1.2, 1.5, 3.75], dtype=np.float32))
- self.compare_backends(ceil, self._modules)
-
- def test_cos(self):
- def cos(module):
- module.cos(np.array([-0.5, 0.0, 0.5, 1.0], dtype=np.float32))
- self.compare_backends(cos, self._modules)
-
- def test_log(self):
- def log(module):
- module.log(np.array([0.1, 0.2, 0.5, 1.0], dtype=np.float32))
- self.compare_backends(log, self._modules)
-
- def test_mod(self):
- def mod(module):
- module.mod(np.array([0.0, 1.2, 1.5, 3.75], dtype=np.float32))
- self.compare_backends(mod, self._modules)
-
- def test_fake_quant(self):
- def abs(module):
- module.fake_quant(np.array([-0.123, 0.1234, 0.743, 4.3], dtype=np.float32))
- self.compare_backends(abs, self._modules)
- # yapf: enable
-
-
-def main(argv):
- del argv # Unused
- if hasattr(tf, 'enable_v2_behavior'):
- tf.enable_v2_behavior()
- tf.test.main()
-
-
-if __name__ == '__main__':
- app.run(main)
diff --git a/integrations/tensorflow/e2e/math_test.py b/integrations/tensorflow/e2e/math_test.py
deleted file mode 100644
index add9ea4..0000000
--- a/integrations/tensorflow/e2e/math_test.py
+++ /dev/null
@@ -1,102 +0,0 @@
-# Lint as: python3
-# Copyright 2020 Google LLC
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# https://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""Tests for ops in the tf.math module."""
-
-from absl import app
-import numpy as np
-from pyiree.tf.support import tf_test_utils
-import tensorflow.compat.v2 as tf
-
-
-class MathModule(tf.Module):
-
- @tf.function(input_signature=[tf.TensorSpec([4], tf.float32)])
- def abs(self, x):
- return tf.math.abs(x)
-
- @tf.function(input_signature=[tf.TensorSpec([4], tf.float32)])
- def ceil(self, x):
- return tf.math.ceil(x)
-
- @tf.function(input_signature=[tf.TensorSpec([4], tf.float32)])
- def cos(self, x):
- return tf.math.cos(x)
-
- @tf.function(input_signature=[tf.TensorSpec([4], tf.float32)])
- def log(self, x):
- return tf.math.log(x)
-
- @tf.function(input_signature=[tf.TensorSpec([4], tf.float32)])
- def mod(self, x):
- return tf.math.mod(x, 2.0)
-
- @tf.function(input_signature=[tf.TensorSpec([4], tf.float32)])
- def fake_quant(self, x):
- return tf.quantization.fake_quant_with_min_max_args(x,
- min=-6,
- max=6,
- num_bits=8,
- narrow_range=False,
- name=None)
-
-
-class MathTest(tf_test_utils.TracedModuleTestCase):
-
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self._modules = tf_test_utils.compile_tf_module(MathModule)
-
- # yapf: disable
- def test_abs(self):
- def abs(module):
- module.abs(np.array([-0.5, 0.0, 0.5, 1.0], dtype=np.float32))
- self.compare_backends(abs, self._modules)
-
- def test_ceil(self):
- def ceil(module):
- module.ceil(np.array([0.0, 1.2, 1.5, 3.75], dtype=np.float32))
- self.compare_backends(ceil, self._modules)
-
- def test_cos(self):
- def cos(module):
- module.cos(np.array([-0.5, 0.0, 0.5, 1.0], dtype=np.float32))
- self.compare_backends(cos, self._modules)
-
- def test_log(self):
- def log(module):
- module.log(np.array([0.1, 0.2, 0.5, 1.0], dtype=np.float32))
- self.compare_backends(log, self._modules)
-
- def test_mod(self):
- def mod(module):
- module.mod(np.array([0.0, 1.2, 1.5, 3.75], dtype=np.float32))
- self.compare_backends(mod, self._modules)
-
- def test_fake_quant(self):
- def abs(module):
- module.fake_quant(np.array([-0.123, 0.1234, 0.743, 4.3], dtype=np.float32))
- self.compare_backends(abs, self._modules)
- # yapf: enable
-
-
-def main(argv):
- del argv # Unused
- if hasattr(tf, 'enable_v2_behavior'):
- tf.enable_v2_behavior()
- tf.test.main()
-
-
-if __name__ == '__main__':
- app.run(main)
diff --git a/integrations/tensorflow/e2e/quantization_dyn_test.py b/integrations/tensorflow/e2e/quantization_dyn_test.py
new file mode 100644
index 0000000..d1c44ef
--- /dev/null
+++ b/integrations/tensorflow/e2e/quantization_dyn_test.py
@@ -0,0 +1,58 @@
+# Lint as: python3
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tests for ops in the tf.math module."""
+
+from absl import app
+import numpy as np
+from pyiree.tf.support import tf_test_utils
+from pyiree.tf.support import tf_utils
+import tensorflow.compat.v2 as tf
+
+
+class QuantizationDynModule(tf.Module):
+
+ @tf.function(input_signature=[tf.TensorSpec([None], tf.float32)])
+ def fake_quant(self, x):
+ return tf.quantization.fake_quant_with_min_max_args(x,
+ min=-6,
+ max=6,
+ num_bits=8,
+ narrow_range=False,
+ name=None)
+
+
+class QuantizationDynTest(tf_test_utils.TracedModuleTestCase):
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self._modules = tf_test_utils.compile_tf_module(QuantizationDynModule)
+
+ def test_fake_quant(self):
+
+ def abs(module):
+ module.fake_quant(tf_utils.uniform([32], low=-6, high=6))
+
+ self.compare_backends(abs, self._modules)
+
+
+def main(argv):
+ del argv # Unused
+ if hasattr(tf, 'enable_v2_behavior'):
+ tf.enable_v2_behavior()
+ tf.test.main()
+
+
+if __name__ == '__main__':
+ app.run(main)
diff --git a/integrations/tensorflow/e2e/quantization_test.py b/integrations/tensorflow/e2e/quantization_test.py
new file mode 100644
index 0000000..2ccf8f8
--- /dev/null
+++ b/integrations/tensorflow/e2e/quantization_test.py
@@ -0,0 +1,55 @@
+# Lint as: python3
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tests for ops in the tf.math module."""
+
+from absl import app
+import numpy as np
+from pyiree.tf.support import tf_test_utils
+from pyiree.tf.support import tf_utils
+import tensorflow.compat.v2 as tf
+
+
+class QuantizationModule(tf_test_utils.TestModule):
+
+ @tf_test_utils.tf_function_unit_test(
+ input_signature=[tf.TensorSpec([32], tf.float32)],
+ input_generator=lambda *args: tf_utils.uniform(*args, low=-6, high=6))
+ def fake_quant(self, x):
+ return tf.quantization.fake_quant_with_min_max_args(x,
+ min=-6,
+ max=6,
+ num_bits=8,
+ narrow_range=False,
+ name=None)
+
+
+class QuantizationTest(tf_test_utils.TracedModuleTestCase):
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self._modules = tf_test_utils.compile_tf_module(QuantizationModule)
+
+
+def main(argv):
+ del argv # Unused
+ if hasattr(tf, 'enable_v2_behavior'):
+ tf.enable_v2_behavior()
+
+ QuantizationTest.generate_unit_tests(QuantizationModule)
+ tf.test.main()
+
+
+if __name__ == '__main__':
+ app.run(main)
diff --git a/scripts/update_e2e_coverage.py b/scripts/update_e2e_coverage.py
index fed8c12..3b5509c 100755
--- a/scripts/update_e2e_coverage.py
+++ b/scripts/update_e2e_coverage.py
@@ -43,7 +43,12 @@
KWS_LINK = f'[Keyword Spotting Streaming]({KWS_LINK})'
COVERAGE_GROUP_TO_TEST_SUITES = {
- 'tf_base_coverage': ['//integrations/tensorflow/e2e:e2e_tests'],
+ 'tf_base_coverage': [
+ '//integrations/tensorflow/e2e:e2e_tests',
+ '//integrations/tensorflow/e2e/math:math_tests',
+ '//integrations/tensorflow/e2e/math:math_dynamic_dims_tests',
+ '//integrations/tensorflow/e2e/math:math_complex_tests',
+ ],
'tf_keras_coverage': [
'//integrations/tensorflow/e2e/keras/layers:layers_tests',
'//integrations/tensorflow/e2e/keras/layers:layers_dynamic_batch_tests',
@@ -81,8 +86,16 @@
}
TEST_SUITES_TO_HEADERS = {
+ # tf_base_coverage
'//integrations/tensorflow/e2e:e2e_tests':
'End to end TensorFlow tests',
+ '//integrations/tensorflow/e2e/math:math_tests':
+ 'End to end tests of tf.math functions with static dimensions',
+ '//integrations/tensorflow/e2e/math:math_dynamic_dims_tests':
+ 'End to end tests of tf.math functions with dynamic dimensions',
+ '//integrations/tensorflow/e2e/math:math_complex_tests':
+ 'End to end tests of tf.math functions with complex numbers',
+ # tf_keras_coverage
'//integrations/tensorflow/e2e/keras/layers:layers_tests':
'End to end tests of tf.keras layers (with default configuration and '
'static batch sizes in inference mode)',
@@ -95,12 +108,14 @@
'//integrations/tensorflow/e2e/keras/layers:layers_training_tests':
'End to end tests of tf.keras layers in training mode (with default'
'configuration and static batch sizes)',
+ # language_and_speech_coverage
'//integrations/tensorflow/e2e:mobile_bert_squad_tests':
'End to end test of MobileBert on SQuAD',
'//integrations/tensorflow/e2e/keras:keyword_spotting_tests':
f'End to end tests of {KWS_LINK} models',
'//integrations/tensorflow/e2e/keras:keyword_spotting_internal_streaming_tests':
f'End to end tests of {KWS_LINK} models in internal streaming mode',
+ # vision_coverage
'//integrations/tensorflow/e2e/keras:imagenet_non_hermetic_tests':
'End to end tests of tf.keras.applications vision models on Imagenet',
'//integrations/tensorflow/e2e/slim_vision_models:slim_vision_tests':
@@ -108,18 +123,31 @@
}
TEST_SUITES_TO_NOTES = {
+ '//integrations/tensorflow/e2e/math:math_tests':
+ ('**Note:** To be thorough, these tests use high rank tensors and\n'
+ 'test int dtypes where TensorFlow allows them to be used. Both of\n'
+ 'these choices disproportionately affect TFLite coverage, and\n'
+ 'don\'t represent coverage for simple use cases.\n'),
'//integrations/tensorflow/e2e/keras/layers:layers_tests': (
'**Note:** Layers like `Dropout` are listed as passing in this table,\n'
'but they function similar to identity layers in these tests. **See \n'
'the third table for the coverage of these layers during training.**\n'
'\n'
'These tests also only modify required `tf.keras.layers` arguments.\n'
- 'See the full API tests below for coverage on of non-default '
+ 'See the full API tests below for coverage on of non-default\n'
'layer configurations.'),
}
# Key to use as the name of the rows in the left column for each test in the
# suite.
TEST_SUITE_TO_ROW_ID_KEY = {
+ # tf_base_coverage
+ '//integrations/tensorflow/e2e/math:math_tests':
+ 'functions',
+ '//integrations/tensorflow/e2e/math:math_dynamic_dims_tests':
+ 'functions',
+ '//integrations/tensorflow/e2e/math:math_complex_tests':
+ 'functions',
+ # tf_keras_coverage
'//integrations/tensorflow/e2e/keras/layers:layers_tests':
'layer',
'//integrations/tensorflow/e2e/keras/layers:layers_full_api_tests':
@@ -128,10 +156,12 @@
'layer',
'//integrations/tensorflow/e2e/keras/layers:layers_training_tests':
'layer',
+ # language_and_speech_coverage
'//integrations/tensorflow/e2e/keras:keyword_spotting_tests':
'model',
'//integrations/tensorflow/e2e/keras:keyword_spotting_internal_streaming_tests':
'model',
+ # vision_coverage
'//integrations/tensorflow/e2e/keras:imagenet_non_hermetic_tests':
'model',
'//integrations/tensorflow/e2e/slim_vision_models:slim_vision_tests':
@@ -141,6 +171,14 @@
# Some test suites are generated from a single source. This allows us to point
# to the right test file when generating test URLs.
SINGLE_SOURCE_SUITES = {
+ # tf_base_coverage
+ '//integrations/tensorflow/e2e/math:math_tests':
+ 'math_test',
+ '//integrations/tensorflow/e2e/math:math_dynamic_dims_tests':
+ 'math_test',
+ '//integrations/tensorflow/e2e/math:math_complex_tests':
+ 'math_test',
+ # tf_keras_coverage
'//integrations/tensorflow/e2e/keras/layers:layers_tests':
'layers_test',
'//integrations/tensorflow/e2e/keras/layers:layers_full_api_tests':
@@ -149,10 +187,12 @@
'layers_test',
'//integrations/tensorflow/e2e/keras/layers:layers_training_tests':
'layers_test',
+ # language_and_speech_coverage
'//integrations/tensorflow/e2e/keras:keyword_spotting_tests':
'keyword_spotting_streaming_test',
'//integrations/tensorflow/e2e/keras:keyword_spotting_internal_streaming_tests':
'keyword_spotting_streaming_test',
+ # vision_coverage
'//integrations/tensorflow/e2e/keras:imagenet_non_hermetic_tests':
'vision_model_test',
'//integrations/tensorflow/e2e/slim_vision_models:slim_vision_tests':