Add decorator for generating unittests from tf.functions (#3782)

This change adds `tf_test_utils.tf_function_unittest`, which is a wrapper
around `tf.function` that allows us to specify the metadata that we need
to generate unit tests for a large portion of our e2e tests. 

This can be used declaratively (see `conv_test.py` and `simple_arithmetic_test.py`
as examples), but is particularly useful when generating `tf.function`s to test 
from a configuration. We use `tf_function_unittest` in `keras/layers/layers_test.py`
to add tests of multiple configurations for each layer (e.g. testing `Conv2D` 
with padding=same, strides, and dilation).

This change also simplifies the `keras/layers/BUILD` file by separating the failing
configurations for `layers_tests`, `layers_full_api_tests`, `layers_dynamic_batch_tests`
and `layers_training_tests` into their own variables. This adds redundancy, but should
be easier for others to maintain as they don't have to handle the situation where
enabling one test causing failures in another test suite.
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
index 111ad46..b2b10ec 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
@@ -693,6 +693,104 @@
   return _global_modules
 
 
+def tf_function_unittest(input_generator: tf_utils.InputGeneratorType = None,
+                         input_args: Sequence[Any] = None,
+                         atol: float = None,
+                         rtol: float = None,
+                         name: str = None,
+                         **tf_function_kwargs):
+  """Creates a tf.function that can be used to generate unittests.
+
+  If 'input_generator' and 'input_args' are unspecified then the function will
+  be tested using random uniform data.
+
+  Args:
+    input_generator:
+      an optional callable taking a shape and dtype that returns input data for
+      the unittest.
+    input_args:
+      an optional sequence of values to pass as positional args to the function.
+    atol:
+      optional, the absolute tolerance to use when comparing the decorated
+      function's output.
+    rtol:
+      optional, the relative tolerance to use when comparing the decorated
+      function's output.
+    name:
+      optional, the name to reference this function with. Must be used if
+      decorating a lambda.
+
+  Raises:
+    ValueError: if 'input_generator' and 'input_args' are both specified.
+
+  Returns:
+    A tf.function with the additional attributes 'input_generator' (from above)
+    'trace_kwargs' (from 'atol' and 'rtol' above), and with an updated
+    __name__ attribute if 'name' was specified.
+  """
+
+  def _store_unittest_info(function):
+    # Validate arguments.
+    if input_generator is not None and input_args is not None:
+      raise ValueError(
+          "'input_generator' and 'input_args' cannot both be specified.")
+
+    function = tf.function(**tf_function_kwargs)(function)
+
+    # Used to identify that the tf.function was created by this decorator.
+    function.is_tf_function_unittest = True
+
+    # Set function.get_trace_args.
+    if input_generator is not None:
+      # Use the user-specificed input_generator.
+      function.get_trace_args = lambda: tf_utils.generate_inputs(
+          function.input_signature, input_generator)
+    elif input_args is not None:
+      # Use the user-specified input_args.
+      function.get_trace_args = lambda: copy.deepcopy(input_args)
+    else:
+      # No user data specification – default to using random uniform data.
+      function.get_trace_args = lambda: tf_utils.generate_inputs(
+          function.input_signature, tf_utils.uniform)
+
+    # Set function.trace_kwargs.
+    function.trace_kwargs = dict(atol=atol, rtol=rtol)
+
+    # Set function.__name__.
+    if name is not None:
+      function.__name__ = name
+    elif function.__name__ == "<lambda>":
+      raise ValueError("The 'name' kwarg must be provided when decorating a "
+                       "lambda function.")
+
+    return function
+
+  return _store_unittest_info
+
+
+class TestModule(tf.Module):
+  """Thin wrapper of tf.Module with helper methods for tf_function_unittests."""
+
+  @classmethod
+  def get_tf_function_unittests(cls):
+    """Get all tf_function_unittest-created tf.functions on the class."""
+    tf_function_unittests = []
+    for name in dir(cls):
+      value = getattr(cls, name)
+      if hasattr(value, 'is_tf_function_unittest'):
+        tf_function_unittests.append(value)
+
+    if not len(tf_function_unittests):
+      raise ValueError(
+          "'get_tf_function_unittests' was called but no unittests were found.")
+    return tf_function_unittests
+
+  @classmethod
+  def get_exported_names(cls):
+    """Get the names of all tf_function_unittest-created tf.functions"""
+    return [function.__name__ for function in cls.get_tf_function_unittests()]
+
+
 class TracedModuleTestCase(tf.test.TestCase):
   """Compiles a tf.Module to multiple backends to test their correctness."""
 
@@ -703,6 +801,37 @@
     for module in self._modules.tar_modules:
       module.reinitialize()
 
+  @classmethod
+  def generate_unittests(cls, module_class: Type[TestModule]):
+    """Generates unittests for each 'tf_function_unittest' on 'module_class'."""
+    for function in module_class.get_tf_function_unittests():
+      # We have to pass the closure argument 'funcion' to 'trace' via a kwarg
+      # instead of using it directly in the body because 'function' is
+      # overwritten in each iteration of this loop, and python will only use
+      # the most recent version of 'function'. If we didn't do this, then we
+      # would only test the last function in this loop. The same is true for
+      # passing 'trace' to 'unittest'.
+
+      # Runs the inputs through a (traced) module.
+      def trace(module, function=function):
+        getattr(module, function.__name__)(*function.get_trace_args(),
+                                           **function.trace_kwargs)
+
+      # Give the trace the name of the tf.function that it is testing.
+      trace.__name__ = function.__name__
+
+      # Runs 'trace' on modules compiled to each backend and compares them.
+      def unittest(self, trace=trace):
+        self.compare_backends(trace, self._modules)
+
+      # Make 'unittest' a function on the TracedModuleTestCase, which tells
+      # the test runner to run it.
+      unittest.__name__ = f"test_{function.__name__}"
+      if hasattr(cls, unittest.__name__):
+        raise ValueError("Tried to generate multiple instances of the unittest "
+                         f"'{unittest.__name__}'.")
+      setattr(cls, unittest.__name__, unittest)
+
   def compare_backends(self, trace_function: Callable[[TracedModule], None],
                        modules: Modules) -> None:
     """Run the reference and target backends on trace_function and compare them.
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils_test.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils_test.py
index efd09ff..63cfa2f 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils_test.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils_test.py
@@ -52,7 +52,55 @@
     self.count.assign_sub(tf.constant([1.]))
 
 
-class UtilsTests(tf.test.TestCase, parameterized.TestCase):
+class TfFunctionUnittestModule(tf_test_utils.TestModule):
+
+  @tf_test_utils.tf_function_unittest(input_signature=[])
+  def no_args(self):
+    return np.array([True], dtype=np.bool)
+
+  @tf_test_utils.tf_function_unittest(input_signature=[
+      tf.TensorSpec([4]),
+      tf.TensorSpec([4]),
+  ])
+  def default_uniform_inputs(self, a, b):
+    return a + b
+
+  @tf_test_utils.tf_function_unittest(
+      input_signature=[
+          tf.TensorSpec([4]),
+          tf.TensorSpec([4]),
+      ],
+      input_generator=tf_utils.ndarange,
+  )
+  def custom_input_generator(self, a, b):
+    return a + b
+
+  @tf_test_utils.tf_function_unittest(
+      input_signature=[
+          tf.TensorSpec([4]),
+          tf.TensorSpec([4]),
+      ],
+      input_args=[
+          np.array([0, 1, 2, 3], np.float32),
+          -np.array([0, 1, 2, 3], np.float32),
+      ],
+  )
+  def custom_input_args(self, a, b):
+    return a + b
+
+  # This test will fail if atol is not successfully set.
+  @tf_test_utils.tf_function_unittest(
+      input_signature=[
+          tf.TensorSpec([128, 3072], tf.float32),
+          tf.TensorSpec([3072, 256], tf.float32),
+      ],
+      atol=1e-2,
+  )
+  def high_tolerance(self, a, b):
+    return tf.matmul(a, b)
+
+
+class TestUtilsTests(tf.test.TestCase, parameterized.TestCase):
 
   @parameterized.named_parameters([
       {
@@ -205,6 +253,27 @@
         if key != 'calls':
           self.assertEqual(trace.__dict__[key], loaded_trace.__dict__[key])
 
+  def test_tf_function_unittet(self):
+
+    class TfFunctionUnittestTest(tf_test_utils.TracedModuleTestCase):
+
+      def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        self._modules = tf_test_utils.compile_tf_module(
+            TfFunctionUnittestModule)
+
+    TfFunctionUnittestTest.generate_unittests(TfFunctionUnittestModule)
+    test_case = TfFunctionUnittestTest()
+    self.assertTrue(hasattr(test_case, 'test_no_args'))
+    self.assertTrue(hasattr(test_case, 'test_default_uniform_inputs'))
+    self.assertTrue(hasattr(test_case, 'test_custom_input_generator'))
+    self.assertTrue(hasattr(test_case, 'test_custom_input_args'))
+    self.assertTrue(hasattr(test_case, 'test_high_tolerance'))
+
+    # Will throw an error if 'atol' and 'rtol' are not set.
+    test_case = TfFunctionUnittestTest()
+    test_case.test_high_tolerance()
+
 
 if __name__ == '__main__':
   tf.test.main()
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py
index 66699ac..8d383c9 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py
@@ -37,14 +37,44 @@
   np.random.seed(seed)
 
 
-def uniform(shape: Sequence[int], dtype: np.dtype = np.float32) -> np.ndarray:
-  return np.random.uniform(size=shape).astype(dtype)
+InputGeneratorType = Callable[[Sequence[int], Union[tf.DType, np.dtype]],
+                              np.ndarray]
 
 
-def ndarange(shape: Sequence[int], dtype: np.dtype = np.float32) -> np.ndarray:
+def uniform(shape: Sequence[int],
+            dtype: Union[tf.DType, np.dtype] = np.float32,
+            low: float = 0.,
+            high: float = 1.) -> np.ndarray:
+  """np.random.uniform with simplified API and dtype control."""
+  dtype = dtype.as_numpy_dtype if isinstance(dtype, tf.DType) else dtype
+  return np.random.uniform(size=shape, low=low, high=high).astype(dtype)
+
+
+def ndarange(shape: Sequence[int],
+             dtype: Union[tf.DType, np.dtype] = np.float32) -> np.ndarray:
+  """np.ndarange for arbitrary input shapes."""
+  dtype = dtype.as_numpy_dtype if isinstance(dtype, tf.DType) else dtype
   return np.arange(np.prod(shape), dtype=dtype).reshape(shape)
 
 
+def generate_inputs(
+    spec,  # Union[Sequence[tf.TensorSpec], tf.TensorSpec]
+    input_generator: InputGeneratorType,
+) -> Sequence[np.ndarray]:
+  """Generates inputs for a given input signature using 'input_generator'."""
+  if isinstance(spec, Sequence):
+    # 'spec' is a sequence of 'tf.TensorSpec'.
+    # Recursively generate inputs.
+    return [generate_inputs(s, input_generator) for s in spec]
+  elif isinstance(spec, tf.TensorSpec):
+    # Handle dynamic shapes (e.g. batches) by substituting an int for None.
+    shape = [size if size is not None else 2 for size in spec.shape]
+    return input_generator(shape, spec.dtype)
+  else:
+    raise TypeError("Expected 'spec' to be a sequence of 'tf.TensorSpec' or "
+                    f"'tf.TensorSpec', but got '{type(spec)}'")
+
+
 def to_mlir_type(dtype: np.dtype) -> str:
   """Returns a string that denotes the type 'dtype' in MLIR style."""
   bits = dtype.itemsize * 8
diff --git a/integrations/tensorflow/e2e/README.md b/integrations/tensorflow/e2e/README.md
index 6aaf9e0..0ac1605 100644
--- a/integrations/tensorflow/e2e/README.md
+++ b/integrations/tensorflow/e2e/README.md
@@ -72,14 +72,132 @@
 
 ## Writing Tests
 
+There are two ways to write tests – via `tf_test_utils.tf_function_unittest` and
+via test methods on a child of `tf_test_utils.TracedModuleTestCase`.
+
+### Via `tf_test_utils.tf_function_unittest`
+
+This is preferred in the cases where
+
+1. Only a single call to the module needs to be tested at once
+2. The inputs are simple to automatically generate or specify inline.
+3. The functions that you want to test are generated automatically from a
+   configuration (e.g. in `.../e2e/keras/layers/layers_test.py`)
+
+Tests are specified by writing modules that inherit from
+`tf_test_utils.TestModule` (which is a thin wrapper around `tf.Module`) with
+methods decorated with `@tf_test_utils.tf_function_unittest` (with is a thin
+wrapper around `tf.function`).
+
+#### Basic example
+
+We use part of `.../e2e/conv_test.py` as an example. The first component is
+the `TestModule` itself:
+
+```python
+class Conv2dModule(tf_test_utils.TestModule):
+
+  # This decorator tells the testing infra to generate a unittest for this
+  # function. The 'input_signature' is required. If no other arguments are
+  # specified then uniform random data is generated from the input signature
+  # to numerically test the function.
+  @tf_test_utils.tf_function_unittest(input_signature=[
+      tf.TensorSpec([1, 4, 5, 1], tf.float32),
+      tf.TensorSpec([1, 1, 1, 1], tf.float32),
+  ])
+  def conv2d_1451x1111_valid(self, img, kernel):
+    return tf.nn.conv2d(img, kernel, [1, 1, 1, 1], "VALID", name="result")
+
+  @tf_test_utils.tf_function_unittest(input_signature=[
+      tf.TensorSpec([2, 4, 5, 1], tf.float32),
+      tf.TensorSpec([1, 1, 1, 1], tf.float32),
+  ])
+  def conv2d_2451x1111_valid(self, img, kernel):
+    return tf.nn.conv2d(img, kernel, [1, 1, 1, 1], "VALID", name="result")
+```
+
+Second, you need to write a test case that inherits from
+`tf_test_utils.TracedModuleTestCase`. This is essentially boiler plate that
+tells `tf.test.main()` what `tf.Module` to test and allows us to generate
+the unittests we specified above.
+
+```python
+class ConvTest(tf_test_utils.TracedModuleTestCase):
+
+  def __init__(self, *args, **kwargs):
+    super().__init__(*args, **kwargs)
+    self._modules = tf_test_utils.compile_tf_module(Conv2dModule)
+```
+
+Finally, in the `main` function, you need to call
+`.generate_unittests(module_class)` on your `TestCase` to actually generate
+the unittests that we specified:
+
+```python
+def main(argv):
+  del argv  # Unused
+  if hasattr(tf, 'enable_v2_behavior'):
+    tf.enable_v2_behavior()
+  # Generates unittests for all @tf_test_utils.tf_function_unittest decorated
+  # functions on the module class.
+  # Note: if you are automatically generating functions to test they need to be
+  # specified via a `classmethod` prior to this call _as well_ as via `__init__`
+  # to properly handle stateful `tf.function`s.
+  ConvTest.generate_unittests(Conv2dModule)
+  tf.test.main()
+
+
+if __name__ == '__main__':
+  app.run(main)
+```
+
+This generates two unittests: `test_conv2d_1451x1111_valid` and
+`test_conv2d_2451x1111_valid`.
+
+#### Configuring `@tf_test_utils.tf_function_unittest`
+
+By default `@tf_test_utils.tf_function_unittest` uses uniform random input data
+to numerically test the function, but you can specify an `input_generator` or
+`input_args` to test data-specific behaviors:
+
+- `input_generator` can be `tf_utils.uniform`, `tf_utils.ndarange`, or any
+function which takes an `shape` and `dtype` as positional args and returns an
+`np.ndarray`.
+- `input_args` is a list of `np.ndarray`s to use as positional arguments.
+
+The comparison `atol` and `rtol` can also be specified in the decorator.
+
+### Via test methods
+
+This is preferred in the cases where
+
+1. The `tf.function` that you want to test is already defined on the module
+   (e.g. on a downloaded model like in `mobile_bert_test.py`)
+2. The inputs are difficult to specify inline and require multiple function
+   calls / reshaping to create
+3. You want to test multiple consecutive calls to a `tf.function` (e.g. to test
+   mutated state in `ring_buffer_test.py`)
+
 Our tests use a class `TracedModule` to capture and store all of the inputs and
 outputs of a `CompiledModule` in a `Trace`. Each unittest on a `TestCase` uses
 the `compare_backends` method. This method runs the function it is passed with a
 `TracedModule` once for each reference and target backend. The inputs and
 outputs to these modules are then checked for correctness, using the reference
-backend as a source of truth. For example:
+backend as a source of truth.
+
+We use `simple_arithmetic_test.py` as an example:
 
 ```python
+# Create a tf.Module with one or more `@tf.function` decorated methods to test.
+class SimpleArithmeticModule(tf.Module):
+
+  @tf.function(input_signature=[
+      tf.TensorSpec([4], tf.float32),
+      tf.TensorSpec([4], tf.float32)
+  ])
+  def simple_mul(self, a, b):
+    return a * b
+
 # Inherit from `TracedModuleTestCase`.
 class SimpleArithmeticTest(tf_test_utils.TracedModuleTestCase):
 
@@ -153,12 +271,20 @@
 ## Generated Artifacts
 
 By default, running an E2E test generates a number of compilation, debugging and
-benchmarking artifacts in `/tmp/iree/modules/`. The location of these artifacts
-can be changed via the `--artifacts_dir` flag. The generated directory structure
-for each module is as follows:
+benchmarking artifacts. These artifacts will be saved
+
+- in `/tmp/iree/modules/` when using `bazel run` or `bazel_test` with
+  `--test_arg=--artifacts_dir=/tmp/iree/modules/`.
+- in `bazel-testlogs/integrations/tensorflow/e2e/test_suite_target_name` when
+  using `bazel test` without specifying `--artifacts_dir`.
+
+The generated directory structure for each module is as follows:
 
 ```shell
 /tmp/iree/modules/ModuleName
+  ├── reproducer__backend.mlir
+  │   # If there is a compilation error, a MLIR file that reproduces the error
+  │   # for a specific backend is included.
   ├── tf_input.mlir
   │   # MLIR for ModuleName in TF's input dialect.
   ├── iree_input.mlir
diff --git a/integrations/tensorflow/e2e/bool_test.py b/integrations/tensorflow/e2e/bool_test.py
index 13adebc..df09ecb 100644
--- a/integrations/tensorflow/e2e/bool_test.py
+++ b/integrations/tensorflow/e2e/bool_test.py
@@ -20,20 +20,28 @@
 import tensorflow.compat.v2 as tf
 
 
-class BooleanModule(tf.Module):
+class BooleanModule(tf_test_utils.TestModule):
 
-  @tf.function(input_signature=[])
+  @tf_test_utils.tf_function_unittest(input_signature=[])
   def constant(self):
     return np.array([True, False, True], dtype=np.bool)
 
-  @tf.function(input_signature=[tf.TensorSpec([4], tf.float32)])
+  @tf_test_utils.tf_function_unittest(
+      input_signature=[tf.TensorSpec([4], tf.float32)],
+      input_args=[np.array([0.0, 1.2, 1.5, 3.75], dtype=np.float32)])
   def greater_than(self, x):
     return x > 1.0
 
-  @tf.function(input_signature=[
-      tf.TensorSpec([4], tf.bool),
-      tf.TensorSpec([4], tf.bool)
-  ])
+  @tf_test_utils.tf_function_unittest(
+      input_signature=[
+          tf.TensorSpec([4], tf.bool),
+          tf.TensorSpec([4], tf.bool)
+      ],
+      input_args=[
+          np.array([True, True, False, False], dtype=np.bool),
+          np.array([True, False, False, True], dtype=np.bool)
+      ],
+  )
   def logical_and(self, x, y):
     return tf.math.logical_and(x, y)
 
@@ -44,33 +52,12 @@
     super().__init__(*args, **kwargs)
     self._modules = tf_test_utils.compile_tf_module(BooleanModule)
 
-  def test_constant(self):
-
-    def constant(module):
-      module.constant()
-
-    self.compare_backends(constant, self._modules)
-
-  def test_greater_than(self):
-
-    def greater_than(module):
-      module.greater_than(np.array([0.0, 1.2, 1.5, 3.75], dtype=np.float32))
-
-    self.compare_backends(greater_than, self._modules)
-
-  def test_logical_and(self):
-
-    def logical_and(module):
-      module.logical_and(np.array([True, True, False, False], dtype=np.bool),
-                         np.array([True, False, False, True], dtype=np.bool))
-
-    self.compare_backends(logical_and, self._modules)
-
 
 def main(argv):
   del argv  # Unused
   if hasattr(tf, 'enable_v2_behavior'):
     tf.enable_v2_behavior()
+  BooleanTest.generate_unittests(BooleanModule)
   tf.test.main()
 
 
diff --git a/integrations/tensorflow/e2e/conv_test.py b/integrations/tensorflow/e2e/conv_test.py
index f9e6f3d..11ca8d6 100644
--- a/integrations/tensorflow/e2e/conv_test.py
+++ b/integrations/tensorflow/e2e/conv_test.py
@@ -20,16 +20,16 @@
 import tensorflow.compat.v2 as tf
 
 
-class Conv2dModule(tf.Module):
+class Conv2dModule(tf_test_utils.TestModule):
 
-  @tf.function(input_signature=[
+  @tf_test_utils.tf_function_unittest(input_signature=[
       tf.TensorSpec([1, 4, 5, 1], tf.float32),
       tf.TensorSpec([1, 1, 1, 1], tf.float32),
   ])
   def conv2d_1451x1111_valid(self, img, kernel):
     return tf.nn.conv2d(img, kernel, [1, 1, 1, 1], "VALID", name="result")
 
-  @tf.function(input_signature=[
+  @tf_test_utils.tf_function_unittest(input_signature=[
       tf.TensorSpec([1, 4, 5, 1], tf.float32),
       tf.TensorSpec([2, 2, 1, 1], tf.float32),
   ])
@@ -40,7 +40,7 @@
                         dilations=[1, 2, 1, 1],
                         name="result")
 
-  @tf.function(input_signature=[
+  @tf_test_utils.tf_function_unittest(input_signature=[
       tf.TensorSpec([1, 4, 5, 2], tf.float32),
       tf.TensorSpec([2, 2, 2, 3], tf.float32),
   ])
@@ -51,70 +51,70 @@
                         dilations=[1, 2, 1, 1],
                         name="result")
 
-  @tf.function(input_signature=[
+  @tf_test_utils.tf_function_unittest(input_signature=[
       tf.TensorSpec([2, 4, 5, 1], tf.float32),
       tf.TensorSpec([1, 1, 1, 1], tf.float32),
   ])
   def conv2d_2451x1111_valid(self, img, kernel):
     return tf.nn.conv2d(img, kernel, [1, 1, 1, 1], "VALID", name="result")
 
-  @tf.function(input_signature=[
+  @tf_test_utils.tf_function_unittest(input_signature=[
       tf.TensorSpec([1, 4, 5, 1], tf.float32),
       tf.TensorSpec([2, 3, 1, 1], tf.float32),
   ])
   def conv2d_1451x2311_valid(self, img, kernel):
     return tf.nn.conv2d(img, kernel, [1, 1, 1, 1], "VALID", name="result")
 
-  @tf.function(input_signature=[
+  @tf_test_utils.tf_function_unittest(input_signature=[
       tf.TensorSpec([1, 4, 5, 1], tf.float32),
       tf.TensorSpec([2, 3, 1, 1], tf.float32),
   ])
   def conv2d_1451x2311_same(self, img, kernel):
     return tf.nn.conv2d(img, kernel, [1, 1, 1, 1], "SAME", name="result")
 
-  @tf.function(input_signature=[
+  @tf_test_utils.tf_function_unittest(input_signature=[
       tf.TensorSpec([2, 4, 5, 1], tf.float32),
       tf.TensorSpec([2, 3, 1, 1], tf.float32),
   ])
   def conv2d_2451x2311_same(self, img, kernel):
     return tf.nn.conv2d(img, kernel, [1, 1, 1, 1], "SAME", name="result")
 
-  @tf.function(input_signature=[
+  @tf_test_utils.tf_function_unittest(input_signature=[
       tf.TensorSpec([1, 4, 5, 2], tf.float32),
       tf.TensorSpec([3, 2, 2, 1], tf.float32),
   ])
   def conv2d_1452x3221_same(self, img, kernel):
     return tf.nn.conv2d(img, kernel, [1, 1, 1, 1], "SAME", name="result")
 
-  @tf.function(input_signature=[
+  @tf_test_utils.tf_function_unittest(input_signature=[
       tf.TensorSpec([1, 4, 5, 1], tf.float32),
       tf.TensorSpec([1, 1, 1, 2], tf.float32),
   ])
   def conv2d_1451x1112_same(self, img, kernel):
     return tf.nn.conv2d(img, kernel, [1, 1, 1, 1], "SAME", name="result")
 
-  @tf.function(input_signature=[
+  @tf_test_utils.tf_function_unittest(input_signature=[
       tf.TensorSpec([1, 4, 5, 2], tf.float32),
       tf.TensorSpec([1, 1, 2, 2], tf.float32),
   ])
   def conv2d_1452x1122_same(self, img, kernel):
     return tf.nn.conv2d(img, kernel, [1, 1, 1, 1], "SAME", name="result")
 
-  @tf.function(input_signature=[
+  @tf_test_utils.tf_function_unittest(input_signature=[
       tf.TensorSpec([1, 4, 5, 2], tf.float32),
       tf.TensorSpec([2, 2, 2, 3], tf.float32),
   ])
   def conv2d_1452x2223_same(self, img, kernel):
     return tf.nn.conv2d(img, kernel, [1, 1, 1, 1], "SAME", name="result")
 
-  @tf.function(input_signature=[
+  @tf_test_utils.tf_function_unittest(input_signature=[
       tf.TensorSpec([1, 4, 5, 2], tf.float32),
       tf.TensorSpec([2, 2, 2, 3], tf.float32),
   ])
   def conv2d_1452x2223_valid(self, img, kernel):
     return tf.nn.conv2d(img, kernel, [1, 1, 1, 1], "VALID", name="result")
 
-  @tf.function(input_signature=[
+  @tf_test_utils.tf_function_unittest(input_signature=[
       tf.TensorSpec([2, 4, 5, 2], tf.float32),
       tf.TensorSpec([2, 2, 2, 3], tf.float32),
   ])
@@ -128,107 +128,12 @@
     super().__init__(*args, **kwargs)
     self._modules = tf_test_utils.compile_tf_module(Conv2dModule)
 
-  # yapf: disable
-  def test_id_batch_size_1(self):
-    def id_batch_size_1(module):
-      i = tf_utils.uniform([1, 4, 5, 1])
-      k = tf_utils.uniform([1, 1, 1, 1], dtype=np.float32)
-      module.conv2d_1451x1111_valid(i, k)
-    self.compare_backends(id_batch_size_1, self._modules)
-
-  def test_id_dilated(self):
-    def id_batch_size_1(module):
-      i = tf_utils.uniform([1, 4, 5, 1])
-      k = tf_utils.uniform([2, 2, 1, 1], dtype=np.float32)
-      module.conv2d_1451x2211_dilated_valid(i, k)
-    self.compare_backends(id_batch_size_1, self._modules)
-
-  def test_id_multichannel_dilated(self):
-    def id_batch_size_1(module):
-      i = tf_utils.uniform([1, 4, 5, 2])
-      k = tf_utils.uniform([2, 2, 2, 3], dtype=np.float32)
-      module.conv2d_1452x2223_dilated_valid(i, k)
-    self.compare_backends(id_batch_size_1, self._modules)
-
-  def test_id_batch_size_2(self):
-    def id_batch_size_2(module):
-      i = tf_utils.uniform([2, 4, 5, 1])
-      k = tf_utils.uniform([1, 1, 1, 1], dtype=np.float32)
-      module.conv2d_2451x1111_valid(i, k)
-    self.compare_backends(id_batch_size_2, self._modules)
-
-  def test_asymmetric_kernel(self):
-    def asymmetric_kernel(module):
-      i = tf_utils.uniform([1, 4, 5, 1])
-      k = np.array([[1, 4, 2], [-2, 0, 1]],
-                   dtype=np.float32).reshape(2, 3, 1, 1)
-      module.conv2d_1451x2311_valid(i, k)
-    self.compare_backends(asymmetric_kernel, self._modules)
-
-  def test_padding(self):
-    def padding(module):
-      i = tf_utils.ndarange([1, 4, 5, 1])
-      k = np.array([[1, 4, 2], [-2, 0, 1]],
-                   dtype=np.float32).reshape(2, 3, 1, 1)
-      module.conv2d_1451x2311_same(i, k)
-    self.compare_backends(padding, self._modules)
-
-  def test_batched_padding(self):
-    def batched_padding(module):
-      i = tf_utils.ndarange([2, 4, 5, 1])
-      k = np.array([[1, 4, 2], [-2, 0, 1]],
-                   dtype=np.float32).reshape(2, 3, 1, 1)
-      module.conv2d_2451x2311_same(i, k)
-    self.compare_backends(batched_padding, self._modules)
-
-  def test_feature_reduce(self):
-    def feature_reduce(module):
-      i = tf_utils.ndarange([1, 4, 5, 2])
-      k = np.ones([3, 2, 2, 1], dtype=np.float32)
-      module.conv2d_1452x3221_same(i, k)
-    self.compare_backends(feature_reduce, self._modules)
-
-  def test_feature_inflate(self):
-    def feature_inflate(module):
-      i = tf_utils.ndarange([1, 4, 5, 1])
-      k = tf_utils.ndarange([1, 1, 1, 2])
-      module.conv2d_1451x1112_same(i, k)
-    self.compare_backends(feature_inflate, self._modules)
-
-  def test_feature_mix(self):
-    def feature_mix(module):
-      i = tf_utils.ndarange([1, 4, 5, 2])
-      k = tf_utils.ndarange([1, 1, 2, 2])
-      module.conv2d_1452x1122_same(i, k)
-    self.compare_backends(feature_mix, self._modules)
-
-  def test_feature_padded(self):
-    def feature_padded(module):
-      i = tf_utils.ndarange([1, 4, 5, 2])
-      k = tf_utils.ndarange([2, 2, 2, 3])
-      module.conv2d_1452x2223_same(i, k)
-    self.compare_backends(feature_padded, self._modules)
-
-  def test_feature_unpadded(self):
-    def feature_unpadded(module):
-      i = tf_utils.ndarange([1, 4, 5, 2])
-      k = tf_utils.ndarange([2, 2, 2, 3])
-      module.conv2d_1452x2223_valid(i, k)
-    self.compare_backends(feature_unpadded, self._modules)
-
-  def test_batched_feature_unpadded(self):
-    def batched_feature_unpadded(module):
-      i = tf_utils.ndarange([2, 4, 5, 2])
-      k = tf_utils.ndarange([2, 2, 2, 3])
-      module.conv2d_2452x2223_valid(i, k)
-    self.compare_backends(batched_feature_unpadded, self._modules)
-  # yapf: enable
-
 
 def main(argv):
   del argv  # Unused
   if hasattr(tf, 'enable_v2_behavior'):
     tf.enable_v2_behavior()
+  ConvTest.generate_unittests(Conv2dModule)
   tf.test.main()
 
 
diff --git a/integrations/tensorflow/e2e/keras/layers/BUILD b/integrations/tensorflow/e2e/keras/layers/BUILD
index 9fae94b..fea2a71 100644
--- a/integrations/tensorflow/e2e/keras/layers/BUILD
+++ b/integrations/tensorflow/e2e/keras/layers/BUILD
@@ -75,7 +75,7 @@
     "Conv2DTranspose",
     "Conv3D",
     "Conv3DTranspose",
-    "ConvLSTM2D",
+    # "ConvLSTM2D",  # TODO(meadowlark): Debug flakiness.
     "Cropping1D",
     "Cropping2D",
     "Cropping3D",
@@ -137,10 +137,6 @@
 
 FAILING_STATIC = [
     {
-        # This layer is numerically flaky – TODO(meadowlark) create minimal reproducer.
-        "layer": "ConvLSTM2D",  # Flaky
-    },
-    {
         # Wrapping these in a tf.function appears to cause a keras bug.
         "layer": [
             "ConvLSTM2D",
@@ -222,6 +218,7 @@
         "layer": LAYERS,
         "dynamic_batch": False,
         "training": False,
+        "test_full_api": False,
         "target_backends": [
             "tf",
             "tflite",
@@ -236,7 +233,145 @@
     ],
 )
 
-FAILING_DYNAMIC = FAILING_STATIC + [
+# A list of all layers with non-default api tests can be generated by running:
+#   bazel run integrations/tensorflow/e2e/keras/layers:layers_test_manual -- \
+#     --list_layers_with_full_api_tests
+LAYERS_WITH_FULL_API_TESTS = [
+    "AdditiveAttention",
+    "Attention",
+    "AveragePooling1D",
+    "AveragePooling2D",
+    "AveragePooling3D",
+    "BatchNormalization",
+    "Conv1D",
+    "Conv1DTranspose",
+    "Conv2D",
+    "Conv2DTranspose",
+    "Conv3D",
+    "Conv3DTranspose",
+    # "ConvLSTM2D",  # TODO(meadowlark): Debug flakiness.
+    "DepthwiseConv2D",
+    "GlobalAveragePooling1D",
+    "GlobalAveragePooling2D",
+    "GlobalAveragePooling3D",
+    "GlobalMaxPool1D",
+    "GlobalMaxPool2D",
+    "GlobalMaxPool3D",
+    "GRU",
+    "LocallyConnected1D",
+    "LocallyConnected2D",
+    "LSTM",
+    "MaxPool1D",
+    "MaxPool2D",
+    "MaxPool3D",
+    "SeparableConv1D",
+    "SeparableConv2D",
+    "SimpleRNN",
+]
+
+FAILING_FULL_API = [
+    {
+        # Failing on TFLite
+        "layer": [
+            "AveragePooling3D",
+            "Conv2DTranspose",
+            "Conv3D",
+            "Conv3DTranspose",
+            "ConvLSTM2D",
+            "DepthwiseConv2D",
+            "GRU",
+            "LocallyConnected1D",
+            "LocallyConnected2D",
+            "LSTM",
+            "MaxPool1D",
+            "MaxPool3D",
+            "SimpleRNN",
+        ],
+        "target_backends": "tflite",
+    },
+    {
+        # Failing on IREE
+        "layer": [
+            "Conv1D",
+            "Conv2D",
+            "Conv2DTranspose",
+            "Conv3DTranspose",
+            "Conv3D",
+            "ConvLSTM2D",
+            "DepthwiseConv2D",
+            "GRU",
+            "LocallyConnected1D",
+            "LocallyConnected2D",
+            "LSTM",
+            "SeparableConv1D",
+            "SeparableConv2D",
+            "SimpleRNN",
+        ],
+        "target_backends": [
+            "iree_vmla",
+            "iree_llvmjit",
+            "iree_vulkan",
+        ],
+    },
+    {
+        # Failing on LLVM and Vulakn
+        "layer": [
+            "AdditiveAttention",
+            "Attention",
+            "AveragePooling1D",
+            "AveragePooling2D",
+            "AveragePooling3D",
+            "Conv1DTranspose",
+        ],
+        "target_backends": [
+            "iree_llvmjit",
+            "iree_vulkan",
+        ],
+    },
+    {
+        # Failing on Vulkan
+        "layer": [
+            "MaxPool1D",
+            "MaxPool2D",
+            "MaxPool3D",
+        ],
+        "target_backends": "iree_vulkan",
+    },
+]
+
+iree_e2e_cartesian_product_test_suite(
+    name = "layers_full_api_tests",
+    srcs = ["layers_test.py"],
+    failing_configurations = FAILING_FULL_API,
+    flags_to_values = {
+        "reference_backend": "tf",
+        "layer": LAYERS_WITH_FULL_API_TESTS,
+        "dynamic_batch": False,
+        "training": False,
+        "test_full_api": True,
+        "target_backends": [
+            "tf",
+            "tflite",
+            "iree_vmla",
+            "iree_llvmjit",
+            "iree_vulkan",
+        ],
+    },
+    main = "layers_test.py",
+    deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
+        "//integrations/tensorflow/bindings/python/pyiree/tf/support",
+    ],
+)
+
+FAILING_DYNAMIC = [
+    {
+        # Wrapping these in a tf.function appears to cause a keras bug.
+        "layer": [
+            "GRUCell",
+            "LSTMCell",
+            "SimpleRNNCell",
+        ],
+    },
     {
         # TFLite does not support dynamic shapes.
         "target_backends": "tflite",
@@ -249,12 +384,13 @@
             "AveragePooling2D",
             "AveragePooling3D",
             "Concatenate",
-            "Conv1DTranspose",
             "Conv1D",
-            "Conv2DTranspose",
+            "Conv1DTranspose",
             "Conv2D",
-            "Conv3DTranspose",
+            "Conv2DTranspose",
             "Conv3D",
+            "Conv3DTranspose",
+            "ConvLSTM2D",
             "Cropping1D",
             "Cropping2D",
             "Cropping3D",
@@ -262,17 +398,24 @@
             "ELU",
             "Flatten",
             "GRU",
+            "LayerNormalization",
+            "LeakyReLU",
             "LocallyConnected1D",
+            "LocallyConnected2D",
             "LSTM",  # TODO(silvasean): Get this test working on IREE.
+            "Masking",
             "MaxPool1D",
             "MaxPool2D",
             "MaxPool3D",
+            "MultiHeadAttention",
             "RepeatVector",
             "Reshape",
             "SeparableConv1D",
             "SeparableConv2D",
+            "SimpleRNN",
             "ThresholdedReLU",
             "UpSampling1D",
+            "UpSampling2D",
             "UpSampling3D",
             "ZeroPadding1D",
             "ZeroPadding2D",
@@ -297,6 +440,7 @@
             "GlobalAveragePooling1D",
             "GlobalAveragePooling2D",
             "GlobalAveragePooling3D",
+            "Lambda",
             "Maximum",
             "Minimum",
             "Multiply",
@@ -326,6 +470,7 @@
         "layer": LAYERS,
         "dynamic_batch": True,
         "training": False,
+        "test_full_api": False,
         "target_backends": [
             "tf",
             "tflite",
@@ -340,50 +485,13 @@
     ],
 )
 
-FAILING_TRAINING = FAILING_STATIC + [
-    {
-        # Failing on TFLite:
-        "layer": [
-            "AlphaDropout",
-            "BatchNormalization",
-            "GRU",
-            "GaussianDropout",
-            "GaussianNoise",
-            "LSTM",
-            "SimpleRNN",
-        ],
-        "target_backends": "tflite",
-    },
-    {
-        # Failing on IREE
-        "layer": [
-            "AdditiveAttention",
-            "AlphaDropout",
-            "Attention",
-            "Dropout",
-            "GRU",
-            "GaussianDropout",
-            "GaussianNoise",
-            "LSTM",
-            "SpatialDropout1D",
-            "SpatialDropout2D",
-            "SpatialDropout3D",
-        ],
-        "target_backends": [
-            "iree_vmla",
-            "iree_llvmjit",
-            "iree_vulkan",
-        ],
-    },
-]
-
 # Layers that mention a training kwarg in their doc.
 LAYERS_WITH_TRAINING_BEHAVIOR = [
     "AdditiveAttention",
     "AlphaDropout",
     "Attention",
     "BatchNormalization",
-    "ConvLSTM2D",
+    # "ConvLSTM2D",  # TODO(meadowlark): Debug flakiness.
     "Dropout",
     "GRU",
     "GRUCell",
@@ -399,6 +507,55 @@
     "SpatialDropout3D",
 ]
 
+FAILING_TRAINING = [
+    {
+        # Wrapping these in a tf.function appears to cause a keras bug.
+        "layer": [
+            "GRUCell",
+            "LSTMCell",
+            "SimpleRNNCell",
+        ],
+    },
+    {
+        # Failing on TFLite:
+        "layer": [
+            "AlphaDropout",
+            "BatchNormalization",
+            "ConvLSTM2D",
+            "GaussianDropout",
+            "GaussianNoise",
+            "GRU",
+            "LSTM",
+            "SimpleRNN",
+        ],
+        "target_backends": "tflite",
+    },
+    {
+        # Failing on IREE
+        "layer": [
+            "AdditiveAttention",
+            "AlphaDropout",
+            "Attention",
+            "ConvLSTM2D",
+            "Dropout",
+            "GaussianDropout",
+            "GaussianNoise",
+            "GRU",
+            "LSTM",
+            "MultiHeadAttention",
+            "SimpleRNN",
+            "SpatialDropout1D",
+            "SpatialDropout2D",
+            "SpatialDropout3D",
+        ],
+        "target_backends": [
+            "iree_vmla",
+            "iree_llvmjit",
+            "iree_vulkan",
+        ],
+    },
+]
+
 iree_e2e_cartesian_product_test_suite(
     name = "layers_training_tests",
     srcs = ["layers_test.py"],
@@ -408,6 +565,7 @@
         "layer": LAYERS_WITH_TRAINING_BEHAVIOR,
         "dynamic_batch": False,
         "training": True,
+        "test_full_api": False,
         "target_backends": [
             "tf",
             "tflite",
diff --git a/integrations/tensorflow/e2e/keras/layers/layers_test.py b/integrations/tensorflow/e2e/keras/layers/layers_test.py
index bdebbb7..0ab5f32 100644
--- a/integrations/tensorflow/e2e/keras/layers/layers_test.py
+++ b/integrations/tensorflow/e2e/keras/layers/layers_test.py
@@ -14,8 +14,10 @@
 # limitations under the License.
 """Tests of tf.keras.layer.Layer subclasses."""
 
+import collections
+import copy
 import os
-from typing import Any, Sequence, Union
+from typing import Any, Dict, Sequence, Union
 
 from absl import app
 from absl import flags
@@ -35,164 +37,472 @@
 CONV_2D_INPUT = [2, 8, 8, 3]
 CONV_3D_INPUT = [2, 8, 8, 8, 3]
 
-LAYER_TO_INPUT_SHAPES = {
-    'Activation': [RANK_2_INPUT],
-    'ActivityRegularization': [RANK_2_INPUT],
-    'Add': [RANK_2_INPUT, RANK_2_INPUT],
-    'AdditiveAttention': [RANK_3_INPUT, RANK_3_INPUT, RANK_3_INPUT],
-    'AlphaDropout': [RANK_2_INPUT],
-    'Attention': [RANK_3_INPUT, RANK_3_INPUT, RANK_3_INPUT],
-    'Average': [RANK_2_INPUT, RANK_2_INPUT],
-    'AveragePooling1D': [CONV_1D_INPUT],
-    'AveragePooling2D': [CONV_2D_INPUT],
-    'AveragePooling3D': [CONV_3D_INPUT],
-    'BatchNormalization': [RANK_2_INPUT],
-    'Concatenate': [RANK_4_INPUT, RANK_4_INPUT],
-    'Conv1D': [CONV_1D_INPUT],
-    'Conv1DTranspose': [CONV_1D_INPUT],
-    'Conv2D': [CONV_2D_INPUT],
-    'Conv2DTranspose': [CONV_2D_INPUT],
-    'Conv3D': [CONV_3D_INPUT],
-    'Conv3DTranspose': [CONV_3D_INPUT],
-    'ConvLSTM2D': [CONV_3D_INPUT],
-    'Cropping1D': [CONV_1D_INPUT],
-    'Cropping2D': [CONV_2D_INPUT],
-    'Cropping3D': [CONV_3D_INPUT],
-    'Dense': [RANK_2_INPUT],
-    'DepthwiseConv2D': [CONV_2D_INPUT],
-    'Dot': [RANK_3_INPUT, RANK_3_INPUT],
-    'Dropout': [RANK_3_INPUT],
-    'ELU': [RANK_2_INPUT],
-    'Embedding': [RANK_2_INPUT],
-    'Flatten': [RANK_2_INPUT],
-    'GRU': [RANK_3_INPUT],
-    'GRUCell': [RANK_2_INPUT, RANK_2_INPUT],
-    'GaussianDropout': [RANK_2_INPUT],
-    'GaussianNoise': [RANK_2_INPUT],
-    'GlobalAveragePooling1D': [CONV_1D_INPUT],
-    'GlobalAveragePooling2D': [CONV_2D_INPUT],
-    'GlobalAveragePooling3D': [CONV_3D_INPUT],
-    'GlobalMaxPool1D': [CONV_1D_INPUT],
-    'GlobalMaxPool2D': [CONV_2D_INPUT],
-    'GlobalMaxPool3D': [CONV_3D_INPUT],
-    'InputLayer': [RANK_2_INPUT],
-    'LSTM': [RANK_3_INPUT],
-    'LSTMCell': [RANK_2_INPUT, RANK_2_INPUT],
-    'Lambda': [RANK_2_INPUT],
-    'LayerNormalization': [RANK_2_INPUT],
-    'LeakyReLU': [RANK_2_INPUT],
-    'LocallyConnected1D': [CONV_1D_INPUT],
-    'LocallyConnected2D': [CONV_2D_INPUT],
-    'Masking': [RANK_2_INPUT],
-    'MaxPool1D': [CONV_1D_INPUT],
-    'MaxPool2D': [CONV_2D_INPUT],
-    'MaxPool3D': [CONV_3D_INPUT],
-    'Maximum': [RANK_2_INPUT, RANK_2_INPUT],
-    'Minimum': [RANK_2_INPUT, RANK_2_INPUT],
-    'MultiHeadAttention': [RANK_3_INPUT, RANK_3_INPUT],
-    'Multiply': [RANK_2_INPUT, RANK_2_INPUT],
-    'PReLU': [RANK_2_INPUT],
-    'Permute': [RANK_4_INPUT],
-    'ReLU': [RANK_2_INPUT],
-    'RepeatVector': [RANK_2_INPUT],
-    'Reshape': [RANK_3_INPUT],
-    'SeparableConv1D': [CONV_1D_INPUT],
-    'SeparableConv2D': [CONV_2D_INPUT],
-    'SimpleRNN': [RANK_3_INPUT],
-    'SimpleRNNCell': [RANK_2_INPUT, RANK_2_INPUT],
-    'Softmax': [RANK_2_INPUT],
-    'SpatialDropout1D': [CONV_1D_INPUT],
-    'SpatialDropout2D': [CONV_2D_INPUT],
-    'SpatialDropout3D': [CONV_3D_INPUT],
-    'Subtract': [RANK_2_INPUT, RANK_2_INPUT],
-    'ThresholdedReLU': [RANK_2_INPUT],
-    'UpSampling1D': [CONV_1D_INPUT],
-    'UpSampling2D': [CONV_2D_INPUT],
-    'UpSampling3D': [CONV_3D_INPUT],
-    'ZeroPadding1D': [CONV_1D_INPUT],
-    'ZeroPadding2D': [CONV_2D_INPUT],
-    'ZeroPadding3D': [CONV_3D_INPUT],
+# Configs are namedtuples storing keyword arguments and shapes to test a
+# tf.keras.layers.Layer with. They are used in two ways:
+#   1. To directly specify the kwargs and shapes for a layers test.
+#   2. In 'generate_configs', to specify how to change a default config to
+#      specify a non-default test. In this case, the overriding Config will
+#      exclusively specify the shape of the test if its shape is not None, and
+#      the overriding Config will extend/update the kwargs of the default
+#      Config.
+Config = collections.namedtuple('Config', ['kwargs', 'shapes'])
+# Use old default API for compatibility with Python 3.6.
+Config.__new__.__defaults__ = (dict(), None)
+
+
+def generate_configs(default_config: Config,
+                     override_configs: Dict[str, Config]) -> Dict[str, Config]:
+  """Generates a dict of 'Config's based off changes to a default Config."""
+  configs = {'default': default_config}
+  for exported_name, config in override_configs.items():
+    shapes = default_config.shapes if config.shapes is None else config.shapes
+
+    # Deep copy to avoid inplace mutation of the default.
+    kwargs = copy.deepcopy(default_config.kwargs)
+    kwargs.update(config.kwargs)  # Adds new and overwrites old kwargs.
+
+    configs[exported_name] = Config(kwargs, shapes)
+  return configs
+
+
+# A dict mapping tf.keras.layers names to either a single Config (representing
+# the kwargs and shapes to use to test a Layer) or a dict mapping exported_names
+# to Configs. The latter case is usually automatically generated via
+# 'generate_configs', with the 'Config's in 'override_configs' specifying how
+# to modify the 'default_config's kwargs and shapes.
+#
+# Each entry will be normalized to be a dict mapping exported_names to Configs,
+# with a default exported_name of 'default'.
+LAYER_TO_UNITTEST_CONFIGURATIONS = {
+    'Activation':
+        Config(dict(activation='relu'), [RANK_2_INPUT]),
+    'ActivityRegularization':
+        Config(dict(l1=0.1, l2=0.1), shapes=[RANK_2_INPUT]),
+    'Add':
+        Config(shapes=[RANK_2_INPUT, RANK_2_INPUT]),
+    'AdditiveAttention':
+        generate_configs(
+            default_config=Config(
+                shapes=[RANK_3_INPUT, RANK_3_INPUT, RANK_3_INPUT],),
+            override_configs={
+                'causal': Config(dict(causal=True)),
+            },
+        ),
+    'AlphaDropout':
+        Config(dict(rate=DROPOUT), [RANK_2_INPUT]),
+    'Attention':
+        generate_configs(
+            default_config=Config(
+                shapes=[RANK_3_INPUT, RANK_3_INPUT, RANK_3_INPUT],),
+            override_configs={
+                'causal': Config(dict(causal=True)),
+            },
+        ),
+    'Average':
+        Config(shapes=[RANK_2_INPUT, RANK_2_INPUT]),
+    'AveragePooling1D':
+        generate_configs(
+            default_config=Config(shapes=[CONV_1D_INPUT]),
+            override_configs={
+                'strides': Config(dict(strides=3)),
+                'padding_same': Config(dict(padding='same')),
+                'channels_first': Config(dict(data_format='channels_first')),
+            },
+        ),
+    'AveragePooling2D':
+        generate_configs(
+            default_config=Config(shapes=[CONV_2D_INPUT]),
+            override_configs={
+                'strides': Config(dict(strides=3)),
+                'padding_same': Config(dict(padding='same')),
+                # TF: Default AvgPoolingOp only supports NHWC on device type CPU
+                # 'channels_first': Config(dict(data_format='channels_first')),
+            },
+        ),
+    'AveragePooling3D':
+        generate_configs(
+            default_config=Config(shapes=[CONV_3D_INPUT]),
+            override_configs={
+                'strides': Config(dict(strides=3)),
+                'padding_same': Config(dict(padding='same')),
+                'channels_first': Config(dict(data_format='channels_first')),
+            },
+        ),
+    'BatchNormalization':
+        generate_configs(
+            default_config=Config(shapes=[RANK_2_INPUT]),
+            override_configs={'renorm': Config(dict(renorm=True))},
+        ),
+    'Concatenate':
+        Config(shapes=[RANK_4_INPUT, RANK_4_INPUT]),
+    'Conv1D':
+        generate_configs(
+            default_config=Config(
+                kwargs=dict(filters=4, kernel_size=3),
+                shapes=[CONV_1D_INPUT],
+            ),
+            override_configs={
+                'strides': Config(dict(strides=3)),
+                'padding_same': Config(dict(padding='same')),
+                # TF: The Conv2D op currently only supports the NHWC tensor
+                #     format on the CPU.
+                # 'channels_first': Config(dict(data_format='channels_first')),
+                'dilation_rate': Config(dict(dilation_rate=3)),
+            },
+        ),
+    'Conv1DTranspose':
+        generate_configs(
+            default_config=Config(
+                kwargs=dict(filters=4, kernel_size=3),
+                shapes=[CONV_1D_INPUT],
+            ),
+            override_configs={
+                'strides': Config(dict(strides=3)),
+                'padding_same': Config(dict(padding='same')),
+                # TF: Conv2DCustomBackpropInputOp only supports NHWC
+                # 'channels_first': Config(dict(data_format='channels_first')),
+                # TF: Current libxsmm and customized CPU implementations do not
+                # yet support dilation rates larger than 1.
+                # 'dilation_rate': Config(dict(dilation_rate=3)),
+            },
+        ),
+    'Conv2D':
+        generate_configs(
+            default_config=Config(
+                kwargs=dict(filters=4, kernel_size=3),
+                shapes=[CONV_2D_INPUT],
+            ),
+            override_configs={
+                'strides': Config(dict(strides=3)),
+                'padding_same': Config(dict(padding='same')),
+                # TF: The Conv2D op currently only supports the NHWC tensor
+                #     format on the CPU.
+                # 'channels_first': Config(dict(data_format='channels_first')),
+                'dilation_rate': Config(dict(dilation_rate=3)),
+            },
+        ),
+    'Conv2DTranspose':
+        generate_configs(
+            default_config=Config(
+                kwargs=dict(filters=4, kernel_size=3),
+                shapes=[CONV_2D_INPUT],
+            ),
+            override_configs={
+                'strides': Config(dict(strides=3)),
+                'padding_same': Config(dict(padding='same')),
+                'channels_first': Config(dict(data_format='channels_first')),
+                'dilation_rate': Config(dict(dilation_rate=3)),
+            },
+        ),
+    'Conv3D':
+        generate_configs(
+            default_config=Config(
+                kwargs=dict(filters=4, kernel_size=3),
+                shapes=[CONV_3D_INPUT],
+            ),
+            override_configs={
+                'strides': Config(dict(strides=3)),
+                'padding_same': Config(dict(padding='same')),
+                # TF: The Conv3D op currently only supports the NHWC tensor
+                #     format on the CPU.
+                # 'channels_first': Config(dict(data_format='channels_first')),
+                'dilation_rate': Config(dict(dilation_rate=3)),
+            },
+        ),
+    'Conv3DTranspose':
+        generate_configs(
+            default_config=Config(
+                kwargs=dict(filters=4, kernel_size=3),
+                shapes=[CONV_3D_INPUT],
+            ),
+            override_configs={
+                'strides': Config(dict(strides=3)),
+                'padding_same': Config(dict(padding='same')),
+                # TF: Conv3DBackpropInputOpV2 only supports NDHWC on the CPU.
+                # 'channels_first': Config(dict(data_format='channels_first')),
+                'dilation_rate': Config(dict(dilation_rate=3)),
+            },
+        ),
+    'ConvLSTM2D':
+        generate_configs(
+            default_config=Config(
+                kwargs=dict(filters=4, kernel_size=3, return_state=True),
+                shapes=[CONV_3D_INPUT],
+            ),
+            override_configs={
+                'strides': Config(dict(strides=3)),
+                'padding': Config(dict(padding='same')),
+                'channels_first': Config(dict(data_format='channels_first')),
+                'dilation_rate': Config(dict(dilation_rate=3)),
+                'go_backwards': Config(dict(go_backwards=True)),
+                'stateful': Config(dict(stateful=True)),
+            },
+        ),
+    'Cropping1D':
+        Config(dict(cropping=2), [CONV_1D_INPUT]),
+    'Cropping2D':
+        Config(dict(cropping=2), [CONV_2D_INPUT]),
+    'Cropping3D':
+        Config(dict(cropping=2), [CONV_3D_INPUT]),
+    'Dense':
+        Config(dict(units=4), [RANK_2_INPUT]),
+    'DepthwiseConv2D':
+        generate_configs(
+            default_config=Config(
+                kwargs=dict(kernel_size=3),
+                shapes=[CONV_2D_INPUT],
+            ),
+            override_configs={
+                'strides': Config(dict(strides=3)),
+                'padding_same': Config(dict(padding='same')),
+                'channels_first': Config(dict(data_format='channels_first')),
+                'depth_multiplier': Config(dict(depth_multiplier=2)),
+                'dilation_rate': Config(dict(dilation_rate=2)),
+            },
+        ),
+    'Dot':
+        Config(dict(axes=(1, 2)), [RANK_3_INPUT, RANK_3_INPUT]),
+    'Dropout':
+        Config(dict(rate=DROPOUT), [RANK_3_INPUT]),
+    'ELU':
+        Config(shapes=[RANK_2_INPUT]),
+    'Embedding':
+        Config(dict(input_dim=4, output_dim=2), [RANK_2_INPUT]),
+    'Flatten':
+        Config(shapes=[RANK_2_INPUT]),
+    'GRU':
+        generate_configs(
+            default_config=Config(
+                kwargs=dict(units=4, return_sequences=True),
+                shapes=[RANK_3_INPUT],
+            ),
+            override_configs={
+                'implementation_1': Config(dict(implementation=1)),
+                'go_backwards': Config(dict(go_backwards=True)),
+                'time_major': Config(dict(time_major=True)),
+                'stateful': Config(dict(stateful=True)),
+            },
+        ),
+    'GRUCell':
+        Config(dict(units=4), [RANK_2_INPUT, RANK_2_INPUT]),
+    'GaussianDropout':
+        Config(dict(rate=DROPOUT), [RANK_2_INPUT]),
+    'GaussianNoise':
+        Config(dict(stddev=1.0), [RANK_2_INPUT]),
+    'GlobalAveragePooling1D':
+        generate_configs(
+            default_config=Config(shapes=[CONV_1D_INPUT]),
+            override_configs={
+                'channels_first': Config(dict(data_format='channels_first')),
+            },
+        ),
+    'GlobalAveragePooling2D':
+        generate_configs(
+            default_config=Config(shapes=[CONV_2D_INPUT]),
+            override_configs={
+                'channels_first': Config(dict(data_format='channels_first')),
+            },
+        ),
+    'GlobalAveragePooling3D':
+        generate_configs(
+            default_config=Config(shapes=[CONV_3D_INPUT]),
+            override_configs={
+                'channels_first': Config(dict(data_format='channels_first')),
+            },
+        ),
+    'GlobalMaxPool1D':
+        generate_configs(
+            default_config=Config(shapes=[CONV_1D_INPUT]),
+            override_configs={
+                'channels_first': Config(dict(data_format='channels_first')),
+            },
+        ),
+    'GlobalMaxPool2D':
+        generate_configs(
+            default_config=Config(shapes=[CONV_2D_INPUT]),
+            override_configs={
+                'channels_first': Config(dict(data_format='channels_first')),
+            },
+        ),
+    'GlobalMaxPool3D':
+        generate_configs(
+            default_config=Config(shapes=[CONV_3D_INPUT]),
+            override_configs={
+                'channels_first': Config(dict(data_format='channels_first')),
+            },
+        ),
+    'InputLayer':
+        Config(shapes=[RANK_2_INPUT]),
+    'LSTM':
+        generate_configs(
+            default_config=Config(
+                kwargs=dict(units=4, return_sequences=True),
+                shapes=[RANK_3_INPUT],
+            ),
+            override_configs={
+                'implementation_1': Config(dict(implementation=1)),
+                'go_backwards': Config(dict(go_backwards=True)),
+                'time_major': Config(dict(time_major=True)),
+                'stateful': Config(dict(stateful=True)),
+            },
+        ),
+    'LSTMCell':
+        Config(dict(units=4), [RANK_2_INPUT, RANK_2_INPUT]),
+    'Lambda':
+        Config(dict(function=lambda x: x**2), [RANK_2_INPUT]),
+    'LayerNormalization':
+        Config(shapes=[RANK_2_INPUT]),
+    'LeakyReLU':
+        Config(shapes=[RANK_2_INPUT]),
+    'LocallyConnected1D':
+        generate_configs(
+            default_config=Config(
+                kwargs=dict(filters=4, kernel_size=3),
+                shapes=[CONV_1D_INPUT],
+            ),
+            override_configs={
+                'strides': Config(dict(strides=3)),
+                'padding_same': Config(dict(padding='same', implementation=2)),
+                'channels_first': Config(dict(data_format='channels_first')),
+                'sparse_implementation': Config(dict(implementation=3)),
+            },
+        ),
+    'LocallyConnected2D':
+        generate_configs(
+            default_config=Config(
+                kwargs=dict(filters=4, kernel_size=3),
+                shapes=[CONV_2D_INPUT],
+            ),
+            override_configs={
+                'strides': Config(dict(strides=3)),
+                'padding_same': Config(dict(padding='same', implementation=2)),
+                'channels_first': Config(dict(data_format='channels_first')),
+                'sparse_implementation': Config(dict(implementation=3)),
+            },
+        ),
+    'Masking':
+        Config(shapes=[RANK_2_INPUT]),
+    'MaxPool1D':
+        generate_configs(
+            default_config=Config(shapes=[CONV_1D_INPUT]),
+            override_configs={
+                'strides': Config(dict(strides=3)),
+                'padding_same': Config(dict(padding='same')),
+                'channels_first': Config(dict(data_format='channels_first')),
+            },
+        ),
+    'MaxPool2D':
+        generate_configs(
+            default_config=Config(shapes=[CONV_2D_INPUT]),
+            override_configs={
+                'strides': Config(dict(strides=3)),
+                'padding_same': Config(dict(padding='same')),
+                # TF: Default MaxPoolingOp only supports NHWC on device type CPU
+                # 'channels_first': Config(dict(data_format='channels_first')),
+            },
+        ),
+    'MaxPool3D':
+        generate_configs(
+            default_config=Config(shapes=[CONV_3D_INPUT]),
+            override_configs={
+                'strides': Config(dict(strides=3)),
+                'padding_same': Config(dict(padding='same')),
+                'channels_first': Config(dict(data_format='channels_first')),
+            },
+        ),
+    'Maximum':
+        Config(shapes=[RANK_2_INPUT, RANK_2_INPUT]),
+    'Minimum':
+        Config(shapes=[RANK_2_INPUT, RANK_2_INPUT]),
+    'MultiHeadAttention':
+        Config(dict(num_heads=2, key_dim=3), [RANK_3_INPUT, RANK_3_INPUT]),
+    'Multiply':
+        Config(shapes=[RANK_2_INPUT, RANK_2_INPUT]),
+    'PReLU':
+        Config(shapes=[RANK_2_INPUT]),
+    'Permute':
+        Config(dict(dims=(3, 1, 2)), [RANK_4_INPUT]),
+    'ReLU':
+        Config(shapes=[RANK_2_INPUT]),
+    'RepeatVector':
+        Config(dict(n=3), [RANK_2_INPUT]),
+    'Reshape':
+        Config(dict(target_shape=[1, 1, 1] + RANK_3_INPUT[1:]), [RANK_3_INPUT]),
+    'SeparableConv1D':
+        generate_configs(
+            default_config=Config(
+                kwargs=dict(filters=4, kernel_size=3),
+                shapes=[CONV_1D_INPUT],
+            ),
+            override_configs={
+                'strides': Config(dict(strides=3)),
+                'padding_same': Config(dict(padding='same')),
+                # TF: Depthwise convolution on CPU is only supported for NHWC
+                #     format
+                # 'channels_first': Config(dict(data_format='channels_first')),
+                'depth_multiplier': Config(dict(depth_multiplier=2)),
+                'dilation_rate': Config(dict(dilation_rate=2)),
+            },
+        ),
+    'SeparableConv2D':
+        generate_configs(
+            default_config=Config(
+                kwargs=dict(filters=4, kernel_size=3),
+                shapes=[CONV_2D_INPUT],
+            ),
+            override_configs={
+                'strides': Config(dict(strides=3)),
+                'padding_same': Config(dict(padding='same')),
+                # TF: Depthwise convolution on CPU is only supported for NHWC
+                #     format
+                # 'channels_first': Config(dict(data_format='channels_first')),
+                'depth_multiplier': Config(dict(depth_multiplier=2)),
+                'dilation_rate': Config(dict(dilation_rate=2)),
+            },
+        ),
+    'SimpleRNN':
+        generate_configs(
+            default_config=Config(
+                kwargs=dict(units=4, return_sequences=True),
+                shapes=[RANK_3_INPUT],
+            ),
+            override_configs={
+                'go_backwards': Config(dict(go_backwards=True)),
+                'stateful': Config(dict(stateful=True)),
+            },
+        ),
+    'SimpleRNNCell':
+        Config(dict(units=4), [RANK_2_INPUT, RANK_2_INPUT]),
+    'Softmax':
+        Config(shapes=[RANK_2_INPUT]),
+    'SpatialDropout1D':
+        Config(dict(rate=DROPOUT), [CONV_1D_INPUT]),
+    'SpatialDropout2D':
+        Config(dict(rate=DROPOUT), [CONV_2D_INPUT]),
+    'SpatialDropout3D':
+        Config(dict(rate=DROPOUT), [CONV_3D_INPUT]),
+    'Subtract':
+        Config(shapes=[RANK_2_INPUT, RANK_2_INPUT]),
+    'ThresholdedReLU':
+        Config(shapes=[RANK_2_INPUT]),
+    'UpSampling1D':
+        Config(shapes=[CONV_1D_INPUT]),
+    'UpSampling2D':
+        Config(shapes=[CONV_2D_INPUT]),
+    'UpSampling3D':
+        Config(shapes=[CONV_3D_INPUT]),
+    'ZeroPadding1D':
+        Config(shapes=[CONV_1D_INPUT]),
+    'ZeroPadding2D':
+        Config(shapes=[CONV_2D_INPUT]),
+    'ZeroPadding3D':
+        Config(shapes=[CONV_3D_INPUT]),
 }
 
-LAYER_TO_KWARGS = {
-    'Activation': dict(activation='relu'),
-    'ActivityRegularization': dict(),
-    'Add': dict(),
-    'AdditiveAttention': dict(),
-    'AlphaDropout': dict(rate=DROPOUT),
-    'Attention': dict(),
-    'Average': dict(),
-    'AveragePooling1D': dict(),
-    'AveragePooling2D': dict(),
-    'AveragePooling3D': dict(),
-    'BatchNormalization': dict(),
-    'Concatenate': dict(),
-    'Conv1D': dict(filters=4, kernel_size=3),
-    'Conv1DTranspose': dict(filters=4, kernel_size=3),
-    'Conv2D': dict(filters=4, kernel_size=3),
-    'Conv2DTranspose': dict(filters=4, kernel_size=3),
-    'Conv3D': dict(filters=4, kernel_size=3),
-    'Conv3DTranspose': dict(filters=4, kernel_size=3),
-    'ConvLSTM2D': dict(filters=4, kernel_size=3),
-    'Cropping1D': dict(cropping=2),
-    'Cropping2D': dict(cropping=2),
-    'Cropping3D': dict(cropping=2),
-    'Dense': dict(units=4),
-    'DepthwiseConv2D': dict(kernel_size=3),
-    'Dot': dict(axes=(1, 2)),
-    'Dropout': dict(rate=DROPOUT),
-    'ELU': dict(),
-    'Embedding': dict(input_dim=4, output_dim=2),
-    'Flatten': dict(),
-    'GRU': dict(units=4, return_sequences=True),
-    'GRUCell': dict(units=4),
-    'GaussianDropout': dict(rate=DROPOUT),
-    'GaussianNoise': dict(stddev=1.0),
-    'GlobalAveragePooling1D': dict(),
-    'GlobalAveragePooling2D': dict(),
-    'GlobalAveragePooling3D': dict(),
-    'GlobalMaxPool1D': dict(),
-    'GlobalMaxPool2D': dict(),
-    'GlobalMaxPool3D': dict(),
-    'InputLayer': dict(),
-    'LSTM': dict(units=4, return_sequences=True),
-    'LSTMCell': dict(units=4),
-    'Lambda': dict(function=lambda x: x**2),
-    'Layer': dict(),
-    'LayerNormalization': dict(),
-    'LeakyReLU': dict(),
-    'LocallyConnected1D': dict(filters=4, kernel_size=3),
-    'LocallyConnected2D': dict(filters=4, kernel_size=3),
-    'Masking': dict(),
-    'MaxPool1D': dict(),
-    'MaxPool2D': dict(),
-    'MaxPool3D': dict(),
-    'Maximum': dict(),
-    'Minimum': dict(),
-    'MultiHeadAttention': dict(num_heads=2, key_dim=3),
-    'Multiply': dict(),
-    'PReLU': dict(),
-    'Permute': dict(dims=(3, 1, 2)),
-    'ReLU': dict(),
-    'RepeatVector': dict(n=3),
-    'Reshape': dict(target_shape=[1, 1, 1] + RANK_3_INPUT[1:]),
-    'SeparableConv1D': dict(filters=4, kernel_size=3),
-    'SeparableConv2D': dict(filters=4, kernel_size=3),
-    'SimpleRNN': dict(units=4, return_sequences=True),
-    'SimpleRNNCell': dict(units=4),
-    'Softmax': dict(),
-    'SpatialDropout1D': dict(rate=DROPOUT),
-    'SpatialDropout2D': dict(rate=DROPOUT),
-    'SpatialDropout3D': dict(rate=DROPOUT),
-    'Subtract': dict(),
-    'ThresholdedReLU': dict(),
-    'UpSampling1D': dict(),
-    'UpSampling2D': dict(),
-    'UpSampling3D': dict(),
-    'ZeroPadding1D': dict(),
-    'ZeroPadding2D': dict(),
-    'ZeroPadding3D': dict(),
-}
+# Normalize LAYER_TO_UNITTEST_CONFIGURATIONS
+for key, value in LAYER_TO_UNITTEST_CONFIGURATIONS.items():
+  if isinstance(value, Config):
+    LAYER_TO_UNITTEST_CONFIGURATIONS[key] = {'default': value}
 
 # Layers that allow specifying the 'dropout' kwarg.
 DROPOUT_LAYERS = [
@@ -201,68 +511,105 @@
 ]
 
 flags.DEFINE_string('layer', 'Dense',
-                    f'One of {list(LAYER_TO_INPUT_SHAPES.keys())}.')
+                    f'One of {list(LAYER_TO_UNITTEST_CONFIGURATIONS.keys())}.')
 flags.DEFINE_bool(
     'dynamic_batch', False,
     'Whether or not to compile the layer with a dynamic batch size.')
 flags.DEFINE_bool('training', False,
                   'Whether or not to compile the layer in training mode.')
+flags.DEFINE_bool(
+    'test_full_api', False,
+    'Whether or not to test multiple layer configurations using non-required '
+    'kwargs.')
+flags.DEFINE_bool(
+    'list_layers_with_full_api_tests', False,
+    'Whether or not to print out all layers with non-default configurations '
+    '(and skip running the tests).')
+
+
+def get_configs() -> Dict[str, Config]:
+  """Gets the configs that we want to test for FLAGS.layer."""
+  configs = LAYER_TO_UNITTEST_CONFIGURATIONS[FLAGS.layer]
+  if not FLAGS.test_full_api:
+    return {'default': configs['default']}
+  return configs  # pytype: disable=bad-return-type
 
 
 def get_input(shape: Sequence[int]) -> tf.keras.layers.Input:
+  """Gets the input shape(s) that we want to test."""
   batch_size = None if FLAGS.dynamic_batch else shape[0]
   return tf.keras.layers.Input(batch_size=batch_size, shape=shape[1:])
 
 
-def normalize(inputs: Sequence[Any]) -> Union[Any, Sequence[Any]]:
+def keras_input_normalizer(inputs: Sequence[Any]) -> Union[Any, Sequence[Any]]:
   """Unpacks inputs if it has length one."""
   return inputs[0] if len(inputs) == 1 else inputs
 
 
-class KerasLayersModule(tf.Module):
+def keras_arg_wrapper(*args):
+  """Wrapper to convert multiple positional args into a list of values."""
+  return list(args) if isinstance(args, tuple) else args
+
+
+def create_wrapped_keras_layer(config: Config) -> tf.keras.Model:
+  """Wraps a keras layer in a model for compilation."""
+  layer_class = getattr(tf.keras.layers, FLAGS.layer)
+
+  if FLAGS.training and FLAGS.layer in DROPOUT_LAYERS:
+    config.kwargs['dropout'] = DROPOUT
+
+  inputs = keras_input_normalizer([get_input(shape) for shape in config.shapes])
+  if FLAGS.layer == 'MultiHeadAttention':
+    # TODO(meadowlark): Remove specialization if API changes.
+    outputs = layer_class(**config.kwargs)(*inputs)
+  else:
+    outputs = layer_class(**config.kwargs)(inputs)
+  return tf.keras.Model(inputs, outputs)
+
+
+def create_tf_function_unittest(config: Config, exported_name: str,
+                                model: tf.keras.Model) -> tf.function:
+  """Wrap the model's __call__ function in a tf.function for testing."""
+  input_shapes = config.shapes
+  if FLAGS.dynamic_batch:
+    input_shapes = [[None] + shape[1:] for shape in input_shapes]
+
+  input_signature = [tf.TensorSpec(shape) for shape in input_shapes]
+  if len(input_signature) > 1:
+    input_signature = [input_signature]
+
+  call = lambda *args: model(keras_arg_wrapper(*args), training=FLAGS.training)
+  return tf_test_utils.tf_function_unittest(input_signature=input_signature,
+                                            name=exported_name)(call)
+
+
+class KerasLayersModule(tf_test_utils.TestModule):
+
+  @classmethod
+  def configure_class(cls):
+    """Configure each tf_function_unittest and define it on the cls."""
+    for i, (exported_name, config) in enumerate(get_configs().items()):
+      model = create_wrapped_keras_layer(config)
+      setattr(cls, exported_name,
+              create_tf_function_unittest(config, exported_name, model))
 
   def __init__(self):
     super().__init__()
-    layer_class = getattr(tf.keras.layers, FLAGS.layer)
-    input_shapes = LAYER_TO_INPUT_SHAPES[FLAGS.layer]
-    kwargs = LAYER_TO_KWARGS[FLAGS.layer]
-    if FLAGS.training and FLAGS.layer in DROPOUT_LAYERS:
-      kwargs['dropout'] = DROPOUT
-
-    # Create a wrapped keras layer.
-    inputs = normalize([get_input(shape) for shape in input_shapes])
-    if FLAGS.layer == 'MultiHeadAttention':
-      # TODO(meadowlark): Fix this if keras updates their API to be consistent.
-      outputs = layer_class(**kwargs)(*inputs)
-    else:
-      outputs = layer_class(**kwargs)(inputs)
-    self.m = tf.keras.Model(inputs, outputs)
-
-    # Wrap the layer in a tf.function.
-    if FLAGS.dynamic_batch:
-      input_shapes = [[None] + shape[1:] for shape in input_shapes]
-    input_signature = [tf.TensorSpec(shape) for shape in input_shapes]
-    if len(input_signature) > 1:
-      input_signature = [input_signature]
-    self.call = tf.function(input_signature=input_signature)(
-        lambda x: self.m(x, training=FLAGS.training))
+    self.models = []
+    for i, (exported_name, config) in enumerate(get_configs().items()):
+      model = create_wrapped_keras_layer(config)
+      self.models.append(model)
+      setattr(self, exported_name,
+              create_tf_function_unittest(config, exported_name, model))
 
 
 class KerasLayersTest(tf_test_utils.TracedModuleTestCase):
 
   def __init__(self, *args, **kwargs):
     super().__init__(*args, **kwargs)
-    self._modules = tf_test_utils.compile_tf_module(KerasLayersModule,
-                                                    exported_names=['call'])
-
-  def test_call(self):
-
-    def call(module):
-      input_shapes = LAYER_TO_INPUT_SHAPES[FLAGS.layer]
-      inputs = normalize([tf_utils.uniform(shape) for shape in input_shapes])
-      module.call(inputs)
-
-    self.compare_backends(call, self._modules)
+    self._modules = tf_test_utils.compile_tf_module(
+        KerasLayersModule,
+        exported_names=KerasLayersModule.get_exported_names())
 
 
 def main(argv):
@@ -270,14 +617,28 @@
   if hasattr(tf, 'enable_v2_behavior'):
     tf.enable_v2_behavior()
 
-  if FLAGS.layer not in LAYER_TO_INPUT_SHAPES:
+  if FLAGS.layer not in LAYER_TO_UNITTEST_CONFIGURATIONS:
     raise ValueError(f"Unrecognized layer: '{FLAGS.layer}'.")
+
+  if FLAGS.list_layers_with_full_api_tests:
+    for layer, configs in sorted(LAYER_TO_UNITTEST_CONFIGURATIONS.items()):
+      if len(configs) > 1:
+        print(f'    "{layer}",')
+    return
+
+  # Set up name for saving artifacts.
   dynamic_batch_str = 'dynamic_batch' if FLAGS.dynamic_batch else 'static_batch'
   training_str = 'training' if FLAGS.training else 'non_training'
-  settings_str = f'{dynamic_batch_str}_{training_str}'
+  full_api_str = 'full_api' if FLAGS.test_full_api else 'default_api'
+  settings_str = f'{full_api_str}_{dynamic_batch_str}_{training_str}'
   KerasLayersModule.__name__ = os.path.join('keras_layers', FLAGS.layer,
                                             settings_str)
 
+  # Use the configurations for FLAGS.layer to add the tf.functions we wish
+  # to test to the KerasLayersModule, and then generate unittests for each of
+  # them.
+  KerasLayersModule.configure_class()
+  KerasLayersTest.generate_unittests(KerasLayersModule)
   tf.test.main()
 
 
diff --git a/scripts/update_e2e_coverage.py b/scripts/update_e2e_coverage.py
index a012ebf..da54697 100755
--- a/scripts/update_e2e_coverage.py
+++ b/scripts/update_e2e_coverage.py
@@ -85,11 +85,17 @@
     '//integrations/tensorflow/e2e:e2e_tests':
         'End to end TensorFlow tests',
     '//integrations/tensorflow/e2e/keras/layers:layers_tests':
-        'End to end tests of tf.keras layers with static batch sizes in inference mode',
+        'End to end tests of tf.keras layers (with default configuration and '
+        'static batch sizes in inference mode)',
+    '//integrations/tensorflow/e2e/keras/layers:layers_full_api_tests':
+        'End to end tests of tf.keras layers full APIs '
+        '(with static batch sizes in inference mode)',
     '//integrations/tensorflow/e2e/keras/layers:layers_dynamic_batch_tests':
-        'End to end tests of tf.keras layers with dynamic batch sizes',
+        'End to end tests of tf.keras layers with dynamic batch sizes '
+        '(with default configuration in inference mode)',
     '//integrations/tensorflow/e2e/keras/layers:layers_training_tests':
-        'End to end tests of tf.keras layers in training mode',
+        'End to end tests of tf.keras layers in training mode (with default'
+        'configuration and static batch sizes)',
     '//integrations/tensorflow/e2e:mobile_bert_squad_tests':
         'End to end test of MobileBert on SQuAD',
     '//integrations/tensorflow/e2e/keras:keyword_spotting_tests':
@@ -106,13 +112,19 @@
     '//integrations/tensorflow/e2e/keras/layers:layers_tests': (
         '**Note:** Layers like `Dropout` are listed as passing in this table,\n'
         'but they function similar to identity layers in these tests. **See \n'
-        'the third table for the coverage of these layers during training.**'),
+        'the third table for the coverage of these layers during training.**\n'
+        '\n',
+        'These tests also only modify required `tf.keras.layers` arguments.\n'
+        'See the full API tests below for coverage on of non-default '
+        'layer configurations.'),
 }
 # Key to use as the name of the rows in the left column for each test in the
 # suite.
 TEST_SUITE_TO_ROW_ID_KEY = {
     '//integrations/tensorflow/e2e/keras/layers:layers_tests':
         'layer',
+    '//integrations/tensorflow/e2e/keras/layers:layers_full_api_tests':
+        'layer',
     '//integrations/tensorflow/e2e/keras/layers:layers_dynamic_batch_tests':
         'layer',
     '//integrations/tensorflow/e2e/keras/layers:layers_training_tests':
@@ -132,6 +144,8 @@
 SINGLE_SOURCE_SUITES = {
     '//integrations/tensorflow/e2e/keras/layers:layers_tests':
         'layers_test',
+    '//integrations/tensorflow/e2e/keras/layers:layers_full_api_tests':
+        'layers_test',
     '//integrations/tensorflow/e2e/keras/layers:layers_dynamic_batch_tests':
         'layers_test',
     '//integrations/tensorflow/e2e/keras/layers:layers_training_tests':