blob: 1734d685650323a0443b675fdf72103d768ae40e [file] [log] [blame]
# Lint as: python3
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests of tf.keras.layer.Layer subclasses."""
import collections
import copy
import inspect
import os
from typing import Any, Dict, List, Sequence, Tuple, Union
from absl import app
from absl import flags
from absl import logging
from pyiree.tf.support import tf_test_utils
from pyiree.tf.support import tf_utils
import tensorflow.compat.v2 as tf
FLAGS = flags.FLAGS
DROPOUT = 0.5
CONV_FILTERS = 2
CONV_KERNEL_SIZE = 2
DIM = 3
# Used for attention layers and recurrent layers.
RANK_3_SHAPE = [DIM] * 3
# Highest rank that tf.keras will allow for all layers.
RANK_5_SHAPE = [DIM] * 5
UNARY_SIGNATURE_SHAPES = [[RANK_5_SHAPE]]
BINARY_SIGNATURE_SHAPES = [[RANK_5_SHAPE] * 2]
TERNARY_SIGNATURE_SHAPES = [[RANK_5_SHAPE] * 3]
CONV_1D_SIGNATURE_SHAPES = [[[2, 8, 3]]]
CONV_2D_SIGNATURE_SHAPES = [[[2, 8, 8, 3]]]
CONV_3D_SIGNATURE_SHAPES = [[[2, 8, 8, 8, 3]]]
RNN_SIGNATURE_SHAPES = [[RANK_3_SHAPE]]
RNN_KWARGS_TO_VALUES = dict(units=[4],
return_sequences=[False, True],
stateful=[False, True])
POOLING_KWARGS_TO_VALUES = dict(strides=[None, 2],
padding=["valid", "same"],
data_format=[None, "channels_first"])
CONV_KWARGS_TO_VALUES = dict(filters=[CONV_FILTERS],
kernel_size=[CONV_KERNEL_SIZE],
strides=[1, 2],
padding=["valid", "same"],
data_format=[None, "channels_first"],
dilation_rate=[1, 2])
# Address pooling and conv layers having different default values for
# 'data_format' for 1D layers.
POOLING_1D_KWARGS_TO_VALUES = copy.deepcopy(POOLING_KWARGS_TO_VALUES)
POOLING_1D_KWARGS_TO_VALUES.update(
{"data_format": ["channels_last", "channels_first"]})
CONV_1D_KWARGS_TO_VALUES = copy.deepcopy(CONV_KWARGS_TO_VALUES)
CONV_1D_KWARGS_TO_VALUES.update(
{"data_format": ["channels_last", "channels_first"]})
# Unsupported by TensorFlow (at least on CPU).
LAYERS_TO_TF_UNSUPPORTED_NON_DEFAULT_KWARGS = {
"AveragePooling2D": ["data_format"],
"Conv1D": ["data_format"],
"Conv1DTranspose": ["data_format", "dilation_rate"],
"Conv2D": ["data_format"],
"Conv3D": ["data_format"],
"Conv3DTranspose": ["data_format"],
"LocallyConnected1D": ["padding"],
"LocallyConnected2D": ["padding"],
"MaxPool2D": ["data_format"],
}
# Some layers have kwargs which cannot both have non-default values.
LAYERS_TO_MUTUALLY_EXCLUSIVE_KWARGS = {
"Conv1D": ["strides", "dilation_rate"],
"Conv2D": ["strides", "dilation_rate"],
"Conv2DTranspose": ["strides", "dilation_rate"],
"Conv3D": ["strides", "dilation_rate"],
"ConvLSTM2D": ["strides", "dilation_rate"],
}
def get_default_kwargs_values(layer: str) -> Dict[str, Any]:
"""Gets the default kwargs for a tf.keras.layers layer."""
layer_class = getattr(tf.keras.layers, layer)
layer_parameters = inspect.signature(layer_class.__init__).parameters
kwargs_to_default_values = {
kwarg: value.default
for kwarg, value in layer_parameters.items()
if value.default is not inspect.Parameter.empty
}
return kwargs_to_default_values
def _equal_or_splat_equal(value: Any, sequence: Any) -> bool:
"""Returns True if value==sequence or value==(every element in sequence)."""
if value == sequence:
return True
elif isinstance(sequence, (list, tuple)):
for element in sequence:
if not _equal_or_splat_equal(value, element):
return False
return True
return False
def get_non_default_kwargs(
layer: str, unit_test_spec: tf_test_utils.UnitTestSpec) -> List[str]:
"""Returns all non-default optional kwargs in unit_test_spec."""
kwargs_to_defaults = get_default_kwargs_values(layer)
non_default_kwargs = []
for kwarg, value in unit_test_spec.kwargs.items():
if (kwarg in kwargs_to_defaults and
not _equal_or_splat_equal(value, kwargs_to_defaults[kwarg])):
non_default_kwargs.append(kwarg)
return non_default_kwargs
def unsupported_by_tf(layer: str,
unit_test_spec: tf_test_utils.UnitTestSpec) -> bool:
"""True if unit_test_spec specifies tf-unsupported non-default kwargs."""
if layer in LAYERS_TO_TF_UNSUPPORTED_NON_DEFAULT_KWARGS:
unsupported_kwargs = LAYERS_TO_TF_UNSUPPORTED_NON_DEFAULT_KWARGS[layer]
non_default_kwargs = get_non_default_kwargs(layer, unit_test_spec)
return any(kwarg in unsupported_kwargs for kwarg in non_default_kwargs)
return False
def has_mutually_exclusive_kwargs(
layer: str, unit_test_spec: tf_test_utils.UnitTestSpec) -> bool:
"""True if unit_test_spec specifies mutually exclusive non-default kwargs."""
if layer in LAYERS_TO_MUTUALLY_EXCLUSIVE_KWARGS:
mutually_exclusive_kwargs = LAYERS_TO_MUTUALLY_EXCLUSIVE_KWARGS[layer]
non_default_kwargs = get_non_default_kwargs(layer, unit_test_spec)
return set(mutually_exclusive_kwargs).issubset(set(non_default_kwargs))
return False
# A dictionary mapping tf.keras.layers names to lists of UnitTestSpecs.
# Each unit_test_name will have the tf.keras.layer name prepended to it.
#
# Each layer is required to have a UnitTestSpec with all-default values for
# unrequired kwargs. This allows us to seperately test the basic api and the
# full api.
LAYERS_TO_UNIT_TEST_SPECS = {
"Activation":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=UNARY_SIGNATURE_SHAPES,
kwargs_to_values=dict(activation=["relu"])),
"ActivityRegularization":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=UNARY_SIGNATURE_SHAPES,
kwargs_to_values=dict(l1=[0.0, 0.1], l2=[0.0, 0.1])),
"Add":
tf_test_utils.unit_test_specs_from_signatures(BINARY_SIGNATURE_SHAPES),
"AdditiveAttention":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=[(RANK_3_SHAPE, RANK_3_SHAPE, RANK_3_SHAPE)],
kwargs_to_values=dict(causal=[False, True])),
"AlphaDropout":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=UNARY_SIGNATURE_SHAPES,
kwargs_to_values=dict(rate=[DROPOUT])),
"Attention":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=[(RANK_3_SHAPE, RANK_3_SHAPE, RANK_3_SHAPE)],
kwargs_to_values=dict(causal=[False, True])),
"Average":
tf_test_utils.unit_test_specs_from_signatures(BINARY_SIGNATURE_SHAPES),
"AveragePooling1D":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=CONV_1D_SIGNATURE_SHAPES,
kwargs_to_values=POOLING_1D_KWARGS_TO_VALUES),
"AveragePooling2D":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=CONV_2D_SIGNATURE_SHAPES,
kwargs_to_values=POOLING_KWARGS_TO_VALUES),
"AveragePooling3D":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=CONV_3D_SIGNATURE_SHAPES,
kwargs_to_values=POOLING_KWARGS_TO_VALUES),
"BatchNormalization":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=UNARY_SIGNATURE_SHAPES,
kwargs_to_values=dict(renorm=[False, True])),
"Concatenate":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=BINARY_SIGNATURE_SHAPES,
kwargs_to_values=dict(axis=[-1, 0])),
"Conv1D":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=CONV_1D_SIGNATURE_SHAPES,
kwargs_to_values=CONV_1D_KWARGS_TO_VALUES),
"Conv1DTranspose":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=CONV_1D_SIGNATURE_SHAPES,
kwargs_to_values=CONV_KWARGS_TO_VALUES),
"Conv2D":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=CONV_2D_SIGNATURE_SHAPES,
kwargs_to_values=CONV_KWARGS_TO_VALUES),
"Conv2DTranspose":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=CONV_2D_SIGNATURE_SHAPES,
kwargs_to_values=CONV_KWARGS_TO_VALUES),
"Conv3D":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=CONV_3D_SIGNATURE_SHAPES,
kwargs_to_values=CONV_KWARGS_TO_VALUES),
"Conv3DTranspose":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=CONV_3D_SIGNATURE_SHAPES,
kwargs_to_values=CONV_KWARGS_TO_VALUES),
"ConvLSTM2D":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=CONV_3D_SIGNATURE_SHAPES,
kwargs_to_values=dict(filters=[CONV_FILTERS],
kernel_size=[CONV_KERNEL_SIZE],
return_state=[False, True],
strides=[1, 2],
dilation_rate=[1, 2],
stateful=[False, True])),
"Cropping1D":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=CONV_1D_SIGNATURE_SHAPES,
kwargs_to_values=dict(cropping=[1, (1, 2)])),
"Cropping2D":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=CONV_2D_SIGNATURE_SHAPES,
kwargs_to_values=dict(cropping=[0, ((1, 2), (2, 1))])),
"Cropping3D":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=CONV_3D_SIGNATURE_SHAPES,
kwargs_to_values=dict(cropping=[1, ((1, 2), (2, 1), (1, 0))])),
"Dense":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=UNARY_SIGNATURE_SHAPES,
kwargs_to_values=dict(units=[8])),
"DepthwiseConv2D":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=CONV_2D_SIGNATURE_SHAPES,
kwargs_to_values=dict(kernel_size=[CONV_KERNEL_SIZE],
strides=[1, 2],
padding=["valid", "same"],
dilation_rate=[1, 2],
depth_multiplier=[1, 2])),
"Dot":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=BINARY_SIGNATURE_SHAPES,
kwargs_to_values=dict(axes=[(1, 2)])),
"Dropout":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=UNARY_SIGNATURE_SHAPES,
kwargs_to_values=dict(rate=[DROPOUT])),
"ELU":
tf_test_utils.unit_test_specs_from_signatures(UNARY_SIGNATURE_SHAPES),
"Embedding":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=UNARY_SIGNATURE_SHAPES,
kwargs_to_values=dict(input_dim=[4], output_dim=[2])),
"Flatten":
tf_test_utils.unit_test_specs_from_signatures(UNARY_SIGNATURE_SHAPES),
"GRU":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=RNN_SIGNATURE_SHAPES,
kwargs_to_values=RNN_KWARGS_TO_VALUES),
"GaussianDropout":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=RNN_SIGNATURE_SHAPES,
kwargs_to_values=dict(rate=[DROPOUT])),
"GaussianNoise":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=RNN_SIGNATURE_SHAPES,
kwargs_to_values=dict(stddev=[1.0])),
"GlobalAveragePooling1D":
tf_test_utils.unit_test_specs_from_signatures(CONV_1D_SIGNATURE_SHAPES),
"GlobalAveragePooling2D":
tf_test_utils.unit_test_specs_from_signatures(CONV_2D_SIGNATURE_SHAPES),
"GlobalAveragePooling3D":
tf_test_utils.unit_test_specs_from_signatures(CONV_3D_SIGNATURE_SHAPES),
"GlobalMaxPool1D":
tf_test_utils.unit_test_specs_from_signatures(CONV_1D_SIGNATURE_SHAPES),
"GlobalMaxPool2D":
tf_test_utils.unit_test_specs_from_signatures(CONV_2D_SIGNATURE_SHAPES),
"GlobalMaxPool3D":
tf_test_utils.unit_test_specs_from_signatures(CONV_3D_SIGNATURE_SHAPES),
"InputLayer":
tf_test_utils.unit_test_specs_from_signatures(UNARY_SIGNATURE_SHAPES),
"LSTM":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=RNN_SIGNATURE_SHAPES,
kwargs_to_values=RNN_KWARGS_TO_VALUES),
"Lambda":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=UNARY_SIGNATURE_SHAPES,
kwargs_to_values=dict(function=[lambda x: x**2])),
"LayerNormalization":
tf_test_utils.unit_test_specs_from_signatures(UNARY_SIGNATURE_SHAPES),
"LeakyReLU":
tf_test_utils.unit_test_specs_from_signatures(UNARY_SIGNATURE_SHAPES),
"LocallyConnected1D":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=CONV_1D_SIGNATURE_SHAPES,
kwargs_to_values=dict(filters=[CONV_FILTERS],
kernel_size=[CONV_KERNEL_SIZE],
strides=[1, 3],
padding=["valid", "same"],
implementation=[1, 3])),
"LocallyConnected2D":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=CONV_2D_SIGNATURE_SHAPES,
kwargs_to_values=dict(filters=[CONV_FILTERS],
kernel_size=[CONV_KERNEL_SIZE],
strides=[1, 3],
padding=["valid", "same"],
implementation=[1, 3])),
"Masking":
tf_test_utils.unit_test_specs_from_signatures(UNARY_SIGNATURE_SHAPES),
"MaxPool1D":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=CONV_1D_SIGNATURE_SHAPES,
kwargs_to_values=POOLING_1D_KWARGS_TO_VALUES),
"MaxPool2D":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=CONV_2D_SIGNATURE_SHAPES,
kwargs_to_values=POOLING_KWARGS_TO_VALUES),
"MaxPool3D":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=CONV_3D_SIGNATURE_SHAPES,
kwargs_to_values=POOLING_KWARGS_TO_VALUES),
"Maximum":
tf_test_utils.unit_test_specs_from_signatures(BINARY_SIGNATURE_SHAPES),
"Minimum":
tf_test_utils.unit_test_specs_from_signatures(BINARY_SIGNATURE_SHAPES),
"MultiHeadAttention":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=[(RANK_3_SHAPE, RANK_3_SHAPE)],
kwargs_to_values=dict(num_heads=[2], key_dim=[3])),
"Multiply":
tf_test_utils.unit_test_specs_from_signatures(BINARY_SIGNATURE_SHAPES),
"PReLU":
tf_test_utils.unit_test_specs_from_signatures(UNARY_SIGNATURE_SHAPES),
"Permute":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=UNARY_SIGNATURE_SHAPES,
kwargs_to_values=dict(dims=[(3, 1, 4, 2)])),
"ReLU":
tf_test_utils.unit_test_specs_from_signatures(UNARY_SIGNATURE_SHAPES),
"RepeatVector":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=[((2, 2),)], kwargs_to_values=dict(n=[3])),
"Reshape":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=[((3, 2, 2, 2),)],
kwargs_to_values=dict(target_shape=[(2, 1, 4, 1)])),
"SeparableConv1D":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=CONV_1D_SIGNATURE_SHAPES,
kwargs_to_values=dict(filters=[CONV_FILTERS],
kernel_size=[CONV_KERNEL_SIZE],
strides=[1, 2],
padding=["valid", "same"],
dilation_rate=[1, 2],
depth_multiplier=[1, 2])),
"SeparableConv2D":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=CONV_2D_SIGNATURE_SHAPES,
kwargs_to_values=dict(filters=[CONV_FILTERS],
kernel_size=[CONV_KERNEL_SIZE],
strides=[1, 2],
padding=["valid", "same"],
dilation_rate=[1, 2],
depth_multiplier=[1, 2])),
"SimpleRNN":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=RNN_SIGNATURE_SHAPES,
kwargs_to_values=RNN_KWARGS_TO_VALUES),
"Softmax":
tf_test_utils.unit_test_specs_from_signatures(UNARY_SIGNATURE_SHAPES),
"SpatialDropout1D":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=CONV_1D_SIGNATURE_SHAPES,
kwargs_to_values=dict(rate=[DROPOUT])),
"SpatialDropout2D":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=CONV_2D_SIGNATURE_SHAPES,
kwargs_to_values=dict(rate=[DROPOUT])),
"SpatialDropout3D":
tf_test_utils.unit_test_specs_from_signatures(
signature_shapes=CONV_3D_SIGNATURE_SHAPES,
kwargs_to_values=dict(rate=[DROPOUT])),
"Subtract":
tf_test_utils.unit_test_specs_from_signatures(BINARY_SIGNATURE_SHAPES),
"ThresholdedReLU":
tf_test_utils.unit_test_specs_from_signatures(UNARY_SIGNATURE_SHAPES),
"UpSampling1D":
tf_test_utils.unit_test_specs_from_signatures(CONV_1D_SIGNATURE_SHAPES),
"UpSampling2D":
tf_test_utils.unit_test_specs_from_signatures(CONV_2D_SIGNATURE_SHAPES),
"UpSampling3D":
tf_test_utils.unit_test_specs_from_signatures(CONV_3D_SIGNATURE_SHAPES),
"ZeroPadding1D":
tf_test_utils.unit_test_specs_from_signatures(CONV_1D_SIGNATURE_SHAPES),
"ZeroPadding2D":
tf_test_utils.unit_test_specs_from_signatures(CONV_2D_SIGNATURE_SHAPES),
"ZeroPadding3D":
tf_test_utils.unit_test_specs_from_signatures(CONV_3D_SIGNATURE_SHAPES),
}
for layer, specs in LAYERS_TO_UNIT_TEST_SPECS.items():
# Update using 'with_name' to avoid updating shared UnitTestSpecs.
specs = [spec.with_name(f"{layer}__{spec.unit_test_name}") for spec in specs]
LAYERS_TO_UNIT_TEST_SPECS[layer] = specs
# Validate that there are not multiple UnitTestSpecs with the same name.
seen_unit_test_names = set()
for spec in specs:
if spec.unit_test_name in seen_unit_test_names:
raise ValueError(
f"Found multiple UnitTestSpecs with the name '{spec.unit_test_name}'")
seen_unit_test_names.add(spec.unit_test_name)
# Validate that there is one spec that has default values for all unrequired
# kwargs.
has_default_unrequired_kwargs = False
for spec in specs:
if not get_non_default_kwargs(layer, spec):
has_default_unrequired_kwargs = True
if not has_default_unrequired_kwargs:
raise ValueError(
f"The configuration for '{layer}' did not have a UnitTestSpec with all "
"default kwargs.")
# Layers that allow specifying the 'dropout' kwarg.
DROPOUT_LAYERS = [
"AdditiveAttention", "Attention", "ConvLSTM2D", "GRU", "LSTM",
"MultiHeadAttention", "SimpleRNN"
]
flags.DEFINE_string("layer", None,
f"One of {list(LAYERS_TO_UNIT_TEST_SPECS.keys())}.")
flags.DEFINE_bool(
"dynamic_dims", False,
"Whether or not to compile the layer with a dynamic dimension sizes.")
flags.DEFINE_bool("training", False,
"Whether or not to compile the layer in training mode.")
flags.DEFINE_bool(
"test_default_kwargs_only", True,
"Whether or not to test multiple layer configurations using non-required "
"kwargs.")
flags.DEFINE_bool(
"list_layers_with_full_api_tests", False,
"Whether or not to print out all layers with non-default configurations "
"(and skip running the tests).")
def get_input(shape: Sequence[int]) -> tf.keras.layers.Input:
"""Converts a shape into a tf.keras.Input."""
# Most keras layers are only compatible with dynamic batch sizes.
batch_size = None if FLAGS.dynamic_dims else shape[0]
return tf.keras.layers.Input(batch_size=batch_size, shape=shape[1:])
def keras_input_normalizer(inputs: Sequence[Any]) -> Union[Any, Sequence[Any]]:
"""Unpacks inputs if it has length one."""
return inputs[0] if len(inputs) == 1 else inputs
def create_wrapped_keras_layer(
layer: str, unit_test_spec: tf_test_utils.UnitTestSpec) -> tf.keras.Model:
"""Wraps a keras layer in a model for compilation."""
layer_class = getattr(tf.keras.layers, layer)
kwargs = copy.deepcopy(unit_test_spec.kwargs)
if FLAGS.training and layer in DROPOUT_LAYERS:
kwargs["dropout"] = DROPOUT
if "dtype" not in unit_test_spec.kwargs:
kwargs["dtype"] = unit_test_spec.input_signature[0].dtype
inputs = keras_input_normalizer(
[get_input(spec.shape) for spec in unit_test_spec.input_signature])
if layer == "MultiHeadAttention":
# TODO(meadowlark): Remove specialization if API changes.
outputs = layer_class(**kwargs)(*inputs)
else:
outputs = layer_class(**kwargs)(inputs)
return tf.keras.Model(inputs, outputs)
def create_layer_unit_test(
model: tf.keras.Model,
unit_test_spec: tf_test_utils.UnitTestSpec) -> tf.function:
"""Wrap the model's __call__ function in a tf.function for testing."""
static_signature = unit_test_spec.input_signature
dynamic_signature = static_signature
if FLAGS.dynamic_dims:
dynamic_signature = tf_utils.apply_function(dynamic_signature,
tf_utils.make_dims_dynamic)
if len(static_signature) > 1:
static_signature = [static_signature]
dynamic_signature = [dynamic_signature]
call = lambda *args: model(*args, training=FLAGS.training)
return tf_test_utils.tf_function_unit_test(
input_signature=dynamic_signature,
static_signature=static_signature,
input_generator=unit_test_spec.input_generator,
input_args=unit_test_spec.input_args,
name=unit_test_spec.unit_test_name)(call)
class KerasLayersModule(tf_test_utils.TestModule):
def __init__(self):
super().__init__()
self.models = []
for unit_test_spec in LAYERS_TO_UNIT_TEST_SPECS[FLAGS.layer]:
if (FLAGS.test_default_kwargs_only and
get_non_default_kwargs(FLAGS.layer, unit_test_spec)):
# Skip all UnitTestSpecs with non-default unrequired kwargs.
continue
if (unsupported_by_tf(FLAGS.layer, unit_test_spec) or
has_mutually_exclusive_kwargs(FLAGS.layer, unit_test_spec)):
# Filter out UnitTestSpecs with kwargs that TensorFlow can't run on
# CPU or that are mutually exclusive. This allows us to take a product
# like that in CONV_KWARGS_TO_VALUE and filter out the configurations
# lacking support for particular layers.
continue
model = create_wrapped_keras_layer(FLAGS.layer, unit_test_spec)
# IREE requires that the models are stored on the module instance.
self.models.append(model)
layer_unit_test = create_layer_unit_test(model, unit_test_spec)
setattr(self, unit_test_spec.unit_test_name, layer_unit_test)
def get_relative_artifacts_dir() -> str:
dynamic_str = "dynamic" if FLAGS.dynamic_dims else "static"
training_str = "training" if FLAGS.training else "non_training"
full_api_str = "default_api" if FLAGS.test_default_kwargs_only else "full_api"
settings_str = f"{full_api_str}_{dynamic_str}_{training_str}"
return os.path.join("tf", "keras", "layers", FLAGS.layer, settings_str)
class KerasLayersTest(tf_test_utils.TracedModuleTestCase):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._modules = tf_test_utils.compile_tf_module(
KerasLayersModule,
exported_names=KerasLayersModule.get_tf_function_unit_tests(),
relative_artifacts_dir=get_relative_artifacts_dir())
def main(argv):
del argv # Unused.
if hasattr(tf, "enable_v2_behavior"):
tf.enable_v2_behavior()
if FLAGS.list_layers_with_full_api_tests:
for layer, unit_test_specs in sorted(LAYERS_TO_UNIT_TEST_SPECS.items()):
if len(unit_test_specs) > 1:
print(f' "{layer}",')
return
if FLAGS.layer not in LAYERS_TO_UNIT_TEST_SPECS:
raise ValueError(f"Unrecognized layer: '{FLAGS.layer}'")
unit_tests = KerasLayersModule.get_tf_function_unit_tests()
logging.info("Testing the following %s functions: %s", len(unit_tests),
unit_tests)
KerasLayersTest.generate_unit_tests(KerasLayersModule)
tf.test.main()
if __name__ == "__main__":
app.run(main)