Open Source KWS Streaming Tests (#3478)
Open sources tests of [kws_steaming](https://github.com/google-research/google-research/tree/master/kws_streaming) models by adding a Bazel `http_archive` and `BUILD.overlay`.
The longest of these tests complete in ~60s, so they shouldn't add a significant amount of time to the CI.
diff --git a/WORKSPACE b/WORKSPACE
index a6db513..25c8a0e 100644
--- a/WORKSPACE
+++ b/WORKSPACE
@@ -299,5 +299,15 @@
path = "third_party/cpuinfo",
)
+GOOGLE_RESEARCH_COMMIT = "1dbf7f4af77ac032ddcf68a7978cc056897015a7"
+
+http_archive(
+ name = "kws_streaming",
+ build_file = "@//:build_tools/third_party/kws_streaming/BUILD.overlay",
+ sha256 = "cdb0b71914999a9cb11b5a80eb16769687c9714d9ac706e6c1cf081c3afbd976",
+ strip_prefix = "google-research-{}/kws_streaming".format(GOOGLE_RESEARCH_COMMIT),
+ url = "https://github.com/google-research/google-research/archive/{}.tar.gz".format(GOOGLE_RESEARCH_COMMIT),
+)
+
# Bootstrap TensorFlow deps last so that ours can take precedence.
tf_repositories()
diff --git a/build_tools/third_party/kws_streaming/BUILD.overlay b/build_tools/third_party/kws_streaming/BUILD.overlay
new file mode 100644
index 0000000..3fb647c
--- /dev/null
+++ b/build_tools/third_party/kws_streaming/BUILD.overlay
@@ -0,0 +1,119 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+package(
+ default_visibility = ["//visibility:public"],
+)
+
+py_library(
+ name = "train_lib",
+ srcs = [
+ "train/base_parser.py",
+ "train/model_flags.py",
+ "train/test.py",
+ "train/train.py",
+ ],
+ srcs_version = "PY3",
+ deps = [
+ ":input_data_lib",
+ ":models_lib",
+ ],
+)
+
+py_library(
+ name = "models_lib",
+ srcs = [
+ "models/att_mh_rnn.py",
+ "models/att_rnn.py",
+ "models/cnn.py",
+ "models/crnn.py",
+ "models/dnn.py",
+ "models/dnn_raw.py",
+ "models/ds_cnn.py",
+ "models/ds_tc_resnet.py",
+ "models/gru.py",
+ "models/inception.py",
+ "models/inception_resnet.py",
+ "models/lstm.py",
+ "models/mobilenet.py",
+ "models/mobilenet_v2.py",
+ "models/model_params.py",
+ "models/models.py",
+ "models/svdf.py",
+ "models/svdf_resnet.py",
+ "models/tc_resnet.py",
+ "models/utils.py",
+ "models/xception.py",
+ ],
+ srcs_version = "PY3",
+ deps = [
+ ":layers_compat",
+ ":layers_lib",
+ ],
+)
+
+py_library(
+ name = "input_data_lib",
+ srcs = ["data/input_data.py"],
+ srcs_version = "PY3",
+ deps = [
+ ":layers_lib",
+ ],
+)
+
+py_library(
+ name = "layers_lib",
+ srcs = [
+ "layers/contrib_conv2d.py",
+ "layers/conv1d_transpose.py",
+ "layers/data_frame.py",
+ "layers/dct.py",
+ "layers/delay.py",
+ "layers/depthwise_conv1d.py",
+ "layers/gru.py",
+ "layers/lstm.py",
+ "layers/magnitude_rdft.py",
+ "layers/magnitude_rdft_mel.py",
+ "layers/mel_spectrogram.py",
+ "layers/mel_table.py",
+ "layers/modes.py",
+ "layers/non_scaling_dropout.py",
+ "layers/normalizer.py",
+ "layers/preemphasis.py",
+ "layers/random_shift.py",
+ "layers/random_stretch_squeeze.py",
+ "layers/spectrogram_augment.py",
+ "layers/spectrogram_cutout.py",
+ "layers/speech_features.py",
+ "layers/stream.py",
+ "layers/svdf.py",
+ "layers/temporal_padding.py",
+ "layers/windowing.py",
+ ],
+ srcs_version = "PY3",
+ deps = [
+ ":layers_compat",
+ "@absl_py//absl/logging",
+ ],
+)
+
+py_library(
+ name = "layers_compat",
+ srcs = ["layers/compat.py"],
+ srcs_version = "PY3",
+ deps = [
+ "@absl_py//absl/flags",
+ "@absl_py//absl/logging",
+ ],
+)
diff --git a/integrations/tensorflow/e2e/keras/BUILD b/integrations/tensorflow/e2e/keras/BUILD
index f93e9d1..9875dde 100644
--- a/integrations/tensorflow/e2e/keras/BUILD
+++ b/integrations/tensorflow/e2e/keras/BUILD
@@ -80,6 +80,7 @@
]
SPECIAL_CASES = [
+ "keyword_spotting_streaming_test.py",
"vision_model_test.py",
]
@@ -372,3 +373,121 @@
"//integrations/tensorflow/bindings/python/pyiree/tf/support",
],
)
+
+# Keyword Spotting Tests:
+KEYWORD_SPOTTING_MODELS = [
+ "svdf",
+ "svdf_resnet",
+ "ds_cnn",
+ "gru",
+ "lstm",
+ "cnn_stride",
+ "cnn",
+ "tc_resnet",
+ "crnn",
+ "dnn",
+ "att_rnn",
+ "att_mh_rnn",
+ "mobilenet",
+ "mobilenet_v2",
+ "xception",
+ "inception",
+ "inception_resnet",
+ "ds_tc_resnet",
+]
+
+iree_e2e_cartesian_product_test_suite(
+ name = "keyword_spotting_tests",
+ srcs = ["keyword_spotting_streaming_test.py"],
+ failing_configurations = [
+ {
+ # Failing on IREE:
+ "model": [
+ "att_mh_rnn", # b/147824465
+ "att_rnn", # b/147824465
+ "crnn", # b/162067867
+ "ds_tc_resnet",
+ "gru", # b/162067867
+ "lstm", # b/162067867
+ "svdf_resnet", # b/171512071
+ "xception", # b/171512071
+ ],
+ "target_backends": [
+ "iree_vmla",
+ "iree_llvmjit",
+ "iree_vulkan",
+ ],
+ },
+ ],
+ flags_to_values = {
+ "reference_backend": "tf",
+ "mode": "non_streaming",
+ "model": KEYWORD_SPOTTING_MODELS,
+ "target_backends": [
+ "tf",
+ "tflite",
+ "iree_vmla",
+ "iree_llvmjit",
+ "iree_vulkan",
+ ],
+ },
+ main = "keyword_spotting_streaming_test.py",
+ deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
+ "//integrations/tensorflow/bindings/python/pyiree/tf/support",
+ "@kws_streaming//:models_lib",
+ "@kws_streaming//:train_lib",
+ ],
+)
+
+iree_e2e_cartesian_product_test_suite(
+ name = "keyword_spotting_internal_streaming_tests",
+ srcs = ["keyword_spotting_streaming_test.py"],
+ failing_configurations = [
+ {
+ # TFLite cannot compile variables.
+ "target_backends": "tflite",
+ },
+ {
+ # These models do not currently support streaming.
+ "model": [
+ "att_mh_rnn",
+ "att_rnn",
+ "ds_cnn",
+ "inception",
+ "inception_resnet",
+ "mobilenet",
+ "mobilenet_v2",
+ "svdf_resnet",
+ "tc_resnet",
+ "xception",
+ ],
+ },
+ {
+ # Failing on IREE:
+ "model": "ds_tc_resnet",
+ "target_backends": [
+ "iree_vmla",
+ "iree_llvmjit",
+ "iree_vulkan",
+ ],
+ },
+ ],
+ flags_to_values = {
+ "reference_backend": "tf",
+ "mode": "internal_streaming",
+ "model": KEYWORD_SPOTTING_MODELS,
+ "target_backends": [
+ "tf",
+ "tflite",
+ "iree_vmla",
+ "iree_llvmjit",
+ "iree_vulkan",
+ ],
+ },
+ main = "keyword_spotting_streaming_test.py",
+ deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
+ "//integrations/tensorflow/bindings/python/pyiree/tf/support",
+ "@kws_streaming//:models_lib",
+ "@kws_streaming//:train_lib",
+ ],
+)
diff --git a/integrations/tensorflow/e2e/keras/keyword_spotting_streaming_test.py b/integrations/tensorflow/e2e/keras/keyword_spotting_streaming_test.py
new file mode 100644
index 0000000..ef9a6f0
--- /dev/null
+++ b/integrations/tensorflow/e2e/keras/keyword_spotting_streaming_test.py
@@ -0,0 +1,117 @@
+# Lint as: python3
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tests of streamable Keyword Spotting models implemented in Keras."""
+
+import os
+import sys
+import pathlib
+
+from absl import app
+from absl import flags
+from absl import logging
+import numpy as np
+from pyiree.tf.support import tf_test_utils
+from pyiree.tf.support import tf_utils
+import tensorflow.compat.v2 as tf
+
+from kws_streaming.layers import modes
+from kws_streaming.models import model_params
+from kws_streaming.models import models
+from kws_streaming.models import utils
+from kws_streaming.train import model_flags
+
+FLAGS = flags.FLAGS
+
+ALL_MODELS = list(model_params.HOTWORD_MODEL_PARAMS.keys())
+MODELS_HELP = [f"'{name}'" for name in ALL_MODELS]
+MODELS_HELP = f'{", ".join(MODELS_HELP[:-1])}, or {MODELS_HELP[-1]}'
+
+flags.DEFINE_string(
+ 'model', 'svdf', f'Name of the model to compile. Either {MODELS_HELP}.\n'
+ 'See https://github.com/google-research/google-research/blob/master/kws_streaming/models/models.py#L38-L58'
+)
+flags.DEFINE_enum('mode', 'non_streaming',
+ ['non_streaming', 'internal_streaming'],
+ 'Mode to execute the model in.')
+
+MODE_ENUM_TO_MODE = {
+ 'non_streaming': modes.Modes.NON_STREAM_INFERENCE,
+ 'internal_streaming': modes.Modes.STREAM_INTERNAL_STATE_INFERENCE,
+}
+MODE_TO_INPUT_SHAPE = {
+ 'non_streaming': (1, 16000),
+ 'internal_streaming': (1, 320),
+}
+
+
+def get_input_shape():
+ return MODE_TO_INPUT_SHAPE[FLAGS.mode]
+
+
+def initialize_model():
+ params = model_params.HOTWORD_MODEL_PARAMS[FLAGS.model]
+ params = model_flags.update_flags(params)
+ model = models.MODELS[params.model_name](params)
+
+ if FLAGS.mode == 'internal_streaming':
+ mode = MODE_ENUM_TO_MODE[FLAGS.mode]
+ input_shape = get_input_shape()
+ params.batch_size = input_shape[0]
+ params.desired_samples = input_shape[1]
+ model = utils.to_streaming_inference(model, flags=params, mode=mode)
+
+ return model
+
+
+class KeywordSpottingModule(tf.Module):
+
+ def __init__(self):
+ super().__init__()
+ self.m = initialize_model()
+ self.m.predict = lambda x: self.m.call(x, training=False)
+ self.predict = tf.function(
+ input_signature=[tf.TensorSpec(get_input_shape())])(self.m.predict)
+
+
+class KeywordSpottingTest(tf_test_utils.TracedModuleTestCase):
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self._modules = tf_test_utils.compile_tf_module(KeywordSpottingModule,
+ exported_names=['predict'])
+
+ def test_predict(self):
+
+ def predict(module):
+ module.predict(tf_utils.uniform(get_input_shape()), atol=1e-5)
+
+ self.compare_backends(predict, self._modules)
+
+
+def main(argv):
+ del argv # Unused.
+ if hasattr(tf, 'enable_v2_behavior'):
+ tf.enable_v2_behavior()
+
+ if FLAGS.model not in ALL_MODELS:
+ raise ValueError(f'Unsupported model: {FLAGS.model}.\n'
+ f'Expected one of {MODELS_HELP}.')
+ KeywordSpottingModule.__name__ = f'kws_{FLAGS.model}'
+
+ tf.test.main()
+
+
+if __name__ == '__main__':
+ app.run(main)
diff --git a/scripts/update_e2e_coverage.py b/scripts/update_e2e_coverage.py
index 35732c6..11d8d52 100755
--- a/scripts/update_e2e_coverage.py
+++ b/scripts/update_e2e_coverage.py
@@ -37,6 +37,10 @@
('iree_vulkan', 'vulkan-spirv'),
])
+KWS_LINK = (
+ 'https://github.com/google-research/google-research/tree/master/kws_streaming'
+)
+KWS_LINK = f'[Keyword Spotting Streaming]({KWS_LINK})'
TEST_SUITES_TO_HEADERS = {
'//integrations/tensorflow/e2e:e2e_tests':
'End to end TensorFlow tests',
@@ -44,6 +48,10 @@
'End to end test of MobileBert on SQuAD',
'//integrations/tensorflow/e2e/keras:keras_tests':
'End to end tests written using tf.keras',
+ '//integrations/tensorflow/e2e/keras:keyword_spotting_tests':
+ f'End to end tests of {KWS_LINK} models',
+ '//integrations/tensorflow/e2e/keras:keyword_spotting_internal_streaming_tests':
+ f'End to end tests of {KWS_LINK} models in internal streaming mode',
'//integrations/tensorflow/e2e/keras:imagenet_external_tests':
'End to end tests of tf.keras.applications vision models on Imagenet',
'//integrations/tensorflow/e2e/slim_vision_models:slim_vision_tests':
@@ -53,6 +61,10 @@
# Key to use as the name of the rows in the left column for each test in the
# suite.
TEST_SUITE_TO_ROW_ID_KEY = {
+ '//integrations/tensorflow/e2e/keras:keyword_spotting_tests':
+ 'model',
+ '//integrations/tensorflow/e2e/keras:keyword_spotting_internal_streaming_tests':
+ 'model',
'//integrations/tensorflow/e2e/keras:imagenet_external_tests':
'model',
'//integrations/tensorflow/e2e/slim_vision_models:slim_vision_tests':
@@ -62,6 +74,10 @@
# Some test suites are generated from a single source. This allows us to point
# to the right test file when generating test URLs.
SINGLE_SOURCE_SUITES = {
+ '//integrations/tensorflow/e2e/keras:keyword_spotting_tests':
+ 'keyword_spotting_streaming_test',
+ '//integrations/tensorflow/e2e/keras:keyword_spotting_internal_streaming_tests':
+ 'keyword_spotting_streaming_test',
'//integrations/tensorflow/e2e/keras:imagenet_external_tests':
'vision_model_test',
'//integrations/tensorflow/e2e/slim_vision_models:slim_vision_tests':
@@ -71,7 +87,6 @@
TARGET_EXCLUSION_FILTERS = [
r'mobilenet_v1_.*', # Slim vision MobileNetV1.
r'mobilenet_v2_.*', # Slim vision MobileNetV2.
- r'amoebanet_a_n18_f448', # SavedModelV2 not available.
]
# The symbols to show in the table if the operation is supported or not.
@@ -194,6 +209,11 @@
# Generate the coverage table as a 2D array.
rows = [first_row, second_row]
for row_id, backends in sorted(table.items()):
+ # If the reference backend is failing then there is no reason to show the
+ # coverage of the other backends.
+ if not backends[ordered_backends.index(REFERENCE_BACKEND)]:
+ continue
+
# Skip any rows defined in the TARGET_EXCLUSION_FILTERS.
if any(re.match(pattern, row_id) for pattern in TARGET_EXCLUSION_FILTERS):
continue