Enable SGD training tests (#4474)
- Explicitly convert inputs to `f32`.
- Compare model weights and biases across backends.
- Add CMake configuration for testing in OSS.
I am adding these tests in OSS before we figure out the long-term structure
for the e2e CMake build since they cannot be enabled internally at this time.
diff --git a/build_tools/kokoro/gcp_ubuntu/cmake-bazel/linux/x86-swiftshader/build.sh b/build_tools/kokoro/gcp_ubuntu/cmake-bazel/linux/x86-swiftshader/build.sh
index 4c6799b..e71c31f 100755
--- a/build_tools/kokoro/gcp_ubuntu/cmake-bazel/linux/x86-swiftshader/build.sh
+++ b/build_tools/kokoro/gcp_ubuntu/cmake-bazel/linux/x86-swiftshader/build.sh
@@ -68,4 +68,4 @@
ninja
echo "Testing with CTest"
-ctest -R 'tensorflow_e2e|bindings/python'
+ctest -R 'tensorflow_e2e|bindings/python|integrations/tensorflow/'
diff --git a/build_tools/kokoro/gcp_ubuntu/cmake-bazel/linux/x86-turing/build.sh b/build_tools/kokoro/gcp_ubuntu/cmake-bazel/linux/x86-turing/build.sh
index 7201905..cfe3e06 100755
--- a/build_tools/kokoro/gcp_ubuntu/cmake-bazel/linux/x86-turing/build.sh
+++ b/build_tools/kokoro/gcp_ubuntu/cmake-bazel/linux/x86-turing/build.sh
@@ -68,4 +68,4 @@
ninja
echo "Testing with CTest"
-ctest -R 'tensorflow_e2e|bindings/python'
+ctest -R 'tensorflow_e2e|bindings/python|integrations/tensorflow/'
diff --git a/integrations/tensorflow/e2e/CMakeLists.txt b/integrations/tensorflow/e2e/CMakeLists.txt
index 7587332..3790193 100644
--- a/integrations/tensorflow/e2e/CMakeLists.txt
+++ b/integrations/tensorflow/e2e/CMakeLists.txt
@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+add_subdirectory(keras)
+
# Special cases to exclude from automatically expanding targets for all
# backends.
set(SPECIAL_CASES
diff --git a/integrations/tensorflow/e2e/keras/CMakeLists.txt b/integrations/tensorflow/e2e/keras/CMakeLists.txt
new file mode 100644
index 0000000..4be3e81
--- /dev/null
+++ b/integrations/tensorflow/e2e/keras/CMakeLists.txt
@@ -0,0 +1,15 @@
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+add_subdirectory(train)
diff --git a/integrations/tensorflow/e2e/keras/train/CMakeLists.txt b/integrations/tensorflow/e2e/keras/train/CMakeLists.txt
new file mode 100644
index 0000000..c4b9bb6
--- /dev/null
+++ b/integrations/tensorflow/e2e/keras/train/CMakeLists.txt
@@ -0,0 +1,47 @@
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+set(REFERENCE_BACKEND tf)
+
+function(add_e2e_test filename target_backend labels)
+ iree_package_ns(_PACKAGE_NS)
+ string(REPLACE "::" "/" _PACKAGE_PATH ${_PACKAGE_NS})
+ string(REPLACE ".py" "" _name ${filename})
+
+ set(_name "${_PACKAGE_PATH}/${_name}__${target_backend}")
+ add_test(
+ NAME
+ ${_name}
+ WORKING_DIRECTORY
+ "${CMAKE_CURRENT_BINARY_DIR}"
+ COMMAND
+ "${Python3_EXECUTABLE}" -B
+ "${CMAKE_CURRENT_SOURCE_DIR}/${filename}"
+ "--reference_backend=${REFERENCE_BACKEND}"
+ "--target_backends=${target_backend}"
+ )
+ set_property(TEST ${_name} PROPERTY LABELS "${labels}")
+ set_property(TEST ${_name} PROPERTY ENVIRONMENT
+ "PYTHONPATH=${CMAKE_BINARY_DIR}/bindings/python")
+endfunction()
+
+add_e2e_test(regression_training_test.py tf "")
+add_e2e_test(regression_training_test.py iree_vmla "")
+add_e2e_test(regression_training_test.py iree_llvmaot "driver=dylib")
+add_e2e_test(regression_training_test.py iree_vulkan "driver=vulkan")
+
+add_e2e_test(classification_training_test.py tf "")
+add_e2e_test(classification_training_test.py iree_vmla "")
+add_e2e_test(classification_training_test.py iree_llvmaot "driver=dylib")
+add_e2e_test(classification_training_test.py iree_vulkan "driver=vulkan")
diff --git a/integrations/tensorflow/e2e/keras/train/classification_training_test.py b/integrations/tensorflow/e2e/keras/train/classification_training_test.py
index a73205b..2565906 100644
--- a/integrations/tensorflow/e2e/keras/train/classification_training_test.py
+++ b/integrations/tensorflow/e2e/keras/train/classification_training_test.py
@@ -18,10 +18,8 @@
from absl import app
from absl import flags
-from absl import logging
import numpy as np
from pyiree.tf.support import tf_test_utils
-from pyiree.tf.support import tf_utils
import tensorflow as tf
FLAGS = flags.FLAGS
@@ -60,7 +58,7 @@
inputs = inputs[index]
labels = labels[index]
- return inputs, labels
+ return inputs.astype(np.float32), labels.astype(np.float32)
class ClassificationTrainingModule(tf.Module):
@@ -88,6 +86,14 @@
self.optimizer.apply_gradients(zip(gradients, variables))
return loss
+ @tf.function(input_signature=[])
+ def get_weights(self):
+ return self.model.weights[0]
+
+ @tf.function(input_signature=[])
+ def get_bias(self):
+ return self.model.weights[1]
+
class ClassificationTrainingTest(tf_test_utils.TracedModuleTestCase):
@@ -95,7 +101,7 @@
super().__init__(*args, **kwargs)
self._modules = tf_test_utils.compile_tf_module(
ClassificationTrainingModule,
- exported_names=["train_on_batch"],
+ exported_names=["train_on_batch", "get_weights", "get_bias"],
relative_artifacts_dir=os.path.join(
ClassificationTrainingModule.__name__, FLAGS.optimizer))
@@ -104,6 +110,9 @@
def train_on_batch(module):
inputs, labels = get_spiral_dataset(SAMPLES_PER_SPIRAL, noise_scale=0.05)
module.train_on_batch(inputs, labels)
+ # Ensures the weights are identical.
+ module.get_weights()
+ module.get_bias()
self.compare_backends(train_on_batch, self._modules)
diff --git a/integrations/tensorflow/e2e/keras/train/regression_training_test.py b/integrations/tensorflow/e2e/keras/train/regression_training_test.py
index 0f9aa72..9d9ac79 100644
--- a/integrations/tensorflow/e2e/keras/train/regression_training_test.py
+++ b/integrations/tensorflow/e2e/keras/train/regression_training_test.py
@@ -20,7 +20,6 @@
from absl import flags
import numpy as np
from pyiree.tf.support import tf_test_utils
-from pyiree.tf.support import tf_utils
import tensorflow as tf
FLAGS = flags.FLAGS
@@ -40,7 +39,7 @@
def get_linear_data():
x = np.random.uniform(-1, 1, size=(BATCH_SIZE, INPUT_DIM))
y = np.dot(x, WEIGHTS) + BIASES
- return x, y
+ return x.astype(np.float32), y.astype(np.float32)
class RegressionTrainingModule(tf.Module):
@@ -67,6 +66,14 @@
self.optimizer.apply_gradients(zip(gradients, variables))
return loss
+ @tf.function(input_signature=[])
+ def get_weights(self):
+ return self.model.weights[0]
+
+ @tf.function(input_signature=[])
+ def get_bias(self):
+ return self.model.weights[1]
+
class RegressionTrainingTest(tf_test_utils.TracedModuleTestCase):
@@ -74,7 +81,7 @@
super().__init__(*args, **kwargs)
self._modules = tf_test_utils.compile_tf_module(
RegressionTrainingModule,
- exported_names=["train_on_batch"],
+ exported_names=["train_on_batch", "get_weights", "get_bias"],
relative_artifacts_dir=os.path.join(RegressionTrainingModule.__name__,
FLAGS.optimizer))
@@ -83,6 +90,9 @@
def train_on_batch(module):
x, y = get_linear_data()
module.train_on_batch(x, y)
+ # Ensures the weights are identical.
+ module.get_weights()
+ module.get_bias()
self.compare_backends(train_on_batch, self._modules)