Refactor e2e integrations BUILD into test suites.
Refactors `integrations/tensorflow/e2e/BUILD` into:
- `integrations/tensorflow/e2e/BUILD`
- `integrations/tensorflow/e2e/keras/BUILD`
Adds `iree_py_test_suite` and `iree_vision_test_suite` macros to simplify the BUILD files and create targetable suites, which are currently named:
- `integrations/tensorflow/e2e:e2e`
- `integrations/tensorflow/e2e/keras:non_vision`
- `integrations/tensorflow/e2e/keras:keras_vision_models`
- `integrations/tensorflow/e2e/keras:keras_vision_models_external`
Naming suggestions are definitely welcome here.
This will allow us to target the keras vision models in internal daily tests.
Closes https://github.com/google/iree/pull/2051
COPYBARA_INTEGRATE_REVIEW=https://github.com/google/iree/pull/2051 from phoenix-meadowlark:e2e-suite 6f54beaa5184b2b84563721f0fa6b206e2a84feb
PiperOrigin-RevId: 314401924
diff --git a/bindings/python/build_defs.oss.bzl b/bindings/python/build_defs.oss.bzl
index 7363c64..d8cf789 100644
--- a/bindings/python/build_defs.oss.bzl
+++ b/bindings/python/build_defs.oss.bzl
@@ -47,8 +47,7 @@
name,
copts = [],
features = [],
- deps = [
- ],
+ deps = [],
**kwargs):
"""Wrapper cc_library for deps that are part of the python bindings."""
cc_library(
diff --git a/build_tools/bazel/iree_py_test_suite.bzl b/build_tools/bazel/iree_py_test_suite.bzl
new file mode 100644
index 0000000..dab7055
--- /dev/null
+++ b/build_tools/bazel/iree_py_test_suite.bzl
@@ -0,0 +1,67 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Macro for building python test suites."""
+
+load("//bindings/python:build_defs.oss.bzl", "iree_py_test")
+
+def iree_py_test_suite(
+ name,
+ srcs,
+ deps = None,
+ tags = None,
+ size = None,
+ python_version = "PY3",
+ **kwargs):
+ """Creates one iree_py_test per source file and a test suite that bundles them.
+
+ Args:
+ name: name of the generated test suite.
+ srcs: test file sources.
+ deps: test dependencies.
+ tags: tags to apply to the test. Note that as in standard test suites,
+ manual is treated specially and will also apply to the test suite
+ itself.
+ size: size of the tests.
+ python_version: the python version to run the tests with. Uses python3
+ by default.
+ **kwargs: Any additional arguments that will be passed to the underlying
+ tests and test_suite.
+ """
+ tests = []
+ for src in srcs:
+ test_name = "{}_{}".format(name, src[:-3])
+ iree_py_test(
+ name = test_name,
+ main = src,
+ srcs = [src],
+ deps = deps,
+ tags = tags,
+ size = size,
+ python_version = python_version,
+ **kwargs
+ )
+ tests.append(test_name)
+
+ native.test_suite(
+ name = name,
+ tests = tests,
+ # Note that only the manual tag really has any effect here. Others are
+ # used for test suite filtering, but all tests are passed the same tags.
+ tags = tags,
+ # If there are kwargs that need to be passed here which only apply to
+ # the generated tests and not to test_suite, they should be extracted
+ # into separate named arguments.
+ **kwargs
+ )
diff --git a/integrations/tensorflow/e2e/BUILD b/integrations/tensorflow/e2e/BUILD
index 8885741..654e663 100644
--- a/integrations/tensorflow/e2e/BUILD
+++ b/integrations/tensorflow/e2e/BUILD
@@ -18,205 +18,34 @@
"NUMPY_DEPS",
"iree_py_test",
)
+load(
+ "//build_tools/bazel:iree_py_test_suite.bzl",
+ "iree_py_test_suite",
+)
package(
default_visibility = ["//visibility:public"],
licenses = ["notice"], # Apache 2.0
)
-[
- iree_py_test(
- name = name,
- srcs = [name + ".py"],
- python_version = "PY3",
- deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
- "//integrations/tensorflow/bindings/python/pyiree/tf/support",
- ],
- )
- for name in [
- "broadcasting_test",
- "batch_norm_test",
- "fill_test",
- "control_flow_test",
- "dynamic_mlp_test",
- "dynamic_mlp_relu_test",
- "depth_conv_test",
- "exported_names_test",
- "gather_test",
- "tensorlist_test",
- "keras_lstm_test",
- "mandelbrot_test",
- "matrix_ops_test",
- "resource_ops_test",
- "ring_buffer_test",
- "sliding_window_test",
- "simple_arithmetic_test",
- "simple_stateful_test",
- "strings_test",
- ]
-]
-
-[
- iree_py_test(
- name = "_".join([
- "keras_model_train",
- optimizer_name,
- backends,
- "test",
- ]),
- srcs = [
- "keras_model_train_test.py",
- ],
- args = [
- "--optimizer_name=%s" % optimizer_name,
- "--override_backends=%s" % backends,
- ],
- main = "keras_model_train_test.py",
- python_version = "PY3",
- tags = [
- "manual",
- "noga",
- "notap",
- ],
- deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
- "//integrations/tensorflow/bindings/python/pyiree/tf/support",
- ],
- )
- for optimizer_name, backends in [
- ("sgd", "tf"),
- ("sgd", "tf,iree_vmla"),
- ("adam", "tf,iree_vmla"), # TODO(b/157581521)
- ]
-]
-
-[
- iree_py_test(
- name = "_".join([
- "keras_vision_model",
- data,
- "top",
- str(include_top),
- model_name,
- backends,
- "test",
- ]),
- size = "large",
- srcs = [
- "keras_vision_model_test.py",
- ],
- args = [
- "--model=%s" % model_name,
- "--override_backends=%s" % backends,
- "--data=%s" % data,
- "--include_top=%d" % include_top,
- ],
- main = "keras_vision_model_test.py",
- python_version = "PY3",
- tags = [
- "manual",
- ],
- deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
- "//integrations/tensorflow/bindings/python/pyiree/tf/support",
- ],
- )
- # ResNet50 test is hermetic - does not access any extrnal urls
- # all other tests need real weights loaded from url
- for data, include_top, model_name, backends in [
- # "cifar10" has toy models with input 32x32, is good for debugging
- ("cifar10", 1, "ResNet50", "tf,iree_vmla"),
- ("cifar10", 1, "ResNet50", "tf,iree_llvmjit"),
- ("cifar10", 1, "ResNet50", "tf,iree_vulkan"),
- ]
-]
-
-# it requres access to external URL, so these tests will be run manually
-[
- iree_py_test(
- name = "_".join([
- "keras_vision_model",
- data,
- "top",
- str(include_top),
- model_name,
- backends,
- "test",
- ]),
- srcs = [
- "keras_vision_model_test.py",
- ],
- args = [
- "--model=%s" % model_name,
- "--override_backends=%s" % backends,
- "--data=%s" % data,
- "--include_top=%d" % include_top,
- "--url=https://storage.googleapis.com/iree_models/",
- ],
- main = "keras_vision_model_test.py",
- python_version = "PY3",
- tags = [
- "external",
- "large",
- "manual",
- "noga",
- "notap",
- ],
- deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
- "//integrations/tensorflow/bindings/python/pyiree/tf/support",
- ],
- )
- # "cifar10" has toy models with input 32x32, is good for debugging
- # "imagenet" has real model weights for input 224x224
- for data, include_top, model_name, backends in [
- ("cifar10", 1, "MobileNet", "tf,iree_vmla"),
- ("cifar10", 1, "MobileNet", "tf,iree_vulkan"), # TODO(b/150244105)
- ("cifar10", 1, "MobileNet", "tf,iree_llvmjit"), # TODO(b/150244105)
- ("cifar10", 1, "MobileNetV2", "tf,iree_vmla"),
- ("cifar10", 1, "MobileNetV2", "tf,iree_vulkan"), # TODO(b/150244105)
- ("cifar10", 1, "MobileNetV2", "tf,iree_llvmjit"), # TODO(b/150244105)
- ("imagenet", 1, "ResNet50", "tf,iree_vmla"),
- ("imagenet", 1, "ResNet50", "tf,iree_vulkan"),
- ("imagenet", 1, "ResNet50", "tf,iree_llvmjit"),
- ("imagenet", 1, "MobileNet", "tf,iree_vmla"),
- ("imagenet", 1, "MobileNet", "tf,iree_vulkan"), # TODO(b/150244105)
- ("imagenet", 1, "MobileNet", "tf,iree_llvmjit"), # TODO(b/150244105)
- ("imagenet", 1, "MobileNetV2", "tf,iree_vmla"),
- ("imagenet", 1, "MobileNetV2", "tf,iree_vulkan"), # TODO(b/150244105)
- ("imagenet", 1, "MobileNetV2", "tf,iree_llvmjit"), # TODO(b/150244105)
- ]
-]
-
-[
- iree_py_test(
- name = name,
- srcs = [name + ".py"],
- python_version = "PY3",
- # TODO(b/145815906) Get this running in OSS CI.
- tags = ["noga"],
- deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
- "//integrations/tensorflow/bindings/python/pyiree/tf/support",
- ],
- )
- for name in [
- "conv_test",
- "linspace_test",
- "math_test",
- # TODO(GH-1620): Re-enable this after fixing the failure on
- # GitHub Actions.
- "keras_lstm_static_test",
- ]
-]
-
-# It is used to produce weights for keras vision models with input image size 32x32.
-# These models are not optimized for accuracy or latency (they are for debugging only).
-# They have the same neural net topology
-# with keras vision models trained on imagenet data sets
-py_binary(
- name = "train_vision_models_on_cifar",
- srcs = [
- "train_vision_models_on_cifar.py",
- ],
+iree_py_test(
+ # TODO(GH-2082): `linspace_test.py` fails due to an unknown location error.
+ name = "linspace_test",
+ srcs = ["linspace_test.py"],
+ main = "linspace_test.py",
python_version = "PY3",
- srcs_version = "PY2AND3",
+ tags = ["noga"],
+ deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
+ "//integrations/tensorflow/bindings/python/pyiree/tf/support",
+ ],
+)
+
+iree_py_test_suite(
+ name = "e2e",
+ srcs = glob(
+ ["*_test.py"],
+ exclude = ["linspace_test.py"],
+ ),
deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
"//integrations/tensorflow/bindings/python/pyiree/tf/support",
],
diff --git a/integrations/tensorflow/e2e/keras/BUILD b/integrations/tensorflow/e2e/keras/BUILD
new file mode 100644
index 0000000..417c438
--- /dev/null
+++ b/integrations/tensorflow/e2e/keras/BUILD
@@ -0,0 +1,145 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+load(
+ "//bindings/python:build_defs.oss.bzl",
+ "INTREE_TENSORFLOW_PY_DEPS",
+ "NUMPY_DEPS",
+ "iree_py_test",
+)
+load(
+ "//build_tools/bazel:iree_py_test_suite.bzl",
+ "iree_py_test_suite",
+)
+load(
+ "//integrations/tensorflow/e2e/keras:iree_vision_test_suite.bzl",
+ "iree_vision_test_suite",
+)
+
+package(
+ default_visibility = ["//visibility:public"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+iree_py_test(
+ # TODO(GH-1620): Include after fixing the failure on GitHub Actions.
+ name = "keras_lstm_static_test",
+ srcs = ["keras_lstm_static_test.py"],
+ main = "keras_lstm_static_test.py",
+ python_version = "PY3",
+ tags = ["noga"],
+ deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
+ "//integrations/tensorflow/bindings/python/pyiree/tf/support",
+ ],
+)
+
+iree_py_test_suite(
+ name = "non_vision",
+ srcs = glob(
+ ["*_test.py"],
+ exclude = [
+ "keras_vision_model_test.py",
+ "keras_lstm_static_test.py",
+ ],
+ ),
+ deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
+ "//integrations/tensorflow/bindings/python/pyiree/tf/support",
+ ],
+)
+
+iree_vision_test_suite(
+ name = "keras_vision_models",
+ configurations = [
+ # tuples of (dataset, include_top, model_name, backends)
+ # "cifar10" has toy models with input 32x32, is good for debugging
+ ("cifar10", 1, "ResNet50", "tf,iree_vmla"),
+ ("cifar10", 1, "ResNet50", "tf,iree_llvmjit"),
+ ("cifar10", 1, "ResNet50", "tf,iree_vulkan"),
+ ],
+ tags = ["manual"],
+ deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
+ "//integrations/tensorflow/bindings/python/pyiree/tf/support",
+ ],
+)
+
+iree_vision_test_suite(
+ name = "keras_vision_models_external",
+ configurations = [
+ # tuples of (dataset, include_top, model_name, backends)
+ # "cifar10" has toy models with 32x32 input which are good for debugging
+ # "imagenet" has real model weights for input 224x224
+ ("cifar10", 1, "MobileNet", "tf,iree_vmla"),
+ ("cifar10", 1, "MobileNetV2", "tf,iree_vmla"),
+ ("imagenet", 1, "ResNet50", "tf,iree_vmla"),
+ ("imagenet", 1, "ResNet50", "tf,iree_vulkan"),
+ ("imagenet", 1, "ResNet50", "tf,iree_llvmjit"),
+ ("imagenet", 1, "MobileNet", "tf,iree_vmla"),
+ ("imagenet", 1, "MobileNetV2", "tf,iree_vmla"),
+ ],
+ external_weights = "https://storage.googleapis.com/iree_models/",
+ tags = [
+ "external",
+ "manual",
+ "noga",
+ "notap",
+ ],
+ deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
+ "//integrations/tensorflow/bindings/python/pyiree/tf/support",
+ ],
+)
+
+iree_vision_test_suite(
+ name = "keras_vision_models_external_failing",
+ configurations = [
+ # tuples of (dataset, include_top, model_name, backends)
+ # "cifar10" has toy models with 32x32 input which are good for debugging
+ # "imagenet" has real model weights for input 224x224
+ # TODO(b/150244105): Compiling fails with commands targeting IREE
+ # interpreter and vulkan backends for these tests.
+ # TODO: Combine this suite with keras_vision_models_external once these
+ # tests pass.
+ ("cifar10", 1, "MobileNet", "tf,iree_vulkan"),
+ ("cifar10", 1, "MobileNet", "tf,iree_llvmjit"),
+ ("cifar10", 1, "MobileNetV2", "tf,iree_vulkan"),
+ ("cifar10", 1, "MobileNetV2", "tf,iree_llvmjit"),
+ ("imagenet", 1, "MobileNet", "tf,iree_vulkan"),
+ ("imagenet", 1, "MobileNet", "tf,iree_llvmjit"),
+ ("imagenet", 1, "MobileNetV2", "tf,iree_vulkan"),
+ ("imagenet", 1, "MobileNetV2", "tf,iree_llvmjit"),
+ ],
+ external_weights = "https://storage.googleapis.com/iree_models/",
+ tags = [
+ "external",
+ "manual",
+ "noga",
+ "notap",
+ ],
+ deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
+ "//integrations/tensorflow/bindings/python/pyiree/tf/support",
+ ],
+)
+
+# It is used to produce weights for keras vision models with input image size
+# 32x32. These models are not optimized for accuracy or latency (they are for
+# debugging only). They have the same neural net topology with keras vision
+# models trained on imagenet data sets
+py_binary(
+ name = "train_vision_models_on_cifar",
+ srcs = ["train_vision_models_on_cifar.py"],
+ python_version = "PY3",
+ srcs_version = "PY2AND3",
+ deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
+ "//integrations/tensorflow/bindings/python/pyiree/tf/support",
+ ],
+)
diff --git a/integrations/tensorflow/e2e/keras/iree_vision_test_suite.bzl b/integrations/tensorflow/e2e/keras/iree_vision_test_suite.bzl
new file mode 100644
index 0000000..893bda7
--- /dev/null
+++ b/integrations/tensorflow/e2e/keras/iree_vision_test_suite.bzl
@@ -0,0 +1,87 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Macro for building e2e keras vision model tests."""
+
+load("//bindings/python:build_defs.oss.bzl", "iree_py_test")
+
+def iree_vision_test_suite(
+ name,
+ configurations,
+ external_weights = None,
+ deps = None,
+ tags = None,
+ size = "large",
+ python_version = "PY3",
+ **kwargs):
+ """Creates one iree_py_test per configuration tuple and a test suite that bundles them.
+
+ Args:
+ name: name of the generated test suite.
+ configurations: a list of tuples of (dataset, include_top, model,
+ backends) that specifies which data, model and backends to
+ use for a given test.
+ external_weights: a base url to fetch trained model weights from.
+ tags: tags to apply to the test. Note that as in standard test suites,
+ manual is treated specially and will also apply to the test suite
+ itself.
+ size: size of the tests.
+ python_version: the python version to run the tests with. Uses python3
+ by default.
+ **kwargs: Any additional arguments that will be passed to the underlying
+ tests and test_suite.
+ """
+ tests = []
+ for dataset, include_top, model, backends in configurations:
+ test_name = "{}_{}_top_{}_{}_{}_test".format(
+ name,
+ dataset,
+ include_top,
+ model,
+ backends,
+ )
+ tests.append(test_name)
+
+ args = [
+ "--data={}".format(dataset),
+ "--include_top={}".format(include_top),
+ "--model={}".format(model),
+ "--override_backends={}".format(backends),
+ ]
+ if external_weights:
+ args.append("--url={}".format(external_weights))
+
+ iree_py_test(
+ name = test_name,
+ main = "keras_vision_model_test.py",
+ srcs = ["keras_vision_model_test.py"],
+ args = args,
+ tags = tags,
+ deps = deps,
+ size = size,
+ python_version = python_version,
+ **kwargs
+ )
+
+ native.test_suite(
+ name = name,
+ tests = tests,
+ # Note that only the manual tag really has any effect here. Others are
+ # used for test suite filtering, but all tests are passed the same tags.
+ tags = tags,
+ # If there are kwargs that need to be passed here which only apply to
+ # the generated tests and not to test_suite, they should be extracted
+ # into separate named arguments.
+ **kwargs
+ )
diff --git a/integrations/tensorflow/e2e/keras_lstm_static_test.py b/integrations/tensorflow/e2e/keras/keras_lstm_static_test.py
similarity index 100%
rename from integrations/tensorflow/e2e/keras_lstm_static_test.py
rename to integrations/tensorflow/e2e/keras/keras_lstm_static_test.py
diff --git a/integrations/tensorflow/e2e/keras_lstm_test.py b/integrations/tensorflow/e2e/keras/keras_lstm_test.py
similarity index 100%
rename from integrations/tensorflow/e2e/keras_lstm_test.py
rename to integrations/tensorflow/e2e/keras/keras_lstm_test.py
diff --git a/integrations/tensorflow/e2e/keras_vision_model_test.py b/integrations/tensorflow/e2e/keras/keras_vision_model_test.py
similarity index 100%
rename from integrations/tensorflow/e2e/keras_vision_model_test.py
rename to integrations/tensorflow/e2e/keras/keras_vision_model_test.py
diff --git a/integrations/tensorflow/e2e/keras/train/BUILD b/integrations/tensorflow/e2e/keras/train/BUILD
new file mode 100644
index 0000000..ac428c7
--- /dev/null
+++ b/integrations/tensorflow/e2e/keras/train/BUILD
@@ -0,0 +1,62 @@
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+load(
+ "//bindings/python:build_defs.oss.bzl",
+ "INTREE_TENSORFLOW_PY_DEPS",
+ "NUMPY_DEPS",
+)
+load(
+ "//integrations/tensorflow/e2e/keras/train:iree_train_test_suite.bzl",
+ "iree_train_test_suite",
+)
+
+package(
+ default_visibility = ["//visibility:public"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+iree_train_test_suite(
+ name = "keras_model_train",
+ configurations = [
+ # tuples of (optimizer, backends)
+ ("sgd", "tf"),
+ ],
+ tags = [
+ "manual",
+ "noga",
+ "notap",
+ ],
+ deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
+ "//integrations/tensorflow/bindings/python/pyiree/tf/support",
+ ],
+)
+
+iree_train_test_suite(
+ name = "keras_model_train_failing",
+ configurations = [
+ # tuples of (optimizer, backends)
+ # TODO: Combine this suite with keras_model_train once these tests pass.
+ ("sgd", "tf,iree_vmla"),
+ ("adam", "tf,iree_vmla"), # TODO(b/157581521)
+ ],
+ tags = [
+ "manual",
+ "noga",
+ "notap",
+ ],
+ deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
+ "//integrations/tensorflow/bindings/python/pyiree/tf/support",
+ ],
+)
diff --git a/integrations/tensorflow/e2e/keras/train/iree_train_test_suite.bzl b/integrations/tensorflow/e2e/keras/train/iree_train_test_suite.bzl
new file mode 100644
index 0000000..1150db6
--- /dev/null
+++ b/integrations/tensorflow/e2e/keras/train/iree_train_test_suite.bzl
@@ -0,0 +1,73 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Macro for building e2e keras vision model tests."""
+
+load("//bindings/python:build_defs.oss.bzl", "iree_py_test")
+
+def iree_train_test_suite(
+ name,
+ configurations,
+ deps = None,
+ tags = None,
+ size = None,
+ python_version = "PY3",
+ **kwargs):
+ """Creates one iree_py_test per configuration tuple and a test suite that bundles them.
+
+ Args:
+ name: name of the generated test suite.
+ configurations: a list of tuples of (optimizer, backends).
+ tags: tags to apply to the test. Note that as in standard test suites,
+ manual is treated specially and will also apply to the test suite
+ itself.
+ size: size of the tests.
+ python_version: the python version to run the tests with. Uses python3
+ by default.
+ **kwargs: Any additional arguments that will be passed to the underlying
+ tests and test_suite.
+ """
+ tests = []
+ for optimizer, backends in configurations:
+ test_name = "{}_{}_{}_test".format(name, optimizer, backends)
+ tests.append(test_name)
+
+ args = [
+ "--optimizer_name={}".format(optimizer),
+ "--override_backends={}".format(backends),
+ ]
+
+ iree_py_test(
+ name = test_name,
+ main = "keras_model_train_test.py",
+ srcs = ["keras_model_train_test.py"],
+ args = args,
+ tags = tags,
+ deps = deps,
+ size = size,
+ python_version = python_version,
+ **kwargs
+ )
+
+ native.test_suite(
+ name = name,
+ tests = tests,
+ # Note that only the manual tag really has any effect here. Others are
+ # used for test suite filtering, but all tests are passed the same tags.
+ tags = tags,
+ # If there are kwargs that need to be passed here which only apply to
+ # the generated tests and not to test_suite, they should be extracted
+ # into separate named arguments.
+ **kwargs
+ )
diff --git a/integrations/tensorflow/e2e/keras_model_train_test.py b/integrations/tensorflow/e2e/keras/train/keras_model_train_test.py
similarity index 100%
rename from integrations/tensorflow/e2e/keras_model_train_test.py
rename to integrations/tensorflow/e2e/keras/train/keras_model_train_test.py
diff --git a/integrations/tensorflow/e2e/train_vision_models_on_cifar.py b/integrations/tensorflow/e2e/keras/train_vision_models_on_cifar.py
similarity index 100%
rename from integrations/tensorflow/e2e/train_vision_models_on_cifar.py
rename to integrations/tensorflow/e2e/keras/train_vision_models_on_cifar.py