Test tf vs tf for debugging model consistency execution

PiperOrigin-RevId: 307728611
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
index 0459983..eb0c72e 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
@@ -29,6 +29,7 @@
 import numpy as np
 from pyiree import rt
 from pyiree.tf import compiler
+import random
 import tensorflow.compat.v2 as tf
 
 flags.DEFINE_string(
@@ -47,6 +48,13 @@
 global_debug_dir = None
 
 
+def set_random_seed(seed=0):
+  """Set random seed for tf, np and random."""
+  tf.random.set_seed(seed)
+  random.seed(seed)
+  np.random.seed(seed)
+
+
 def save_and_compile_tf_module(tf_module, exported_names=(),
                                target_backends=()):
   """Saves and compiles a TensorFlow tf.Module.
@@ -462,6 +470,13 @@
     CompiledModule=TfCompiledModule,
     iree_driver=None,
     iree_compiler_targets=None)
+# tf_also is used for checking test consistency
+# to catch any initialization/randomization issues between model runs
+BackendInfo.add(
+    name="tf_also",
+    CompiledModule=TfCompiledModule,
+    iree_driver=None,
+    iree_compiler_targets=None)
 BackendInfo.add(
     name="iree_vmla",
     CompiledModule=IreeCompiledModule,
@@ -583,6 +598,10 @@
           elif backends is None:
             backends = list(BackendInfo.ALL.keys())
           backends = [_resolve(backend) for backend in backends]
+          # if "tf" is specified as a only backend then
+          # we will test it always against "tf" by adding "tf_also".
+          if len(backends) == 1 and "tf" == backends[0].name:
+            backends.append(BackendInfo.ALL["tf_also"])
           available_backends = get_available_backends()
           backends = [
               backend for backend in backends if backend in available_backends
diff --git a/integrations/tensorflow/e2e/README.md b/integrations/tensorflow/e2e/README.md
index 12b5a45..c611048 100644
--- a/integrations/tensorflow/e2e/README.md
+++ b/integrations/tensorflow/e2e/README.md
@@ -38,6 +38,15 @@
     --test_output=errors
 ```
 
+If you specify the same backend multiple times, for example
+--override_backends=iree_vmla,iree_vmla. The same backends are grouped and in
+this example iree_vmla will run once. If you specify tf,iree_vmla as backends,
+then we will test both backends and compare them with each other. If you specify
+tf backend only, then we will also test tf vs tf to capture any model
+initialization/randomization issues (it is a special case for debug purpose).
+For reproducibility of the unit tests we set random seed of tf and numpy by
+calling tf_test_utils.set_random_seed() before model creation.
+
 ## Debugging tests
 
 If the compiler fails to compile the program, then it will create a crash
diff --git a/integrations/tensorflow/e2e/dynamic_mlp_relu_test.py b/integrations/tensorflow/e2e/dynamic_mlp_relu_test.py
index e7660b0..ec870c1 100644
--- a/integrations/tensorflow/e2e/dynamic_mlp_relu_test.py
+++ b/integrations/tensorflow/e2e/dynamic_mlp_relu_test.py
@@ -35,6 +35,7 @@
                input_dim=28 * 28,
                classes=10):
     super().__init__()
+    tf_test_utils.set_random_seed()
     self.hidden_1_dim = hidden_1_dim
     self.hidden_2_dim = hidden_2_dim
     self.input_dim = input_dim
diff --git a/integrations/tensorflow/e2e/dynamic_mlp_test.py b/integrations/tensorflow/e2e/dynamic_mlp_test.py
index 461ceed..d2dc527 100644
--- a/integrations/tensorflow/e2e/dynamic_mlp_test.py
+++ b/integrations/tensorflow/e2e/dynamic_mlp_test.py
@@ -31,6 +31,7 @@
                input_dim=28 * 28,
                classes=10):
     super().__init__()
+    tf_test_utils.set_random_seed()
     self.hidden_1_dim = hidden_1_dim
     self.hidden_2_dim = hidden_2_dim
     self.input_dim = input_dim
diff --git a/integrations/tensorflow/e2e/keras_lstm_static_test.py b/integrations/tensorflow/e2e/keras_lstm_static_test.py
index fb894f5..c62629a 100644
--- a/integrations/tensorflow/e2e/keras_lstm_static_test.py
+++ b/integrations/tensorflow/e2e/keras_lstm_static_test.py
@@ -16,36 +16,43 @@
 # This test is the same as keras_lstm_test, but all shapes are static.
 # This stresses the TensorList lowering more specifically.
 
+import numpy as np
 from pyiree.tf.support import tf_test_utils
 import tensorflow.compat.v2 as tf
 
 NUM_UNITS = 10
 NUM_TIMESTEPS = 24
 NUM_BATCH = 7
+INPUT_SHAPE = [NUM_BATCH, NUM_TIMESTEPS, NUM_UNITS]
 
 
-class Lstm(tf.Module):
-
-  def __init__(self):
-    super(Lstm, self).__init__()
-    self.lstm = tf.keras.layers.LSTM(units=NUM_UNITS, return_sequences=True)
-
-  @tf.function(input_signature=[
-      tf.TensorSpec([NUM_BATCH, NUM_TIMESTEPS, NUM_UNITS], tf.float32)
-  ])
-  def predict(self, x):
-    return self.lstm(x)
+def lstm_module():
+  tf_test_utils.set_random_seed()
+  inputs = tf.keras.layers.Input(batch_size=NUM_BATCH, shape=INPUT_SHAPE[1:])
+  outputs = tf.keras.layers.LSTM(units=NUM_UNITS, return_sequences=True)(inputs)
+  model = tf.keras.Model(inputs, outputs)
+  module = tf.Module()
+  module.m = model
+  module.predict = tf.function(
+      input_signature=[tf.TensorSpec(INPUT_SHAPE, tf.float32)])(
+          model.call)
+  return module
 
 
 # TODO(silvasean): Get this test working on other backends.
 @tf_test_utils.compile_modules(
-    backends=["tf", "iree_vmla"], lstm=(Lstm, ["predict"]))
+    backends=["tf", "iree_vmla"], lstm=(lstm_module, ["predict"]))
 class LstmTest(tf_test_utils.SavedModelTestCase):
 
   def test_lstm(self):
     m = self.modules.lstm.all
-    m.predict(tf.constant(0., shape=[NUM_BATCH, NUM_TIMESTEPS,
-                                     NUM_UNITS])).print().assert_all_close()
+    m.predict(
+        tf.constant(
+            np.arange(NUM_BATCH * NUM_TIMESTEPS * NUM_UNITS,
+                      dtype=np.float32).reshape(
+                          [NUM_BATCH, NUM_TIMESTEPS, NUM_UNITS]),
+            shape=[NUM_BATCH, NUM_TIMESTEPS,
+                   NUM_UNITS])).print().assert_all_close()
 
 
 if __name__ == "__main__":
diff --git a/integrations/tensorflow/e2e/keras_lstm_test.py b/integrations/tensorflow/e2e/keras_lstm_test.py
index 3ae1d1f..475083f 100644
--- a/integrations/tensorflow/e2e/keras_lstm_test.py
+++ b/integrations/tensorflow/e2e/keras_lstm_test.py
@@ -13,35 +13,43 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import numpy as np
 from pyiree.tf.support import tf_test_utils
 import tensorflow.compat.v2 as tf
 
 NUM_UNITS = 10
 NUM_TIMESTEPS = 24
 NUM_BATCH = 7
+INPUT_SHAPE = [None, None, NUM_UNITS]
 
 
-class Lstm(tf.Module):
-
-  def __init__(self):
-    super(Lstm, self).__init__()
-    self.lstm = tf.keras.layers.LSTM(units=NUM_UNITS, return_sequences=True)
-
-  @tf.function(
-      input_signature=[tf.TensorSpec([None, None, NUM_UNITS], tf.float32)])
-  def predict(self, x):
-    return self.lstm(x)
+def lstm_module():
+  tf_test_utils.set_random_seed()
+  inputs = tf.keras.layers.Input(batch_size=None, shape=INPUT_SHAPE[1:])
+  outputs = tf.keras.layers.LSTM(units=NUM_UNITS, return_sequences=True)(inputs)
+  model = tf.keras.Model(inputs, outputs)
+  module = tf.Module()
+  module.m = model
+  module.predict = tf.function(
+      input_signature=[tf.TensorSpec(INPUT_SHAPE, tf.float32)])(
+          model.call)
+  return module
 
 
 # TODO(silvasean): Get this test working on IREE.
 # Needs TensorList with current Keras implementation.
-@tf_test_utils.compile_modules(backends=["tf"], lstm=(Lstm, ["predict"]))
+@tf_test_utils.compile_modules(backends=["tf"], lstm=(lstm_module, ["predict"]))
 class LstmTest(tf_test_utils.SavedModelTestCase):
 
   def test_lstm(self):
     m = self.modules.lstm.all
-    m.predict(tf.constant(0., shape=[NUM_BATCH, NUM_TIMESTEPS,
-                                     NUM_UNITS])).print().assert_all_close()
+    m.predict(
+        tf.constant(
+            np.arange(NUM_BATCH * NUM_TIMESTEPS * NUM_UNITS,
+                      dtype=np.float32).reshape(
+                          [NUM_BATCH, NUM_TIMESTEPS, NUM_UNITS]),
+            shape=[NUM_BATCH, NUM_TIMESTEPS,
+                   NUM_UNITS])).print().assert_all_close()
 
 
 if __name__ == "__main__":
diff --git a/integrations/tensorflow/e2e/keras_vision_model_test.py b/integrations/tensorflow/e2e/keras_vision_model_test.py
index 02dd7af..dfa6b3b 100644
--- a/integrations/tensorflow/e2e/keras_vision_model_test.py
+++ b/integrations/tensorflow/e2e/keras_vision_model_test.py
@@ -73,7 +73,7 @@
   tf.keras.backend.set_learning_phase(False)
   # TODO(ataei): This should move somewhere in SavedModelTestCase, it should
   # guarantee test is deterministic.
-  tf.random.set_seed(0)
+  tf_test_utils.set_random_seed()
 
   # keras model receives images size as input,
   # where batch size is not specified - by default it is dynamic
diff --git a/integrations/tensorflow/e2e/strings_test.py b/integrations/tensorflow/e2e/strings_test.py
index e54e5b4..3d818dd 100644
--- a/integrations/tensorflow/e2e/strings_test.py
+++ b/integrations/tensorflow/e2e/strings_test.py
@@ -54,7 +54,7 @@
         [[12, 10, 29, 28, 94, 15, 24, 27, 94, 25, 21, 10, 34],
          [13, 24, 16, 28, 94, 15, 24, 27, 94, 28, 29, 10, 34]])
     result = self.modules.strings.all.strings_to_ids(input_ids)
-    result.assert_all_close()
+    result.assert_all_equal()
 
 
 if __name__ == "__main__":