blob: 0ab5f32c86771ba041339386fe768abd6ad2a923 [file] [log] [blame]
# Lint as: python3
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests of tf.keras.layer.Layer subclasses."""
import collections
import copy
import os
from typing import Any, Dict, Sequence, Union
from absl import app
from absl import flags
from pyiree.tf.support import tf_test_utils
from pyiree.tf.support import tf_utils
import tensorflow.compat.v2 as tf
FLAGS = flags.FLAGS
DROPOUT = 0.5
DIM = 4
RANK_2_INPUT = [DIM] * 2
RANK_3_INPUT = [DIM] * 3
RANK_4_INPUT = [DIM] * 4
CONV_1D_INPUT = [2, 8, 3]
CONV_2D_INPUT = [2, 8, 8, 3]
CONV_3D_INPUT = [2, 8, 8, 8, 3]
# Configs are namedtuples storing keyword arguments and shapes to test a
# tf.keras.layers.Layer with. They are used in two ways:
# 1. To directly specify the kwargs and shapes for a layers test.
# 2. In 'generate_configs', to specify how to change a default config to
# specify a non-default test. In this case, the overriding Config will
# exclusively specify the shape of the test if its shape is not None, and
# the overriding Config will extend/update the kwargs of the default
# Config.
Config = collections.namedtuple('Config', ['kwargs', 'shapes'])
# Use old default API for compatibility with Python 3.6.
Config.__new__.__defaults__ = (dict(), None)
def generate_configs(default_config: Config,
override_configs: Dict[str, Config]) -> Dict[str, Config]:
"""Generates a dict of 'Config's based off changes to a default Config."""
configs = {'default': default_config}
for exported_name, config in override_configs.items():
shapes = default_config.shapes if config.shapes is None else config.shapes
# Deep copy to avoid inplace mutation of the default.
kwargs = copy.deepcopy(default_config.kwargs)
kwargs.update(config.kwargs) # Adds new and overwrites old kwargs.
configs[exported_name] = Config(kwargs, shapes)
return configs
# A dict mapping tf.keras.layers names to either a single Config (representing
# the kwargs and shapes to use to test a Layer) or a dict mapping exported_names
# to Configs. The latter case is usually automatically generated via
# 'generate_configs', with the 'Config's in 'override_configs' specifying how
# to modify the 'default_config's kwargs and shapes.
#
# Each entry will be normalized to be a dict mapping exported_names to Configs,
# with a default exported_name of 'default'.
LAYER_TO_UNITTEST_CONFIGURATIONS = {
'Activation':
Config(dict(activation='relu'), [RANK_2_INPUT]),
'ActivityRegularization':
Config(dict(l1=0.1, l2=0.1), shapes=[RANK_2_INPUT]),
'Add':
Config(shapes=[RANK_2_INPUT, RANK_2_INPUT]),
'AdditiveAttention':
generate_configs(
default_config=Config(
shapes=[RANK_3_INPUT, RANK_3_INPUT, RANK_3_INPUT],),
override_configs={
'causal': Config(dict(causal=True)),
},
),
'AlphaDropout':
Config(dict(rate=DROPOUT), [RANK_2_INPUT]),
'Attention':
generate_configs(
default_config=Config(
shapes=[RANK_3_INPUT, RANK_3_INPUT, RANK_3_INPUT],),
override_configs={
'causal': Config(dict(causal=True)),
},
),
'Average':
Config(shapes=[RANK_2_INPUT, RANK_2_INPUT]),
'AveragePooling1D':
generate_configs(
default_config=Config(shapes=[CONV_1D_INPUT]),
override_configs={
'strides': Config(dict(strides=3)),
'padding_same': Config(dict(padding='same')),
'channels_first': Config(dict(data_format='channels_first')),
},
),
'AveragePooling2D':
generate_configs(
default_config=Config(shapes=[CONV_2D_INPUT]),
override_configs={
'strides': Config(dict(strides=3)),
'padding_same': Config(dict(padding='same')),
# TF: Default AvgPoolingOp only supports NHWC on device type CPU
# 'channels_first': Config(dict(data_format='channels_first')),
},
),
'AveragePooling3D':
generate_configs(
default_config=Config(shapes=[CONV_3D_INPUT]),
override_configs={
'strides': Config(dict(strides=3)),
'padding_same': Config(dict(padding='same')),
'channels_first': Config(dict(data_format='channels_first')),
},
),
'BatchNormalization':
generate_configs(
default_config=Config(shapes=[RANK_2_INPUT]),
override_configs={'renorm': Config(dict(renorm=True))},
),
'Concatenate':
Config(shapes=[RANK_4_INPUT, RANK_4_INPUT]),
'Conv1D':
generate_configs(
default_config=Config(
kwargs=dict(filters=4, kernel_size=3),
shapes=[CONV_1D_INPUT],
),
override_configs={
'strides': Config(dict(strides=3)),
'padding_same': Config(dict(padding='same')),
# TF: The Conv2D op currently only supports the NHWC tensor
# format on the CPU.
# 'channels_first': Config(dict(data_format='channels_first')),
'dilation_rate': Config(dict(dilation_rate=3)),
},
),
'Conv1DTranspose':
generate_configs(
default_config=Config(
kwargs=dict(filters=4, kernel_size=3),
shapes=[CONV_1D_INPUT],
),
override_configs={
'strides': Config(dict(strides=3)),
'padding_same': Config(dict(padding='same')),
# TF: Conv2DCustomBackpropInputOp only supports NHWC
# 'channels_first': Config(dict(data_format='channels_first')),
# TF: Current libxsmm and customized CPU implementations do not
# yet support dilation rates larger than 1.
# 'dilation_rate': Config(dict(dilation_rate=3)),
},
),
'Conv2D':
generate_configs(
default_config=Config(
kwargs=dict(filters=4, kernel_size=3),
shapes=[CONV_2D_INPUT],
),
override_configs={
'strides': Config(dict(strides=3)),
'padding_same': Config(dict(padding='same')),
# TF: The Conv2D op currently only supports the NHWC tensor
# format on the CPU.
# 'channels_first': Config(dict(data_format='channels_first')),
'dilation_rate': Config(dict(dilation_rate=3)),
},
),
'Conv2DTranspose':
generate_configs(
default_config=Config(
kwargs=dict(filters=4, kernel_size=3),
shapes=[CONV_2D_INPUT],
),
override_configs={
'strides': Config(dict(strides=3)),
'padding_same': Config(dict(padding='same')),
'channels_first': Config(dict(data_format='channels_first')),
'dilation_rate': Config(dict(dilation_rate=3)),
},
),
'Conv3D':
generate_configs(
default_config=Config(
kwargs=dict(filters=4, kernel_size=3),
shapes=[CONV_3D_INPUT],
),
override_configs={
'strides': Config(dict(strides=3)),
'padding_same': Config(dict(padding='same')),
# TF: The Conv3D op currently only supports the NHWC tensor
# format on the CPU.
# 'channels_first': Config(dict(data_format='channels_first')),
'dilation_rate': Config(dict(dilation_rate=3)),
},
),
'Conv3DTranspose':
generate_configs(
default_config=Config(
kwargs=dict(filters=4, kernel_size=3),
shapes=[CONV_3D_INPUT],
),
override_configs={
'strides': Config(dict(strides=3)),
'padding_same': Config(dict(padding='same')),
# TF: Conv3DBackpropInputOpV2 only supports NDHWC on the CPU.
# 'channels_first': Config(dict(data_format='channels_first')),
'dilation_rate': Config(dict(dilation_rate=3)),
},
),
'ConvLSTM2D':
generate_configs(
default_config=Config(
kwargs=dict(filters=4, kernel_size=3, return_state=True),
shapes=[CONV_3D_INPUT],
),
override_configs={
'strides': Config(dict(strides=3)),
'padding': Config(dict(padding='same')),
'channels_first': Config(dict(data_format='channels_first')),
'dilation_rate': Config(dict(dilation_rate=3)),
'go_backwards': Config(dict(go_backwards=True)),
'stateful': Config(dict(stateful=True)),
},
),
'Cropping1D':
Config(dict(cropping=2), [CONV_1D_INPUT]),
'Cropping2D':
Config(dict(cropping=2), [CONV_2D_INPUT]),
'Cropping3D':
Config(dict(cropping=2), [CONV_3D_INPUT]),
'Dense':
Config(dict(units=4), [RANK_2_INPUT]),
'DepthwiseConv2D':
generate_configs(
default_config=Config(
kwargs=dict(kernel_size=3),
shapes=[CONV_2D_INPUT],
),
override_configs={
'strides': Config(dict(strides=3)),
'padding_same': Config(dict(padding='same')),
'channels_first': Config(dict(data_format='channels_first')),
'depth_multiplier': Config(dict(depth_multiplier=2)),
'dilation_rate': Config(dict(dilation_rate=2)),
},
),
'Dot':
Config(dict(axes=(1, 2)), [RANK_3_INPUT, RANK_3_INPUT]),
'Dropout':
Config(dict(rate=DROPOUT), [RANK_3_INPUT]),
'ELU':
Config(shapes=[RANK_2_INPUT]),
'Embedding':
Config(dict(input_dim=4, output_dim=2), [RANK_2_INPUT]),
'Flatten':
Config(shapes=[RANK_2_INPUT]),
'GRU':
generate_configs(
default_config=Config(
kwargs=dict(units=4, return_sequences=True),
shapes=[RANK_3_INPUT],
),
override_configs={
'implementation_1': Config(dict(implementation=1)),
'go_backwards': Config(dict(go_backwards=True)),
'time_major': Config(dict(time_major=True)),
'stateful': Config(dict(stateful=True)),
},
),
'GRUCell':
Config(dict(units=4), [RANK_2_INPUT, RANK_2_INPUT]),
'GaussianDropout':
Config(dict(rate=DROPOUT), [RANK_2_INPUT]),
'GaussianNoise':
Config(dict(stddev=1.0), [RANK_2_INPUT]),
'GlobalAveragePooling1D':
generate_configs(
default_config=Config(shapes=[CONV_1D_INPUT]),
override_configs={
'channels_first': Config(dict(data_format='channels_first')),
},
),
'GlobalAveragePooling2D':
generate_configs(
default_config=Config(shapes=[CONV_2D_INPUT]),
override_configs={
'channels_first': Config(dict(data_format='channels_first')),
},
),
'GlobalAveragePooling3D':
generate_configs(
default_config=Config(shapes=[CONV_3D_INPUT]),
override_configs={
'channels_first': Config(dict(data_format='channels_first')),
},
),
'GlobalMaxPool1D':
generate_configs(
default_config=Config(shapes=[CONV_1D_INPUT]),
override_configs={
'channels_first': Config(dict(data_format='channels_first')),
},
),
'GlobalMaxPool2D':
generate_configs(
default_config=Config(shapes=[CONV_2D_INPUT]),
override_configs={
'channels_first': Config(dict(data_format='channels_first')),
},
),
'GlobalMaxPool3D':
generate_configs(
default_config=Config(shapes=[CONV_3D_INPUT]),
override_configs={
'channels_first': Config(dict(data_format='channels_first')),
},
),
'InputLayer':
Config(shapes=[RANK_2_INPUT]),
'LSTM':
generate_configs(
default_config=Config(
kwargs=dict(units=4, return_sequences=True),
shapes=[RANK_3_INPUT],
),
override_configs={
'implementation_1': Config(dict(implementation=1)),
'go_backwards': Config(dict(go_backwards=True)),
'time_major': Config(dict(time_major=True)),
'stateful': Config(dict(stateful=True)),
},
),
'LSTMCell':
Config(dict(units=4), [RANK_2_INPUT, RANK_2_INPUT]),
'Lambda':
Config(dict(function=lambda x: x**2), [RANK_2_INPUT]),
'LayerNormalization':
Config(shapes=[RANK_2_INPUT]),
'LeakyReLU':
Config(shapes=[RANK_2_INPUT]),
'LocallyConnected1D':
generate_configs(
default_config=Config(
kwargs=dict(filters=4, kernel_size=3),
shapes=[CONV_1D_INPUT],
),
override_configs={
'strides': Config(dict(strides=3)),
'padding_same': Config(dict(padding='same', implementation=2)),
'channels_first': Config(dict(data_format='channels_first')),
'sparse_implementation': Config(dict(implementation=3)),
},
),
'LocallyConnected2D':
generate_configs(
default_config=Config(
kwargs=dict(filters=4, kernel_size=3),
shapes=[CONV_2D_INPUT],
),
override_configs={
'strides': Config(dict(strides=3)),
'padding_same': Config(dict(padding='same', implementation=2)),
'channels_first': Config(dict(data_format='channels_first')),
'sparse_implementation': Config(dict(implementation=3)),
},
),
'Masking':
Config(shapes=[RANK_2_INPUT]),
'MaxPool1D':
generate_configs(
default_config=Config(shapes=[CONV_1D_INPUT]),
override_configs={
'strides': Config(dict(strides=3)),
'padding_same': Config(dict(padding='same')),
'channels_first': Config(dict(data_format='channels_first')),
},
),
'MaxPool2D':
generate_configs(
default_config=Config(shapes=[CONV_2D_INPUT]),
override_configs={
'strides': Config(dict(strides=3)),
'padding_same': Config(dict(padding='same')),
# TF: Default MaxPoolingOp only supports NHWC on device type CPU
# 'channels_first': Config(dict(data_format='channels_first')),
},
),
'MaxPool3D':
generate_configs(
default_config=Config(shapes=[CONV_3D_INPUT]),
override_configs={
'strides': Config(dict(strides=3)),
'padding_same': Config(dict(padding='same')),
'channels_first': Config(dict(data_format='channels_first')),
},
),
'Maximum':
Config(shapes=[RANK_2_INPUT, RANK_2_INPUT]),
'Minimum':
Config(shapes=[RANK_2_INPUT, RANK_2_INPUT]),
'MultiHeadAttention':
Config(dict(num_heads=2, key_dim=3), [RANK_3_INPUT, RANK_3_INPUT]),
'Multiply':
Config(shapes=[RANK_2_INPUT, RANK_2_INPUT]),
'PReLU':
Config(shapes=[RANK_2_INPUT]),
'Permute':
Config(dict(dims=(3, 1, 2)), [RANK_4_INPUT]),
'ReLU':
Config(shapes=[RANK_2_INPUT]),
'RepeatVector':
Config(dict(n=3), [RANK_2_INPUT]),
'Reshape':
Config(dict(target_shape=[1, 1, 1] + RANK_3_INPUT[1:]), [RANK_3_INPUT]),
'SeparableConv1D':
generate_configs(
default_config=Config(
kwargs=dict(filters=4, kernel_size=3),
shapes=[CONV_1D_INPUT],
),
override_configs={
'strides': Config(dict(strides=3)),
'padding_same': Config(dict(padding='same')),
# TF: Depthwise convolution on CPU is only supported for NHWC
# format
# 'channels_first': Config(dict(data_format='channels_first')),
'depth_multiplier': Config(dict(depth_multiplier=2)),
'dilation_rate': Config(dict(dilation_rate=2)),
},
),
'SeparableConv2D':
generate_configs(
default_config=Config(
kwargs=dict(filters=4, kernel_size=3),
shapes=[CONV_2D_INPUT],
),
override_configs={
'strides': Config(dict(strides=3)),
'padding_same': Config(dict(padding='same')),
# TF: Depthwise convolution on CPU is only supported for NHWC
# format
# 'channels_first': Config(dict(data_format='channels_first')),
'depth_multiplier': Config(dict(depth_multiplier=2)),
'dilation_rate': Config(dict(dilation_rate=2)),
},
),
'SimpleRNN':
generate_configs(
default_config=Config(
kwargs=dict(units=4, return_sequences=True),
shapes=[RANK_3_INPUT],
),
override_configs={
'go_backwards': Config(dict(go_backwards=True)),
'stateful': Config(dict(stateful=True)),
},
),
'SimpleRNNCell':
Config(dict(units=4), [RANK_2_INPUT, RANK_2_INPUT]),
'Softmax':
Config(shapes=[RANK_2_INPUT]),
'SpatialDropout1D':
Config(dict(rate=DROPOUT), [CONV_1D_INPUT]),
'SpatialDropout2D':
Config(dict(rate=DROPOUT), [CONV_2D_INPUT]),
'SpatialDropout3D':
Config(dict(rate=DROPOUT), [CONV_3D_INPUT]),
'Subtract':
Config(shapes=[RANK_2_INPUT, RANK_2_INPUT]),
'ThresholdedReLU':
Config(shapes=[RANK_2_INPUT]),
'UpSampling1D':
Config(shapes=[CONV_1D_INPUT]),
'UpSampling2D':
Config(shapes=[CONV_2D_INPUT]),
'UpSampling3D':
Config(shapes=[CONV_3D_INPUT]),
'ZeroPadding1D':
Config(shapes=[CONV_1D_INPUT]),
'ZeroPadding2D':
Config(shapes=[CONV_2D_INPUT]),
'ZeroPadding3D':
Config(shapes=[CONV_3D_INPUT]),
}
# Normalize LAYER_TO_UNITTEST_CONFIGURATIONS
for key, value in LAYER_TO_UNITTEST_CONFIGURATIONS.items():
if isinstance(value, Config):
LAYER_TO_UNITTEST_CONFIGURATIONS[key] = {'default': value}
# Layers that allow specifying the 'dropout' kwarg.
DROPOUT_LAYERS = [
'AdditiveAttention', 'Attention', 'ConvLSTM2D', 'GRU', 'GRUCell', 'LSTM',
'LSTMCell', 'MultiHeadAttention', 'SimpleRNN', 'SimpleRNNCell'
]
flags.DEFINE_string('layer', 'Dense',
f'One of {list(LAYER_TO_UNITTEST_CONFIGURATIONS.keys())}.')
flags.DEFINE_bool(
'dynamic_batch', False,
'Whether or not to compile the layer with a dynamic batch size.')
flags.DEFINE_bool('training', False,
'Whether or not to compile the layer in training mode.')
flags.DEFINE_bool(
'test_full_api', False,
'Whether or not to test multiple layer configurations using non-required '
'kwargs.')
flags.DEFINE_bool(
'list_layers_with_full_api_tests', False,
'Whether or not to print out all layers with non-default configurations '
'(and skip running the tests).')
def get_configs() -> Dict[str, Config]:
"""Gets the configs that we want to test for FLAGS.layer."""
configs = LAYER_TO_UNITTEST_CONFIGURATIONS[FLAGS.layer]
if not FLAGS.test_full_api:
return {'default': configs['default']}
return configs # pytype: disable=bad-return-type
def get_input(shape: Sequence[int]) -> tf.keras.layers.Input:
"""Gets the input shape(s) that we want to test."""
batch_size = None if FLAGS.dynamic_batch else shape[0]
return tf.keras.layers.Input(batch_size=batch_size, shape=shape[1:])
def keras_input_normalizer(inputs: Sequence[Any]) -> Union[Any, Sequence[Any]]:
"""Unpacks inputs if it has length one."""
return inputs[0] if len(inputs) == 1 else inputs
def keras_arg_wrapper(*args):
"""Wrapper to convert multiple positional args into a list of values."""
return list(args) if isinstance(args, tuple) else args
def create_wrapped_keras_layer(config: Config) -> tf.keras.Model:
"""Wraps a keras layer in a model for compilation."""
layer_class = getattr(tf.keras.layers, FLAGS.layer)
if FLAGS.training and FLAGS.layer in DROPOUT_LAYERS:
config.kwargs['dropout'] = DROPOUT
inputs = keras_input_normalizer([get_input(shape) for shape in config.shapes])
if FLAGS.layer == 'MultiHeadAttention':
# TODO(meadowlark): Remove specialization if API changes.
outputs = layer_class(**config.kwargs)(*inputs)
else:
outputs = layer_class(**config.kwargs)(inputs)
return tf.keras.Model(inputs, outputs)
def create_tf_function_unittest(config: Config, exported_name: str,
model: tf.keras.Model) -> tf.function:
"""Wrap the model's __call__ function in a tf.function for testing."""
input_shapes = config.shapes
if FLAGS.dynamic_batch:
input_shapes = [[None] + shape[1:] for shape in input_shapes]
input_signature = [tf.TensorSpec(shape) for shape in input_shapes]
if len(input_signature) > 1:
input_signature = [input_signature]
call = lambda *args: model(keras_arg_wrapper(*args), training=FLAGS.training)
return tf_test_utils.tf_function_unittest(input_signature=input_signature,
name=exported_name)(call)
class KerasLayersModule(tf_test_utils.TestModule):
@classmethod
def configure_class(cls):
"""Configure each tf_function_unittest and define it on the cls."""
for i, (exported_name, config) in enumerate(get_configs().items()):
model = create_wrapped_keras_layer(config)
setattr(cls, exported_name,
create_tf_function_unittest(config, exported_name, model))
def __init__(self):
super().__init__()
self.models = []
for i, (exported_name, config) in enumerate(get_configs().items()):
model = create_wrapped_keras_layer(config)
self.models.append(model)
setattr(self, exported_name,
create_tf_function_unittest(config, exported_name, model))
class KerasLayersTest(tf_test_utils.TracedModuleTestCase):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._modules = tf_test_utils.compile_tf_module(
KerasLayersModule,
exported_names=KerasLayersModule.get_exported_names())
def main(argv):
del argv # Unused.
if hasattr(tf, 'enable_v2_behavior'):
tf.enable_v2_behavior()
if FLAGS.layer not in LAYER_TO_UNITTEST_CONFIGURATIONS:
raise ValueError(f"Unrecognized layer: '{FLAGS.layer}'.")
if FLAGS.list_layers_with_full_api_tests:
for layer, configs in sorted(LAYER_TO_UNITTEST_CONFIGURATIONS.items()):
if len(configs) > 1:
print(f' "{layer}",')
return
# Set up name for saving artifacts.
dynamic_batch_str = 'dynamic_batch' if FLAGS.dynamic_batch else 'static_batch'
training_str = 'training' if FLAGS.training else 'non_training'
full_api_str = 'full_api' if FLAGS.test_full_api else 'default_api'
settings_str = f'{full_api_str}_{dynamic_batch_str}_{training_str}'
KerasLayersModule.__name__ = os.path.join('keras_layers', FLAGS.layer,
settings_str)
# Use the configurations for FLAGS.layer to add the tf.functions we wish
# to test to the KerasLayersModule, and then generate unittests for each of
# them.
KerasLayersModule.configure_class()
KerasLayersTest.generate_unittests(KerasLayersModule)
tf.test.main()
if __name__ == '__main__':
app.run(main)