Centralize tf.math coverage and expand to all functions (#3826)

Adds tests for the functions in the `tf.math` API, with dynamic and complex variants.

To avoid adding another 1332 tests of `bazel` overhead to the CI while still generating coverage tables, I created two versions of the test suites: one that CI will run that detects any new errors, and one that creates a test for each function and backend. The tests that this change adds to the CI are:

```bazel
//integrations/tensorflow/e2e/math:math_complex_tests_multiple__iree_llvmjit PASSED in 26.9s
//integrations/tensorflow/e2e/math:math_complex_tests_multiple__iree_vmla PASSED in 15.3s
//integrations/tensorflow/e2e/math:math_complex_tests_multiple__iree_vulkan PASSED in 72.1s
//integrations/tensorflow/e2e/math:math_complex_tests_multiple__tf       PASSED in 38.1s
//integrations/tensorflow/e2e/math:math_dynamic_dims_tests_multiple__iree_llvmjit PASSED in 7.9s
//integrations/tensorflow/e2e/math:math_dynamic_dims_tests_multiple__iree_vmla PASSED in 13.9s
//integrations/tensorflow/e2e/math:math_dynamic_dims_tests_multiple__iree_vulkan PASSED in 4.5s
//integrations/tensorflow/e2e/math:math_dynamic_dims_tests_multiple__tf  PASSED in 38.0s
//integrations/tensorflow/e2e/math:math_tests_multiple__iree_llvmjit     PASSED in 41.5s
//integrations/tensorflow/e2e/math:math_tests_multiple__iree_vmla        PASSED in 19.7s
//integrations/tensorflow/e2e/math:math_tests_multiple__iree_vulkan      PASSED in 106.6s
//integrations/tensorflow/e2e/math:math_tests_multiple__tf               PASSED in 38.2s
//integrations/tensorflow/e2e/math:math_tests_multiple__tflite           PASSED in 21.1s
```

The following tests are deleted:

- `bool_test.py`
- `complex_test.py`
- `finite_test.py`
- `logical_ops_test.py`

`math_test.py` and `math_dyn_test.py` are replaced by `quantization_test.py` and `quantization_dyn_test.py`, since a fake quant test was added to them after the PR was made.

Supporting changes:

- Allow `tf.Tensor`s to be passed to traces. (Numpy's unchangeable `float64` default is cumbersome).
- Add `set_minus` to `bazel`.
- Expand support for input generation (e.g. uniform bools).
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
index 49eb68a..e18b5ce 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
@@ -26,11 +26,13 @@
 import copy
 import glob
 import inspect
+import itertools
 import os
 import pickle
+import re
 import sys
 import tempfile
-from typing import Any, Callable, Dict, Sequence, Set, Tuple, Type, Union
+from typing import Any, Callable, Dict, List, Sequence, Set, Tuple, Type, Union
 
 from absl import flags
 from absl import logging
@@ -40,8 +42,8 @@
 
 flags.DEFINE_string("reference_backend", "tf",
                     "The backend to treat as a source of truth.")
-flags.DEFINE_string("target_backends", None,
-                    "Explicit comma-delimited list of target backends.")
+flags.DEFINE_list("target_backends", None,
+                  "Explicit comma-delimited list of target backends.")
 flags.DEFINE_string(
     "artifacts_dir", None,
     "Specifies a directory to dump compilation artifacts and traces to. "
@@ -56,6 +58,7 @@
     "Creates and stores a SavedModel for the tf.Module class to be tested.")
 FLAGS = flags.FLAGS
 NUMPY_LINEWIDTH = 120
+DEFAULT_INPUT_GENERATOR = tf_utils.uniform
 
 
 def _setup_artifacts_dir(module_name: str) -> str:
@@ -76,7 +79,7 @@
 
 def _parse_target_backends() -> Tuple[Sequence[str], Sequence[str]]:
   """Decodes --target_backends and creates unique ids for them."""
-  backend_names = FLAGS.target_backends.split(",")
+  backend_names = FLAGS.target_backends
   backend_to_index = {k: 0 for k in backend_names if backend_names.count(k) > 1}
   backend_ids = []
 
@@ -141,11 +144,6 @@
     """Records the details of a call to a CompiledModule."""
     self.method = method
 
-    for value in inputs:
-      if isinstance(value, tf.Tensor):
-        raise TypeError("Expected inputs to be native python types or numpy "
-                        f"arrays, but got {type(value)}")
-
     # Deepcopy to safegard against mutation.
     self.inputs = copy.deepcopy(inputs)
     if outputs is not None:
@@ -574,6 +572,10 @@
       # Only pass these to ModuleCall if they were specified by the user.
       tolerances = {k: v for k, v in tolerances.items() if v is not None}
 
+      # Ensure the inputs are numpy inputs.
+      args = tf_utils.convert_to_numpy(args)
+      kwargs = tf_utils.convert_to_numpy(kwargs)
+
       # Run the method and record the details of the call.
       outputs = method(*args, **kwargs)
       serialized_inputs, serialized_outputs = method.get_serialized_values()
@@ -602,7 +604,7 @@
 
 # We have to use a global variable to store the compiled modules so that we can
 # avoid recompilation. This is because the TestCase class resets it's entire
-# state and calls __init__ before each unittest. It also calls __init__ one
+# state and calls __init__ before each unit_test. It also calls __init__ one
 # additional time before that for good measure, which means without storing the
 # modules somewhere else we would have to compile each of them at least twice.
 # We can't store the modules on the class itself via setUpClass because of #2900
@@ -693,13 +695,229 @@
   return _global_modules
 
 
-def tf_function_unittest(input_generator: tf_utils.InputGeneratorType = None,
-                         input_args: Sequence[Any] = None,
-                         atol: float = None,
-                         rtol: float = None,
-                         name: str = None,
-                         **tf_function_kwargs):
-  """Creates a tf.function that can be used to generate unittests.
+# We use global variables to store the configuration information for
+# tf_function_unit_tests because tensorflow.python.eager.def_function.Function
+# is not an API that we can subclass, and storing the information directly
+# that class results in it being deleted at tf.Module initialization.
+# _global_unit_test_configs is a dict mapping exported_names to dicts containing
+# a get-function for input args and the tolerance kwargs for the trace.
+global _global_unit_test_configs
+_global_unit_test_configs = dict()
+
+
+class UnitTestSpec:
+
+  def __init__(self,
+               unit_test_name: str,
+               input_signature: Sequence[tf.TensorSpec],
+               input_generator: tf_utils.InputGeneratorType = None,
+               input_args: Union[Sequence[Any], None] = None,
+               kwargs: Dict[str, Any] = None):
+    self.unit_test_name = tf_utils.remove_special_characters(unit_test_name)
+    self.input_signature = input_signature
+    self.input_args = input_args
+    self.kwargs = dict() if kwargs is None else kwargs
+    self.input_generator = input_generator
+
+  def with_name(self, new_name: str) -> "UnitTestSpec":
+    return UnitTestSpec(new_name, self.input_signature, self.input_generator,
+                        self.input_args, self.kwargs)
+
+  def __str__(self):
+    return self.unit_test_name
+
+
+def _dictionary_product(dictionary: Dict[Any, Any]) -> List[Dict[Any, Any]]:
+  """Returns a named cartesian product of dictionary's values.
+
+  Converts {'a': [1, 2], 'b': [3, 4]} into
+  [{'a': 1, 'b': 3}, {'a': 1, 'b': 4}, {'a': 2, 'b': 3}, {'a': 2, 'b': 4}]
+  """
+  product = [[]]
+  for values in dictionary.values():
+    # Iteratively grow the elements of the product.
+    product = [element + [value] for element in product for value in values]
+  dicts = [{k: v for k, v in zip(dictionary, element)} for element in product]
+  return dicts
+
+
+def _named_kwargs_product(
+    kwargs_to_values: Dict[str, Sequence[Any]]) -> Dict[str, Dict[str, Any]]:
+  """Splits kwargs_to_values into a Cartesian product of its elements."""
+  # Validate 'kwargs_to_values'
+  if kwargs_to_values is None:
+    kwargs_to_values = dict()  # Use only default kwargs.
+  for kwarg_key, kwarg_values in kwargs_to_values.items():
+    if not isinstance(kwarg_values, Sequence):
+      raise TypeError(f"Expected kwargs_to_values[{repr(kwarg_key)}] to be a "
+                      f"sequence, but got '{type(kwarg_values)}'")
+
+  # Expand across a Cartesian product.
+  kwargs_product = _dictionary_product(kwargs_to_values)
+  # {'a': 1, 'b': 3} -> "a_1__b_3"
+  dict_to_str = lambda d: "__".join([f"{k}_{v}" for k, v in d.items()])
+  return {dict_to_str(kwargs): kwargs for kwargs in kwargs_product}
+
+
+def unit_test_specs_from_signatures(
+    signature_shapes: Sequence[Sequence[Sequence[int]]],
+    signature_dtypes: Sequence[tf.DType] = [tf.float32],
+    input_generators: Union[Sequence[tf_utils.InputGeneratorType],
+                            Dict[str, tf_utils.InputGeneratorType]] = [
+                                DEFAULT_INPUT_GENERATOR
+                            ],
+    kwargs_to_values: Dict[str, Sequence[Any]] = None) -> List[UnitTestSpec]:
+  """Generates a Cartesian product of UnitTestSpecs from the given arguments.
+
+  Args:
+    signature_shapes:
+      A sequence (representing multiple signatures to test) of sequences
+      (representing the shapes of the args in those signatures) of ints
+      (representing the individual sizes of those shapes).
+    signature_dtypes:
+      A sequence of dtypes to test each signature with.
+    input_generators:
+      Either:
+        1. a sequence of input generators to test each of the signature-dtype
+           pairs with
+        2. a dictionary mapping input generator names to input generators to
+           test each of the signature-dtype pairs with. This format must be used
+           if any of the generators are lambda functions.
+    kwargs_to_values:
+      A dict mapping kwarg names to sequences of values that they can take.
+
+  Returns:
+    A list of 'UnitTestSpec's generated from the provided arguments.
+  """
+  # Validate 'signature_shapes'
+  for i, shapes in enumerate(signature_shapes):
+    if not isinstance(shapes, Sequence):
+      raise TypeError(f"Expected signature_shapes[{i}] to be a sequence, but "
+                      f"got '{type(shapes)}'")
+    for j, shape in enumerate(shapes):
+      if not isinstance(shape, Sequence):
+        raise TypeError(f"Expected signature_shapes[{i}][{j}] to be a "
+                        f"sequence, but got '{type(shape)}'")
+      for k, size in enumerate(shape):
+        if not isinstance(size, int):
+          raise TypeError(f"Expected signature_shapes[{i}][{j}][{k}] to be an "
+                          f"int but got '{type(size)}")
+
+  # Parse 'signature_shapes'
+  names_to_shapes = dict()
+  for signature in signature_shapes:
+    # Converts [[1, 2, 3], [4, 5]] into 1x2x3_4x5.
+    signature_key = "_".join(
+        ["x".join(str(size) for size in shape) for shape in signature])
+    names_to_shapes[signature_key] = signature
+
+  # Validate 'signature_dtypes'
+  for i, dtype in enumerate(signature_dtypes):
+    if not isinstance(dtype, tf.DType):
+      raise TypeError(
+          f"Expected dtypes[{i}] to be a tf.DType, but got '{type(dtype)}'")
+
+  # Parse 'signature_dtypes'
+  # 'complex64' -> 'c64'
+  abbreviate = lambda dtype: re.sub(r"([a-z])[a-z]*([0-9]+)", r"\1\2", dtype)
+  names_to_dtypes = {
+      abbreviate(dtype.name): dtype for dtype in signature_dtypes
+  }
+
+  # Validate 'input_generators'
+  if not isinstance(input_generators, (Sequence, Dict)):
+    raise TypeError("Expected 'input_generators' to be a sequence or "
+                    f"dictionary, but got '{type(input_generators)}'")
+  if isinstance(input_generators, Sequence):
+    for i, generator in enumerate(input_generators):
+      if generator.__name__ == "<lambda>":
+        raise TypeError(
+            f"'input_generators' was a sequence but input_generators[{i}] was "
+            "lambda function. 'input_generators' must be a dictionary if "
+            "lambda functions are used.")
+
+  # Parse 'input_generators'
+  if isinstance(input_generators, Sequence):
+    names_to_generators = {gen.__name__: gen for gen in input_generators}
+  else:
+    names_to_generators = input_generators
+
+  # Validate and parse 'kwargs_to_values'
+  names_to_kwargs = _named_kwargs_product(kwargs_to_values)
+
+  # Create a Cartesian product through all specifications and their names.
+  specs = [
+      names_to_shapes, names_to_dtypes, names_to_generators, names_to_kwargs
+  ]
+  # pytype: disable=attribute-error
+  key_product = itertools.product(*[list(spec.keys()) for spec in specs])
+  value_product = itertools.product(*[list(spec.values()) for spec in specs])
+  # pytype: enable=attribute-error
+
+  # Generate a UnitTestSpec for each element in the above product.
+  unit_tests = []
+  for keys, (shapes, dtype, generator, kwargs) in zip(key_product,
+                                                      value_product):
+    unit_test_name = "__".join(key for key in keys if key)
+    input_signature = [tf.TensorSpec(shape, dtype) for shape in shapes]
+    unit_tests.append(
+        UnitTestSpec(
+            unit_test_name=unit_test_name,
+            input_signature=input_signature,
+            input_generator=generator,
+            input_args=None,
+            kwargs=kwargs,
+        ))
+  return unit_tests
+
+
+def unit_test_specs_from_args(
+    names_to_input_args: Dict[str, Sequence[Any]],
+    kwargs_to_values: Dict[str, Sequence[Any]] = None) -> List[UnitTestSpec]:
+  """Generates a Cartesian product of UnitTestSpecs from the given arguments.
+
+  Args:
+    signature_shapes:
+      A dict mapping names for input arguments to the arguments themselves.
+    kwargs_to_values:
+      A dict mapping kwarg names to sequences of values that they can take.
+
+  Returns:
+    A list of 'UnitTestSpec's generated from the provided arguments.
+  """
+  # Validate and parse 'kwargs_to_values'
+  names_to_kwargs = _named_kwargs_product(kwargs_to_values)
+
+  # Create a Cartesian product through all specifications and their names.
+  specs = [names_to_input_args, names_to_kwargs]
+  key_product = itertools.product(*[list(spec.keys()) for spec in specs])
+  value_product = itertools.product(*[list(spec.values()) for spec in specs])
+
+  # Generate a UnitTestSpec for each element in the above product.
+  unit_tests = []
+  for keys, (input_args, kwargs) in zip(key_product, value_product):
+    unit_test_name = "__".join(key for key in keys if key)
+    input_signature = tf_utils.apply_function(
+        input_args,
+        lambda x: tf.TensorSpec.from_tensor(tf.convert_to_tensor(x)))
+    unit_tests.append(
+        UnitTestSpec(
+            unit_test_name=unit_test_name,
+            input_signature=input_signature,
+            input_generator=None,
+            input_args=input_args,
+            kwargs=kwargs,
+        ))
+  return unit_tests
+
+
+def tf_function_unit_test(input_generator: tf_utils.InputGeneratorType = None,
+                          input_args: Sequence[Any] = None,
+                          atol: float = None,
+                          rtol: float = None,
+                          name: str = None,
+                          **tf_function_kwargs):
+  """Creates a tf.function that can be used to generate unit_tests.
 
   If 'input_generator' and 'input_args' are unspecified then the function will
   be tested using random uniform data.
@@ -707,7 +925,7 @@
   Args:
     input_generator:
       an optional callable taking a shape and dtype that returns input data for
-      the unittest.
+      the unit_test.
     input_args:
       an optional sequence of values to pass as positional args to the function.
     atol:
@@ -729,7 +947,7 @@
     __name__ attribute if 'name' was specified.
   """
 
-  def _store_unittest_info(function):
+  def _store_unit_test_info(function):
     # Validate arguments.
     if input_generator is not None and input_args is not None:
       raise ValueError(
@@ -737,58 +955,54 @@
 
     function = tf.function(**tf_function_kwargs)(function)
 
-    # Used to identify that the tf.function was created by this decorator.
-    function.is_tf_function_unittest = True
-
-    # Set function.get_trace_args.
-    if input_generator is not None:
-      # Use the user-specificed input_generator.
-      function.get_trace_args = lambda: tf_utils.generate_inputs(
-          function.input_signature, input_generator)
-    elif input_args is not None:
-      # Use the user-specified input_args.
-      function.get_trace_args = lambda: copy.deepcopy(input_args)
-    else:
-      # No user data specification – default to using random uniform data.
-      function.get_trace_args = lambda: tf_utils.generate_inputs(
-          function.input_signature, tf_utils.uniform)
-
-    # Set function.trace_kwargs.
-    function.trace_kwargs = dict(atol=atol, rtol=rtol)
-
-    # Set function.__name__.
+    # Set function.__name__
     if name is not None:
       function.__name__ = name
     elif function.__name__ == "<lambda>":
       raise ValueError("The 'name' kwarg must be provided when decorating a "
                        "lambda function.")
 
+    global _global_unit_test_configs
+    if function.__name__ not in _global_unit_test_configs:
+
+      if input_generator is not None:
+        # Use the user-specificed input_generator.
+        get_trace_args = lambda: tf_utils.generate_inputs(
+            function.input_signature, input_generator)
+      elif input_args is not None:
+        # Use the user-specified input_args.
+        get_trace_args = lambda: copy.deepcopy(input_args)
+      else:
+        # No user data specification – default to using random uniform data.
+        get_trace_args = lambda: tf_utils.generate_inputs(
+            function.input_signature, DEFAULT_INPUT_GENERATOR)
+
+      _global_unit_test_configs[function.__name__] = dict(
+          get_trace_args=get_trace_args,
+          trace_kwargs=dict(atol=atol, rtol=rtol))
+
     return function
 
-  return _store_unittest_info
+  return _store_unit_test_info
 
 
 class TestModule(tf.Module):
-  """Thin wrapper of tf.Module with helper methods for tf_function_unittests."""
+  """Thin tf.Module wrapper with helper methods for tf_function_unit_tests."""
 
   @classmethod
-  def get_tf_function_unittests(cls):
-    """Get all tf_function_unittest-created tf.functions on the class."""
-    tf_function_unittests = []
-    for name in dir(cls):
-      value = getattr(cls, name)
-      if hasattr(value, 'is_tf_function_unittest'):
-        tf_function_unittests.append(value)
+  def get_tf_function_unit_tests(cls):
+    """Get all tf_function_unit_test-created tf.functions on the class."""
+    # Initialize the module to ensure that _global_unit_test_configs has the
+    # info for all of the unit_tests. (Only doing this if
+    # _global_unit_test_configs is empty wouldn't address the case where some
+    # unit_tests are defined on the class and some are generated by __init__).
+    cls()
 
-    if not len(tf_function_unittests):
+    tf_function_unit_tests = list(_global_unit_test_configs.keys())
+    if not len(tf_function_unit_tests):
       raise ValueError(
-          "'get_tf_function_unittests' was called but no unittests were found.")
-    return tf_function_unittests
-
-  @classmethod
-  def get_exported_names(cls):
-    """Get the names of all tf_function_unittest-created tf.functions"""
-    return [function.__name__ for function in cls.get_tf_function_unittests()]
+          "'get_tf_function_unit_tests' was called but no tests were found.")
+    return tf_function_unit_tests
 
 
 class TracedModuleTestCase(tf.test.TestCase):
@@ -802,35 +1016,38 @@
       module.reinitialize()
 
   @classmethod
-  def generate_unittests(cls, module_class: Type[TestModule]):
-    """Generates unittests for each 'tf_function_unittest' on 'module_class'."""
-    for function in module_class.get_tf_function_unittests():
-      # We have to pass the closure argument 'funcion' to 'trace' via a kwarg
-      # instead of using it directly in the body because 'function' is
-      # overwritten in each iteration of this loop, and python will only use
-      # the most recent version of 'function'. If we didn't do this, then we
-      # would only test the last function in this loop. The same is true for
-      # passing 'trace' to 'unittest'.
+  def generate_unit_tests(cls, module_class: Type[TestModule]):
+    """Generates tests for each 'tf_function_unit_test' on 'module_class'."""
+    for function_name in module_class.get_tf_function_unit_tests():
+      # We have to pass the closure arguments 'function_name', 'get_args' and
+      # 'kwargs' to 'trace' via a kwarg instead of using it directly in the body
+      # because 'function_name' and 'unit_test_config' are overwritten in each
+      # iteration of this loop, and python will only use the most recent version
+      # of each. If we didn't do this, then we would only test the last function
+      # in this loop. The same is true for passing 'trace' to 'unit_test'.
+      unit_test_config = _global_unit_test_configs[function_name]
 
       # Runs the inputs through a (traced) module.
-      def trace(module, function=function):
-        getattr(module, function.__name__)(*function.get_trace_args(),
-                                           **function.trace_kwargs)
+      def trace(module,
+                function_name=function_name,
+                get_args=unit_test_config["get_trace_args"],
+                kwargs=unit_test_config["trace_kwargs"]):
+        getattr(module, function_name)(*get_args(), **kwargs)
 
       # Give the trace the name of the tf.function that it is testing.
-      trace.__name__ = function.__name__
+      trace.__name__ = function_name
 
       # Runs 'trace' on modules compiled to each backend and compares them.
-      def unittest(self, trace=trace):
+      def unit_test(self, trace=trace):
         self.compare_backends(trace, self._modules)
 
-      # Make 'unittest' a function on the TracedModuleTestCase, which tells
+      # Make 'unit_test' a function on the TracedModuleTestCase, which tells
       # the test runner to run it.
-      unittest.__name__ = f"test_{function.__name__}"
-      if hasattr(cls, unittest.__name__):
-        raise ValueError("Tried to generate multiple instances of the unittest "
-                         f"'{unittest.__name__}'.")
-      setattr(cls, unittest.__name__, unittest)
+      unit_test.__name__ = f"test_{function_name}"
+      if hasattr(cls, unit_test.__name__):
+        raise ValueError("Tried to generate multiple instances of the "
+                         f"unit_test '{unit_test.__name__}'.")
+      setattr(cls, unit_test.__name__, unit_test)
 
   def compare_backends(self, trace_function: Callable[[TracedModule], None],
                        modules: Modules) -> None:
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils_test.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils_test.py
index 63cfa2f..6069688 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils_test.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils_test.py
@@ -54,18 +54,18 @@
 
 class TfFunctionUnittestModule(tf_test_utils.TestModule):
 
-  @tf_test_utils.tf_function_unittest(input_signature=[])
+  @tf_test_utils.tf_function_unit_test(input_signature=[])
   def no_args(self):
     return np.array([True], dtype=np.bool)
 
-  @tf_test_utils.tf_function_unittest(input_signature=[
+  @tf_test_utils.tf_function_unit_test(input_signature=[
       tf.TensorSpec([4]),
       tf.TensorSpec([4]),
   ])
   def default_uniform_inputs(self, a, b):
     return a + b
 
-  @tf_test_utils.tf_function_unittest(
+  @tf_test_utils.tf_function_unit_test(
       input_signature=[
           tf.TensorSpec([4]),
           tf.TensorSpec([4]),
@@ -75,7 +75,7 @@
   def custom_input_generator(self, a, b):
     return a + b
 
-  @tf_test_utils.tf_function_unittest(
+  @tf_test_utils.tf_function_unit_test(
       input_signature=[
           tf.TensorSpec([4]),
           tf.TensorSpec([4]),
@@ -89,7 +89,7 @@
     return a + b
 
   # This test will fail if atol is not successfully set.
-  @tf_test_utils.tf_function_unittest(
+  @tf_test_utils.tf_function_unit_test(
       input_signature=[
           tf.TensorSpec([128, 3072], tf.float32),
           tf.TensorSpec([3072, 256], tf.float32),
@@ -262,7 +262,7 @@
         self._modules = tf_test_utils.compile_tf_module(
             TfFunctionUnittestModule)
 
-    TfFunctionUnittestTest.generate_unittests(TfFunctionUnittestModule)
+    TfFunctionUnittestTest.generate_unit_tests(TfFunctionUnittestModule)
     test_case = TfFunctionUnittestTest()
     self.assertTrue(hasattr(test_case, 'test_no_args'))
     self.assertTrue(hasattr(test_case, 'test_default_uniform_inputs'))
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py
index 591df43..aa94744 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py
@@ -45,9 +45,15 @@
             dtype: Union[tf.DType, np.dtype] = np.float32,
             low: float = -1.0,
             high: float = 1.0) -> np.ndarray:
-  """np.random.uniform with simplified API and dtype control."""
+  """np.random.uniform with simplified API and dtype and bool support."""
   dtype = dtype.as_numpy_dtype if isinstance(dtype, tf.DType) else dtype
-  return np.random.uniform(size=shape, low=low, high=high).astype(dtype)
+  if dtype == np.bool:
+    return np.random.choice(2, shape).astype(np.bool)
+  else:
+    values = np.random.uniform(size=shape, low=low, high=high)
+    if np.issubdtype(dtype, np.integer):
+      values = np.round(values)
+    return values.astype(dtype)
 
 
 def ndarange(shape: Sequence[int],
@@ -57,22 +63,35 @@
   return np.arange(np.prod(shape), dtype=dtype).reshape(shape)
 
 
+def random_permutation(
+    shape: Sequence[int],
+    dtype: Union[tf.DType, np.dtype] = np.float32) -> np.ndarray:
+  """Returns a random permutation of [0, np.prod(shape))."""
+  values = ndarange(shape, dtype)
+  np.random.shuffle(values)
+  return values
+
+
+def apply_function(values, function):
+  """Applies 'function' recursively to the inputted values."""
+  if isinstance(values, list):
+    return [apply_function(v, function) for v in values]
+  elif isinstance(values, tuple):
+    return tuple(apply_function(v, function) for v in values)
+  elif isinstance(values, dict):
+    return {k: apply_function(v, function) for k, v in values.items()}
+  else:
+    return function(values)
+
+
 def generate_inputs(
     spec,  # Union[Sequence[tf.TensorSpec], tf.TensorSpec]
     input_generator: InputGeneratorType,
 ) -> Sequence[np.ndarray]:
   """Generates inputs for a given input signature using 'input_generator'."""
-  if isinstance(spec, Sequence):
-    # 'spec' is a sequence of 'tf.TensorSpec'.
-    # Recursively generate inputs.
-    return [generate_inputs(s, input_generator) for s in spec]
-  elif isinstance(spec, tf.TensorSpec):
-    # Handle dynamic shapes (e.g. batches) by substituting an int for None.
-    shape = [size if size is not None else 2 for size in spec.shape]
-    return input_generator(shape, spec.dtype)
-  else:
-    raise TypeError("Expected 'spec' to be a sequence of 'tf.TensorSpec' or "
-                    f"'tf.TensorSpec', but got '{type(spec)}'")
+  make_static = lambda shape: [dim if dim is not None else 2 for dim in shape]
+  generate = lambda spec: input_generator(make_static(spec.shape), spec.dtype)
+  return apply_function(spec, generate)
 
 
 def to_mlir_type(dtype: np.dtype) -> str:
@@ -116,6 +135,63 @@
   return result
 
 
+def remove_special_characters(value: str) -> str:
+  """Replaces special characters with '_' while keeping instances of '__'."""
+  normalized_parts = []
+  for part in value.split("__"):
+    part = re.sub(r"[^a-zA-Z0-9_]", "_", part)  # Remove special characters.
+    part = re.sub(r"_+", "_", part)  # Remove duplicate "_".
+    part = part.strip("_")  # Don't end or start in "_".
+    normalized_parts.append(part)
+  return "__".join(normalized_parts)
+
+
+def is_complex(tensors: Union[Sequence[tf.TensorSpec], tf.TensorSpec]) -> bool:
+  if isinstance(tensors, Sequence):
+    for tensor in tensors:
+      if is_complex(tensor):
+        return True
+    return False
+  else:
+    return tensors.dtype.is_complex  # pytype: disable=attribute-error
+
+
+def _complex_wrapper(function):
+  """Wraps a tf.function to allow compiling functions of complex numbers."""
+
+  def decorator(*args, **kwargs):
+    inputs = []
+    for real, imag in zip(args[::2], args[1::2]):
+      inputs.append(tf.complex(real, imag))
+    result = function(*inputs, **kwargs)
+    # TODO(meadowlark): Support returning complex numbers.
+    return tf.math.real(result) + tf.math.imag(result)
+
+  return decorator
+
+
+def rewrite_complex_signature(function, signature: Sequence[tf.TensorSpec]):
+  """Compatibility layer for testing complex numbers."""
+  if not all([spec.dtype.is_complex for spec in signature]):
+    raise NotImplementedError("Signatures with mixed complex and non-complex "
+                              "tensor specs are not supported.")
+
+  # Rewrite the signature, replacing all complex tensors with pairs of real
+  # and imaginary tensors.
+  real_imag_signature = []
+  for spec in signature:
+    new_dtype = tf.float32 if spec.dtype.size == 8 else tf.float64
+    real_imag_signature.append(tf.TensorSpec(spec.shape, new_dtype))
+    real_imag_signature.append(tf.TensorSpec(spec.shape, new_dtype))
+
+  return _complex_wrapper(function), real_imag_signature
+
+
+def make_dims_dynamic(spec: tf.TensorSpec) -> tf.TensorSpec:
+  """Gives a tf.TensorSpec dynamic dims."""
+  return tf.TensorSpec([None] * len(spec.shape), spec.dtype)
+
+
 def _setup_mlir_crash_reproducer(
     function: Any,  # pytype doesn't support arbitrary Callable[*args, **kwargs]
     artifacts_dir: str,
@@ -557,28 +633,27 @@
   return result
 
 
+def _convert_to_numpy(tensor: Any) -> Any:
+  if not isinstance(tensor, tf.Tensor):
+    return tensor
+  return _normalize_numpy(tensor.numpy())
+
+
+def convert_to_numpy(values: Any) -> Any:
+  """Converts any tf.Tensor in values to numpy."""
+  return apply_function(values, _convert_to_numpy)
+
+
 class _TfFunctionWrapper(_FunctionWrapper):
   """Wraps a TF function, normalizing it to numpy."""
 
   def __init__(self, f: Callable[..., Any]):
     self._f = f
 
-  def _convert_to_numpy(self, tensor: Any) -> Any:
-    if not isinstance(tensor, tf.Tensor):
-      return tensor
-    return _normalize_numpy(tensor.numpy())
-
   def __call__(self, *args, **kwargs):
     # TensorFlow will auto-convert all inbound args.
     results = self._f(*args, **kwargs)
-    # Then unmarshal them to numpy in the same way that the other backends do.
-    # Handle single result (technically ambiguous with return of a tuple,
-    # which is sad).
-    if not isinstance(results, tuple):
-      results = (results,)
-    return tf.nest.map_structure(self._convert_to_numpy,
-                                 *results,
-                                 check_types=False)
+    return convert_to_numpy(results)
 
 
 def _convert_inputs_to_tensors(function):
@@ -738,9 +813,22 @@
   tflite_modules = []
   methods, method_names, instance = _get_concrete_functions(
       module_class, exported_names)
-  for method in methods:
-    converter = tf.lite.TFLiteConverter.from_concrete_functions([method])
-    tflite_modules.append(converter.convert())
+  failed_methods = []
+  for method, method_name in zip(methods, method_names):
+    logging.info("Attempting to convert '%s' to tflite...", method_name)
+    try:
+      converter = tf.lite.TFLiteConverter.from_concrete_functions([method])
+      logging.info("...converted '%s' to tflite.", method_name)
+      tflite_modules.append(converter.convert())
+    except Exception as e:
+      logging.error("Failed to convert '%s' to tflite.", method_name)
+      logging.error("TFLite excpetion: %s", e)
+      failed_methods.append(method_name)
+
+  if failed_methods:
+    raise RuntimeError(
+        f"Failed to convert the following methods to tflite: {failed_methods}")
+
   # Keep variables alive until TFLite has done the conversion; ConcreteFunctions
   # themselves only keep weak references to variables.
   del instance
diff --git a/integrations/tensorflow/e2e/BUILD b/integrations/tensorflow/e2e/BUILD
index 9dfc5d4..2cdaaf2 100644
--- a/integrations/tensorflow/e2e/BUILD
+++ b/integrations/tensorflow/e2e/BUILD
@@ -59,7 +59,6 @@
 # keep sorted
 TFLITE_FAILING = [
     "broadcasting_test.py",
-    "complex_test.py",
     "concat_test.py",
     "dynamic_mlp_relu_test.py",
     "dynamic_mlp_test.py",
@@ -67,12 +66,11 @@
     "einsum_static_test.py",
     "einsum_vector_test.py",
     "fft_test.py",
-    "finite_test.py",
     "gather_test.py",
     "image_resize_test.py",
     "mandelbrot_test.py",
-    "math_dyn_test.py",
     "matrix_ops_dynamic_test.py",
+    "quantization_dyn_test.py",
     "resource_ops_test.py",
     "ring_buffer_test.py",
     "scatter_update_test.py",
@@ -105,10 +103,9 @@
     "fft_test.py",  # TODO(natashaknk): Get this working after kernel is in.
     "fill_test.py",  # TODO(jennik): Get this test working on IREE.
     "linspace_test.py",  # TODO(https://github.com/google/iree/issues/1521)
-    "logical_ops_test.py",
     "mandelbrot_test.py",  # TODO(silvasean): Get this working on IREE.
-    "math_dyn_test.py",
     "matrix_ops_dynamic_test.py",
+    "quantization_dyn_test.py",
     "range_test.py",
     "ring_buffer_test.py",  # TODO(b/148747011)
     "scatter_update_test.py",
@@ -117,7 +114,6 @@
 
 # keep sorted
 VULKAN_FAILING = [
-    "bool_test.py",
     "broadcast_to_test.py",
     "broadcasting_test.py",
     "conv_transpose_test.py",
@@ -129,10 +125,9 @@
     "fft_test.py",  # TODO(natashaknk): Get this working after kernel is in.
     "fill_test.py",  # TODO(jennik): Get this test working on IREE.
     "linspace_test.py",  # TODO(https://github.com/google/iree/issues/1521)
-    "logical_ops_test.py",
     "mandelbrot_test.py",  # TODO(silvasean): Get this working on IREE.
-    "math_dyn_test.py",
     "matrix_ops_dynamic_test.py",
+    "quantization_dyn_test.py",
     "range_test.py",
     "ring_buffer_test.py",  # TODO(b/148747011)
     "scatter_update_test.py",
diff --git a/integrations/tensorflow/e2e/README.md b/integrations/tensorflow/e2e/README.md
index 4b457b5..e806496 100644
--- a/integrations/tensorflow/e2e/README.md
+++ b/integrations/tensorflow/e2e/README.md
@@ -52,17 +52,17 @@
 preferred.
 
 ```shell
-# Run math_test on all backends.
-bazel run //integrations/tensorflow/e2e:math_test_manual
+# Run conv_test on all backends.
+bazel run //integrations/tensorflow/e2e:conv_test_manual
 
-# Run math_test comparing TensorFlow to itself (e.g. to debug randomization).
-bazel run //integrations/tensorflow/e2e:math_test_manual -- --target_backends=tf
+# Run conv_test comparing TensorFlow to itself (e.g. to debug randomization).
+bazel run //integrations/tensorflow/e2e:conv_test_manual -- --target_backends=tf
 
-# Run math_test comparing the VMLA backend and TensorFlow.
-bazel run //integrations/tensorflow/e2e:math_test_manual -- --target_backends=iree_vmla
+# Run conv_test comparing the VMLA backend and TensorFlow.
+bazel run //integrations/tensorflow/e2e:conv_test_manual -- --target_backends=iree_vmla
 
-# Run math_test comparing the VMLA backend to itself multiple times.
-bazel run //integrations/tensorflow/e2e:math_test_manual -- \
+# Run conv_test comparing the VMLA backend to itself multiple times.
+bazel run //integrations/tensorflow/e2e:conv_test_manual -- \
   --reference_backend=iree_vmla --target_backends=iree_vmla,iree_vmla
 ```
 
@@ -72,10 +72,10 @@
 
 ## Writing Tests
 
-There are two ways to write tests – via `tf_test_utils.tf_function_unittest` and
+There are two ways to write tests – via `tf_test_utils.tf_function_unit_test` and
 via test methods on a child of `tf_test_utils.TracedModuleTestCase`.
 
-### Via `tf_test_utils.tf_function_unittest`
+### Via `tf_test_utils.tf_function_unit_test`
 
 This is preferred in the cases where
 
@@ -86,7 +86,7 @@
 
 Tests are specified by writing modules that inherit from
 `tf_test_utils.TestModule` (which is a thin wrapper around `tf.Module`) with
-methods decorated with `@tf_test_utils.tf_function_unittest` (with is a thin
+methods decorated with `@tf_test_utils.tf_function_unit_test` (with is a thin
 wrapper around `tf.function`).
 
 #### Basic example
@@ -101,14 +101,14 @@
   # function. The 'input_signature' is required. If no other arguments are
   # specified then uniform random data is generated from the input signature
   # to numerically test the function.
-  @tf_test_utils.tf_function_unittest(input_signature=[
+  @tf_test_utils.tf_function_unit_test(input_signature=[
       tf.TensorSpec([1, 4, 5, 1], tf.float32),
       tf.TensorSpec([1, 1, 1, 1], tf.float32),
   ])
   def conv2d_1451x1111_valid(self, img, kernel):
     return tf.nn.conv2d(img, kernel, [1, 1, 1, 1], "VALID", name="result")
 
-  @tf_test_utils.tf_function_unittest(input_signature=[
+  @tf_test_utils.tf_function_unit_test(input_signature=[
       tf.TensorSpec([2, 4, 5, 1], tf.float32),
       tf.TensorSpec([1, 1, 1, 1], tf.float32),
   ])
@@ -130,7 +130,7 @@
 ```
 
 Finally, in the `main` function, you need to call
-`.generate_unittests(module_class)` on your `TestCase` to actually generate
+`.generate_unit_tests(module_class)` on your `TestCase` to actually generate
 the unittests that we specified:
 
 ```python
@@ -138,12 +138,12 @@
   del argv  # Unused
   if hasattr(tf, 'enable_v2_behavior'):
     tf.enable_v2_behavior()
-  # Generates unittests for all @tf_test_utils.tf_function_unittest decorated
+  # Generates unittests for all @tf_test_utils.tf_function_unit_test decorated
   # functions on the module class.
   # Note: if you are automatically generating functions to test they need to be
   # specified via a `classmethod` prior to this call _as well_ as via `__init__`
   # to properly handle stateful `tf.function`s.
-  ConvTest.generate_unittests(Conv2dModule)
+  ConvTest.generate_unit_tests(Conv2dModule)
   tf.test.main()
 
 
@@ -154,9 +154,9 @@
 This generates two unittests: `test_conv2d_1451x1111_valid` and
 `test_conv2d_2451x1111_valid`.
 
-#### Configuring `@tf_test_utils.tf_function_unittest`
+#### Configuring `@tf_test_utils.tf_function_unit_test`
 
-By default `@tf_test_utils.tf_function_unittest` uses uniform random input data
+By default `@tf_test_utils.tf_function_unit_test` uses uniform random input data
 to numerically test the function, but you can specify an `input_generator` or
 `input_args` to test data-specific behaviors:
 
diff --git a/integrations/tensorflow/e2e/bool_test.py b/integrations/tensorflow/e2e/bool_test.py
deleted file mode 100644
index df09ecb..0000000
--- a/integrations/tensorflow/e2e/bool_test.py
+++ /dev/null
@@ -1,65 +0,0 @@
-# Lint as: python3
-# Copyright 2020 Google LLC
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#      https://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""Tests for ops in the tf.math module."""
-
-from absl import app
-import numpy as np
-from pyiree.tf.support import tf_test_utils
-import tensorflow.compat.v2 as tf
-
-
-class BooleanModule(tf_test_utils.TestModule):
-
-  @tf_test_utils.tf_function_unittest(input_signature=[])
-  def constant(self):
-    return np.array([True, False, True], dtype=np.bool)
-
-  @tf_test_utils.tf_function_unittest(
-      input_signature=[tf.TensorSpec([4], tf.float32)],
-      input_args=[np.array([0.0, 1.2, 1.5, 3.75], dtype=np.float32)])
-  def greater_than(self, x):
-    return x > 1.0
-
-  @tf_test_utils.tf_function_unittest(
-      input_signature=[
-          tf.TensorSpec([4], tf.bool),
-          tf.TensorSpec([4], tf.bool)
-      ],
-      input_args=[
-          np.array([True, True, False, False], dtype=np.bool),
-          np.array([True, False, False, True], dtype=np.bool)
-      ],
-  )
-  def logical_and(self, x, y):
-    return tf.math.logical_and(x, y)
-
-
-class BooleanTest(tf_test_utils.TracedModuleTestCase):
-
-  def __init__(self, *args, **kwargs):
-    super().__init__(*args, **kwargs)
-    self._modules = tf_test_utils.compile_tf_module(BooleanModule)
-
-
-def main(argv):
-  del argv  # Unused
-  if hasattr(tf, 'enable_v2_behavior'):
-    tf.enable_v2_behavior()
-  BooleanTest.generate_unittests(BooleanModule)
-  tf.test.main()
-
-
-if __name__ == '__main__':
-  app.run(main)
diff --git a/integrations/tensorflow/e2e/complex_test.py b/integrations/tensorflow/e2e/complex_test.py
deleted file mode 100644
index e60f532..0000000
--- a/integrations/tensorflow/e2e/complex_test.py
+++ /dev/null
@@ -1,60 +0,0 @@
-# Copyright 2020 Google LLC
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#      https://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from absl import app
-import numpy as np
-from pyiree.tf.support import tf_test_utils
-import tensorflow.compat.v2 as tf
-
-
-class ComplexModule(tf.Module):
-
-  def __init__(self):
-    pass
-
-  @tf.function(input_signature=[
-      tf.TensorSpec([2], tf.float32),
-      tf.TensorSpec([2], tf.float32)
-  ])
-  def complex_exp(self, real, imag):
-    tensor = tf.complex(real, imag)
-    exp = tf.exp(tensor)
-    return tf.math.real(exp)
-
-
-class ComplexTest(tf_test_utils.TracedModuleTestCase):
-
-  def __init__(self, *args, **kwargs):
-    super().__init__(*args, **kwargs)
-    self._modules = tf_test_utils.compile_tf_module(ComplexModule)
-
-  def test_complex(self):
-
-    def complex_exp(module):
-      real = np.array([2., 3.], dtype=np.float32)
-      imag = np.array([-1., 0.4], dtype=np.float32)
-      module.complex_exp(real, imag)
-
-    self.compare_backends(complex_exp, self._modules)
-
-
-def main(argv):
-  del argv  # Unused
-  if hasattr(tf, 'enable_v2_behavior'):
-    tf.enable_v2_behavior()
-  tf.test.main()
-
-
-if __name__ == '__main__':
-  app.run(main)
diff --git a/integrations/tensorflow/e2e/conv_test.py b/integrations/tensorflow/e2e/conv_test.py
index 11ca8d6..ca0076e 100644
--- a/integrations/tensorflow/e2e/conv_test.py
+++ b/integrations/tensorflow/e2e/conv_test.py
@@ -22,14 +22,14 @@
 
 class Conv2dModule(tf_test_utils.TestModule):
 
-  @tf_test_utils.tf_function_unittest(input_signature=[
+  @tf_test_utils.tf_function_unit_test(input_signature=[
       tf.TensorSpec([1, 4, 5, 1], tf.float32),
       tf.TensorSpec([1, 1, 1, 1], tf.float32),
   ])
   def conv2d_1451x1111_valid(self, img, kernel):
     return tf.nn.conv2d(img, kernel, [1, 1, 1, 1], "VALID", name="result")
 
-  @tf_test_utils.tf_function_unittest(input_signature=[
+  @tf_test_utils.tf_function_unit_test(input_signature=[
       tf.TensorSpec([1, 4, 5, 1], tf.float32),
       tf.TensorSpec([2, 2, 1, 1], tf.float32),
   ])
@@ -40,7 +40,7 @@
                         dilations=[1, 2, 1, 1],
                         name="result")
 
-  @tf_test_utils.tf_function_unittest(input_signature=[
+  @tf_test_utils.tf_function_unit_test(input_signature=[
       tf.TensorSpec([1, 4, 5, 2], tf.float32),
       tf.TensorSpec([2, 2, 2, 3], tf.float32),
   ])
@@ -51,70 +51,70 @@
                         dilations=[1, 2, 1, 1],
                         name="result")
 
-  @tf_test_utils.tf_function_unittest(input_signature=[
+  @tf_test_utils.tf_function_unit_test(input_signature=[
       tf.TensorSpec([2, 4, 5, 1], tf.float32),
       tf.TensorSpec([1, 1, 1, 1], tf.float32),
   ])
   def conv2d_2451x1111_valid(self, img, kernel):
     return tf.nn.conv2d(img, kernel, [1, 1, 1, 1], "VALID", name="result")
 
-  @tf_test_utils.tf_function_unittest(input_signature=[
+  @tf_test_utils.tf_function_unit_test(input_signature=[
       tf.TensorSpec([1, 4, 5, 1], tf.float32),
       tf.TensorSpec([2, 3, 1, 1], tf.float32),
   ])
   def conv2d_1451x2311_valid(self, img, kernel):
     return tf.nn.conv2d(img, kernel, [1, 1, 1, 1], "VALID", name="result")
 
-  @tf_test_utils.tf_function_unittest(input_signature=[
+  @tf_test_utils.tf_function_unit_test(input_signature=[
       tf.TensorSpec([1, 4, 5, 1], tf.float32),
       tf.TensorSpec([2, 3, 1, 1], tf.float32),
   ])
   def conv2d_1451x2311_same(self, img, kernel):
     return tf.nn.conv2d(img, kernel, [1, 1, 1, 1], "SAME", name="result")
 
-  @tf_test_utils.tf_function_unittest(input_signature=[
+  @tf_test_utils.tf_function_unit_test(input_signature=[
       tf.TensorSpec([2, 4, 5, 1], tf.float32),
       tf.TensorSpec([2, 3, 1, 1], tf.float32),
   ])
   def conv2d_2451x2311_same(self, img, kernel):
     return tf.nn.conv2d(img, kernel, [1, 1, 1, 1], "SAME", name="result")
 
-  @tf_test_utils.tf_function_unittest(input_signature=[
+  @tf_test_utils.tf_function_unit_test(input_signature=[
       tf.TensorSpec([1, 4, 5, 2], tf.float32),
       tf.TensorSpec([3, 2, 2, 1], tf.float32),
   ])
   def conv2d_1452x3221_same(self, img, kernel):
     return tf.nn.conv2d(img, kernel, [1, 1, 1, 1], "SAME", name="result")
 
-  @tf_test_utils.tf_function_unittest(input_signature=[
+  @tf_test_utils.tf_function_unit_test(input_signature=[
       tf.TensorSpec([1, 4, 5, 1], tf.float32),
       tf.TensorSpec([1, 1, 1, 2], tf.float32),
   ])
   def conv2d_1451x1112_same(self, img, kernel):
     return tf.nn.conv2d(img, kernel, [1, 1, 1, 1], "SAME", name="result")
 
-  @tf_test_utils.tf_function_unittest(input_signature=[
+  @tf_test_utils.tf_function_unit_test(input_signature=[
       tf.TensorSpec([1, 4, 5, 2], tf.float32),
       tf.TensorSpec([1, 1, 2, 2], tf.float32),
   ])
   def conv2d_1452x1122_same(self, img, kernel):
     return tf.nn.conv2d(img, kernel, [1, 1, 1, 1], "SAME", name="result")
 
-  @tf_test_utils.tf_function_unittest(input_signature=[
+  @tf_test_utils.tf_function_unit_test(input_signature=[
       tf.TensorSpec([1, 4, 5, 2], tf.float32),
       tf.TensorSpec([2, 2, 2, 3], tf.float32),
   ])
   def conv2d_1452x2223_same(self, img, kernel):
     return tf.nn.conv2d(img, kernel, [1, 1, 1, 1], "SAME", name="result")
 
-  @tf_test_utils.tf_function_unittest(input_signature=[
+  @tf_test_utils.tf_function_unit_test(input_signature=[
       tf.TensorSpec([1, 4, 5, 2], tf.float32),
       tf.TensorSpec([2, 2, 2, 3], tf.float32),
   ])
   def conv2d_1452x2223_valid(self, img, kernel):
     return tf.nn.conv2d(img, kernel, [1, 1, 1, 1], "VALID", name="result")
 
-  @tf_test_utils.tf_function_unittest(input_signature=[
+  @tf_test_utils.tf_function_unit_test(input_signature=[
       tf.TensorSpec([2, 4, 5, 2], tf.float32),
       tf.TensorSpec([2, 2, 2, 3], tf.float32),
   ])
@@ -133,7 +133,7 @@
   del argv  # Unused
   if hasattr(tf, 'enable_v2_behavior'):
     tf.enable_v2_behavior()
-  ConvTest.generate_unittests(Conv2dModule)
+  ConvTest.generate_unit_tests(Conv2dModule)
   tf.test.main()
 
 
diff --git a/integrations/tensorflow/e2e/finite_test.py b/integrations/tensorflow/e2e/finite_test.py
deleted file mode 100644
index 761cae4..0000000
--- a/integrations/tensorflow/e2e/finite_test.py
+++ /dev/null
@@ -1,50 +0,0 @@
-# Copyright 2020 Google LLC
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#      https://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from absl import app
-import numpy as np
-from pyiree.tf.support import tf_test_utils
-import tensorflow.compat.v2 as tf
-
-
-class FiniteModule(tf.Module):
-
-  @tf.function(input_signature=[tf.TensorSpec([4], tf.float32)])
-  def finite(self, x):
-    return tf.math.is_finite(x)
-
-
-class FiniteTest(tf_test_utils.TracedModuleTestCase):
-
-  def __init__(self, *args, **kwargs):
-    super().__init__(*args, **kwargs)
-    self._modules = tf_test_utils.compile_tf_module(FiniteModule)
-
-  def test_finite(self):
-
-    def finite(module):
-      module.finite(np.array([0.0, 1.2, -5.0, np.inf], dtype=np.float32))
-
-    self.compare_backends(finite, self._modules)
-
-
-def main(argv):
-  del argv  # Unused
-  if hasattr(tf, 'enable_v2_behavior'):
-    tf.enable_v2_behavior()
-  tf.test.main()
-
-
-if __name__ == '__main__':
-  app.run(main)
diff --git a/integrations/tensorflow/e2e/iree_e2e_test_suite.bzl b/integrations/tensorflow/e2e/iree_e2e_test_suite.bzl
index 58be0f6..661cc30 100644
--- a/integrations/tensorflow/e2e/iree_e2e_test_suite.bzl
+++ b/integrations/tensorflow/e2e/iree_e2e_test_suite.bzl
@@ -25,6 +25,13 @@
         driver = ""
     return driver
 
+def set_difference(include, exclude):
+    return [
+        value
+        for value in include
+        if value not in exclude
+    ]
+
 def iree_e2e_test_suite(
         name,
         backends_to_srcs,
diff --git a/integrations/tensorflow/e2e/keras/layers/BUILD b/integrations/tensorflow/e2e/keras/layers/BUILD
index 148c808..75dbeda 100644
--- a/integrations/tensorflow/e2e/keras/layers/BUILD
+++ b/integrations/tensorflow/e2e/keras/layers/BUILD
@@ -287,7 +287,7 @@
             "LSTM",
             "MaxPool1D",
             "MaxPool3D",
-            "SeparableConv1D",
+            "SeparableConv1D",  # Failing on Kokoro.
             "SimpleRNN",
         ],
         "target_backends": "tflite",
diff --git a/integrations/tensorflow/e2e/keras/layers/layers_test.py b/integrations/tensorflow/e2e/keras/layers/layers_test.py
index 0ab5f32..4755450 100644
--- a/integrations/tensorflow/e2e/keras/layers/layers_test.py
+++ b/integrations/tensorflow/e2e/keras/layers/layers_test.py
@@ -567,8 +567,8 @@
   return tf.keras.Model(inputs, outputs)
 
 
-def create_tf_function_unittest(config: Config, exported_name: str,
-                                model: tf.keras.Model) -> tf.function:
+def create_tf_function_unit_test(config: Config, exported_name: str,
+                                 model: tf.keras.Model) -> tf.function:
   """Wrap the model's __call__ function in a tf.function for testing."""
   input_shapes = config.shapes
   if FLAGS.dynamic_batch:
@@ -579,19 +579,19 @@
     input_signature = [input_signature]
 
   call = lambda *args: model(keras_arg_wrapper(*args), training=FLAGS.training)
-  return tf_test_utils.tf_function_unittest(input_signature=input_signature,
-                                            name=exported_name)(call)
+  return tf_test_utils.tf_function_unit_test(input_signature=input_signature,
+                                             name=exported_name)(call)
 
 
 class KerasLayersModule(tf_test_utils.TestModule):
 
   @classmethod
   def configure_class(cls):
-    """Configure each tf_function_unittest and define it on the cls."""
+    """Configure each tf_function_unit_test and define it on the cls."""
     for i, (exported_name, config) in enumerate(get_configs().items()):
       model = create_wrapped_keras_layer(config)
       setattr(cls, exported_name,
-              create_tf_function_unittest(config, exported_name, model))
+              create_tf_function_unit_test(config, exported_name, model))
 
   def __init__(self):
     super().__init__()
@@ -600,7 +600,7 @@
       model = create_wrapped_keras_layer(config)
       self.models.append(model)
       setattr(self, exported_name,
-              create_tf_function_unittest(config, exported_name, model))
+              create_tf_function_unit_test(config, exported_name, model))
 
 
 class KerasLayersTest(tf_test_utils.TracedModuleTestCase):
@@ -609,7 +609,7 @@
     super().__init__(*args, **kwargs)
     self._modules = tf_test_utils.compile_tf_module(
         KerasLayersModule,
-        exported_names=KerasLayersModule.get_exported_names())
+        exported_names=KerasLayersModule.get_tf_function_unit_tests())
 
 
 def main(argv):
@@ -638,7 +638,7 @@
   # to test to the KerasLayersModule, and then generate unittests for each of
   # them.
   KerasLayersModule.configure_class()
-  KerasLayersTest.generate_unittests(KerasLayersModule)
+  KerasLayersTest.generate_unit_tests(KerasLayersModule)
   tf.test.main()
 
 
diff --git a/integrations/tensorflow/e2e/logical_ops_test.py b/integrations/tensorflow/e2e/logical_ops_test.py
deleted file mode 100644
index cb25db2..0000000
--- a/integrations/tensorflow/e2e/logical_ops_test.py
+++ /dev/null
@@ -1,93 +0,0 @@
-# Copyright 2020 Google LLC
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#      https://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""Tests for ops in the tf.math module that specifically handle logical ops."""
-
-from absl import app
-import numpy as np
-from pyiree.tf.support import tf_test_utils
-import tensorflow.compat.v2 as tf
-
-
-class LogicalOpsModule(tf.Module):
-
-  @tf.function(input_signature=[
-      tf.TensorSpec([4], tf.bool),
-      tf.TensorSpec([4], tf.bool)
-  ])
-  def logical_and(self, x, y):
-    return tf.math.logical_and(x, y)
-
-  @tf.function(input_signature=[
-      tf.TensorSpec([4], tf.bool),
-      tf.TensorSpec([4], tf.bool)
-  ])
-  def logical_or(self, x, y):
-    return tf.math.logical_or(x, y)
-
-  @tf.function(input_signature=[
-      tf.TensorSpec([4], tf.bool),
-      tf.TensorSpec([4], tf.bool)
-  ])
-  def logical_xor(self, x, y):
-    return tf.math.logical_xor(x, y)
-
-  @tf.function(input_signature=[tf.TensorSpec([4], tf.bool)])
-  def logical_not(self, x):
-    return tf.math.logical_not(x)
-
-
-class LogicalOpsTest(tf_test_utils.TracedModuleTestCase):
-
-  def __init__(self, *args, **kwargs):
-    super().__init__(*args, **kwargs)
-    self._modules = tf_test_utils.compile_tf_module(LogicalOpsModule)
-
-  # yapf: disable
-  def test_logical_and(self):
-    def logical_and(module):
-      module.logical_and(
-          np.array([1, 1, 0, 0], dtype=np.bool),
-          np.array([0, 1, 1, 0], dtype=np.bool))
-    self.compare_backends(logical_and, self._modules)
-
-  def test_logical_or(self):
-    def logical_or(module):
-      module.logical_or(
-          np.array([1, 1, 0, 0], dtype=np.bool),
-          np.array([0, 1, 1, 0], dtype=np.bool))
-    self.compare_backends(logical_or, self._modules)
-
-  def test_logical_xor(self):
-    def logical_xor(module):
-      module.logical_xor(
-          np.array([1, 1, 0, 0], dtype=np.bool),
-          np.array([0, 1, 1, 0], dtype=np.bool))
-    self.compare_backends(logical_xor, self._modules)
-
-  def test_logical_not(self):
-    def logical_not(module):
-      module.logical_not(np.array([0, 1, 1, 0], dtype=np.bool))
-    self.compare_backends(logical_not, self._modules)
-  # yapf: enable
-
-
-def main(argv):
-  del argv  # Unused
-  if hasattr(tf, 'enable_v2_behavior'):
-    tf.enable_v2_behavior()
-  tf.test.main()
-
-
-if __name__ == '__main__':
-  app.run(main)
diff --git a/integrations/tensorflow/e2e/math/BUILD b/integrations/tensorflow/e2e/math/BUILD
new file mode 100644
index 0000000..242a9c9
--- /dev/null
+++ b/integrations/tensorflow/e2e/math/BUILD
@@ -0,0 +1,971 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Test coverage across backends for e2e tests is defined directly in the BUILD
+# files. A coverage table generated from this file can be viewed here:
+#   https://google.github.io/iree/tf-e2e-coverage
+# Updates made to test suite names should also be reflected here:
+#   https://github.com/google/iree/blob/main/scripts/update_e2e_coverage.py
+
+load(
+    "//bindings/python:build_defs.oss.bzl",
+    "INTREE_TENSORFLOW_PY_DEPS",
+    "NUMPY_DEPS",
+    "iree_py_binary",
+    "iree_py_test",
+)
+load(
+    "//integrations/tensorflow/e2e:iree_e2e_test_suite.bzl",
+    "set_difference",
+)
+load(
+    "//integrations/tensorflow/e2e:iree_e2e_cartesian_product_test_suite.bzl",
+    "iree_e2e_cartesian_product_test_suite",
+)
+
+package(
+    default_visibility = ["//visibility:public"],
+    features = ["layering_check"],
+    licenses = ["notice"],  # Apache 2.0
+)
+
+[
+    iree_py_binary(
+        name = src.replace(".py", "_manual"),
+        srcs = [src],
+        main = src,
+        python_version = "PY3",
+        deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
+            "//integrations/tensorflow/bindings/python/pyiree/tf/support",
+        ],
+    )
+    for src in glob(
+        ["*_test.py"],
+        exclude = ["keyword_spotting_streaming_test.py"],
+    )
+]
+
+# These functions were selected using all of the funcions in the tf.math docs:
+#   https://www.tensorflow.org/api_docs/python/tf/math
+TF_MATH_FUNCTIONS = [
+    "abs",
+    "accumulate_n",
+    "acos",
+    "acosh",
+    "add",
+    "add_n",
+    "angle",
+    "argmax",
+    "argmin",
+    "asin",
+    "asinh",
+    "atan",
+    "atan2",
+    "atanh",
+    "bessel_i0",
+    "bessel_i0e",
+    "bessel_i1",
+    "bessel_i1e",
+    "betainc",
+    "bincount",
+    "ceil",
+    "confusion_matrix",
+    "cos",
+    "cosh",
+    "count_nonzero",
+    "cumprod",
+    "cumsum",
+    "cumulative_logsumexp",
+    "digamma",
+    "divide",
+    "divide_no_nan",
+    "equal",
+    "erf",
+    "erfc",
+    "erfinv",
+    "exp",
+    "expm1",
+    "floor",
+    "floordiv",
+    "floormod",
+    "greater",
+    "greater_equal",
+    "igamma",
+    "igammac",
+    "imag",
+    "in_top_k",
+    "invert_permutation",
+    "is_finite",
+    "is_inf",
+    "is_nan",
+    "is_non_decreasing",
+    "is_strictly_increasing",
+    "lbeta",
+    "less",
+    "less_equal",
+    "lgamma",
+    "log",
+    "log1p",
+    "log_sigmoid",
+    "log_softmax",
+    "logical_and",
+    "logical_not",
+    "logical_or",
+    "logical_xor",
+    "maximum",
+    "minimum",
+    "mod",
+    "multiply",
+    "multiply_no_nan",
+    "ndtri",
+    "negative",
+    "nextafter",
+    "not_equal",
+    "polygamma",
+    "polyval",
+    "pow",
+    "real",
+    "reciprocal",
+    "reciprocal_no_nan",
+    "reduce_all",
+    "reduce_any",
+    "reduce_euclidean_norm",
+    "reduce_logsumexp",
+    "reduce_max",
+    "reduce_mean",
+    "reduce_min",
+    "reduce_prod",
+    "reduce_std",
+    "reduce_sum",
+    "reduce_variance",
+    "rint",
+    "round",
+    "rsqrt",
+    "scalar_mul",
+    "segment_max",
+    "segment_mean",
+    "segment_min",
+    "segment_prod",
+    "segment_sum",
+    "sigmoid",
+    "sign",
+    "sin",
+    "sinh",
+    "sobol_sample",
+    "softmax",
+    "softplus",
+    "softsign",
+    "sqrt",
+    "square",
+    "squared_difference",
+    "subtract",
+    "tan",
+    "tanh",
+    # "top_k",  # TODO(meadowlark): Enable once list outputs are supported.
+    "truediv",
+    "unsorted_segment_max",
+    "unsorted_segment_mean",
+    "unsorted_segment_min",
+    "unsorted_segment_prod",
+    "unsorted_segment_sqrt_n",
+    "unsorted_segment_sum",
+    "xdivy",
+    "xlog1py",
+    "xlogy",
+    "zero_fraction",
+    "zeta",
+]
+
+# keep sorted
+TFLITE_FAILING = [
+    "abs",  # Failing for integer inputs.
+    "acos",
+    "acosh",
+    "asin",
+    "asinh",
+    "atan",
+    "atan2",
+    "atanh",
+    "bessel_i0",
+    "bessel_i0e",
+    "bessel_i1",
+    "bessel_i1e",
+    "betainc",
+    "bincount",
+    "confusion_matrix",
+    "conj",
+    "cosh",
+    "cumprod",
+    "cumulative_logsumexp",
+    "digamma",
+    "divide",  # Failing for integer inputs.
+    "erf",
+    "erfc",
+    "erfinv",
+    "expm1",
+    "igamma",
+    "igammac",
+    "in_top_k",
+    "invert_permutation",
+    "is_finite",
+    "is_non_decreasing",
+    "is_strictly_increasing",
+    "l2_normalize",
+    "lbeta",
+    "lgamma",
+    "log1p",
+    "log_sigmoid",
+    "ndtri",
+    "nextafter",
+    "polygamma",
+    "polyval",
+    "pow",  # Failing for integer inputs.
+    "reduce_all",
+    "reduce_euclidean_norm",
+    "reduce_logsumexp",
+    "reduce_mean",
+    "reduce_std",
+    "reduce_variance",
+    "rint",
+    "segment_max",
+    "segment_mean",
+    "segment_min",
+    "segment_prod",
+    "sign",
+    "sinh",
+    "sobol_sample",
+    "softmax",
+    "softplus",
+    "softsign",
+    "tan",
+    "unsorted_segment_max",
+    "unsorted_segment_mean",
+    "unsorted_segment_min",
+    "unsorted_segment_prod",
+    "unsorted_segment_sqrt_n",
+    "unsorted_segment_sum",
+    "xdivy",
+    "xlog1py",
+    "xlogy",
+    "zeta",
+]
+
+# Note: The VMLA_FAILING_DYNAMIC specification extends this list. Newly-passing
+# functions removed from this list may need to be added to VMLA_FAILING_DYNAMIC.
+# keep sorted
+VMLA_FAILING = [
+    "acosh",
+    "argmax",
+    "argmin",
+    "asin",
+    "asinh",
+    "atan2",
+    "atanh",
+    "bessel_i0",
+    "bessel_i0e",
+    "bessel_i1",
+    "bessel_i1e",
+    "betainc",
+    "bincount",
+    "confusion_matrix",
+    "cosh",
+    "count_nonzero",
+    "cumprod",
+    "cumulative_logsumexp",
+    "digamma",
+    "divide",  # Failing for integer inputs because iree doesn't output 'f64'.
+    "erf",
+    "erfc",
+    "erfinv",
+    "expm1",
+    "igamma",
+    "igammac",
+    "in_top_k",
+    "invert_permutation",
+    "is_nan",
+    "is_non_decreasing",
+    "is_strictly_increasing",
+    "ndtri",
+    "nextafter",
+    "polygamma",
+    "pow",  # Failing for integer inputs.
+    "reduce_all",
+    "reduce_any",
+    "reduce_euclidean_norm",
+    "reduce_prod",
+    "rint",
+    "segment_max",
+    "segment_mean",
+    "segment_min",
+    "segment_prod",
+    "segment_sum",
+    "sign",
+    "sobol_sample",
+    "softsign",
+    "unsorted_segment_max",
+    "unsorted_segment_mean",
+    "unsorted_segment_min",
+    "unsorted_segment_prod",
+    "unsorted_segment_sqrt_n",
+    "unsorted_segment_sum",
+    "xdivy",
+    "xlog1py",
+    "xlogy",
+    "zeta",
+]
+
+# Note: The LLVM_FAILING_DYNAMIC specification extends this list. Newly-passing
+# functions removed from this list may need to be added to LLVM_FAILING_DYNAMIC.
+# keep sorted
+LLVM_FAILING = [
+    "acos",
+    "acosh",
+    "argmax",
+    "argmin",
+    "asin",
+    "asinh",
+    "atan",
+    "atan2",
+    "atanh",
+    "bessel_i0",
+    "bessel_i0e",
+    "bessel_i1",
+    "bessel_i1e",
+    "betainc",
+    "bincount",
+    "confusion_matrix",
+    "cosh",
+    "count_nonzero",
+    "cumprod",
+    "cumulative_logsumexp",
+    "digamma",
+    "divide",  # Failing for integer inputs because iree doesn't output 'f64'.
+    "erf",
+    "erfc",
+    "erfinv",
+    "expm1",
+    "igamma",
+    "igammac",
+    "in_top_k",
+    "invert_permutation",
+    "is_nan",
+    "is_non_decreasing",
+    "is_strictly_increasing",
+    "l2_normalize",
+    "logical_or",
+    "logical_xor",
+    "ndtri",
+    "nextafter",
+    "polygamma",
+    "pow",
+    "reduce_all",
+    "reduce_any",
+    "reduce_euclidean_norm",
+    "reduce_logsumexp",
+    "reduce_max",
+    "reduce_mean",
+    "reduce_min",
+    "reduce_prod",
+    "reduce_std",
+    "reduce_sum",
+    "reduce_variance",
+    "rint",
+    "segment_max",
+    "segment_mean",
+    "segment_min",
+    "segment_prod",
+    "segment_sum",
+    "sign",
+    "sobol_sample",
+    "softsign",
+    "unsorted_segment_max",
+    "unsorted_segment_mean",
+    "unsorted_segment_min",
+    "unsorted_segment_prod",
+    "unsorted_segment_sqrt_n",
+    "unsorted_segment_sum",
+    "xdivy",
+    "xlog1py",
+    "xlogy",
+    "zeta",
+]
+
+# Note: The VULKAN_FAILING_DYNAMIC specification extends this list.
+# Newly-passing functions removed from this list may need to be added to
+# VULKAN_FAILING_DYNAMIC.
+# keep sorted
+VULKAN_FAILING = [
+    "acos",
+    "acosh",
+    "argmax",
+    "argmin",
+    "asin",
+    "asinh",
+    "atan",
+    "atan2",
+    "atanh",
+    "bessel_i0",
+    "bessel_i0e",
+    "bessel_i1",
+    "bessel_i1e",
+    "betainc",
+    "bincount",
+    "confusion_matrix",
+    "cosh",
+    "count_nonzero",
+    "cumprod",
+    "cumsum",
+    "cumulative_logsumexp",
+    "digamma",
+    "divide",  # Failing for integer inputs because iree doesn't output 'f64'.
+    "erf",
+    "erfc",
+    "erfinv",
+    "expm1",
+    "igamma",
+    "igammac",
+    "in_top_k",
+    "invert_permutation",
+    "is_nan",
+    "is_non_decreasing",
+    "is_strictly_increasing",
+    "l2_normalize",
+    "logical_and",
+    "logical_not",
+    "logical_or",
+    "logical_xor",
+    "ndtri",
+    "nextafter",
+    "polygamma",
+    "pow",
+    "reduce_all",
+    "reduce_any",
+    "reduce_euclidean_norm",
+    "reduce_logsumexp",
+    "reduce_max",
+    "reduce_mean",
+    "reduce_min",
+    "reduce_prod",
+    "reduce_std",
+    "reduce_sum",
+    "reduce_variance",
+    "rint",
+    "segment_max",
+    "segment_mean",
+    "segment_min",
+    "segment_prod",
+    "segment_sum",
+    "sign",
+    "sobol_sample",
+    "softsign",
+    "unsorted_segment_max",
+    "unsorted_segment_mean",
+    "unsorted_segment_min",
+    "unsorted_segment_prod",
+    "unsorted_segment_sqrt_n",
+    "unsorted_segment_sum",
+    "xdivy",
+    "xlog1py",
+    "xlogy",
+    "zeta",
+]
+
+# ---- INDIVIDUAL STATIC TESTS ----------------------------------------------- #
+
+# These tests allow us to generate coverage tables and give a finer-grained view
+# of the coverage, but are very slow due to bazel overhead, so they are not
+# run on the internal or OSS CI.
+iree_e2e_cartesian_product_test_suite(
+    name = "math_tests",
+    srcs = ["math_test.py"],
+    failing_configurations = [
+        {
+            # Failing on TFLite.
+            "functions": TFLITE_FAILING,
+            "target_backends": "tflite",
+        },
+        {
+            # Failing on vmla.
+            "functions": VMLA_FAILING,
+            "target_backends": "iree_vmla",
+        },
+        {
+            # Failing on llvm.
+            "functions": LLVM_FAILING,
+            "target_backends": "iree_llvmjit",
+        },
+        {
+            # Failing on vulkan.
+            "functions": VULKAN_FAILING,
+            "target_backends": "iree_vulkan",
+        },
+    ],
+    flags_to_values = {
+        "reference_backend": "tf",
+        "functions": TF_MATH_FUNCTIONS,
+        "dynamic_dims": False,
+        "test_complex": False,
+        "target_backends": [
+            "tf",
+            "tflite",
+            "iree_vmla",
+            "iree_llvmjit",
+            "iree_vulkan",
+        ],
+    },
+    main = "math_test.py",
+    tags = [
+        "manual",
+        "nokokoro",
+        "notap",
+    ],
+    deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
+        "//integrations/tensorflow/bindings/python/pyiree/tf/support",
+    ],
+)
+
+# ---- MULTIPLE STATIC TESTS ------------------------------------------------ #
+
+# These tests compile all functions in tf.math at once for testing so that
+# we can run them on the CI with 5 additional targets instead of 640.
+
+# TODO(#3810) 'multiply' outputs all zeros when compiled with other functions.
+VMLA_FAILING_MULTIPLE = VMLA_FAILING + ["multiply"]
+
+# TODO(#3810) Including 'square' causes error: Recieved signal 11.
+LLVM_FAILING_MULTIPLE = LLVM_FAILING + ["square"]
+
+# TODO(#3810) Including 'square' causes error: Recieved signal 11.
+VULKAN_FAILING_MULTIPLE = VULKAN_FAILING + ["square"]
+
+[
+    iree_py_test(
+        name = "math_tests_multiple__{}".format(target_backend),
+        srcs = ["math_test.py"],
+        args = [
+            "--reference_backend=tf",
+            "--target_backends={}".format(target_backend),
+            "--functions={}".format(",".join(functions)),
+            "--dynamic_dims=False",
+        ],
+        main = "math_test.py",
+        deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
+            "//integrations/tensorflow/bindings/python/pyiree/tf/support",
+        ],
+    )
+    for target_backend, functions in dict(
+        iree_llvmjit = set_difference(TF_MATH_FUNCTIONS, LLVM_FAILING_MULTIPLE),
+        iree_vmla = set_difference(TF_MATH_FUNCTIONS, VMLA_FAILING_MULTIPLE),
+        iree_vulkan = set_difference(TF_MATH_FUNCTIONS, VULKAN_FAILING_MULTIPLE),
+        tf = TF_MATH_FUNCTIONS,
+        tflite = set_difference(TF_MATH_FUNCTIONS, TFLITE_FAILING),
+    ).items()
+]
+
+# ---- INDIVIDUAL DYNAMIC TESTS ---------------------------------------------- #
+
+# keep sorted
+VMLA_FAILING_DYNAMIC = VMLA_FAILING + [
+    "angle",
+    "cumsum",
+    "divide_no_nan",
+    "equal",
+    "floormod",
+    "imag",
+    "lbeta",
+    "lgamma",
+    "log_sigmoid",
+    "log1p",
+    "logical_and",
+    "logical_not",
+    "logical_or",
+    "logical_xor",
+    "mod",
+    "floordiv",
+    "multiply_no_nan",
+    "round",
+    "not_equal",
+    "reciprocal_no_nan",
+    "reduce_logsumexp",
+    "reduce_max",
+    "reduce_min",
+    "reduce_sum",
+    "reduce_mean",
+    "reduce_std",
+    "reduce_variance",
+    "softplus",
+    "zero_fraction",
+]
+
+# keep sorted
+LLVM_FAILING_DYNAMIC = LLVM_FAILING + [
+    "accumulate_n",
+    "add",
+    "add_n",
+    "angle",
+    "cumsum",
+    "divide",
+    "divide_no_nan",
+    "equal",
+    "floordiv",
+    "floormod",
+    "greater",
+    "greater_equal",
+    "is_finite",
+    "is_inf",
+    "lbeta",
+    "less",
+    "less_equal",
+    "lgamma",
+    "log_sigmoid",
+    "log_softmax",
+    "log1p",
+    "logical_and",
+    "logical_not",
+    "maximum",
+    "minimum",
+    "mod",
+    "multiply",
+    "multiply_no_nan",
+    "not_equal",
+    "polyval",
+    "reciprocal",
+    "reciprocal_no_nan",
+    "reduce_mean",
+    "scalar_mul",
+    "sigmoid",
+    "sinh",
+    "softmax",
+    "softplus",
+    "square",
+    "squared_difference",
+    "subtract",
+    "round",
+    "tan",
+    "truediv",
+    "zero_fraction",
+]
+
+# keep sorted
+VULKAN_FAILING_DYNAMIC = VULKAN_FAILING + [
+    "abs",
+    "accumulate_n",
+    "add",
+    "add_n",
+    "angle",
+    "ceil",
+    "cos",
+    "divide",
+    "divide_no_nan",
+    "equal",
+    "exp",
+    "floor",
+    "floordiv",
+    "floormod",
+    "greater",
+    "greater_equal",
+    "imag",
+    "is_finite",
+    "is_inf",
+    "lbeta",
+    "less",
+    "round",
+    "less_equal",
+    "lgamma",
+    "log",
+    "log_sigmoid",
+    "log_softmax",
+    "log1p",
+    "maximum",
+    "minimum",
+    "mod",
+    "multiply",
+    "multiply_no_nan",
+    "negative",
+    "not_equal",
+    "polyval",
+    "reciprocal",
+    "reciprocal_no_nan",
+    "reduce_max",
+    "reduce_mean",
+    "reduce_min",
+    "reduce_sum",
+    "rsqrt",
+    "scalar_mul",
+    "sigmoid",
+    "sin",
+    "sinh",
+    "softmax",
+    "softplus",
+    "sqrt",
+    "square",
+    "squared_difference",
+    "subtract",
+    "tan",
+    "tanh",
+    "truediv",
+    "zero_fraction",
+]
+
+# These tests allow us to generate coverage tables and give a finer-grained view
+# of the coverage, but are very slow due to bazel overhead, so they are not
+# run on the internal or OSS CI.
+iree_e2e_cartesian_product_test_suite(
+    name = "math_dynamic_dims_tests",
+    srcs = ["math_test.py"],
+    failing_configurations = [
+        {
+            # TFLite does not support dynamic shapes.
+            "target_backends": "tflite",
+        },
+        {
+            # Failing on vmla.
+            "functions": VMLA_FAILING_DYNAMIC,
+            "target_backends": "iree_vmla",
+        },
+        {
+            # Failing on llvm.
+            "functions": LLVM_FAILING_DYNAMIC,
+            "target_backends": "iree_llvmjit",
+        },
+        {
+            # Failing on vulkan.
+            "functions": VULKAN_FAILING_DYNAMIC,
+            "target_backends": "iree_vulkan",
+        },
+    ],
+    flags_to_values = {
+        "reference_backend": "tf",
+        "functions": TF_MATH_FUNCTIONS,
+        "dynamic_dims": True,
+        "test_complex": False,
+        "target_backends": [
+            "tf",
+            "tflite",
+            "iree_vmla",
+            "iree_llvmjit",
+            "iree_vulkan",
+        ],
+    },
+    main = "math_test.py",
+    tags = [
+        "manual",
+        "nokokoro",
+        "notap",
+    ],
+    deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
+        "//integrations/tensorflow/bindings/python/pyiree/tf/support",
+    ],
+)
+
+# ---- MULTIPLE DYNAMIC TESTS ----------------------------------------------- #
+
+# These tests compile all functions in tf.math at once for testing so that
+# we can run them on the CI with 4 additional targets instead of 512.
+
+# TODO(#3810) 'multiply' outputs all zeros when compiled with other functions.
+VMLA_FAILING_DYNAMIC_MULTIPLE = VMLA_FAILING_DYNAMIC + ["multiply"]
+
+[
+    iree_py_test(
+        name = "math_dynamic_dims_tests_multiple__{}".format(target_backend),
+        srcs = ["math_test.py"],
+        args = [
+            "--reference_backend=tf",
+            "--target_backends={}".format(target_backend),
+            "--functions={}".format(",".join(functions)),
+            "--dynamic_dims=False",
+        ],
+        main = "math_test.py",
+        deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
+            "//integrations/tensorflow/bindings/python/pyiree/tf/support",
+        ],
+    )
+    for target_backend, functions in dict(
+        iree_llvmjit = set_difference(TF_MATH_FUNCTIONS, LLVM_FAILING_DYNAMIC),
+        iree_vmla = set_difference(TF_MATH_FUNCTIONS, VMLA_FAILING_DYNAMIC_MULTIPLE),
+        iree_vulkan = set_difference(TF_MATH_FUNCTIONS, VULKAN_FAILING_DYNAMIC),
+        tf = TF_MATH_FUNCTIONS,
+    ).items()
+]
+
+# ---- INDIVIDUAL COMPLEX TESTS ---------------------------------------------- #
+
+# This list was generated by running:
+#   bazel run integrations/tensorflow/e2e/math:math_test_manual -- --list_functions_with_complex_tests
+COMPLEX_FUNCTIONS = [
+    "abs",
+    "add",
+    "angle",
+    "asinh",
+    "atanh",
+    "conj",
+    "cos",
+    "cosh",
+    "count_nonzero",
+    "cumprod",
+    "cumsum",
+    "divide",
+    "divide_no_nan",
+    "exp",
+    "expm1",
+    "imag",
+    "l2_normalize",
+    "log",
+    "log1p",
+    "multiply",
+    "multiply_no_nan",
+    "negative",
+    "pow",
+    "real",
+    "reciprocal",
+    "reciprocal_no_nan",
+    "reduce_euclidean_norm",
+    "reduce_std",
+    "reduce_variance",
+    "rsqrt",
+    "sigmoid",
+    "sign",
+    "sin",
+    "sinh",
+    "sqrt",
+    "square",
+    "squared_difference",
+    "subtract",
+    "tan",
+    "tanh",
+    "truediv",
+    "xdivy",
+    "xlog1py",
+    "xlogy",
+    "zero_fraction",
+]
+
+# keep sorted
+FAILING_COMPLEX = [
+    "angle",
+    "cos",
+    "cumsum",
+    "divide_no_nan",
+    "log",
+    "log1p",
+    "multiply_no_nan",
+    "negative",
+    "reciprocal",
+    "reciprocal_no_nan",
+    "reduce_std",
+    "reduce_variance",
+    "rsqrt",
+    "sigmoid",
+    "sin",
+    "sinh",
+    "sqrt",
+    "tan",
+    "tanh",
+    "zero_fraction",
+]
+
+VMLA_FAILING_COMPLEX = VMLA_FAILING + FAILING_COMPLEX
+
+LLVM_FAILING_COMPLEX = LLVM_FAILING + FAILING_COMPLEX
+
+VULKAN_FAILING_COMPLEX = VULKAN_FAILING + FAILING_COMPLEX
+
+# These tests allow us to generate coverage tables and give a finer-grained view
+# of the coverage, but are very slow due to bazel overhead, so they are not
+# run on the internal or OSS CI.
+iree_e2e_cartesian_product_test_suite(
+    name = "math_complex_tests",
+    srcs = ["math_test.py"],
+    failing_configurations = [
+        {
+            # TFLite does not support complex numbers.
+            "target_backends": "tflite",
+        },
+        {
+            # Failing on vmla.
+            "functions": VMLA_FAILING_COMPLEX,
+            "target_backends": "iree_vmla",
+        },
+        {
+            # Failing on llvm.
+            "functions": LLVM_FAILING_COMPLEX,
+            "target_backends": "iree_llvmjit",
+        },
+        {
+            # Failing on vulkan.
+            "functions": VULKAN_FAILING_COMPLEX,
+            "target_backends": "iree_vulkan",
+        },
+    ],
+    flags_to_values = {
+        "reference_backend": "tf",
+        "functions": COMPLEX_FUNCTIONS,
+        "dynamic_dims": False,
+        "test_complex": True,
+        "target_backends": [
+            "tf",
+            "tflite",
+            "iree_vmla",
+            "iree_llvmjit",
+            "iree_vulkan",
+        ],
+    },
+    main = "math_test.py",
+    tags = [
+        "manual",
+        "nokokoro",
+        "notap",
+    ],
+    deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
+        "//integrations/tensorflow/bindings/python/pyiree/tf/support",
+    ],
+)
+
+# ---- MULTIPLE COMPLEX TESTS ----------------------------------------------- #
+
+# These tests compile all functions in tf.math at once for testing so that
+# we can run them on the CI with 4 additional targets instead of 512.
+
+# TODO(#3810) 'multiply' outputs all zeros when compiled with other functions.
+VMLA_FAILING_COMPLEX_MULTIPLE = VMLA_FAILING_COMPLEX + ["multiply"]
+
+# TODO(#3810) Including 'square' causes error: Recieved signal 11.
+LLVM_FAILING_COMPLEX_MULTIPLE = LLVM_FAILING_COMPLEX + ["square"]
+
+# TODO(#3810) Including 'square' causes error: Recieved signal 11.
+VULKAN_FAILING_COMPLEX_MULTIPLE = VULKAN_FAILING_COMPLEX + ["square"]
+
+[
+    iree_py_test(
+        name = "math_complex_tests_multiple__{}".format(target_backend),
+        srcs = ["math_test.py"],
+        args = [
+            "--reference_backend=tf",
+            "--target_backends={}".format(target_backend),
+            "--functions={}".format(",".join(functions)),
+            "--dynamic_dims=False",
+        ],
+        main = "math_test.py",
+        deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
+            "//integrations/tensorflow/bindings/python/pyiree/tf/support",
+        ],
+    )
+    for target_backend, functions in dict(
+        iree_llvmjit = set_difference(TF_MATH_FUNCTIONS, LLVM_FAILING_COMPLEX_MULTIPLE),
+        iree_vmla = set_difference(TF_MATH_FUNCTIONS, VMLA_FAILING_COMPLEX_MULTIPLE),
+        iree_vulkan = set_difference(TF_MATH_FUNCTIONS, VULKAN_FAILING_COMPLEX_MULTIPLE),
+        tf = TF_MATH_FUNCTIONS,
+    ).items()
+]
diff --git a/integrations/tensorflow/e2e/math/math_test.py b/integrations/tensorflow/e2e/math/math_test.py
new file mode 100644
index 0000000..7a1096f
--- /dev/null
+++ b/integrations/tensorflow/e2e/math/math_test.py
@@ -0,0 +1,735 @@
+import collections
+import os
+from typing import Any, Dict, Sequence, Type, Union
+
+from absl import app
+from absl import flags
+import numpy as np
+from pyiree.tf.support import tf_test_utils
+from pyiree.tf.support import tf_utils
+import tensorflow.compat.v2 as tf
+
+FLAGS = flags.FLAGS
+
+# As high as tf goes without breaking.
+RANK_7_SHAPE = [2] * 7
+UNARY_SIGNATURE_SHAPES = [[RANK_7_SHAPE]]
+BINARY_SIGNATURE_SHAPES = [[RANK_7_SHAPE] * 2]
+TERNARY_SIGNATURE_SHAPES = [[RANK_7_SHAPE] * 3]
+
+# Reused UnitTestSpecs.
+SEGMENT_UNIT_TEST_SPECS = tf_test_utils.unit_test_specs_from_args(
+    names_to_input_args={
+        "tf_doc_example": [
+            tf.constant([
+                [1, 2, 3, 4],
+                [4, 3, 2, 1],
+                [5, 6, 7, 8],
+            ], np.float32),
+            np.array([0, 0, 1], np.int32),
+        ]
+    })
+UNSORTED_SEGMENT_UNIT_TEST_SPECS = tf_test_utils.unit_test_specs_from_args(
+    names_to_input_args={
+        "tf_doc_example": [
+            tf.constant([
+                [1, 2, 3, 4],
+                [4, 3, 2, 1],
+                [5, 6, 7, 8],
+            ], np.float32),
+            np.array([0, 0, 1], np.int32),
+            2,
+        ]
+    })
+
+REDUCE_KWARGS_TO_VALUES = {
+    "axis": [None, 1],
+    "keepdims": [False, True],
+}
+
+# A dictionary mapping tf.math function names to lists of UnitTestSpecs.
+# Each unit_test_name will have the tf.math function name prepended to it.
+FUNCTIONS_TO_UNIT_TEST_SPECS = {
+    "abs":
+        tf_test_utils.unit_test_specs_from_signatures(
+            signature_shapes=UNARY_SIGNATURE_SHAPES,
+            signature_dtypes=[tf.float32, tf.int32, tf.complex64]),
+    "accumulate_n": [
+        tf_test_utils.UnitTestSpec(
+            unit_test_name='f32',
+            input_signature=[[tf.TensorSpec(RANK_7_SHAPE, tf.float32)] * 5]),
+        tf_test_utils.UnitTestSpec(
+            unit_test_name='i32',
+            input_signature=[[tf.TensorSpec(RANK_7_SHAPE, tf.int32)] * 5]),
+    ],
+    "acos":
+        tf_test_utils.unit_test_specs_from_signatures(
+            signature_shapes=UNARY_SIGNATURE_SHAPES,
+            signature_dtypes=[tf.float32]),
+    "acosh":
+        tf_test_utils.unit_test_specs_from_signatures(
+            signature_shapes=UNARY_SIGNATURE_SHAPES,
+            signature_dtypes=[tf.float32],
+            input_generators=[tf_utils.ndarange]),
+    "add":
+        tf_test_utils.unit_test_specs_from_signatures(
+            signature_shapes=BINARY_SIGNATURE_SHAPES,
+            signature_dtypes=[tf.float32, tf.int32, tf.complex64]),
+    "add_n": [
+        tf_test_utils.UnitTestSpec(
+            unit_test_name='f32',
+            input_signature=[[tf.TensorSpec(RANK_7_SHAPE, tf.float32)] * 5]),
+        tf_test_utils.UnitTestSpec(
+            unit_test_name='i32',
+            input_signature=[[tf.TensorSpec(RANK_7_SHAPE, tf.int32)] * 5]),
+    ],
+    "angle":
+        tf_test_utils.unit_test_specs_from_signatures(
+            signature_shapes=UNARY_SIGNATURE_SHAPES,
+            signature_dtypes=[tf.float32, tf.complex64]),
+    "argmax":
+        tf_test_utils.unit_test_specs_from_signatures(
+            signature_shapes=UNARY_SIGNATURE_SHAPES,
+            signature_dtypes=[tf.float32, tf.int32]),
+    "argmin":
+        tf_test_utils.unit_test_specs_from_signatures(
+            signature_shapes=UNARY_SIGNATURE_SHAPES,
+            signature_dtypes=[tf.float32, tf.int32]),
+    "asin":
+        tf_test_utils.unit_test_specs_from_signatures(
+            signature_shapes=UNARY_SIGNATURE_SHAPES,
+            signature_dtypes=[tf.float32]),
+    "asinh":
+        tf_test_utils.unit_test_specs_from_signatures(
+            signature_shapes=UNARY_SIGNATURE_SHAPES,
+            signature_dtypes=[tf.float32, tf.complex64]),
+    "atan":
+        tf_test_utils.unit_test_specs_from_signatures(
+            signature_shapes=UNARY_SIGNATURE_SHAPES,
+            signature_dtypes=[tf.float32]),
+    "atan2":
+        tf_test_utils.unit_test_specs_from_signatures(
+            signature_shapes=BINARY_SIGNATURE_SHAPES,
+            signature_dtypes=[tf.float32]),
+    "atanh":
+        tf_test_utils.unit_test_specs_from_signatures(
+            signature_shapes=UNARY_SIGNATURE_SHAPES,
+            signature_dtypes=[tf.float32, tf.complex64]),
+    "bessel_i0":
+        tf_test_utils.unit_test_specs_from_signatures(
+            signature_shapes=UNARY_SIGNATURE_SHAPES,
+            signature_dtypes=[tf.float32]),
+    "bessel_i0e":
+        tf_test_utils.unit_test_specs_from_signatures(
+            signature_shapes=UNARY_SIGNATURE_SHAPES,
+            signature_dtypes=[tf.float32]),
+    "bessel_i1":
+        tf_test_utils.unit_test_specs_from_signatures(
+            signature_shapes=UNARY_SIGNATURE_SHAPES,
+            signature_dtypes=[tf.float32]),
+    "bessel_i1e":
+        tf_test_utils.unit_test_specs_from_signatures(
+            signature_shapes=UNARY_SIGNATURE_SHAPES,
+            signature_dtypes=[tf.float32]),
+    "betainc":
+        tf_test_utils.unit_test_specs_from_signatures(
+            signature_shapes=TERNARY_SIGNATURE_SHAPES,
+            signature_dtypes=[tf.float32]),
+    "bincount":
+        tf_test_utils.unit_test_specs_from_signatures(
+            signature_shapes=UNARY_SIGNATURE_SHAPES,
+            signature_dtypes=[tf.int32],
+            input_generators=[tf_utils.ndarange]),
+    "ceil":
+        tf_test_utils.unit_test_specs_from_signatures(
+            signature_shapes=UNARY_SIGNATURE_SHAPES,
+            signature_dtypes=[tf.float32]),
+    "confusion_matrix":
+        tf_test_utils.unit_test_specs_from_args(names_to_input_args={
+            "five_classes": [tf.constant([1, 2, 4]),
+                             tf.constant([2, 2, 4])]
+        }),
+    "conj":
+        tf_test_utils.unit_test_specs_from_signatures(
+            signature_shapes=UNARY_SIGNATURE_SHAPES,
+            signature_dtypes=[tf.float32, tf.int32, tf.complex64]),
+    "cos":
+        tf_test_utils.unit_test_specs_from_signatures(
+            signature_shapes=UNARY_SIGNATURE_SHAPES,
+            signature_dtypes=[tf.float32, tf.complex64]),
+    "cosh":
+        tf_test_utils.unit_test_specs_from_signatures(
+            signature_shapes=UNARY_SIGNATURE_SHAPES,
+            signature_dtypes=[tf.float32, tf.complex64]),
+    "count_nonzero":
+        tf_test_utils.unit_test_specs_from_signatures(
+            signature_shapes=UNARY_SIGNATURE_SHAPES,
+            signature_dtypes=[tf.float32, tf.int32, tf.complex64],
+            input_generators=[tf_utils.ndarange]),
+    "cumprod":
+        tf_test_utils.unit_test_specs_from_signatures(
+            signature_shapes=UNARY_SIGNATURE_SHAPES,
+            signature_dtypes=[tf.float32, tf.int32, tf.complex64]),
+    "cumsum":
+        tf_test_utils.unit_test_specs_from_signatures(
+            signature_shapes=UNARY_SIGNATURE_SHAPES,
+            signature_dtypes=[tf.float32, tf.int32, tf.complex64]),
+    "cumulative_logsumexp":
+        tf_test_utils.unit_test_specs_from_signatures(
+            signature_shapes=UNARY_SIGNATURE_SHAPES,
+            signature_dtypes=[tf.float32]),
+    "digamma":
+        tf_test_utils.unit_test_specs_from_signatures(
+            signature_shapes=UNARY_SIGNATURE_SHAPES,
+            signature_dtypes=[tf.float32]),
+    "divide":
+        tf_test_utils.unit_test_specs_from_signatures(
+            signature_shapes=BINARY_SIGNATURE_SHAPES,
+            signature_dtypes=[tf.float32, tf.int32, tf.complex64]),
+    "divide_no_nan":
+        tf_test_utils.unit_test_specs_from_signatures(
+            signature_shapes=BINARY_SIGNATURE_SHAPES,
+            signature_dtypes=[tf.float32, tf.complex64]),
+    "equal":
+        tf_test_utils.unit_test_specs_from_signatures(
+            signature_shapes=BINARY_SIGNATURE_SHAPES,
+            signature_dtypes=[tf.float32, tf.int32]),
+    "erf":
+        tf_test_utils.unit_test_specs_from_signatures(
+            signature_shapes=UNARY_SIGNATURE_SHAPES,
+            signature_dtypes=[tf.float32]),
+    "erfc":
+        tf_test_utils.unit_test_specs_from_signatures(
+            signature_shapes=UNARY_SIGNATURE_SHAPES,
+            signature_dtypes=[tf.float32]),
+    "erfinv":
+        tf_test_utils.unit_test_specs_from_signatures(
+            signature_shapes=UNARY_SIGNATURE_SHAPES,
+            signature_dtypes=[tf.float32]),
+    "exp":
+        tf_test_utils.unit_test_specs_from_signatures(
+            signature_shapes=UNARY_SIGNATURE_SHAPES,
+            signature_dtypes=[tf.float32, tf.complex64]),
+    "expm1":
+        tf_test_utils.unit_test_specs_from_signatures(
+            signature_shapes=UNARY_SIGNATURE_SHAPES,
+            signature_dtypes=[tf.float32, tf.complex64]),
+    "floor":
+        tf_test_utils.unit_test_specs_from_signatures(
+            signature_shapes=UNARY_SIGNATURE_SHAPES,
+            signature_dtypes=[tf.float32]),
+    "floordiv":
+        tf_test_utils.unit_test_specs_from_signatures(
+            signature_shapes=BINARY_SIGNATURE_SHAPES,
+            signature_dtypes=[tf.float32, tf.int32],
+            # Avoid integer division by 0.
+            input_generators={
+                "uniform_1_3":
+                    lambda *args: tf_utils.uniform(*args, low=1.0, high=3.0)
+            }),
+    "floormod":
+        tf_test_utils.unit_test_specs_from_signatures(
+            signature_shapes=BINARY_SIGNATURE_SHAPES,
+            signature_dtypes=[tf.float32, tf.int32],
+            # Avoid integer division by 0.
+            input_generators={
+                "uniform_1_3":
+                    lambda *args: tf_utils.uniform(*args, low=1.0, high=3.0)
+            }),
+    "greater":
+        tf_test_utils.unit_test_specs_from_signatures(
+            signature_shapes=BINARY_SIGNATURE_SHAPES,
+            signature_dtypes=[tf.float32, tf.int32]),
+    "greater_equal":
+        tf_test_utils.unit_test_specs_from_signatures(
+            signature_shapes=BINARY_SIGNATURE_SHAPES,
+            signature_dtypes=[tf.float32, tf.int32]),
+    "igamma":
+        tf_test_utils.unit_test_specs_from_signatures(
+            signature_shapes=BINARY_SIGNATURE_SHAPES,
+            signature_dtypes=[tf.float32]),
+    "igammac":
+        tf_test_utils.unit_test_specs_from_signatures(
+            signature_shapes=BINARY_SIGNATURE_SHAPES,
+            signature_dtypes=[tf.float32]),
+    "imag":
+        tf_test_utils.unit_test_specs_from_signatures(
+            signature_shapes=UNARY_SIGNATURE_SHAPES,
+            signature_dtypes=[tf.float32, tf.complex64]),
+    "in_top_k": [
+        tf_test_utils.UnitTestSpec(
+            unit_test_name="k_3",
+            input_signature=[
+                tf.TensorSpec([8], tf.int32),
+                tf.TensorSpec([8, 3])
+            ],
+            input_generator=tf_utils.ndarange,
+            kwargs=dict(k=3),
+        )
+    ],
+    "invert_permutation": [
+        tf_test_utils.UnitTestSpec(
+            unit_test_name="random",
+            input_signature=[tf.TensorSpec([8], tf.int32)],
+            input_generator=tf_utils.random_permutation,
+        )
+    ],
+    "is_finite":
+        tf_test_utils.unit_test_specs_from_args(names_to_input_args={
+            "nan_and_inf": [tf.constant([[1., np.nan], [np.inf, 2.]])]
+        }),
+    "is_inf":
+        tf_test_utils.unit_test_specs_from_args(names_to_input_args={
+            "nan_and_inf": [tf.constant([[1., np.nan], [np.inf, 2.]])]
+        }),
+    "is_nan":
+        tf_test_utils.unit_test_specs_from_args(names_to_input_args={
+            "nan_and_inf": [tf.constant([[1., np.nan], [np.inf, 2.]])]
+        }),
+    "is_non_decreasing":
+        tf_test_utils.unit_test_specs_from_signatures(
+            signature_shapes=UNARY_SIGNATURE_SHAPES,
+            signature_dtypes=[tf.float32, tf.int32]),
+    "is_strictly_increasing":
+        tf_test_utils.unit_test_specs_from_signatures(
+            signature_shapes=UNARY_SIGNATURE_SHAPES,
+            signature_dtypes=[tf.float32, tf.int32]),
+    "l2_normalize":
+        tf_test_utils.unit_test_specs_from_signatures(
+            signature_shapes=UNARY_SIGNATURE_SHAPES,
+            signature_dtypes=[tf.float32, tf.complex64]),
+    "lbeta":
+        tf_test_utils.unit_test_specs_from_signatures(
+            signature_shapes=UNARY_SIGNATURE_SHAPES,
+            signature_dtypes=[tf.float32]),
+    "less":
+        tf_test_utils.unit_test_specs_from_signatures(
+            signature_shapes=BINARY_SIGNATURE_SHAPES,
+            signature_dtypes=[tf.float32, tf.int32]),
+    "less_equal":
+        tf_test_utils.unit_test_specs_from_signatures(
+            signature_shapes=BINARY_SIGNATURE_SHAPES,
+            signature_dtypes=[tf.float32, tf.int32]),
+    "lgamma":
+        tf_test_utils.unit_test_specs_from_signatures(
+            signature_shapes=UNARY_SIGNATURE_SHAPES,
+            signature_dtypes=[tf.float32]),
+    "log":
+        tf_test_utils.unit_test_specs_from_signatures(
+            signature_shapes=UNARY_SIGNATURE_SHAPES,
+            signature_dtypes=[tf.float32, tf.complex64]),
+    "log1p":
+        tf_test_utils.unit_test_specs_from_signatures(
+            signature_shapes=UNARY_SIGNATURE_SHAPES,
+            signature_dtypes=[tf.float32, tf.complex64]),
+    "log_sigmoid":
+        tf_test_utils.unit_test_specs_from_signatures(
+            signature_shapes=UNARY_SIGNATURE_SHAPES,
+            signature_dtypes=[tf.float32]),
+    "log_softmax":
+        tf_test_utils.unit_test_specs_from_signatures(
+            signature_shapes=UNARY_SIGNATURE_SHAPES,
+            signature_dtypes=[tf.float32]),
+    "logical_and":
+        tf_test_utils.unit_test_specs_from_signatures(
+            signature_shapes=BINARY_SIGNATURE_SHAPES,
+            signature_dtypes=[tf.bool]),
+    "logical_not":
+        tf_test_utils.unit_test_specs_from_signatures(
+            signature_shapes=UNARY_SIGNATURE_SHAPES,
+            signature_dtypes=[tf.bool]),
+    "logical_or":
+        tf_test_utils.unit_test_specs_from_signatures(
+            signature_shapes=BINARY_SIGNATURE_SHAPES,
+            signature_dtypes=[tf.bool]),
+    "logical_xor":
+        tf_test_utils.unit_test_specs_from_signatures(
+            signature_shapes=BINARY_SIGNATURE_SHAPES,
+            signature_dtypes=[tf.bool]),
+    "maximum":
+        tf_test_utils.unit_test_specs_from_signatures(
+            signature_shapes=BINARY_SIGNATURE_SHAPES,
+            signature_dtypes=[tf.float32, tf.int32]),
+    "minimum":
+        tf_test_utils.unit_test_specs_from_signatures(
+            signature_shapes=BINARY_SIGNATURE_SHAPES,
+            signature_dtypes=[tf.float32, tf.int32]),
+    "mod":
+        tf_test_utils.unit_test_specs_from_signatures(
+            signature_shapes=BINARY_SIGNATURE_SHAPES,
+            signature_dtypes=[tf.float32, tf.int32],
+            input_generators={
+                "positive_ndarange": lambda *args: tf_utils.ndarange(*args) + 1
+            }),
+    "multiply":
+        tf_test_utils.unit_test_specs_from_signatures(
+            signature_shapes=BINARY_SIGNATURE_SHAPES,
+            signature_dtypes=[tf.float32, tf.int32, tf.complex64]),
+    "multiply_no_nan":
+        tf_test_utils.unit_test_specs_from_signatures(
+            signature_shapes=BINARY_SIGNATURE_SHAPES,
+            signature_dtypes=[tf.float32, tf.complex64]),
+    "ndtri":
+        tf_test_utils.unit_test_specs_from_signatures(
+            signature_shapes=UNARY_SIGNATURE_SHAPES,
+            signature_dtypes=[tf.float32]),
+    "negative":
+        tf_test_utils.unit_test_specs_from_signatures(
+            signature_shapes=UNARY_SIGNATURE_SHAPES,
+            signature_dtypes=[tf.float32, tf.int32, tf.complex64]),
+    "nextafter":
+        tf_test_utils.unit_test_specs_from_signatures(
+            signature_shapes=BINARY_SIGNATURE_SHAPES),
+    "not_equal":
+        tf_test_utils.unit_test_specs_from_signatures(
+            signature_shapes=BINARY_SIGNATURE_SHAPES,
+            signature_dtypes=[tf.float32, tf.int32]),
+    "polygamma":
+        tf_test_utils.unit_test_specs_from_args(names_to_input_args={
+            "nan_and_inf": [tf.ones(16), tf.linspace(0.5, 4, 16)]
+        }),
+    "polyval": [
+        tf_test_utils.UnitTestSpec(
+            unit_test_name="three_coeffs",
+            input_signature=[[tf.TensorSpec(RANK_7_SHAPE)] * 3,
+                             tf.TensorSpec([])],
+        )
+    ],
+    "pow":
+        tf_test_utils.unit_test_specs_from_signatures(
+            signature_shapes=BINARY_SIGNATURE_SHAPES,
+            signature_dtypes=[tf.float32, tf.int32, tf.complex64],
+            input_generators={
+                "positive_ndarange": lambda *args: tf_utils.ndarange(*args) + 1
+            }),
+    "real":
+        tf_test_utils.unit_test_specs_from_signatures(
+            signature_shapes=UNARY_SIGNATURE_SHAPES,
+            signature_dtypes=[tf.float32, tf.complex64]),
+    "reciprocal":
+        tf_test_utils.unit_test_specs_from_signatures(
+            signature_shapes=UNARY_SIGNATURE_SHAPES,
+            signature_dtypes=[tf.float32, tf.complex64]),
+    "reciprocal_no_nan":
+        tf_test_utils.unit_test_specs_from_signatures(
+            signature_shapes=UNARY_SIGNATURE_SHAPES,
+            signature_dtypes=[tf.float32, tf.complex64]),
+    "reduce_all": [
+        # Explicitly test all True inputs to be absolutely sure that some
+        # reduction axes return True.
+        *tf_test_utils.unit_test_specs_from_args(
+            names_to_input_args={
+                "all_true": [np.ones(RANK_7_SHAPE, np.bool)],
+            },
+            kwargs_to_values=REDUCE_KWARGS_TO_VALUES),
+        *tf_test_utils.unit_test_specs_from_signatures(
+            signature_shapes=UNARY_SIGNATURE_SHAPES,
+            signature_dtypes=[tf.bool],
+            kwargs_to_values=REDUCE_KWARGS_TO_VALUES),
+    ],
+    "reduce_any": [
+        # Explicitly test all False inputs to be absolutely sure that some
+        # reduction axes return False.
+        *tf_test_utils.unit_test_specs_from_args(
+            names_to_input_args={
+                "all_false": [np.zeros(RANK_7_SHAPE, np.bool)],
+            },
+            kwargs_to_values=REDUCE_KWARGS_TO_VALUES),
+        *tf_test_utils.unit_test_specs_from_signatures(
+            signature_shapes=UNARY_SIGNATURE_SHAPES,
+            signature_dtypes=[tf.bool],
+            kwargs_to_values=REDUCE_KWARGS_TO_VALUES),
+    ],
+    "reduce_euclidean_norm":
+        tf_test_utils.unit_test_specs_from_signatures(
+            signature_shapes=UNARY_SIGNATURE_SHAPES,
+            signature_dtypes=[tf.float32, tf.complex64],
+            kwargs_to_values=REDUCE_KWARGS_TO_VALUES),
+    "reduce_logsumexp":
+        tf_test_utils.unit_test_specs_from_signatures(
+            signature_shapes=UNARY_SIGNATURE_SHAPES,
+            signature_dtypes=[tf.float32],
+            kwargs_to_values=REDUCE_KWARGS_TO_VALUES),
+    "reduce_max":
+        tf_test_utils.unit_test_specs_from_signatures(
+            signature_shapes=UNARY_SIGNATURE_SHAPES,
+            signature_dtypes=[tf.float32, tf.int32],
+            kwargs_to_values=REDUCE_KWARGS_TO_VALUES),
+    "reduce_mean":
+        tf_test_utils.unit_test_specs_from_signatures(
+            signature_shapes=UNARY_SIGNATURE_SHAPES,
+            signature_dtypes=[tf.float32, tf.int32],
+            kwargs_to_values=REDUCE_KWARGS_TO_VALUES),
+    "reduce_min":
+        tf_test_utils.unit_test_specs_from_signatures(
+            signature_shapes=UNARY_SIGNATURE_SHAPES,
+            signature_dtypes=[tf.float32, tf.int32],
+            kwargs_to_values=REDUCE_KWARGS_TO_VALUES),
+    "reduce_prod":
+        tf_test_utils.unit_test_specs_from_signatures(
+            signature_shapes=UNARY_SIGNATURE_SHAPES,
+            signature_dtypes=[tf.float32, tf.int32],
+            kwargs_to_values=REDUCE_KWARGS_TO_VALUES),
+    "reduce_std":
+        tf_test_utils.unit_test_specs_from_signatures(
+            signature_shapes=UNARY_SIGNATURE_SHAPES,
+            signature_dtypes=[tf.float32, tf.complex64],
+            kwargs_to_values=REDUCE_KWARGS_TO_VALUES),
+    "reduce_sum":
+        tf_test_utils.unit_test_specs_from_signatures(
+            signature_shapes=UNARY_SIGNATURE_SHAPES,
+            signature_dtypes=[tf.float32, tf.int32],
+            kwargs_to_values=REDUCE_KWARGS_TO_VALUES),
+    "reduce_variance":
+        tf_test_utils.unit_test_specs_from_signatures(
+            signature_shapes=UNARY_SIGNATURE_SHAPES,
+            signature_dtypes=[tf.float32, tf.complex64],
+            kwargs_to_values=REDUCE_KWARGS_TO_VALUES),
+    "rint":
+        tf_test_utils.unit_test_specs_from_signatures(
+            signature_shapes=UNARY_SIGNATURE_SHAPES,
+            signature_dtypes=[tf.float32]),
+    "round":
+        tf_test_utils.unit_test_specs_from_signatures(
+            signature_shapes=UNARY_SIGNATURE_SHAPES,
+            signature_dtypes=[tf.float32]),
+    "rsqrt":
+        tf_test_utils.unit_test_specs_from_signatures(
+            signature_shapes=UNARY_SIGNATURE_SHAPES,
+            signature_dtypes=[tf.float32, tf.complex64]),
+    "scalar_mul":
+        tf_test_utils.unit_test_specs_from_signatures(
+            signature_shapes=[[[], [8]]]),
+    "segment_max":
+        SEGMENT_UNIT_TEST_SPECS,
+    "segment_mean":
+        SEGMENT_UNIT_TEST_SPECS,
+    "segment_min":
+        SEGMENT_UNIT_TEST_SPECS,
+    "segment_prod":
+        SEGMENT_UNIT_TEST_SPECS,
+    "segment_sum":
+        SEGMENT_UNIT_TEST_SPECS,
+    "sigmoid":
+        tf_test_utils.unit_test_specs_from_signatures(
+            signature_shapes=UNARY_SIGNATURE_SHAPES,
+            signature_dtypes=[tf.float32, tf.complex64]),
+    "sign":
+        tf_test_utils.unit_test_specs_from_signatures(
+            signature_shapes=UNARY_SIGNATURE_SHAPES,
+            signature_dtypes=[tf.float32, tf.int32, tf.complex64]),
+    "sin":
+        tf_test_utils.unit_test_specs_from_signatures(
+            signature_shapes=UNARY_SIGNATURE_SHAPES,
+            signature_dtypes=[tf.float32, tf.complex64]),
+    "sinh":
+        tf_test_utils.unit_test_specs_from_signatures(
+            signature_shapes=UNARY_SIGNATURE_SHAPES,
+            signature_dtypes=[tf.float32, tf.complex64]),
+    "sobol_sample":
+        tf_test_utils.unit_test_specs_from_args(
+            names_to_input_args={"simple": [4, 3]}),
+    "softmax":
+        tf_test_utils.unit_test_specs_from_signatures(
+            signature_shapes=UNARY_SIGNATURE_SHAPES,
+            signature_dtypes=[tf.float32]),
+    "softplus":
+        tf_test_utils.unit_test_specs_from_signatures(
+            signature_shapes=UNARY_SIGNATURE_SHAPES,
+            signature_dtypes=[tf.float32]),
+    "softsign":
+        tf_test_utils.unit_test_specs_from_signatures(
+            signature_shapes=UNARY_SIGNATURE_SHAPES,
+            signature_dtypes=[tf.float32]),
+    "sqrt":
+        tf_test_utils.unit_test_specs_from_signatures(
+            signature_shapes=UNARY_SIGNATURE_SHAPES,
+            signature_dtypes=[tf.float32, tf.complex64]),
+    "square":
+        tf_test_utils.unit_test_specs_from_signatures(
+            signature_shapes=UNARY_SIGNATURE_SHAPES,
+            signature_dtypes=[tf.float32, tf.int32, tf.complex64]),
+    "squared_difference":
+        tf_test_utils.unit_test_specs_from_signatures(
+            signature_shapes=BINARY_SIGNATURE_SHAPES,
+            signature_dtypes=[tf.float32, tf.int32, tf.complex64]),
+    "subtract":
+        tf_test_utils.unit_test_specs_from_signatures(
+            signature_shapes=BINARY_SIGNATURE_SHAPES,
+            signature_dtypes=[tf.float32, tf.int32, tf.complex64]),
+    "tan":
+        tf_test_utils.unit_test_specs_from_signatures(
+            signature_shapes=UNARY_SIGNATURE_SHAPES,
+            signature_dtypes=[tf.float32, tf.complex64]),
+    "tanh":
+        tf_test_utils.unit_test_specs_from_signatures(
+            signature_shapes=UNARY_SIGNATURE_SHAPES,
+            signature_dtypes=[tf.float32, tf.complex64]),
+    "top_k":
+        tf_test_utils.unit_test_specs_from_signatures(
+            signature_shapes=UNARY_SIGNATURE_SHAPES,
+            signature_dtypes=[tf.float32],
+            kwargs_to_values={"k": [1, 2]}),
+    "truediv":
+        tf_test_utils.unit_test_specs_from_signatures(
+            signature_shapes=BINARY_SIGNATURE_SHAPES,
+            signature_dtypes=[tf.float32, tf.complex64]),
+    "unsorted_segment_max":
+        UNSORTED_SEGMENT_UNIT_TEST_SPECS,
+    "unsorted_segment_mean":
+        UNSORTED_SEGMENT_UNIT_TEST_SPECS,
+    "unsorted_segment_min":
+        UNSORTED_SEGMENT_UNIT_TEST_SPECS,
+    "unsorted_segment_prod":
+        UNSORTED_SEGMENT_UNIT_TEST_SPECS,
+    "unsorted_segment_sqrt_n":
+        UNSORTED_SEGMENT_UNIT_TEST_SPECS,
+    "unsorted_segment_sum":
+        UNSORTED_SEGMENT_UNIT_TEST_SPECS,
+    "xdivy":
+        tf_test_utils.unit_test_specs_from_signatures(
+            signature_shapes=BINARY_SIGNATURE_SHAPES,
+            signature_dtypes=[tf.float32, tf.complex64]),
+    "xlog1py":
+        tf_test_utils.unit_test_specs_from_signatures(
+            signature_shapes=BINARY_SIGNATURE_SHAPES,
+            signature_dtypes=[tf.float32, tf.complex64]),
+    "xlogy":
+        tf_test_utils.unit_test_specs_from_signatures(
+            signature_shapes=BINARY_SIGNATURE_SHAPES,
+            signature_dtypes=[tf.float32, tf.complex64]),
+    "zero_fraction":
+        tf_test_utils.unit_test_specs_from_signatures(
+            signature_shapes=UNARY_SIGNATURE_SHAPES,
+            signature_dtypes=[tf.float32, tf.int32, tf.complex64]),
+    "zeta":
+        tf_test_utils.unit_test_specs_from_signatures(
+            signature_shapes=BINARY_SIGNATURE_SHAPES,
+            # The function is poorly behaved near zero, so we test this range
+            # to avoid outputing all nans.
+            input_generators={
+                "uniform_3_4":
+                    lambda *args: tf_utils.uniform(*args, low=3.0, high=4.0)
+            },
+        )
+}
+
+for function, specs in FUNCTIONS_TO_UNIT_TEST_SPECS.items():
+  # Update using 'with_name' to avoid updating shared UnitTestSpecs.
+  specs = [
+      spec.with_name(f"{function}__{spec.unit_test_name}") for spec in specs
+  ]
+  FUNCTIONS_TO_UNIT_TEST_SPECS[function] = specs
+
+  # Validate that there are not multiple UnitTestSpecs with the same name.
+  seen_unit_test_names = set()
+  for spec in specs:
+    if spec.unit_test_name in seen_unit_test_names:
+      raise ValueError(
+          f"Found multiple UnitTestSpecs with the name '{spec.unit_test_name}'")
+    seen_unit_test_names.add(spec.unit_test_name)
+
+flags.DEFINE_list(
+    "functions", None,
+    f"Any of {list(FUNCTIONS_TO_UNIT_TEST_SPECS.keys())}. If more than one "
+    "function is provided then len(--target_backends) must be one.")
+flags.DEFINE_bool(
+    "dynamic_dims", False,
+    "Whether or not to compile the layer with dynamic dimensions.")
+flags.DEFINE_bool(
+    "test_complex", False,
+    "Whether or not to test or ignore function signatures with complex types.")
+flags.DEFINE_bool(
+    'list_functions_with_complex_tests', False,
+    'Whether or not to print out all functions with complex inputs '
+    '(and skip running the tests).')
+
+
+def create_function_unit_test(
+    function_name: str,
+    unit_test_spec: tf_test_utils.UnitTestSpec) -> tf.function:
+  """Creates a tf_function_unit_test from the provided UnitTestSpec."""
+  function = getattr(tf.math, function_name)
+  signature = unit_test_spec.input_signature
+
+  if tf_utils.is_complex(signature):
+    function, signature = tf_utils.rewrite_complex_signature(
+        function, signature)
+  wrapped_function = lambda *args: function(*args, **unit_test_spec.kwargs)
+
+  if FLAGS.dynamic_dims:
+    signature = tf_utils.apply_function(signature, tf_utils.make_dims_dynamic)
+
+  return tf_test_utils.tf_function_unit_test(
+      input_signature=signature,
+      input_generator=unit_test_spec.input_generator,
+      input_args=unit_test_spec.input_args,
+      name=unit_test_spec.unit_test_name,
+      rtol=1e-5,
+      atol=1e-5)(wrapped_function)
+
+
+class TfMathModule(tf_test_utils.TestModule):
+
+  def __init__(self):
+    super().__init__()
+    for function in FLAGS.functions:
+      for unit_test_spec in FUNCTIONS_TO_UNIT_TEST_SPECS[function]:
+        if not FLAGS.test_complex and tf_utils.is_complex(
+            unit_test_spec.input_signature):
+          continue
+        function_unit_test = create_function_unit_test(function, unit_test_spec)
+        setattr(self, unit_test_spec.unit_test_name, function_unit_test)
+
+
+class TfMathTest(tf_test_utils.TracedModuleTestCase):
+
+  def __init__(self, *args, **kwargs):
+    super().__init__(*args, **kwargs)
+    self._modules = tf_test_utils.compile_tf_module(
+        TfMathModule, exported_names=TfMathModule.get_tf_function_unit_tests())
+
+
+def main(argv):
+  del argv  # Unused.
+  if hasattr(tf, "enable_v2_behavior"):
+    tf.enable_v2_behavior()
+
+  if FLAGS.list_functions_with_complex_tests:
+    for function_name, unit_test_specs in FUNCTIONS_TO_UNIT_TEST_SPECS.items():
+      for spec in unit_test_specs:
+        if tf_utils.is_complex(spec.input_signature):
+          print(f'    "{function_name}",')
+    return
+
+  if FLAGS.functions is None:
+    raise flags.IllegalFlagValueError(
+        "'--functions' must be specified if "
+        "'--list_functions_with_complex_tests' isn't")
+
+  if len(FLAGS.functions) > 1:
+    # We only allow testing multiple functions with a single target backend
+    # so that we can store the artifacts under:
+    #   'artifacts_dir/multiple_functions__backend/...'
+    # We specialize the 'multiple_functions' dir by backend to avoid overwriting
+    # tf_input.mlir and iree_input.mlir. These are typically identical across
+    # backends, but are not when the functions to compile change per-backend.
+    if len(FLAGS.target_backends) != 1:
+      raise flags.IllegalFlagValueError(
+          "Expected len(target_backends) == 1 when len(functions) > 1, but got "
+          f"the following values for target_backends: {FLAGS.target_backends}.")
+    function_str = f"multiple_functions__{FLAGS.target_backends[0]}"
+  else:
+    function_str = FLAGS.functions[0]
+  dim_str = "dynamic_dims" if FLAGS.dynamic_dims else "static_dims"
+  settings_str = os.path.join(function_str, dim_str)
+  # The relative artifacts directory path is calculated from the module name
+  # TODO(meadowlark): provide a better way of overridding this default.
+  TfMathModule.__name__ = os.path.join("tf", "math", settings_str)
+
+  TfMathTest.generate_unit_tests(TfMathModule)
+  tf.test.main()
+
+
+if __name__ == "__main__":
+  app.run(main)
diff --git a/integrations/tensorflow/e2e/math_dyn_test.py b/integrations/tensorflow/e2e/math_dyn_test.py
deleted file mode 100644
index cca0918..0000000
--- a/integrations/tensorflow/e2e/math_dyn_test.py
+++ /dev/null
@@ -1,102 +0,0 @@
-# Lint as: python3
-# Copyright 2020 Google LLC
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#      https://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""Tests for ops in the tf.math module."""
-
-from absl import app
-import numpy as np
-from pyiree.tf.support import tf_test_utils
-import tensorflow.compat.v2 as tf
-
-
-class MathModule(tf.Module):
-
-  @tf.function(input_signature=[tf.TensorSpec([None], tf.float32)])
-  def abs(self, x):
-    return tf.math.abs(x)
-
-  @tf.function(input_signature=[tf.TensorSpec([None], tf.float32)])
-  def ceil(self, x):
-    return tf.math.ceil(x)
-
-  @tf.function(input_signature=[tf.TensorSpec([None], tf.float32)])
-  def cos(self, x):
-    return tf.math.cos(x)
-
-  @tf.function(input_signature=[tf.TensorSpec([None], tf.float32)])
-  def log(self, x):
-    return tf.math.log(x)
-
-  @tf.function(input_signature=[tf.TensorSpec([None], tf.float32)])
-  def mod(self, x):
-    return tf.math.mod(x, 2.0)
-
-  @tf.function(input_signature=[tf.TensorSpec([None], tf.float32)])
-  def fake_quant(self, x):
-    return tf.quantization.fake_quant_with_min_max_args(x,
-                                                        min=-6,
-                                                        max=6,
-                                                        num_bits=8,
-                                                        narrow_range=False,
-                                                        name=None)
-
-
-class MathTest(tf_test_utils.TracedModuleTestCase):
-
-  def __init__(self, *args, **kwargs):
-    super().__init__(*args, **kwargs)
-    self._modules = tf_test_utils.compile_tf_module(MathModule)
-
-  # yapf: disable
-  def test_abs(self):
-    def abs(module):
-      module.abs(np.array([-0.5, 0.0, 0.5, 1.0], dtype=np.float32))
-    self.compare_backends(abs, self._modules)
-
-  def test_ceil(self):
-    def ceil(module):
-      module.ceil(np.array([0.0, 1.2, 1.5, 3.75], dtype=np.float32))
-    self.compare_backends(ceil, self._modules)
-
-  def test_cos(self):
-    def cos(module):
-      module.cos(np.array([-0.5, 0.0, 0.5, 1.0], dtype=np.float32))
-    self.compare_backends(cos, self._modules)
-
-  def test_log(self):
-    def log(module):
-      module.log(np.array([0.1, 0.2, 0.5, 1.0], dtype=np.float32))
-    self.compare_backends(log, self._modules)
-
-  def test_mod(self):
-    def mod(module):
-      module.mod(np.array([0.0, 1.2, 1.5, 3.75], dtype=np.float32))
-    self.compare_backends(mod, self._modules)
-
-  def test_fake_quant(self):
-    def abs(module):
-      module.fake_quant(np.array([-0.123, 0.1234, 0.743, 4.3], dtype=np.float32))
-    self.compare_backends(abs, self._modules)
-  # yapf: enable
-
-
-def main(argv):
-  del argv  # Unused
-  if hasattr(tf, 'enable_v2_behavior'):
-    tf.enable_v2_behavior()
-  tf.test.main()
-
-
-if __name__ == '__main__':
-  app.run(main)
diff --git a/integrations/tensorflow/e2e/math_test.py b/integrations/tensorflow/e2e/math_test.py
deleted file mode 100644
index add9ea4..0000000
--- a/integrations/tensorflow/e2e/math_test.py
+++ /dev/null
@@ -1,102 +0,0 @@
-# Lint as: python3
-# Copyright 2020 Google LLC
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#      https://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""Tests for ops in the tf.math module."""
-
-from absl import app
-import numpy as np
-from pyiree.tf.support import tf_test_utils
-import tensorflow.compat.v2 as tf
-
-
-class MathModule(tf.Module):
-
-  @tf.function(input_signature=[tf.TensorSpec([4], tf.float32)])
-  def abs(self, x):
-    return tf.math.abs(x)
-
-  @tf.function(input_signature=[tf.TensorSpec([4], tf.float32)])
-  def ceil(self, x):
-    return tf.math.ceil(x)
-
-  @tf.function(input_signature=[tf.TensorSpec([4], tf.float32)])
-  def cos(self, x):
-    return tf.math.cos(x)
-
-  @tf.function(input_signature=[tf.TensorSpec([4], tf.float32)])
-  def log(self, x):
-    return tf.math.log(x)
-
-  @tf.function(input_signature=[tf.TensorSpec([4], tf.float32)])
-  def mod(self, x):
-    return tf.math.mod(x, 2.0)
-
-  @tf.function(input_signature=[tf.TensorSpec([4], tf.float32)])
-  def fake_quant(self, x):
-    return tf.quantization.fake_quant_with_min_max_args(x,
-                                                        min=-6,
-                                                        max=6,
-                                                        num_bits=8,
-                                                        narrow_range=False,
-                                                        name=None)
-
-
-class MathTest(tf_test_utils.TracedModuleTestCase):
-
-  def __init__(self, *args, **kwargs):
-    super().__init__(*args, **kwargs)
-    self._modules = tf_test_utils.compile_tf_module(MathModule)
-
-  # yapf: disable
-  def test_abs(self):
-    def abs(module):
-      module.abs(np.array([-0.5, 0.0, 0.5, 1.0], dtype=np.float32))
-    self.compare_backends(abs, self._modules)
-
-  def test_ceil(self):
-    def ceil(module):
-      module.ceil(np.array([0.0, 1.2, 1.5, 3.75], dtype=np.float32))
-    self.compare_backends(ceil, self._modules)
-
-  def test_cos(self):
-    def cos(module):
-      module.cos(np.array([-0.5, 0.0, 0.5, 1.0], dtype=np.float32))
-    self.compare_backends(cos, self._modules)
-
-  def test_log(self):
-    def log(module):
-      module.log(np.array([0.1, 0.2, 0.5, 1.0], dtype=np.float32))
-    self.compare_backends(log, self._modules)
-
-  def test_mod(self):
-    def mod(module):
-      module.mod(np.array([0.0, 1.2, 1.5, 3.75], dtype=np.float32))
-    self.compare_backends(mod, self._modules)
-
-  def test_fake_quant(self):
-    def abs(module):
-      module.fake_quant(np.array([-0.123, 0.1234, 0.743, 4.3], dtype=np.float32))
-    self.compare_backends(abs, self._modules)
-  # yapf: enable
-
-
-def main(argv):
-  del argv  # Unused
-  if hasattr(tf, 'enable_v2_behavior'):
-    tf.enable_v2_behavior()
-  tf.test.main()
-
-
-if __name__ == '__main__':
-  app.run(main)
diff --git a/integrations/tensorflow/e2e/quantization_dyn_test.py b/integrations/tensorflow/e2e/quantization_dyn_test.py
new file mode 100644
index 0000000..d1c44ef
--- /dev/null
+++ b/integrations/tensorflow/e2e/quantization_dyn_test.py
@@ -0,0 +1,58 @@
+# Lint as: python3
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tests for ops in the tf.math module."""
+
+from absl import app
+import numpy as np
+from pyiree.tf.support import tf_test_utils
+from pyiree.tf.support import tf_utils
+import tensorflow.compat.v2 as tf
+
+
+class QuantizationDynModule(tf.Module):
+
+  @tf.function(input_signature=[tf.TensorSpec([None], tf.float32)])
+  def fake_quant(self, x):
+    return tf.quantization.fake_quant_with_min_max_args(x,
+                                                        min=-6,
+                                                        max=6,
+                                                        num_bits=8,
+                                                        narrow_range=False,
+                                                        name=None)
+
+
+class QuantizationDynTest(tf_test_utils.TracedModuleTestCase):
+
+  def __init__(self, *args, **kwargs):
+    super().__init__(*args, **kwargs)
+    self._modules = tf_test_utils.compile_tf_module(QuantizationDynModule)
+
+  def test_fake_quant(self):
+
+    def abs(module):
+      module.fake_quant(tf_utils.uniform([32], low=-6, high=6))
+
+    self.compare_backends(abs, self._modules)
+
+
+def main(argv):
+  del argv  # Unused
+  if hasattr(tf, 'enable_v2_behavior'):
+    tf.enable_v2_behavior()
+  tf.test.main()
+
+
+if __name__ == '__main__':
+  app.run(main)
diff --git a/integrations/tensorflow/e2e/quantization_test.py b/integrations/tensorflow/e2e/quantization_test.py
new file mode 100644
index 0000000..2ccf8f8
--- /dev/null
+++ b/integrations/tensorflow/e2e/quantization_test.py
@@ -0,0 +1,55 @@
+# Lint as: python3
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tests for ops in the tf.math module."""
+
+from absl import app
+import numpy as np
+from pyiree.tf.support import tf_test_utils
+from pyiree.tf.support import tf_utils
+import tensorflow.compat.v2 as tf
+
+
+class QuantizationModule(tf_test_utils.TestModule):
+
+  @tf_test_utils.tf_function_unit_test(
+      input_signature=[tf.TensorSpec([32], tf.float32)],
+      input_generator=lambda *args: tf_utils.uniform(*args, low=-6, high=6))
+  def fake_quant(self, x):
+    return tf.quantization.fake_quant_with_min_max_args(x,
+                                                        min=-6,
+                                                        max=6,
+                                                        num_bits=8,
+                                                        narrow_range=False,
+                                                        name=None)
+
+
+class QuantizationTest(tf_test_utils.TracedModuleTestCase):
+
+  def __init__(self, *args, **kwargs):
+    super().__init__(*args, **kwargs)
+    self._modules = tf_test_utils.compile_tf_module(QuantizationModule)
+
+
+def main(argv):
+  del argv  # Unused
+  if hasattr(tf, 'enable_v2_behavior'):
+    tf.enable_v2_behavior()
+
+  QuantizationTest.generate_unit_tests(QuantizationModule)
+  tf.test.main()
+
+
+if __name__ == '__main__':
+  app.run(main)
diff --git a/scripts/update_e2e_coverage.py b/scripts/update_e2e_coverage.py
index fed8c12..3b5509c 100755
--- a/scripts/update_e2e_coverage.py
+++ b/scripts/update_e2e_coverage.py
@@ -43,7 +43,12 @@
 KWS_LINK = f'[Keyword Spotting Streaming]({KWS_LINK})'
 
 COVERAGE_GROUP_TO_TEST_SUITES = {
-    'tf_base_coverage': ['//integrations/tensorflow/e2e:e2e_tests'],
+    'tf_base_coverage': [
+        '//integrations/tensorflow/e2e:e2e_tests',
+        '//integrations/tensorflow/e2e/math:math_tests',
+        '//integrations/tensorflow/e2e/math:math_dynamic_dims_tests',
+        '//integrations/tensorflow/e2e/math:math_complex_tests',
+    ],
     'tf_keras_coverage': [
         '//integrations/tensorflow/e2e/keras/layers:layers_tests',
         '//integrations/tensorflow/e2e/keras/layers:layers_dynamic_batch_tests',
@@ -81,8 +86,16 @@
 }
 
 TEST_SUITES_TO_HEADERS = {
+    # tf_base_coverage
     '//integrations/tensorflow/e2e:e2e_tests':
         'End to end TensorFlow tests',
+    '//integrations/tensorflow/e2e/math:math_tests':
+        'End to end tests of tf.math functions with static dimensions',
+    '//integrations/tensorflow/e2e/math:math_dynamic_dims_tests':
+        'End to end tests of tf.math functions with dynamic dimensions',
+    '//integrations/tensorflow/e2e/math:math_complex_tests':
+        'End to end tests of tf.math functions with complex numbers',
+    # tf_keras_coverage
     '//integrations/tensorflow/e2e/keras/layers:layers_tests':
         'End to end tests of tf.keras layers (with default configuration and '
         'static batch sizes in inference mode)',
@@ -95,12 +108,14 @@
     '//integrations/tensorflow/e2e/keras/layers:layers_training_tests':
         'End to end tests of tf.keras layers in training mode (with default'
         'configuration and static batch sizes)',
+    # language_and_speech_coverage
     '//integrations/tensorflow/e2e:mobile_bert_squad_tests':
         'End to end test of MobileBert on SQuAD',
     '//integrations/tensorflow/e2e/keras:keyword_spotting_tests':
         f'End to end tests of {KWS_LINK} models',
     '//integrations/tensorflow/e2e/keras:keyword_spotting_internal_streaming_tests':
         f'End to end tests of {KWS_LINK} models in internal streaming mode',
+    # vision_coverage
     '//integrations/tensorflow/e2e/keras:imagenet_non_hermetic_tests':
         'End to end tests of tf.keras.applications vision models on Imagenet',
     '//integrations/tensorflow/e2e/slim_vision_models:slim_vision_tests':
@@ -108,18 +123,31 @@
 }
 
 TEST_SUITES_TO_NOTES = {
+    '//integrations/tensorflow/e2e/math:math_tests':
+        ('**Note:** To be thorough, these tests use high rank tensors and\n'
+         'test int dtypes where TensorFlow allows them to be used. Both of\n'
+         'these choices disproportionately affect TFLite coverage, and\n'
+         'don\'t represent coverage for simple use cases.\n'),
     '//integrations/tensorflow/e2e/keras/layers:layers_tests': (
         '**Note:** Layers like `Dropout` are listed as passing in this table,\n'
         'but they function similar to identity layers in these tests. **See \n'
         'the third table for the coverage of these layers during training.**\n'
         '\n'
         'These tests also only modify required `tf.keras.layers` arguments.\n'
-        'See the full API tests below for coverage on of non-default '
+        'See the full API tests below for coverage on of non-default\n'
         'layer configurations.'),
 }
 # Key to use as the name of the rows in the left column for each test in the
 # suite.
 TEST_SUITE_TO_ROW_ID_KEY = {
+    # tf_base_coverage
+    '//integrations/tensorflow/e2e/math:math_tests':
+        'functions',
+    '//integrations/tensorflow/e2e/math:math_dynamic_dims_tests':
+        'functions',
+    '//integrations/tensorflow/e2e/math:math_complex_tests':
+        'functions',
+    # tf_keras_coverage
     '//integrations/tensorflow/e2e/keras/layers:layers_tests':
         'layer',
     '//integrations/tensorflow/e2e/keras/layers:layers_full_api_tests':
@@ -128,10 +156,12 @@
         'layer',
     '//integrations/tensorflow/e2e/keras/layers:layers_training_tests':
         'layer',
+    # language_and_speech_coverage
     '//integrations/tensorflow/e2e/keras:keyword_spotting_tests':
         'model',
     '//integrations/tensorflow/e2e/keras:keyword_spotting_internal_streaming_tests':
         'model',
+    # vision_coverage
     '//integrations/tensorflow/e2e/keras:imagenet_non_hermetic_tests':
         'model',
     '//integrations/tensorflow/e2e/slim_vision_models:slim_vision_tests':
@@ -141,6 +171,14 @@
 # Some test suites are generated from a single source. This allows us to point
 # to the right test file when generating test URLs.
 SINGLE_SOURCE_SUITES = {
+    # tf_base_coverage
+    '//integrations/tensorflow/e2e/math:math_tests':
+        'math_test',
+    '//integrations/tensorflow/e2e/math:math_dynamic_dims_tests':
+        'math_test',
+    '//integrations/tensorflow/e2e/math:math_complex_tests':
+        'math_test',
+    # tf_keras_coverage
     '//integrations/tensorflow/e2e/keras/layers:layers_tests':
         'layers_test',
     '//integrations/tensorflow/e2e/keras/layers:layers_full_api_tests':
@@ -149,10 +187,12 @@
         'layers_test',
     '//integrations/tensorflow/e2e/keras/layers:layers_training_tests':
         'layers_test',
+    # language_and_speech_coverage
     '//integrations/tensorflow/e2e/keras:keyword_spotting_tests':
         'keyword_spotting_streaming_test',
     '//integrations/tensorflow/e2e/keras:keyword_spotting_internal_streaming_tests':
         'keyword_spotting_streaming_test',
+    # vision_coverage
     '//integrations/tensorflow/e2e/keras:imagenet_non_hermetic_tests':
         'vision_model_test',
     '//integrations/tensorflow/e2e/slim_vision_models:slim_vision_tests':