blob: bdebbb7fe7bd06d251909dbbc4d3463f93a1ea55 [file] [log] [blame]
# Lint as: python3
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests of tf.keras.layer.Layer subclasses."""
import os
from typing import Any, Sequence, Union
from absl import app
from absl import flags
from pyiree.tf.support import tf_test_utils
from pyiree.tf.support import tf_utils
import tensorflow.compat.v2 as tf
FLAGS = flags.FLAGS
DROPOUT = 0.5
DIM = 4
RANK_2_INPUT = [DIM] * 2
RANK_3_INPUT = [DIM] * 3
RANK_4_INPUT = [DIM] * 4
CONV_1D_INPUT = [2, 8, 3]
CONV_2D_INPUT = [2, 8, 8, 3]
CONV_3D_INPUT = [2, 8, 8, 8, 3]
LAYER_TO_INPUT_SHAPES = {
'Activation': [RANK_2_INPUT],
'ActivityRegularization': [RANK_2_INPUT],
'Add': [RANK_2_INPUT, RANK_2_INPUT],
'AdditiveAttention': [RANK_3_INPUT, RANK_3_INPUT, RANK_3_INPUT],
'AlphaDropout': [RANK_2_INPUT],
'Attention': [RANK_3_INPUT, RANK_3_INPUT, RANK_3_INPUT],
'Average': [RANK_2_INPUT, RANK_2_INPUT],
'AveragePooling1D': [CONV_1D_INPUT],
'AveragePooling2D': [CONV_2D_INPUT],
'AveragePooling3D': [CONV_3D_INPUT],
'BatchNormalization': [RANK_2_INPUT],
'Concatenate': [RANK_4_INPUT, RANK_4_INPUT],
'Conv1D': [CONV_1D_INPUT],
'Conv1DTranspose': [CONV_1D_INPUT],
'Conv2D': [CONV_2D_INPUT],
'Conv2DTranspose': [CONV_2D_INPUT],
'Conv3D': [CONV_3D_INPUT],
'Conv3DTranspose': [CONV_3D_INPUT],
'ConvLSTM2D': [CONV_3D_INPUT],
'Cropping1D': [CONV_1D_INPUT],
'Cropping2D': [CONV_2D_INPUT],
'Cropping3D': [CONV_3D_INPUT],
'Dense': [RANK_2_INPUT],
'DepthwiseConv2D': [CONV_2D_INPUT],
'Dot': [RANK_3_INPUT, RANK_3_INPUT],
'Dropout': [RANK_3_INPUT],
'ELU': [RANK_2_INPUT],
'Embedding': [RANK_2_INPUT],
'Flatten': [RANK_2_INPUT],
'GRU': [RANK_3_INPUT],
'GRUCell': [RANK_2_INPUT, RANK_2_INPUT],
'GaussianDropout': [RANK_2_INPUT],
'GaussianNoise': [RANK_2_INPUT],
'GlobalAveragePooling1D': [CONV_1D_INPUT],
'GlobalAveragePooling2D': [CONV_2D_INPUT],
'GlobalAveragePooling3D': [CONV_3D_INPUT],
'GlobalMaxPool1D': [CONV_1D_INPUT],
'GlobalMaxPool2D': [CONV_2D_INPUT],
'GlobalMaxPool3D': [CONV_3D_INPUT],
'InputLayer': [RANK_2_INPUT],
'LSTM': [RANK_3_INPUT],
'LSTMCell': [RANK_2_INPUT, RANK_2_INPUT],
'Lambda': [RANK_2_INPUT],
'LayerNormalization': [RANK_2_INPUT],
'LeakyReLU': [RANK_2_INPUT],
'LocallyConnected1D': [CONV_1D_INPUT],
'LocallyConnected2D': [CONV_2D_INPUT],
'Masking': [RANK_2_INPUT],
'MaxPool1D': [CONV_1D_INPUT],
'MaxPool2D': [CONV_2D_INPUT],
'MaxPool3D': [CONV_3D_INPUT],
'Maximum': [RANK_2_INPUT, RANK_2_INPUT],
'Minimum': [RANK_2_INPUT, RANK_2_INPUT],
'MultiHeadAttention': [RANK_3_INPUT, RANK_3_INPUT],
'Multiply': [RANK_2_INPUT, RANK_2_INPUT],
'PReLU': [RANK_2_INPUT],
'Permute': [RANK_4_INPUT],
'ReLU': [RANK_2_INPUT],
'RepeatVector': [RANK_2_INPUT],
'Reshape': [RANK_3_INPUT],
'SeparableConv1D': [CONV_1D_INPUT],
'SeparableConv2D': [CONV_2D_INPUT],
'SimpleRNN': [RANK_3_INPUT],
'SimpleRNNCell': [RANK_2_INPUT, RANK_2_INPUT],
'Softmax': [RANK_2_INPUT],
'SpatialDropout1D': [CONV_1D_INPUT],
'SpatialDropout2D': [CONV_2D_INPUT],
'SpatialDropout3D': [CONV_3D_INPUT],
'Subtract': [RANK_2_INPUT, RANK_2_INPUT],
'ThresholdedReLU': [RANK_2_INPUT],
'UpSampling1D': [CONV_1D_INPUT],
'UpSampling2D': [CONV_2D_INPUT],
'UpSampling3D': [CONV_3D_INPUT],
'ZeroPadding1D': [CONV_1D_INPUT],
'ZeroPadding2D': [CONV_2D_INPUT],
'ZeroPadding3D': [CONV_3D_INPUT],
}
LAYER_TO_KWARGS = {
'Activation': dict(activation='relu'),
'ActivityRegularization': dict(),
'Add': dict(),
'AdditiveAttention': dict(),
'AlphaDropout': dict(rate=DROPOUT),
'Attention': dict(),
'Average': dict(),
'AveragePooling1D': dict(),
'AveragePooling2D': dict(),
'AveragePooling3D': dict(),
'BatchNormalization': dict(),
'Concatenate': dict(),
'Conv1D': dict(filters=4, kernel_size=3),
'Conv1DTranspose': dict(filters=4, kernel_size=3),
'Conv2D': dict(filters=4, kernel_size=3),
'Conv2DTranspose': dict(filters=4, kernel_size=3),
'Conv3D': dict(filters=4, kernel_size=3),
'Conv3DTranspose': dict(filters=4, kernel_size=3),
'ConvLSTM2D': dict(filters=4, kernel_size=3),
'Cropping1D': dict(cropping=2),
'Cropping2D': dict(cropping=2),
'Cropping3D': dict(cropping=2),
'Dense': dict(units=4),
'DepthwiseConv2D': dict(kernel_size=3),
'Dot': dict(axes=(1, 2)),
'Dropout': dict(rate=DROPOUT),
'ELU': dict(),
'Embedding': dict(input_dim=4, output_dim=2),
'Flatten': dict(),
'GRU': dict(units=4, return_sequences=True),
'GRUCell': dict(units=4),
'GaussianDropout': dict(rate=DROPOUT),
'GaussianNoise': dict(stddev=1.0),
'GlobalAveragePooling1D': dict(),
'GlobalAveragePooling2D': dict(),
'GlobalAveragePooling3D': dict(),
'GlobalMaxPool1D': dict(),
'GlobalMaxPool2D': dict(),
'GlobalMaxPool3D': dict(),
'InputLayer': dict(),
'LSTM': dict(units=4, return_sequences=True),
'LSTMCell': dict(units=4),
'Lambda': dict(function=lambda x: x**2),
'Layer': dict(),
'LayerNormalization': dict(),
'LeakyReLU': dict(),
'LocallyConnected1D': dict(filters=4, kernel_size=3),
'LocallyConnected2D': dict(filters=4, kernel_size=3),
'Masking': dict(),
'MaxPool1D': dict(),
'MaxPool2D': dict(),
'MaxPool3D': dict(),
'Maximum': dict(),
'Minimum': dict(),
'MultiHeadAttention': dict(num_heads=2, key_dim=3),
'Multiply': dict(),
'PReLU': dict(),
'Permute': dict(dims=(3, 1, 2)),
'ReLU': dict(),
'RepeatVector': dict(n=3),
'Reshape': dict(target_shape=[1, 1, 1] + RANK_3_INPUT[1:]),
'SeparableConv1D': dict(filters=4, kernel_size=3),
'SeparableConv2D': dict(filters=4, kernel_size=3),
'SimpleRNN': dict(units=4, return_sequences=True),
'SimpleRNNCell': dict(units=4),
'Softmax': dict(),
'SpatialDropout1D': dict(rate=DROPOUT),
'SpatialDropout2D': dict(rate=DROPOUT),
'SpatialDropout3D': dict(rate=DROPOUT),
'Subtract': dict(),
'ThresholdedReLU': dict(),
'UpSampling1D': dict(),
'UpSampling2D': dict(),
'UpSampling3D': dict(),
'ZeroPadding1D': dict(),
'ZeroPadding2D': dict(),
'ZeroPadding3D': dict(),
}
# Layers that allow specifying the 'dropout' kwarg.
DROPOUT_LAYERS = [
'AdditiveAttention', 'Attention', 'ConvLSTM2D', 'GRU', 'GRUCell', 'LSTM',
'LSTMCell', 'MultiHeadAttention', 'SimpleRNN', 'SimpleRNNCell'
]
flags.DEFINE_string('layer', 'Dense',
f'One of {list(LAYER_TO_INPUT_SHAPES.keys())}.')
flags.DEFINE_bool(
'dynamic_batch', False,
'Whether or not to compile the layer with a dynamic batch size.')
flags.DEFINE_bool('training', False,
'Whether or not to compile the layer in training mode.')
def get_input(shape: Sequence[int]) -> tf.keras.layers.Input:
batch_size = None if FLAGS.dynamic_batch else shape[0]
return tf.keras.layers.Input(batch_size=batch_size, shape=shape[1:])
def normalize(inputs: Sequence[Any]) -> Union[Any, Sequence[Any]]:
"""Unpacks inputs if it has length one."""
return inputs[0] if len(inputs) == 1 else inputs
class KerasLayersModule(tf.Module):
def __init__(self):
super().__init__()
layer_class = getattr(tf.keras.layers, FLAGS.layer)
input_shapes = LAYER_TO_INPUT_SHAPES[FLAGS.layer]
kwargs = LAYER_TO_KWARGS[FLAGS.layer]
if FLAGS.training and FLAGS.layer in DROPOUT_LAYERS:
kwargs['dropout'] = DROPOUT
# Create a wrapped keras layer.
inputs = normalize([get_input(shape) for shape in input_shapes])
if FLAGS.layer == 'MultiHeadAttention':
# TODO(meadowlark): Fix this if keras updates their API to be consistent.
outputs = layer_class(**kwargs)(*inputs)
else:
outputs = layer_class(**kwargs)(inputs)
self.m = tf.keras.Model(inputs, outputs)
# Wrap the layer in a tf.function.
if FLAGS.dynamic_batch:
input_shapes = [[None] + shape[1:] for shape in input_shapes]
input_signature = [tf.TensorSpec(shape) for shape in input_shapes]
if len(input_signature) > 1:
input_signature = [input_signature]
self.call = tf.function(input_signature=input_signature)(
lambda x: self.m(x, training=FLAGS.training))
class KerasLayersTest(tf_test_utils.TracedModuleTestCase):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._modules = tf_test_utils.compile_tf_module(KerasLayersModule,
exported_names=['call'])
def test_call(self):
def call(module):
input_shapes = LAYER_TO_INPUT_SHAPES[FLAGS.layer]
inputs = normalize([tf_utils.uniform(shape) for shape in input_shapes])
module.call(inputs)
self.compare_backends(call, self._modules)
def main(argv):
del argv # Unused.
if hasattr(tf, 'enable_v2_behavior'):
tf.enable_v2_behavior()
if FLAGS.layer not in LAYER_TO_INPUT_SHAPES:
raise ValueError(f"Unrecognized layer: '{FLAGS.layer}'.")
dynamic_batch_str = 'dynamic_batch' if FLAGS.dynamic_batch else 'static_batch'
training_str = 'training' if FLAGS.training else 'non_training'
settings_str = f'{dynamic_batch_str}_{training_str}'
KerasLayersModule.__name__ = os.path.join('keras_layers', FLAGS.layer,
settings_str)
tf.test.main()
if __name__ == '__main__':
app.run(main)