Test tf vs tf for debugging model consistency execution
PiperOrigin-RevId: 307728611
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
index 0459983..eb0c72e 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
@@ -29,6 +29,7 @@
import numpy as np
from pyiree import rt
from pyiree.tf import compiler
+import random
import tensorflow.compat.v2 as tf
flags.DEFINE_string(
@@ -47,6 +48,13 @@
global_debug_dir = None
+def set_random_seed(seed=0):
+ """Set random seed for tf, np and random."""
+ tf.random.set_seed(seed)
+ random.seed(seed)
+ np.random.seed(seed)
+
+
def save_and_compile_tf_module(tf_module, exported_names=(),
target_backends=()):
"""Saves and compiles a TensorFlow tf.Module.
@@ -462,6 +470,13 @@
CompiledModule=TfCompiledModule,
iree_driver=None,
iree_compiler_targets=None)
+# tf_also is used for checking test consistency
+# to catch any initialization/randomization issues between model runs
+BackendInfo.add(
+ name="tf_also",
+ CompiledModule=TfCompiledModule,
+ iree_driver=None,
+ iree_compiler_targets=None)
BackendInfo.add(
name="iree_vmla",
CompiledModule=IreeCompiledModule,
@@ -583,6 +598,10 @@
elif backends is None:
backends = list(BackendInfo.ALL.keys())
backends = [_resolve(backend) for backend in backends]
+ # if "tf" is specified as a only backend then
+ # we will test it always against "tf" by adding "tf_also".
+ if len(backends) == 1 and "tf" == backends[0].name:
+ backends.append(BackendInfo.ALL["tf_also"])
available_backends = get_available_backends()
backends = [
backend for backend in backends if backend in available_backends
diff --git a/integrations/tensorflow/e2e/README.md b/integrations/tensorflow/e2e/README.md
index 12b5a45..c611048 100644
--- a/integrations/tensorflow/e2e/README.md
+++ b/integrations/tensorflow/e2e/README.md
@@ -38,6 +38,15 @@
--test_output=errors
```
+If you specify the same backend multiple times, for example
+--override_backends=iree_vmla,iree_vmla. The same backends are grouped and in
+this example iree_vmla will run once. If you specify tf,iree_vmla as backends,
+then we will test both backends and compare them with each other. If you specify
+tf backend only, then we will also test tf vs tf to capture any model
+initialization/randomization issues (it is a special case for debug purpose).
+For reproducibility of the unit tests we set random seed of tf and numpy by
+calling tf_test_utils.set_random_seed() before model creation.
+
## Debugging tests
If the compiler fails to compile the program, then it will create a crash
diff --git a/integrations/tensorflow/e2e/dynamic_mlp_relu_test.py b/integrations/tensorflow/e2e/dynamic_mlp_relu_test.py
index e7660b0..ec870c1 100644
--- a/integrations/tensorflow/e2e/dynamic_mlp_relu_test.py
+++ b/integrations/tensorflow/e2e/dynamic_mlp_relu_test.py
@@ -35,6 +35,7 @@
input_dim=28 * 28,
classes=10):
super().__init__()
+ tf_test_utils.set_random_seed()
self.hidden_1_dim = hidden_1_dim
self.hidden_2_dim = hidden_2_dim
self.input_dim = input_dim
diff --git a/integrations/tensorflow/e2e/dynamic_mlp_test.py b/integrations/tensorflow/e2e/dynamic_mlp_test.py
index 461ceed..d2dc527 100644
--- a/integrations/tensorflow/e2e/dynamic_mlp_test.py
+++ b/integrations/tensorflow/e2e/dynamic_mlp_test.py
@@ -31,6 +31,7 @@
input_dim=28 * 28,
classes=10):
super().__init__()
+ tf_test_utils.set_random_seed()
self.hidden_1_dim = hidden_1_dim
self.hidden_2_dim = hidden_2_dim
self.input_dim = input_dim
diff --git a/integrations/tensorflow/e2e/keras_lstm_static_test.py b/integrations/tensorflow/e2e/keras_lstm_static_test.py
index fb894f5..c62629a 100644
--- a/integrations/tensorflow/e2e/keras_lstm_static_test.py
+++ b/integrations/tensorflow/e2e/keras_lstm_static_test.py
@@ -16,36 +16,43 @@
# This test is the same as keras_lstm_test, but all shapes are static.
# This stresses the TensorList lowering more specifically.
+import numpy as np
from pyiree.tf.support import tf_test_utils
import tensorflow.compat.v2 as tf
NUM_UNITS = 10
NUM_TIMESTEPS = 24
NUM_BATCH = 7
+INPUT_SHAPE = [NUM_BATCH, NUM_TIMESTEPS, NUM_UNITS]
-class Lstm(tf.Module):
-
- def __init__(self):
- super(Lstm, self).__init__()
- self.lstm = tf.keras.layers.LSTM(units=NUM_UNITS, return_sequences=True)
-
- @tf.function(input_signature=[
- tf.TensorSpec([NUM_BATCH, NUM_TIMESTEPS, NUM_UNITS], tf.float32)
- ])
- def predict(self, x):
- return self.lstm(x)
+def lstm_module():
+ tf_test_utils.set_random_seed()
+ inputs = tf.keras.layers.Input(batch_size=NUM_BATCH, shape=INPUT_SHAPE[1:])
+ outputs = tf.keras.layers.LSTM(units=NUM_UNITS, return_sequences=True)(inputs)
+ model = tf.keras.Model(inputs, outputs)
+ module = tf.Module()
+ module.m = model
+ module.predict = tf.function(
+ input_signature=[tf.TensorSpec(INPUT_SHAPE, tf.float32)])(
+ model.call)
+ return module
# TODO(silvasean): Get this test working on other backends.
@tf_test_utils.compile_modules(
- backends=["tf", "iree_vmla"], lstm=(Lstm, ["predict"]))
+ backends=["tf", "iree_vmla"], lstm=(lstm_module, ["predict"]))
class LstmTest(tf_test_utils.SavedModelTestCase):
def test_lstm(self):
m = self.modules.lstm.all
- m.predict(tf.constant(0., shape=[NUM_BATCH, NUM_TIMESTEPS,
- NUM_UNITS])).print().assert_all_close()
+ m.predict(
+ tf.constant(
+ np.arange(NUM_BATCH * NUM_TIMESTEPS * NUM_UNITS,
+ dtype=np.float32).reshape(
+ [NUM_BATCH, NUM_TIMESTEPS, NUM_UNITS]),
+ shape=[NUM_BATCH, NUM_TIMESTEPS,
+ NUM_UNITS])).print().assert_all_close()
if __name__ == "__main__":
diff --git a/integrations/tensorflow/e2e/keras_lstm_test.py b/integrations/tensorflow/e2e/keras_lstm_test.py
index 3ae1d1f..475083f 100644
--- a/integrations/tensorflow/e2e/keras_lstm_test.py
+++ b/integrations/tensorflow/e2e/keras_lstm_test.py
@@ -13,35 +13,43 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import numpy as np
from pyiree.tf.support import tf_test_utils
import tensorflow.compat.v2 as tf
NUM_UNITS = 10
NUM_TIMESTEPS = 24
NUM_BATCH = 7
+INPUT_SHAPE = [None, None, NUM_UNITS]
-class Lstm(tf.Module):
-
- def __init__(self):
- super(Lstm, self).__init__()
- self.lstm = tf.keras.layers.LSTM(units=NUM_UNITS, return_sequences=True)
-
- @tf.function(
- input_signature=[tf.TensorSpec([None, None, NUM_UNITS], tf.float32)])
- def predict(self, x):
- return self.lstm(x)
+def lstm_module():
+ tf_test_utils.set_random_seed()
+ inputs = tf.keras.layers.Input(batch_size=None, shape=INPUT_SHAPE[1:])
+ outputs = tf.keras.layers.LSTM(units=NUM_UNITS, return_sequences=True)(inputs)
+ model = tf.keras.Model(inputs, outputs)
+ module = tf.Module()
+ module.m = model
+ module.predict = tf.function(
+ input_signature=[tf.TensorSpec(INPUT_SHAPE, tf.float32)])(
+ model.call)
+ return module
# TODO(silvasean): Get this test working on IREE.
# Needs TensorList with current Keras implementation.
-@tf_test_utils.compile_modules(backends=["tf"], lstm=(Lstm, ["predict"]))
+@tf_test_utils.compile_modules(backends=["tf"], lstm=(lstm_module, ["predict"]))
class LstmTest(tf_test_utils.SavedModelTestCase):
def test_lstm(self):
m = self.modules.lstm.all
- m.predict(tf.constant(0., shape=[NUM_BATCH, NUM_TIMESTEPS,
- NUM_UNITS])).print().assert_all_close()
+ m.predict(
+ tf.constant(
+ np.arange(NUM_BATCH * NUM_TIMESTEPS * NUM_UNITS,
+ dtype=np.float32).reshape(
+ [NUM_BATCH, NUM_TIMESTEPS, NUM_UNITS]),
+ shape=[NUM_BATCH, NUM_TIMESTEPS,
+ NUM_UNITS])).print().assert_all_close()
if __name__ == "__main__":
diff --git a/integrations/tensorflow/e2e/keras_vision_model_test.py b/integrations/tensorflow/e2e/keras_vision_model_test.py
index 02dd7af..dfa6b3b 100644
--- a/integrations/tensorflow/e2e/keras_vision_model_test.py
+++ b/integrations/tensorflow/e2e/keras_vision_model_test.py
@@ -73,7 +73,7 @@
tf.keras.backend.set_learning_phase(False)
# TODO(ataei): This should move somewhere in SavedModelTestCase, it should
# guarantee test is deterministic.
- tf.random.set_seed(0)
+ tf_test_utils.set_random_seed()
# keras model receives images size as input,
# where batch size is not specified - by default it is dynamic
diff --git a/integrations/tensorflow/e2e/strings_test.py b/integrations/tensorflow/e2e/strings_test.py
index e54e5b4..3d818dd 100644
--- a/integrations/tensorflow/e2e/strings_test.py
+++ b/integrations/tensorflow/e2e/strings_test.py
@@ -54,7 +54,7 @@
[[12, 10, 29, 28, 94, 15, 24, 27, 94, 25, 21, 10, 34],
[13, 24, 16, 28, 94, 15, 24, 27, 94, 28, 29, 10, 34]])
result = self.modules.strings.all.strings_to_ids(input_ids)
- result.assert_all_close()
+ result.assert_all_equal()
if __name__ == "__main__":