Open Source KWS Streaming Tests (#3478)

Open sources tests of [kws_steaming](https://github.com/google-research/google-research/tree/master/kws_streaming) models by adding a Bazel `http_archive` and `BUILD.overlay`.

The longest of these tests complete in ~60s, so they shouldn't add a significant amount of time to the CI.
diff --git a/WORKSPACE b/WORKSPACE
index a6db513..25c8a0e 100644
--- a/WORKSPACE
+++ b/WORKSPACE
@@ -299,5 +299,15 @@
     path = "third_party/cpuinfo",
 )
 
+GOOGLE_RESEARCH_COMMIT = "1dbf7f4af77ac032ddcf68a7978cc056897015a7"
+
+http_archive(
+    name = "kws_streaming",
+    build_file = "@//:build_tools/third_party/kws_streaming/BUILD.overlay",
+    sha256 = "cdb0b71914999a9cb11b5a80eb16769687c9714d9ac706e6c1cf081c3afbd976",
+    strip_prefix = "google-research-{}/kws_streaming".format(GOOGLE_RESEARCH_COMMIT),
+    url = "https://github.com/google-research/google-research/archive/{}.tar.gz".format(GOOGLE_RESEARCH_COMMIT),
+)
+
 # Bootstrap TensorFlow deps last so that ours can take precedence.
 tf_repositories()
diff --git a/build_tools/third_party/kws_streaming/BUILD.overlay b/build_tools/third_party/kws_streaming/BUILD.overlay
new file mode 100644
index 0000000..3fb647c
--- /dev/null
+++ b/build_tools/third_party/kws_streaming/BUILD.overlay
@@ -0,0 +1,119 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+package(
+    default_visibility = ["//visibility:public"],
+)
+
+py_library(
+    name = "train_lib",
+    srcs = [
+        "train/base_parser.py",
+        "train/model_flags.py",
+        "train/test.py",
+        "train/train.py",
+    ],
+    srcs_version = "PY3",
+    deps = [
+        ":input_data_lib",
+        ":models_lib",
+    ],
+)
+
+py_library(
+    name = "models_lib",
+    srcs = [
+        "models/att_mh_rnn.py",
+        "models/att_rnn.py",
+        "models/cnn.py",
+        "models/crnn.py",
+        "models/dnn.py",
+        "models/dnn_raw.py",
+        "models/ds_cnn.py",
+        "models/ds_tc_resnet.py",
+        "models/gru.py",
+        "models/inception.py",
+        "models/inception_resnet.py",
+        "models/lstm.py",
+        "models/mobilenet.py",
+        "models/mobilenet_v2.py",
+        "models/model_params.py",
+        "models/models.py",
+        "models/svdf.py",
+        "models/svdf_resnet.py",
+        "models/tc_resnet.py",
+        "models/utils.py",
+        "models/xception.py",
+    ],
+    srcs_version = "PY3",
+    deps = [
+        ":layers_compat",
+        ":layers_lib",
+    ],
+)
+
+py_library(
+    name = "input_data_lib",
+    srcs = ["data/input_data.py"],
+    srcs_version = "PY3",
+    deps = [
+        ":layers_lib",
+    ],
+)
+
+py_library(
+    name = "layers_lib",
+    srcs = [
+        "layers/contrib_conv2d.py",
+        "layers/conv1d_transpose.py",
+        "layers/data_frame.py",
+        "layers/dct.py",
+        "layers/delay.py",
+        "layers/depthwise_conv1d.py",
+        "layers/gru.py",
+        "layers/lstm.py",
+        "layers/magnitude_rdft.py",
+        "layers/magnitude_rdft_mel.py",
+        "layers/mel_spectrogram.py",
+        "layers/mel_table.py",
+        "layers/modes.py",
+        "layers/non_scaling_dropout.py",
+        "layers/normalizer.py",
+        "layers/preemphasis.py",
+        "layers/random_shift.py",
+        "layers/random_stretch_squeeze.py",
+        "layers/spectrogram_augment.py",
+        "layers/spectrogram_cutout.py",
+        "layers/speech_features.py",
+        "layers/stream.py",
+        "layers/svdf.py",
+        "layers/temporal_padding.py",
+        "layers/windowing.py",
+    ],
+    srcs_version = "PY3",
+    deps = [
+        ":layers_compat",
+        "@absl_py//absl/logging",
+    ],
+)
+
+py_library(
+    name = "layers_compat",
+    srcs = ["layers/compat.py"],
+    srcs_version = "PY3",
+    deps = [
+        "@absl_py//absl/flags",
+        "@absl_py//absl/logging",
+    ],
+)
diff --git a/integrations/tensorflow/e2e/keras/BUILD b/integrations/tensorflow/e2e/keras/BUILD
index f93e9d1..9875dde 100644
--- a/integrations/tensorflow/e2e/keras/BUILD
+++ b/integrations/tensorflow/e2e/keras/BUILD
@@ -80,6 +80,7 @@
 ]
 
 SPECIAL_CASES = [
+    "keyword_spotting_streaming_test.py",
     "vision_model_test.py",
 ]
 
@@ -372,3 +373,121 @@
         "//integrations/tensorflow/bindings/python/pyiree/tf/support",
     ],
 )
+
+# Keyword Spotting Tests:
+KEYWORD_SPOTTING_MODELS = [
+    "svdf",
+    "svdf_resnet",
+    "ds_cnn",
+    "gru",
+    "lstm",
+    "cnn_stride",
+    "cnn",
+    "tc_resnet",
+    "crnn",
+    "dnn",
+    "att_rnn",
+    "att_mh_rnn",
+    "mobilenet",
+    "mobilenet_v2",
+    "xception",
+    "inception",
+    "inception_resnet",
+    "ds_tc_resnet",
+]
+
+iree_e2e_cartesian_product_test_suite(
+    name = "keyword_spotting_tests",
+    srcs = ["keyword_spotting_streaming_test.py"],
+    failing_configurations = [
+        {
+            # Failing on IREE:
+            "model": [
+                "att_mh_rnn",  # b/147824465
+                "att_rnn",  # b/147824465
+                "crnn",  # b/162067867
+                "ds_tc_resnet",
+                "gru",  # b/162067867
+                "lstm",  # b/162067867
+                "svdf_resnet",  # b/171512071
+                "xception",  # b/171512071
+            ],
+            "target_backends": [
+                "iree_vmla",
+                "iree_llvmjit",
+                "iree_vulkan",
+            ],
+        },
+    ],
+    flags_to_values = {
+        "reference_backend": "tf",
+        "mode": "non_streaming",
+        "model": KEYWORD_SPOTTING_MODELS,
+        "target_backends": [
+            "tf",
+            "tflite",
+            "iree_vmla",
+            "iree_llvmjit",
+            "iree_vulkan",
+        ],
+    },
+    main = "keyword_spotting_streaming_test.py",
+    deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
+        "//integrations/tensorflow/bindings/python/pyiree/tf/support",
+        "@kws_streaming//:models_lib",
+        "@kws_streaming//:train_lib",
+    ],
+)
+
+iree_e2e_cartesian_product_test_suite(
+    name = "keyword_spotting_internal_streaming_tests",
+    srcs = ["keyword_spotting_streaming_test.py"],
+    failing_configurations = [
+        {
+            # TFLite cannot compile variables.
+            "target_backends": "tflite",
+        },
+        {
+            # These models do not currently support streaming.
+            "model": [
+                "att_mh_rnn",
+                "att_rnn",
+                "ds_cnn",
+                "inception",
+                "inception_resnet",
+                "mobilenet",
+                "mobilenet_v2",
+                "svdf_resnet",
+                "tc_resnet",
+                "xception",
+            ],
+        },
+        {
+            # Failing on IREE:
+            "model": "ds_tc_resnet",
+            "target_backends": [
+                "iree_vmla",
+                "iree_llvmjit",
+                "iree_vulkan",
+            ],
+        },
+    ],
+    flags_to_values = {
+        "reference_backend": "tf",
+        "mode": "internal_streaming",
+        "model": KEYWORD_SPOTTING_MODELS,
+        "target_backends": [
+            "tf",
+            "tflite",
+            "iree_vmla",
+            "iree_llvmjit",
+            "iree_vulkan",
+        ],
+    },
+    main = "keyword_spotting_streaming_test.py",
+    deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
+        "//integrations/tensorflow/bindings/python/pyiree/tf/support",
+        "@kws_streaming//:models_lib",
+        "@kws_streaming//:train_lib",
+    ],
+)
diff --git a/integrations/tensorflow/e2e/keras/keyword_spotting_streaming_test.py b/integrations/tensorflow/e2e/keras/keyword_spotting_streaming_test.py
new file mode 100644
index 0000000..ef9a6f0
--- /dev/null
+++ b/integrations/tensorflow/e2e/keras/keyword_spotting_streaming_test.py
@@ -0,0 +1,117 @@
+# Lint as: python3
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tests of streamable Keyword Spotting models implemented in Keras."""
+
+import os
+import sys
+import pathlib
+
+from absl import app
+from absl import flags
+from absl import logging
+import numpy as np
+from pyiree.tf.support import tf_test_utils
+from pyiree.tf.support import tf_utils
+import tensorflow.compat.v2 as tf
+
+from kws_streaming.layers import modes
+from kws_streaming.models import model_params
+from kws_streaming.models import models
+from kws_streaming.models import utils
+from kws_streaming.train import model_flags
+
+FLAGS = flags.FLAGS
+
+ALL_MODELS = list(model_params.HOTWORD_MODEL_PARAMS.keys())
+MODELS_HELP = [f"'{name}'" for name in ALL_MODELS]
+MODELS_HELP = f'{", ".join(MODELS_HELP[:-1])}, or {MODELS_HELP[-1]}'
+
+flags.DEFINE_string(
+    'model', 'svdf', f'Name of the model to compile. Either {MODELS_HELP}.\n'
+    'See https://github.com/google-research/google-research/blob/master/kws_streaming/models/models.py#L38-L58'
+)
+flags.DEFINE_enum('mode', 'non_streaming',
+                  ['non_streaming', 'internal_streaming'],
+                  'Mode to execute the model in.')
+
+MODE_ENUM_TO_MODE = {
+    'non_streaming': modes.Modes.NON_STREAM_INFERENCE,
+    'internal_streaming': modes.Modes.STREAM_INTERNAL_STATE_INFERENCE,
+}
+MODE_TO_INPUT_SHAPE = {
+    'non_streaming': (1, 16000),
+    'internal_streaming': (1, 320),
+}
+
+
+def get_input_shape():
+  return MODE_TO_INPUT_SHAPE[FLAGS.mode]
+
+
+def initialize_model():
+  params = model_params.HOTWORD_MODEL_PARAMS[FLAGS.model]
+  params = model_flags.update_flags(params)
+  model = models.MODELS[params.model_name](params)
+
+  if FLAGS.mode == 'internal_streaming':
+    mode = MODE_ENUM_TO_MODE[FLAGS.mode]
+    input_shape = get_input_shape()
+    params.batch_size = input_shape[0]
+    params.desired_samples = input_shape[1]
+    model = utils.to_streaming_inference(model, flags=params, mode=mode)
+
+  return model
+
+
+class KeywordSpottingModule(tf.Module):
+
+  def __init__(self):
+    super().__init__()
+    self.m = initialize_model()
+    self.m.predict = lambda x: self.m.call(x, training=False)
+    self.predict = tf.function(
+        input_signature=[tf.TensorSpec(get_input_shape())])(self.m.predict)
+
+
+class KeywordSpottingTest(tf_test_utils.TracedModuleTestCase):
+
+  def __init__(self, *args, **kwargs):
+    super().__init__(*args, **kwargs)
+    self._modules = tf_test_utils.compile_tf_module(KeywordSpottingModule,
+                                                    exported_names=['predict'])
+
+  def test_predict(self):
+
+    def predict(module):
+      module.predict(tf_utils.uniform(get_input_shape()), atol=1e-5)
+
+    self.compare_backends(predict, self._modules)
+
+
+def main(argv):
+  del argv  # Unused.
+  if hasattr(tf, 'enable_v2_behavior'):
+    tf.enable_v2_behavior()
+
+  if FLAGS.model not in ALL_MODELS:
+    raise ValueError(f'Unsupported model: {FLAGS.model}.\n'
+                     f'Expected one of {MODELS_HELP}.')
+  KeywordSpottingModule.__name__ = f'kws_{FLAGS.model}'
+
+  tf.test.main()
+
+
+if __name__ == '__main__':
+  app.run(main)
diff --git a/scripts/update_e2e_coverage.py b/scripts/update_e2e_coverage.py
index 35732c6..11d8d52 100755
--- a/scripts/update_e2e_coverage.py
+++ b/scripts/update_e2e_coverage.py
@@ -37,6 +37,10 @@
     ('iree_vulkan', 'vulkan-spirv'),
 ])
 
+KWS_LINK = (
+    'https://github.com/google-research/google-research/tree/master/kws_streaming'
+)
+KWS_LINK = f'[Keyword Spotting Streaming]({KWS_LINK})'
 TEST_SUITES_TO_HEADERS = {
     '//integrations/tensorflow/e2e:e2e_tests':
         'End to end TensorFlow tests',
@@ -44,6 +48,10 @@
         'End to end test of MobileBert on SQuAD',
     '//integrations/tensorflow/e2e/keras:keras_tests':
         'End to end tests written using tf.keras',
+    '//integrations/tensorflow/e2e/keras:keyword_spotting_tests':
+        f'End to end tests of {KWS_LINK} models',
+    '//integrations/tensorflow/e2e/keras:keyword_spotting_internal_streaming_tests':
+        f'End to end tests of {KWS_LINK} models in internal streaming mode',
     '//integrations/tensorflow/e2e/keras:imagenet_external_tests':
         'End to end tests of tf.keras.applications vision models on Imagenet',
     '//integrations/tensorflow/e2e/slim_vision_models:slim_vision_tests':
@@ -53,6 +61,10 @@
 # Key to use as the name of the rows in the left column for each test in the
 # suite.
 TEST_SUITE_TO_ROW_ID_KEY = {
+    '//integrations/tensorflow/e2e/keras:keyword_spotting_tests':
+        'model',
+    '//integrations/tensorflow/e2e/keras:keyword_spotting_internal_streaming_tests':
+        'model',
     '//integrations/tensorflow/e2e/keras:imagenet_external_tests':
         'model',
     '//integrations/tensorflow/e2e/slim_vision_models:slim_vision_tests':
@@ -62,6 +74,10 @@
 # Some test suites are generated from a single source. This allows us to point
 # to the right test file when generating test URLs.
 SINGLE_SOURCE_SUITES = {
+    '//integrations/tensorflow/e2e/keras:keyword_spotting_tests':
+        'keyword_spotting_streaming_test',
+    '//integrations/tensorflow/e2e/keras:keyword_spotting_internal_streaming_tests':
+        'keyword_spotting_streaming_test',
     '//integrations/tensorflow/e2e/keras:imagenet_external_tests':
         'vision_model_test',
     '//integrations/tensorflow/e2e/slim_vision_models:slim_vision_tests':
@@ -71,7 +87,6 @@
 TARGET_EXCLUSION_FILTERS = [
     r'mobilenet_v1_.*',  # Slim vision MobileNetV1.
     r'mobilenet_v2_.*',  # Slim vision MobileNetV2.
-    r'amoebanet_a_n18_f448',  # SavedModelV2 not available.
 ]
 
 # The symbols to show in the table if the operation is supported or not.
@@ -194,6 +209,11 @@
   # Generate the coverage table as a 2D array.
   rows = [first_row, second_row]
   for row_id, backends in sorted(table.items()):
+    # If the reference backend is failing then there is no reason to show the
+    # coverage of the other backends.
+    if not backends[ordered_backends.index(REFERENCE_BACKEND)]:
+      continue
+
     # Skip any rows defined in the TARGET_EXCLUSION_FILTERS.
     if any(re.match(pattern, row_id) for pattern in TARGET_EXCLUSION_FILTERS):
       continue