blob: 6536cb06006350e320b14f409a9b0cb592e26cae [file] [log] [blame]
# Copyright 2020 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# ***THIS FILE DOES NOT BUILD WITH BAZEL***
#
# It is open sourced to enable Bazel->CMake conversion to maintain test coverage
# of our integration tests in open source while we figure out a long term plan
# for our integration testing.
load(
"@iree//integrations/tensorflow/e2e:iree_e2e_cartesian_product_test_suite.bzl",
"iree_e2e_cartesian_product_test_suite",
)
package(
default_visibility = ["//visibility:public"],
features = ["layering_check"],
licenses = ["notice"], # Apache 2.0
)
[
py_binary(
name = src.replace(".py", "_manual"),
srcs = [src],
main = src,
python_version = "PY3",
deps = [
"//third_party/py/absl:app",
"//third_party/py/absl/flags",
"//third_party/py/iree:pylib_tf_support",
"//third_party/py/numpy",
"//third_party/py/tensorflow",
"//util/debuginfo:signalsafe_addr2line_installer",
],
)
for src in glob(
["*_test.py"],
exclude = ["keyword_spotting_streaming_test.py"],
)
]
# Keyword Spotting Tests:
KEYWORD_SPOTTING_MODELS = [
"svdf",
"svdf_resnet",
"ds_cnn",
"gru",
"lstm",
"cnn_stride",
"cnn",
"tc_resnet",
"crnn",
"dnn",
"att_rnn",
"att_mh_rnn",
"mobilenet",
"mobilenet_v2",
"xception",
"inception",
"inception_resnet",
"ds_tc_resnet",
]
NON_STREAMING_KEYWORD_SPOTTING_MODELS = [
"att_mh_rnn",
"att_rnn",
"ds_cnn",
"inception",
"inception_resnet",
"mobilenet",
"mobilenet_v2",
"svdf_resnet",
"tc_resnet",
"xception",
]
iree_e2e_cartesian_product_test_suite(
name = "keyword_spotting_tests",
failing_configurations = [
{
"model": [
# unrolling True: "Unrolling requires a fixed number of timesteps."
# unrolling False: "error: 'tf.BatchMatMulV2' op : unlegalized TensorFlow op still exists"
"att_mh_rnn",
"att_rnn",
],
"target_backends": [
"iree_vulkan",
"iree_llvmaot",
],
},
],
matrix = {
"src": "keyword_spotting_streaming_test.py",
"reference_backend": "tf",
"mode": "non_streaming",
"model": KEYWORD_SPOTTING_MODELS,
"target_backends": [
"tf",
"tflite",
"iree_llvmaot",
"iree_vulkan",
],
},
deps = [
"//third_party/google_research/google_research/kws_streaming/models:models_lib",
"//third_party/google_research/google_research/kws_streaming/train:train_lib",
"//third_party/py/absl:app",
"//third_party/py/absl/flags",
"//third_party/py/iree:pylib_tf_support",
"//third_party/py/numpy",
"//third_party/py/tensorflow",
"//util/debuginfo:signalsafe_addr2line_installer",
],
)
iree_e2e_cartesian_product_test_suite(
name = "keyword_spotting_internal_streaming_tests",
failing_configurations = [
{
# TFLite cannot compile variables.
"target_backends": "tflite",
},
{
# These models do not currently support streaming.
"model": NON_STREAMING_KEYWORD_SPOTTING_MODELS,
},
{
"model": [
"crnn", # TODO(b/188221333): Get this test working.
],
"target_backends": "iree_vulkan",
},
],
matrix = {
"src": "keyword_spotting_streaming_test.py",
"reference_backend": "tf",
"mode": "internal_streaming",
"model": KEYWORD_SPOTTING_MODELS,
"target_backends": [
"tf",
"tflite",
"iree_llvmaot",
"iree_vulkan",
],
},
deps = [
"//third_party/google_research/google_research/kws_streaming/models:models_lib",
"//third_party/google_research/google_research/kws_streaming/train:train_lib",
"//third_party/py/absl:app",
"//third_party/py/absl/flags",
"//third_party/py/iree:pylib_tf_support",
"//third_party/py/numpy",
"//third_party/py/tensorflow",
"//util/debuginfo:signalsafe_addr2line_installer",
],
)
iree_e2e_cartesian_product_test_suite(
name = "keyword_spotting_external_streaming_tests",
failing_configurations = [
{
# A bug in keras causes the external steraming conversion to fail
# when TensorFlow 2.x is used.
"target_backends": [
"tf",
"tflite",
"iree_llvmaot",
"iree_vulkan",
],
},
{
# These models do not currently support streaming.
"model": NON_STREAMING_KEYWORD_SPOTTING_MODELS,
},
],
matrix = {
"src": "keyword_spotting_streaming_test.py",
"reference_backend": "tf",
"mode": "external_streaming",
"model": KEYWORD_SPOTTING_MODELS,
"target_backends": [
"tf",
"tflite",
"iree_llvmaot",
"iree_vulkan",
],
},
deps = [
"//third_party/google_research/google_research/kws_streaming/models:models_lib",
"//third_party/google_research/google_research/kws_streaming/train:train_lib",
"//third_party/py/absl:app",
"//third_party/py/absl/flags",
"//third_party/py/iree:pylib_tf_support",
"//third_party/py/numpy",
"//third_party/py/tensorflow",
"//util/debuginfo:signalsafe_addr2line_installer",
],
)
py_binary(
name = "keyword_spotting_streaming_test_manual",
srcs = ["keyword_spotting_streaming_test.py"],
main = "keyword_spotting_streaming_test.py",
python_version = "PY3",
deps = [
"//third_party/google_research/google_research/kws_streaming/models:models_lib",
"//third_party/google_research/google_research/kws_streaming/train:train_lib",
"//third_party/py/absl:app",
"//third_party/py/absl/flags",
"//third_party/py/iree:pylib_tf_support",
"//third_party/py/numpy",
"//third_party/py/tensorflow",
"//util/debuginfo:signalsafe_addr2line_installer",
],
)