Remove e2e test case decorator (#3304)
Replaces compiling a module with a decorator on a child of
`TracedModuleTestCase` with calling `tf_test_utils.compile_tf_module`
in the overridden `__init__` method of the child class.
This will make supporting alternative compilation paths easier.
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
index 87a4c8a..fa3df03 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
@@ -22,6 +22,7 @@
# ref: reference – for the reference CompiledModule
# tar: target - for one of the target CompiledModules
+import collections
import copy
import glob
import inspect
@@ -465,7 +466,7 @@
f"--input_file={compiled_path}",
f"--driver={self.backend_driver}",
f"--inputs={serialized_inputs}",
- f"--entry_function={entry_function}"
+ f"--entry_function={entry_function}",
]
with open(os.path.join(trace_dir, "flagfile"), "w") as f:
f.writelines(line + "\n" for line in flagfile)
@@ -551,7 +552,11 @@
return self._trace_call(module_attr, method_name=attr)
-def compile_module(
+Modules = collections.namedtuple('Modules',
+ ['ref_module', 'tar_modules', 'artifacts_dir'])
+
+
+def compile_tf_module(
module_class: Type[tf.Module], exported_names: Sequence[str] = ()
) -> Callable[[Any], Any]:
"""CompiledModuleTestCase decorator that compiles a tf.Module.
@@ -565,80 +570,38 @@
exported_names: optional iterable of strings representing which of
module_class's functions to compile. If exported_names is empty all
functions will be compiled.
-
- Returns:
- Class decorator function.
"""
- def decorator(cls):
- """Decorator Function."""
- if not issubclass(cls, TracedModuleTestCase):
- logging.exception(
- "The 'compile_module' decorator must be applied to a "
- "TracedModuleTestCase derived class, which %s is not.", cls)
- cls._module_class = module_class
- cls._exported_names = exported_names
- return cls
+ # Setup the directory for saving compilation artifacts and traces.
+ artifacts_dir = _setup_artifacts_dir(module_class.__name__)
- return decorator
+ # Get the backend information for this test.
+ ref_backend_info = tf_utils.BackendInfo(FLAGS.reference_backend,
+ f"{FLAGS.reference_backend}_ref")
+ tar_backend_infos = get_target_backends()
+ compile_backend = lambda backend_info: backend_info.compile(
+ module_class, exported_names, artifacts_dir)
-# Will be initialized by TracedModuleTestCase.setUpClass
-# Global variables are used because storing the compiler context on the cls
-# causes cleaning up refcounts to fail, and tf.test.TestCase wipes the variables
-# on the class instance (self.*) before each unittest.
-# TODO(#2900): Move these back to class variables when we figure out issues with
-# refcounting.
-_global_ref_module = None
-_global_tar_modules = None
+ ref_module = compile_backend(ref_backend_info)
+ tar_modules = [
+ compile_backend(backend_info) for backend_info in tar_backend_infos
+ ]
+ return Modules(ref_module, tar_modules, artifacts_dir)
class TracedModuleTestCase(tf.test.TestCase):
"""Compiles a tf.Module to multiple backends to test their correctness."""
- # Will be initialized by the @compile_module decorator.
- _module_class = None
- _exported_names = ()
-
- @classmethod
- def _compile(cls, backend_info: tf_utils.BackendInfo):
- return backend_info.compile(cls._module_class, cls._exported_names,
- cls._artifacts_dir)
-
- @classmethod
- def setUpClass(cls) -> None:
- # Ran before any of the unit tests.
- super().setUpClass()
- if cls._module_class is None:
- raise AttributeError(
- "setUpClass was called but no module was specified. Specify a module "
- "to compile via the @tf_test_utils.compile_module decorator.")
-
- # Setup the directory for saving compilation artifacts and traces.
- cls._artifacts_dir = _setup_artifacts_dir(cls._module_class.__name__)
-
- # Get the backend information for this test.
- ref_backend_info = tf_utils.BackendInfo(FLAGS.reference_backend,
- f"{FLAGS.reference_backend}_ref")
- tar_backend_infos = get_target_backends()
-
- global _global_ref_module
- global _global_tar_modules
- _global_ref_module = cls._compile(ref_backend_info)
- _global_tar_modules = [
- cls._compile(backend_info) for backend_info in tar_backend_infos
- ]
def setUp(self) -> None:
# Runs before each unit test.
super().setUp()
- global _global_ref_module
- global _global_tar_modules
- _global_ref_module.reinitialize()
- for module in _global_tar_modules:
+ self._modules.ref_module.reinitialize()
+ for module in self._modules.tar_modules:
module.reinitialize()
- def compare_backends(self, trace_function: Callable[[TracedModule],
- None]) -> None:
+ def compare_backends(self, trace_function: Callable[[TracedModule], None],
+ modules: Modules) -> None:
"""Run the reference and target backends on trace_function and compare them.
Random seeds for tensorflow, numpy and python are set before each invocation
@@ -648,17 +611,17 @@
trace_function: a function accepting a TracedModule as its argument.
"""
# Create Traces for each backend.
- ref_trace = Trace(_global_ref_module, trace_function)
+ ref_trace = Trace(modules.ref_module, trace_function)
tar_traces = [
- Trace(module, trace_function) for module in _global_tar_modules
+ Trace(module, trace_function) for module in modules.tar_modules
]
# Run the traces through trace_function with their associated modules.
tf_utils.set_random_seed()
- trace_function(TracedModule(_global_ref_module, ref_trace))
+ trace_function(TracedModule(modules.ref_module, ref_trace))
if FLAGS.log_all_traces:
logging.info(ref_trace)
- for module, trace in zip(_global_tar_modules, tar_traces):
+ for module, trace in zip(modules.tar_modules, tar_traces):
tf_utils.set_random_seed()
trace_function(TracedModule(module, trace))
if FLAGS.log_all_traces:
@@ -674,11 +637,11 @@
failed_backend_indices.append(i)
# Save the results to disk before validating.
- ref_trace_dir = _get_trace_dir(self._artifacts_dir, ref_trace)
+ ref_trace_dir = _get_trace_dir(modules.artifacts_dir, ref_trace)
ref_trace.save_plaintext(ref_trace_dir, FLAGS.summarize)
ref_trace.serialize(ref_trace_dir)
for tar_trace in tar_traces:
- tar_trace_dir = _get_trace_dir(self._artifacts_dir, tar_trace)
+ tar_trace_dir = _get_trace_dir(modules.artifacts_dir, tar_trace)
tar_trace.save_plaintext(tar_trace_dir, FLAGS.summarize)
tar_trace.serialize(tar_trace_dir)
diff --git a/integrations/tensorflow/e2e/README.md b/integrations/tensorflow/e2e/README.md
index ff95379..7adb281 100644
--- a/integrations/tensorflow/e2e/README.md
+++ b/integrations/tensorflow/e2e/README.md
@@ -80,11 +80,15 @@
backend as a source of truth. For example:
```python
-# Compile a `tf.Module` named `SimpleArithmeticModule` into a `CompiledModule`.
-@tf_test_utils.compile_module(SimpleArithmeticModule)
# Inherit from `TracedModuleTestCase`.
class SimpleArithmeticTest(tf_test_utils.TracedModuleTestCase):
+ def __init__(self, methodName="runTest"):
+ super(SimpleArithmeticTest, self).__init__(methodName)
+ # Compile a `tf.Module` named `SimpleArithmeticModule` into
+ # `CompiledModule`s for each reference and target backend.
+ self._modules = tf_test_utils.compile_tf_module(SimpleArithmeticModule)
+
# Unit test.
def test_simple_mul(self):
@@ -103,7 +107,7 @@
# Calls `simple_mul` once for each backend, recording the inputs and outputs
# to `module` and then comparing them.
- self.compare_backends(simple_mul)
+ self.compare_backends(simple_mul, self._modules)
```
## Test Suites
diff --git a/integrations/tensorflow/e2e/batch_norm_test.py b/integrations/tensorflow/e2e/batch_norm_test.py
index e16436d..4e3c0cd 100644
--- a/integrations/tensorflow/e2e/batch_norm_test.py
+++ b/integrations/tensorflow/e2e/batch_norm_test.py
@@ -14,6 +14,7 @@
# limitations under the License.
"""Batch norm tests."""
+from absl import app
import numpy as np
from pyiree.tf.support import tf_test_utils
from pyiree.tf.support import tf_utils
@@ -39,9 +40,12 @@
variance_epsilon=1e-4)
-@tf_test_utils.compile_module(BatchNormModule)
class BatchNormTest(tf_test_utils.TracedModuleTestCase):
+ def __init__(self, methodName="runTest"):
+ super(BatchNormTest, self).__init__(methodName)
+ self._modules = tf_test_utils.compile_tf_module(BatchNormModule)
+
def test_batch_norm_inference(self):
def batch_norm_inference(module):
@@ -53,10 +57,15 @@
scale = tf_utils.uniform((16,)) * 1e-3
module.batch_norm_inference(x, mean, variance, offset, scale)
- self.compare_backends(batch_norm_inference)
+ self.compare_backends(batch_norm_inference, self._modules)
-if __name__ == "__main__":
- if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+ del argv # Unused
+ if hasattr(tf, 'enable_v2_behavior'):
tf.enable_v2_behavior()
tf.test.main()
+
+
+if __name__ == '__main__':
+ app.run(main)
diff --git a/integrations/tensorflow/e2e/bool_test.py b/integrations/tensorflow/e2e/bool_test.py
index 161f6ad..ef9dd10 100644
--- a/integrations/tensorflow/e2e/bool_test.py
+++ b/integrations/tensorflow/e2e/bool_test.py
@@ -14,6 +14,7 @@
# limitations under the License.
"""Tests for ops in the tf.math module."""
+from absl import app
import numpy as np
from pyiree.tf.support import tf_test_utils
import tensorflow.compat.v2 as tf
@@ -37,22 +38,25 @@
return tf.math.logical_and(x, y)
-@tf_test_utils.compile_module(MathModule)
class BooleanTest(tf_test_utils.TracedModuleTestCase):
+ def __init__(self, methodName="runTest"):
+ super(BooleanTest, self).__init__(methodName)
+ self._modules = tf_test_utils.compile_tf_module(MathModule)
+
def test_constant(self):
def constant(module):
module.constant()
- self.compare_backends(constant)
+ self.compare_backends(constant, self._modules)
def test_greater_than(self):
def greater_than(module):
module.greater_than(np.array([0.0, 1.2, 1.5, 3.75], dtype=np.float32))
- self.compare_backends(greater_than)
+ self.compare_backends(greater_than, self._modules)
def test_logical_and(self):
@@ -61,10 +65,15 @@
np.array([True, True, False, False], dtype=np.bool),
np.array([True, False, False, True], dtype=np.bool))
- self.compare_backends(logical_and)
+ self.compare_backends(logical_and, self._modules)
-if __name__ == "__main__":
- if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+ del argv # Unused
+ if hasattr(tf, 'enable_v2_behavior'):
tf.enable_v2_behavior()
tf.test.main()
+
+
+if __name__ == '__main__':
+ app.run(main)
diff --git a/integrations/tensorflow/e2e/broadcast_to_test.py b/integrations/tensorflow/e2e/broadcast_to_test.py
index 6d57d6f..dc53f34 100644
--- a/integrations/tensorflow/e2e/broadcast_to_test.py
+++ b/integrations/tensorflow/e2e/broadcast_to_test.py
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from absl import app
import numpy as np
from pyiree.tf.support import tf_test_utils
import tensorflow.compat.v2 as tf
@@ -30,9 +31,12 @@
return tf.broadcast_to(x, shape)
-@tf_test_utils.compile_module(BroadcastToModule)
class BroadcastToTest(tf_test_utils.TracedModuleTestCase):
+ def __init__(self, methodName="runTest"):
+ super(BroadcastToTest, self).__init__(methodName)
+ self._modules = tf_test_utils.compile_tf_module(BroadcastToModule)
+
def test_scalar_broadcast_to(self):
def scalar_broadcast_to(module):
@@ -40,10 +44,15 @@
shape = np.array([3, 3], dtype=np.int32)
result = module.scalar_broadcast_to(x, shape)
- self.compare_backends(scalar_broadcast_to)
+ self.compare_backends(scalar_broadcast_to, self._modules)
-if __name__ == "__main__":
- if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+ del argv # Unused
+ if hasattr(tf, 'enable_v2_behavior'):
tf.enable_v2_behavior()
tf.test.main()
+
+
+if __name__ == '__main__':
+ app.run(main)
diff --git a/integrations/tensorflow/e2e/broadcasting_test.py b/integrations/tensorflow/e2e/broadcasting_test.py
index c4f8e38..f72f3b3 100644
--- a/integrations/tensorflow/e2e/broadcasting_test.py
+++ b/integrations/tensorflow/e2e/broadcasting_test.py
@@ -14,6 +14,7 @@
# limitations under the License.
"""Test broadcasting support."""
+from absl import app
import numpy as np
from pyiree.tf.support import tf_test_utils
from pyiree.tf.support import tf_utils
@@ -30,9 +31,12 @@
return lhs + rhs
-@tf_test_utils.compile_module(BroadcastingModule)
class BroadcastingTest(tf_test_utils.TracedModuleTestCase):
+ def __init__(self, methodName="runTest"):
+ super(BroadcastingTest, self).__init__(methodName)
+ self._modules = tf_test_utils.compile_tf_module(BroadcastingModule)
+
def test_add_same_shape(self):
def add_same_shape(module):
@@ -40,7 +44,7 @@
rhs = tf_utils.uniform([4])
module.add(lhs, rhs)
- self.compare_backends(add_same_shape)
+ self.compare_backends(add_same_shape, self._modules)
def test_add_broadcast_lhs(self):
@@ -49,7 +53,7 @@
rhs = tf_utils.uniform([4])
module.add(lhs, rhs)
- self.compare_backends(add_broadcast_lhs)
+ self.compare_backends(add_broadcast_lhs, self._modules)
def test_add_broadcast_rhs(self):
@@ -58,10 +62,15 @@
rhs = tf_utils.uniform([1])
module.add(lhs, rhs)
- self.compare_backends(add_broadcast_rhs)
+ self.compare_backends(add_broadcast_rhs, self._modules)
-if __name__ == "__main__":
- if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+ del argv # Unused
+ if hasattr(tf, 'enable_v2_behavior'):
tf.enable_v2_behavior()
tf.test.main()
+
+
+if __name__ == '__main__':
+ app.run(main)
diff --git a/integrations/tensorflow/e2e/complex_test.py b/integrations/tensorflow/e2e/complex_test.py
index 102396b..4832146 100644
--- a/integrations/tensorflow/e2e/complex_test.py
+++ b/integrations/tensorflow/e2e/complex_test.py
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from absl import app
import numpy as np
from pyiree.tf.support import tf_test_utils
import tensorflow.compat.v2 as tf
@@ -32,9 +33,12 @@
return tf.math.real(exp)
-@tf_test_utils.compile_module(ComplexModule)
class ComplexTest(tf_test_utils.TracedModuleTestCase):
+ def __init__(self, methodName="runTest"):
+ super(ComplexTest, self).__init__(methodName)
+ self._modules = tf_test_utils.compile_tf_module(ComplexModule)
+
def test_complex(self):
def complex_exp(module):
@@ -42,10 +46,15 @@
imag = np.array([-1., 0.4], dtype=np.float32)
module.complex_exp(real, imag)
- self.compare_backends(complex_exp)
+ self.compare_backends(complex_exp, self._modules)
-if __name__ == "__main__":
- if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+ del argv # Unused
+ if hasattr(tf, 'enable_v2_behavior'):
tf.enable_v2_behavior()
tf.test.main()
+
+
+if __name__ == '__main__':
+ app.run(main)
diff --git a/integrations/tensorflow/e2e/concat_test.py b/integrations/tensorflow/e2e/concat_test.py
index 187c6cb..dc0fab8 100644
--- a/integrations/tensorflow/e2e/concat_test.py
+++ b/integrations/tensorflow/e2e/concat_test.py
@@ -14,6 +14,7 @@
# limitations under the License.
"""Test concat op."""
+from absl import app
import numpy as np
from pyiree.tf.support import tf_test_utils
from pyiree.tf.support import tf_utils
@@ -51,9 +52,12 @@
return tf.concat([a, b], axis=2)
-@tf_test_utils.compile_module(ConcatOpsModule)
class ConcatOpsTest(tf_test_utils.TracedModuleTestCase):
+ def __init__(self, methodName="runTest"):
+ super(ConcatOpsTest, self).__init__(methodName)
+ self._modules = tf_test_utils.compile_tf_module(ConcatOpsModule)
+
def test_concat_zero_dim(self):
def concat_zero_dim(module):
@@ -61,7 +65,7 @@
b = tf_utils.uniform([1, 5, 1])
module.concat_zero_dim(a, b)
- self.compare_backends(concat_zero_dim)
+ self.compare_backends(concat_zero_dim, self._modules)
def test_concat0axis(self):
@@ -70,7 +74,7 @@
b = tf_utils.uniform([1, 5, 1])
module.concat0axis(a, b)
- self.compare_backends(concat0axis)
+ self.compare_backends(concat0axis, self._modules)
def test_concat1axis(self):
@@ -79,7 +83,7 @@
b = tf_utils.uniform([1, 5, 1])
module.concat1axis(a, b)
- self.compare_backends(concat1axis)
+ self.compare_backends(concat1axis, self._modules)
def test_concat2axis(self):
@@ -88,10 +92,15 @@
b = tf_utils.uniform([1, 5, 1])
module.concat2axis(a, b)
- self.compare_backends(concat2axis)
+ self.compare_backends(concat2axis, self._modules)
-if __name__ == "__main__":
- if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+ del argv # Unused
+ if hasattr(tf, 'enable_v2_behavior'):
tf.enable_v2_behavior()
tf.test.main()
+
+
+if __name__ == '__main__':
+ app.run(main)
diff --git a/integrations/tensorflow/e2e/control_flow_test.py b/integrations/tensorflow/e2e/control_flow_test.py
index bc0a328..4ba691a 100644
--- a/integrations/tensorflow/e2e/control_flow_test.py
+++ b/integrations/tensorflow/e2e/control_flow_test.py
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from absl import app
import numpy as np
from pyiree.tf.support import tf_test_utils
import tensorflow.compat.v2 as tf
@@ -34,16 +35,19 @@
return i
-@tf_test_utils.compile_module(ControlFlowModule)
class ControlFlowTest(tf_test_utils.TracedModuleTestCase):
+ def __init__(self, methodName="runTest"):
+ super(ControlFlowTest, self).__init__(methodName)
+ self._modules = tf_test_utils.compile_tf_module(ControlFlowModule)
+
def test_short_sequence(self):
def short_sequence(module):
input_array = np.array(9., dtype=np.float32)
module.collatz(input_array)
- self.compare_backends(short_sequence)
+ self.compare_backends(short_sequence, self._modules)
def test_long_sequence(self):
@@ -51,10 +55,15 @@
input_array = np.array(178., dtype=np.float32)
module.collatz(input_array)
- self.compare_backends(long_sequence)
+ self.compare_backends(long_sequence, self._modules)
-if __name__ == "__main__":
- if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+ del argv # Unused
+ if hasattr(tf, 'enable_v2_behavior'):
tf.enable_v2_behavior()
tf.test.main()
+
+
+if __name__ == '__main__':
+ app.run(main)
diff --git a/integrations/tensorflow/e2e/conv_test.py b/integrations/tensorflow/e2e/conv_test.py
index 9346a52..cff4496 100644
--- a/integrations/tensorflow/e2e/conv_test.py
+++ b/integrations/tensorflow/e2e/conv_test.py
@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from absl import app
import numpy as np
from pyiree.tf.support import tf_test_utils
from pyiree.tf.support import tf_utils
@@ -99,9 +100,12 @@
return tf.nn.conv2d(img, kernel, [1, 1, 1, 1], "VALID", name="result")
-@tf_test_utils.compile_module(Conv2dModule)
class ConvTest(tf_test_utils.TracedModuleTestCase):
+ def __init__(self, methodName="runTest"):
+ super(ConvTest, self).__init__(methodName)
+ self._modules = tf_test_utils.compile_tf_module(Conv2dModule)
+
def test_id_batch_size_1(self):
def id_batch_size_1(module):
@@ -109,7 +113,7 @@
k = np.ones([1, 1, 1, 1], dtype=np.float32)
module.conv2d_1451x1111_valid(i, k)
- self.compare_backends(id_batch_size_1)
+ self.compare_backends(id_batch_size_1, self._modules)
def test_id_batch_size_2(self):
@@ -118,7 +122,7 @@
k = np.ones([1, 1, 1, 1], dtype=np.float32)
module.conv2d_2451x1111_valid(i, k)
- self.compare_backends(id_batch_size_2)
+ self.compare_backends(id_batch_size_2, self._modules)
def test_asymmetric_kernel(self):
@@ -128,7 +132,7 @@
dtype=np.float32).reshape(2, 3, 1, 1)
module.conv2d_1451x2311_valid(i, k)
- self.compare_backends(asymmetric_kernel)
+ self.compare_backends(asymmetric_kernel, self._modules)
def test_padding(self):
@@ -138,7 +142,7 @@
dtype=np.float32).reshape(2, 3, 1, 1)
module.conv2d_1451x2311_same(i, k)
- self.compare_backends(padding)
+ self.compare_backends(padding, self._modules)
def test_batched_padding(self):
@@ -148,7 +152,7 @@
dtype=np.float32).reshape(2, 3, 1, 1)
module.conv2d_2451x2311_same(i, k)
- self.compare_backends(batched_padding)
+ self.compare_backends(batched_padding, self._modules)
def test_feature_reduce(self):
@@ -157,7 +161,7 @@
k = np.ones([3, 2, 2, 1], dtype=np.float32)
module.conv2d_1452x3221_same(i, k)
- self.compare_backends(feature_reduce)
+ self.compare_backends(feature_reduce, self._modules)
def test_feature_inflate(self):
@@ -166,7 +170,7 @@
k = tf_utils.ndarange([1, 1, 1, 2])
module.conv2d_1451x1112_same(i, k)
- self.compare_backends(feature_inflate)
+ self.compare_backends(feature_inflate, self._modules)
def test_feature_mix(self):
@@ -175,7 +179,7 @@
k = tf_utils.ndarange([1, 1, 2, 2])
module.conv2d_1452x1122_same(i, k)
- self.compare_backends(feature_mix)
+ self.compare_backends(feature_mix, self._modules)
def test_feature_padded(self):
@@ -184,7 +188,7 @@
k = tf_utils.ndarange([2, 2, 2, 3])
module.conv2d_1452x2223_same(i, k)
- self.compare_backends(feature_padded)
+ self.compare_backends(feature_padded, self._modules)
def test_feature_unpadded(self):
@@ -193,7 +197,7 @@
k = tf_utils.ndarange([2, 2, 2, 3])
module.conv2d_1452x2223_valid(i, k)
- self.compare_backends(feature_unpadded)
+ self.compare_backends(feature_unpadded, self._modules)
def test_batched_feature_unpadded(self):
@@ -202,10 +206,15 @@
k = tf_utils.ndarange([2, 2, 2, 3])
module.conv2d_2452x2223_valid(i, k)
- self.compare_backends(batched_feature_unpadded)
+ self.compare_backends(batched_feature_unpadded, self._modules)
-if __name__ == "__main__":
- if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+ del argv # Unused
+ if hasattr(tf, 'enable_v2_behavior'):
tf.enable_v2_behavior()
tf.test.main()
+
+
+if __name__ == '__main__':
+ app.run(main)
diff --git a/integrations/tensorflow/e2e/depth_conv_test.py b/integrations/tensorflow/e2e/depth_conv_test.py
index e3ba303..c928d5e 100644
--- a/integrations/tensorflow/e2e/depth_conv_test.py
+++ b/integrations/tensorflow/e2e/depth_conv_test.py
@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from absl import app
import numpy as np
from pyiree.tf.support import tf_test_utils
from pyiree.tf.support import tf_utils
@@ -27,45 +28,58 @@
tf.TensorSpec([2, 2, 2, 3], tf.float32),
])
def conv2d_2452x2423_valid(self, img, kernel):
- return tf.nn.depthwise_conv2d(
- img, kernel, [1, 1, 1, 1], "VALID", name="result")
+ return tf.nn.depthwise_conv2d(img,
+ kernel, [1, 1, 1, 1],
+ "VALID",
+ name="result")
@tf.function(input_signature=[
tf.TensorSpec([2, 4, 5, 2], tf.float32),
tf.TensorSpec([2, 4, 2, 3], tf.float32),
])
def conv2d_2452x2423_same(self, img, kernel):
- return tf.nn.depthwise_conv2d(
- img, kernel, [1, 1, 1, 1], "SAME", name="result")
+ return tf.nn.depthwise_conv2d(img,
+ kernel, [1, 1, 1, 1],
+ "SAME",
+ name="result")
@tf.function(input_signature=[
tf.TensorSpec([2, 4, 5, 2], tf.float32),
tf.TensorSpec([2, 4, 2, 3], tf.float32),
])
def conv2d_2452x2423_valid_stride_2(self, img, kernel):
- return tf.nn.depthwise_conv2d(
- img, kernel, [1, 2, 2, 1], "VALID", name="result")
+ return tf.nn.depthwise_conv2d(img,
+ kernel, [1, 2, 2, 1],
+ "VALID",
+ name="result")
@tf.function(input_signature=[
tf.TensorSpec([2, 4, 5, 2], tf.float32),
tf.TensorSpec([2, 4, 2, 3], tf.float32),
])
def conv2d_2452x2423_same_stride_2(self, img, kernel):
- return tf.nn.depthwise_conv2d(
- img, kernel, [1, 2, 2, 1], "SAME", name="result")
+ return tf.nn.depthwise_conv2d(img,
+ kernel, [1, 2, 2, 1],
+ "SAME",
+ name="result")
@tf.function(input_signature=[
tf.TensorSpec([2, 4, 5, 4], tf.float32),
tf.TensorSpec([2, 4, 4, 1], tf.float32),
])
def conv2d_2453x2441_same_stride_1(self, img, kernel):
- return tf.nn.depthwise_conv2d(
- img, kernel, [1, 1, 1, 1], "SAME", name="result")
+ return tf.nn.depthwise_conv2d(img,
+ kernel, [1, 1, 1, 1],
+ "SAME",
+ name="result")
-@tf_test_utils.compile_module(DepthConv2dModule)
class ConvTest(tf_test_utils.TracedModuleTestCase):
+ def __init__(self, methodName="runTest"):
+ super(ConvTest, self).__init__(methodName)
+ self._modules = tf_test_utils.compile_tf_module(DepthConv2dModule)
+
def test_batched_feature_unpadded(self):
def batched_feature_unpadded(module):
@@ -73,7 +87,7 @@
k = tf_utils.ndarange([2, 2, 2, 3])
module.conv2d_2452x2423_valid(i, k)
- self.compare_backends(batched_feature_unpadded)
+ self.compare_backends(batched_feature_unpadded, self._modules)
def test_batched_feature_unpadded_same(self):
@@ -82,7 +96,7 @@
k = tf_utils.ndarange([2, 4, 2, 3])
module.conv2d_2452x2423_same(i, k)
- self.compare_backends(batched_feature_unpadded_same)
+ self.compare_backends(batched_feature_unpadded_same, self._modules)
def test_batched_feature_unpadded_same_stride_2(self):
@@ -91,7 +105,8 @@
k = tf_utils.ndarange([2, 4, 2, 3])
module.conv2d_2452x2423_valid_stride_2(i, k)
- self.compare_backends(batched_feature_unpadded_same_stride_2)
+ self.compare_backends(batched_feature_unpadded_same_stride_2,
+ self._modules)
def test_batched_feature_padded_same_stride_2(self):
@@ -100,7 +115,7 @@
k = tf_utils.ndarange([2, 4, 2, 3])
module.conv2d_2452x2423_same_stride_2(i, k)
- self.compare_backends(batched_feature_padded_same_stride_2)
+ self.compare_backends(batched_feature_padded_same_stride_2, self._modules)
def test_batched_feature_padded_same_stride_1_output_1(self):
@@ -109,10 +124,16 @@
k = tf_utils.ndarange([2, 4, 4, 1])
module.conv2d_2453x2441_same_stride_1(i, k)
- self.compare_backends(batched_feature_padded_same_stride_1_output_1)
+ self.compare_backends(batched_feature_padded_same_stride_1_output_1,
+ self._modules)
-if __name__ == "__main__":
- if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+ del argv # Unused
+ if hasattr(tf, 'enable_v2_behavior'):
tf.enable_v2_behavior()
tf.test.main()
+
+
+if __name__ == '__main__':
+ app.run(main)
diff --git a/integrations/tensorflow/e2e/dynamic_mlp_relu_test.py b/integrations/tensorflow/e2e/dynamic_mlp_relu_test.py
index 3742c6e..1e5ba08 100644
--- a/integrations/tensorflow/e2e/dynamic_mlp_relu_test.py
+++ b/integrations/tensorflow/e2e/dynamic_mlp_relu_test.py
@@ -17,6 +17,7 @@
# This uses a relu instead, allowing it to get to the remaining issue
# (unimplemented dynamic dot_general).
+from absl import app
import numpy as np
from pyiree.tf.support import tf_test_utils
from pyiree.tf.support import tf_utils
@@ -42,8 +43,8 @@
self.input_dim = input_dim
self.classes = classes
self.h1_weights = tf.Variable(tf.random.normal([input_dim, hidden_1_dim]))
- self.h2_weights = tf.Variable(
- tf.random.normal([hidden_1_dim, hidden_2_dim]))
+ self.h2_weights = tf.Variable(tf.random.normal([hidden_1_dim,
+ hidden_2_dim]))
self.out_weights = tf.Variable(tf.random.normal([hidden_2_dim, classes]))
self.h1_bias = tf.Variable(tf.random.normal([hidden_1_dim]))
self.h2_bias = tf.Variable(tf.random.normal([hidden_2_dim]))
@@ -51,8 +52,7 @@
# Compile with dynamic batch dim.
self.predict = tf.function(
- input_signature=[tf.TensorSpec([None, self.input_dim])])(
- self.predict)
+ input_signature=[tf.TensorSpec([None, self.input_dim])])(self.predict)
def mlp(self, x):
layer_1 = tf.nn.relu(tf.add(tf.matmul(x, self.h1_weights), self.h1_bias))
@@ -65,19 +65,28 @@
return tf.nn.softmax(self.mlp(x))
-@tf_test_utils.compile_module(DynamicMlpReluModule, exported_names=["predict"])
class DynamicMlpReluTest(tf_test_utils.TracedModuleTestCase):
+ def __init__(self, methodName="runTest"):
+ super(DynamicMlpReluTest, self).__init__(methodName)
+ self._modules = tf_test_utils.compile_tf_module(DynamicMlpReluModule,
+ exported_names=["predict"])
+
def test_dynamic_batch(self):
def dynamic_batch(module):
x = tf_utils.uniform([3, 28 * 28]) * 1e-3
module.predict(x)
- self.compare_backends(dynamic_batch)
+ self.compare_backends(dynamic_batch, self._modules)
-if __name__ == "__main__":
- if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+ del argv # Unused
+ if hasattr(tf, 'enable_v2_behavior'):
tf.enable_v2_behavior()
tf.test.main()
+
+
+if __name__ == '__main__':
+ app.run(main)
diff --git a/integrations/tensorflow/e2e/dynamic_mlp_test.py b/integrations/tensorflow/e2e/dynamic_mlp_test.py
index ff905a5..a0ecdc5 100644
--- a/integrations/tensorflow/e2e/dynamic_mlp_test.py
+++ b/integrations/tensorflow/e2e/dynamic_mlp_test.py
@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from absl import app
import numpy as np
from pyiree.tf.support import tf_test_utils
from pyiree.tf.support import tf_utils
@@ -38,8 +39,8 @@
self.input_dim = input_dim
self.classes = classes
self.h1_weights = tf.Variable(tf.random.normal([input_dim, hidden_1_dim]))
- self.h2_weights = tf.Variable(
- tf.random.normal([hidden_1_dim, hidden_2_dim]))
+ self.h2_weights = tf.Variable(tf.random.normal([hidden_1_dim,
+ hidden_2_dim]))
self.out_weights = tf.Variable(tf.random.normal([hidden_2_dim, classes]))
self.h1_bias = tf.Variable(tf.random.normal([hidden_1_dim]))
self.h2_bias = tf.Variable(tf.random.normal([hidden_2_dim]))
@@ -47,8 +48,7 @@
# Compile with dynamic batch dim.
self.predict = tf.function(
- input_signature=[tf.TensorSpec([None, self.input_dim])])(
- self.predict)
+ input_signature=[tf.TensorSpec([None, self.input_dim])])(self.predict)
def mlp(self, x):
layer_1 = tf.sigmoid(tf.add(tf.matmul(x, self.h1_weights), self.h1_bias))
@@ -61,19 +61,28 @@
return tf.nn.softmax(self.mlp(x))
-@tf_test_utils.compile_module(DynamicMlpModule, exported_names=["predict"])
class DynamicMlpTest(tf_test_utils.TracedModuleTestCase):
+ def __init__(self, methodName="runTest"):
+ super(DynamicMlpTest, self).__init__(methodName)
+ self._modules = tf_test_utils.compile_tf_module(DynamicMlpModule,
+ exported_names=["predict"])
+
def test_dynamic_batch(self):
def dynamic_batch(module):
x = tf_utils.uniform([3, 28 * 28]) * 1e-3
module.predict(x)
- self.compare_backends(dynamic_batch)
+ self.compare_backends(dynamic_batch, self._modules)
-if __name__ == "__main__":
- if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+ del argv # Unused
+ if hasattr(tf, 'enable_v2_behavior'):
tf.enable_v2_behavior()
tf.test.main()
+
+
+if __name__ == '__main__':
+ app.run(main)
diff --git a/integrations/tensorflow/e2e/fill_test.py b/integrations/tensorflow/e2e/fill_test.py
index 050eb47..adaeac9 100644
--- a/integrations/tensorflow/e2e/fill_test.py
+++ b/integrations/tensorflow/e2e/fill_test.py
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from absl import app
import numpy as np
from pyiree.tf.support import tf_test_utils
import tensorflow.compat.v2 as tf
@@ -30,9 +31,12 @@
return tf.fill(dims, value)
-@tf_test_utils.compile_module(FillModule)
class FillTest(tf_test_utils.TracedModuleTestCase):
+ def __init__(self, methodName="runTest"):
+ super(FillTest, self).__init__(methodName)
+ self._modules = tf_test_utils.compile_tf_module(FillModule)
+
def test_fill(self):
def fill(module):
@@ -40,10 +44,15 @@
value = np.array(9., dtype=np.float32)
module.fill(dims, value)
- self.compare_backends(fill)
+ self.compare_backends(fill, self._modules)
-if __name__ == "__main__":
- if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+ del argv # Unused
+ if hasattr(tf, 'enable_v2_behavior'):
tf.enable_v2_behavior()
tf.test.main()
+
+
+if __name__ == '__main__':
+ app.run(main)
diff --git a/integrations/tensorflow/e2e/finite_test.py b/integrations/tensorflow/e2e/finite_test.py
index ff62f3a..1f53406 100644
--- a/integrations/tensorflow/e2e/finite_test.py
+++ b/integrations/tensorflow/e2e/finite_test.py
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from absl import app
import numpy as np
from pyiree.tf.support import tf_test_utils
import tensorflow.compat.v2 as tf
@@ -24,18 +25,26 @@
return tf.math.is_finite(x)
-@tf_test_utils.compile_module(FiniteModule)
class FiniteTest(tf_test_utils.TracedModuleTestCase):
+ def __init__(self, methodName="runTest"):
+ super(FiniteTest, self).__init__(methodName)
+ self._modules = tf_test_utils.compile_tf_module(FiniteModule)
+
def test_finite(self):
def finite(module):
module.finite(np.array([0.0, 1.2, -5.0, np.inf], dtype=np.float32))
- self.compare_backends(finite)
+ self.compare_backends(finite, self._modules)
-if __name__ == "__main__":
- if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+ del argv # Unused
+ if hasattr(tf, 'enable_v2_behavior'):
tf.enable_v2_behavior()
tf.test.main()
+
+
+if __name__ == '__main__':
+ app.run(main)
diff --git a/integrations/tensorflow/e2e/gather_test.py b/integrations/tensorflow/e2e/gather_test.py
index fc0ec13..da71d4b 100644
--- a/integrations/tensorflow/e2e/gather_test.py
+++ b/integrations/tensorflow/e2e/gather_test.py
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from absl import app
import numpy as np
from pyiree.tf.support import tf_test_utils
from pyiree.tf.support import tf_utils
@@ -65,9 +66,12 @@
return tf.gather(params, indices, axis=1, batch_dims=1)
-@tf_test_utils.compile_module(GatherModule)
class GatherTest(tf_test_utils.TracedModuleTestCase):
+ def __init__(self, methodName="runTest"):
+ super(GatherTest, self).__init__(methodName)
+ self._modules = tf_test_utils.compile_tf_module(GatherModule)
+
def test_gather_axis0_scalar(self):
def gather_axis0_scalar(module):
@@ -75,7 +79,7 @@
params = tf_utils.ndarange([4, 8])
module.gather_axis0_scalar(params, indices)
- self.compare_backends(gather_axis0_scalar)
+ self.compare_backends(gather_axis0_scalar, self._modules)
def test_gather_axis0_batch0(self):
@@ -84,7 +88,7 @@
params = tf_utils.ndarange([4, 8])
module.gather_axis0_batch0(params, indices)
- self.compare_backends(gather_axis0_batch0)
+ self.compare_backends(gather_axis0_batch0, self._modules)
def test_gather_axis1_batch0(self):
@@ -93,7 +97,7 @@
params = tf_utils.ndarange([4, 7, 8])
module.gather_axis1_batch0(params, indices)
- self.compare_backends(gather_axis1_batch0)
+ self.compare_backends(gather_axis1_batch0, self._modules)
def test_gather_axis2_batch1(self):
@@ -102,7 +106,7 @@
params = tf_utils.ndarange([4, 7, 8, 2])
module.gather_axis2_batch1(params, indices)
- self.compare_backends(gather_axis2_batch1)
+ self.compare_backends(gather_axis2_batch1, self._modules)
def test_gather_axis1_batch1(self):
@@ -111,7 +115,7 @@
params = tf_utils.ndarange([4, 7, 8, 2])
module.gather_axis1_batch1(params, indices)
- self.compare_backends(gather_axis1_batch1)
+ self.compare_backends(gather_axis1_batch1, self._modules)
def test_gather_axis2_batch2(self):
@@ -120,11 +124,16 @@
values = np.array([[0, 1, 2, 3], [9, 8, 7, 0]], dtype=np.int32)
module.gather_axis2_batch2(values, indices)
- self.compare_backends(gather_axis2_batch2)
+ self.compare_backends(gather_axis2_batch2, self._modules)
-if __name__ == "__main__":
- if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+ del argv # Unused
+ if hasattr(tf, 'enable_v2_behavior'):
tf.enable_v2_behavior()
tf.test.main()
+
+
+if __name__ == '__main__':
+ app.run(main)
diff --git a/integrations/tensorflow/e2e/keras/lstm_static_test.py b/integrations/tensorflow/e2e/keras/lstm_static_test.py
index db8f86e..5cb1827 100644
--- a/integrations/tensorflow/e2e/keras/lstm_static_test.py
+++ b/integrations/tensorflow/e2e/keras/lstm_static_test.py
@@ -16,6 +16,7 @@
# This test is the same as keras_lstm_test, but all shapes are static.
# This stresses the TensorList lowering more specifically.
+from absl import app
import numpy as np
from pyiree.tf.support import tf_test_utils
from pyiree.tf.support import tf_utils
@@ -42,19 +43,28 @@
self.m.call)
-@tf_test_utils.compile_module(LstmStaticModule, exported_names=["predict"])
class LstmStaticTest(tf_test_utils.TracedModuleTestCase):
+ def __init__(self, methodName="runTest"):
+ super(LstmStaticTest, self).__init__(methodName)
+ self._modules = tf_test_utils.compile_tf_module(LstmStaticModule,
+ exported_names=["predict"])
+
def test_lstm(self):
def predict(module):
inputs = tf_utils.ndarange(INPUT_SHAPE)
module.predict(inputs, rtol=1e-5, atol=1e-5)
- self.compare_backends(predict)
+ self.compare_backends(predict, self._modules)
-if __name__ == "__main__":
- if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+ del argv # Unused
+ if hasattr(tf, 'enable_v2_behavior'):
tf.enable_v2_behavior()
tf.test.main()
+
+
+if __name__ == '__main__':
+ app.run(main)
diff --git a/integrations/tensorflow/e2e/keras/lstm_test.py b/integrations/tensorflow/e2e/keras/lstm_test.py
index 112e32a..747cc6f 100644
--- a/integrations/tensorflow/e2e/keras/lstm_test.py
+++ b/integrations/tensorflow/e2e/keras/lstm_test.py
@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from absl import app
import numpy as np
from pyiree.tf.support import tf_test_utils
from pyiree.tf.support import tf_utils
@@ -31,28 +32,36 @@
super(LstmModule, self).__init__()
tf_utils.set_random_seed()
inputs = tf.keras.layers.Input(batch_size=None, shape=DYNAMIC_SHAPE[1:])
- outputs = tf.keras.layers.LSTM(
- units=NUM_UNITS, return_sequences=True)(
- inputs)
+ outputs = tf.keras.layers.LSTM(units=NUM_UNITS,
+ return_sequences=True)(inputs)
self.m = tf.keras.Model(inputs, outputs)
self.predict = tf.function(
- input_signature=[tf.TensorSpec(DYNAMIC_SHAPE, tf.float32)])(
- self.m.call)
+ input_signature=[tf.TensorSpec(DYNAMIC_SHAPE, tf.float32)])(self.m.call)
-@tf_test_utils.compile_module(LstmModule, exported_names=["predict"])
class LstmTest(tf_test_utils.TracedModuleTestCase):
+ def __init__(self, methodName="runTest"):
+ super(LstmTest, self).__init__(methodName)
+ self._modules = tf_test_utils.compile_tf_module(LstmModule,
+ exported_names=["predict"])
+
def test_lstm(self):
def predict(module):
inputs = tf_utils.ndarange(INPUT_SHAPE)
module.predict(inputs)
- self.compare_backends(predict)
+ self.compare_backends(predict, self._modules)
-if __name__ == "__main__":
- if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+ del argv # Unused
+ if hasattr(tf, 'enable_v2_behavior'):
tf.enable_v2_behavior()
+
tf.test.main()
+
+
+if __name__ == '__main__':
+ app.run(main)
diff --git a/integrations/tensorflow/e2e/keras/train/model_train_test.py b/integrations/tensorflow/e2e/keras/train/model_train_test.py
index 75286ad..49a3949 100644
--- a/integrations/tensorflow/e2e/keras/train/model_train_test.py
+++ b/integrations/tensorflow/e2e/keras/train/model_train_test.py
@@ -75,8 +75,6 @@
return loss_value
-@tf_test_utils.compile_module(
- ModelTrain.CreateModule, exported_names=["train_step"])
class ModelTrainTest(tf_test_utils.TracedModuleTestCase):
def generate_regression_data(self, size=8):
@@ -104,11 +102,12 @@
# Run one iteration of training step.
module.train_step(inputs, targets)
- self.compare_backends(train_step)
+ self.compare_backends(train_step, self._modules)
if __name__ == "__main__":
if hasattr(tf, "enable_v2_behavior"):
tf.enable_v2_behavior()
-
+ tf_test_utils.compile_tf_module(ModelTrain.CreateModule,
+ exported_names=["train_step"])
tf.test.main()
diff --git a/integrations/tensorflow/e2e/keras/vision_model_test.py b/integrations/tensorflow/e2e/keras/vision_model_test.py
index 10663ec..4e1ca64 100644
--- a/integrations/tensorflow/e2e/keras/vision_model_test.py
+++ b/integrations/tensorflow/e2e/keras/vision_model_test.py
@@ -113,8 +113,9 @@
# an external tf.keras URL.
weights = 'imagenet' if FLAGS.data == 'imagenet' else None
- model = APP_MODELS[FLAGS.model](
- weights=weights, include_top=FLAGS.include_top, input_shape=input_shape)
+ model = APP_MODELS[FLAGS.model](weights=weights,
+ include_top=FLAGS.include_top,
+ input_shape=input_shape)
if FLAGS.data == 'cifar10' and FLAGS.url:
model = load_cifar10_weights(model)
@@ -131,19 +132,22 @@
# TODO(b/142948097): Add support for dynamic shapes in SPIR-V lowering.
# Replace input_shape with m.input_shape to make the batch size dynamic.
self.predict = tf.function(
- input_signature=[tf.TensorSpec(get_input_shape())])(
- self.m.predict)
+ input_signature=[tf.TensorSpec(get_input_shape())])(self.m.predict)
-@tf_test_utils.compile_module(VisionModule, exported_names=['predict'])
class AppTest(tf_test_utils.TracedModuleTestCase):
+ def __init__(self, methodName="runTest"):
+ super(AppTest, self).__init__(methodName)
+ self._modules = tf_test_utils.compile_tf_module(VisionModule,
+ exported_names=['predict'])
+
def test_application(self):
def predict(module):
module.predict(tf_utils.uniform(get_input_shape()))
- self.compare_backends(predict)
+ self.compare_backends(predict, self._modules)
def main(argv):
diff --git a/integrations/tensorflow/e2e/linspace_test.py b/integrations/tensorflow/e2e/linspace_test.py
index b535021..4acc170 100644
--- a/integrations/tensorflow/e2e/linspace_test.py
+++ b/integrations/tensorflow/e2e/linspace_test.py
@@ -12,12 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from absl import app
import numpy as np
from pyiree.tf.support import tf_test_utils
import tensorflow.compat.v2 as tf
-class LinSpaceModule(tf.Module):
+class LinspaceModule(tf.Module):
def __init__(self):
pass
@@ -33,9 +34,12 @@
return tf.linspace(start, stop, num)
-@tf_test_utils.compile_module(LinSpaceModule)
class LinspaceTest(tf_test_utils.TracedModuleTestCase):
+ def __init__(self, methodName="runTest"):
+ super(LinspaceTest, self).__init__(methodName)
+ self._modules = tf_test_utils.compile_tf_module(LinspaceModule)
+
def test_linspace(self):
def linspace(module):
@@ -43,10 +47,15 @@
stop = np.array(12., dtype=np.float32)
module.linspace(start, stop)
- self.compare_backends(linspace)
+ self.compare_backends(linspace, self._modules)
-if __name__ == "__main__":
- if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+ del argv # Unused
+ if hasattr(tf, 'enable_v2_behavior'):
tf.enable_v2_behavior()
tf.test.main()
+
+
+if __name__ == '__main__':
+ app.run(main)
diff --git a/integrations/tensorflow/e2e/logical_ops_test.py b/integrations/tensorflow/e2e/logical_ops_test.py
index ea3fd7a..ee83a95 100644
--- a/integrations/tensorflow/e2e/logical_ops_test.py
+++ b/integrations/tensorflow/e2e/logical_ops_test.py
@@ -13,6 +13,7 @@
# limitations under the License.
"""Tests for ops in the tf.math module that specifically handle logical ops."""
+from absl import app
import numpy as np
from pyiree.tf.support import tf_test_utils
import tensorflow.compat.v2 as tf
@@ -46,9 +47,12 @@
return tf.math.logical_not(x)
-@tf_test_utils.compile_module(LogicalOpsModule)
class LogicalOpsTest(tf_test_utils.TracedModuleTestCase):
+ def __init__(self, methodName="runTest"):
+ super(LogicalOpsTest, self).__init__(methodName)
+ self._modules = tf_test_utils.compile_tf_module(LogicalOpsModule)
+
def test_logical_and(self):
def logical_and(module):
@@ -56,7 +60,7 @@
np.array([1, 1, 0, 0], dtype=np.bool),
np.array([0, 1, 1, 0], dtype=np.bool))
- self.compare_backends(logical_and)
+ self.compare_backends(logical_and, self._modules)
def test_logical_or(self):
@@ -65,7 +69,7 @@
np.array([1, 1, 0, 0], dtype=np.bool),
np.array([0, 1, 1, 0], dtype=np.bool))
- self.compare_backends(logical_or)
+ self.compare_backends(logical_or, self._modules)
def test_logical_xor(self):
@@ -74,17 +78,22 @@
np.array([1, 1, 0, 0], dtype=np.bool),
np.array([0, 1, 1, 0], dtype=np.bool))
- self.compare_backends(logical_xor)
+ self.compare_backends(logical_xor, self._modules)
def test_logical_not(self):
def logical_not(module):
module.logical_not(np.array([0, 1, 1, 0], dtype=np.bool))
- self.compare_backends(logical_not)
+ self.compare_backends(logical_not, self._modules)
-if __name__ == "__main__":
- if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+ del argv # Unused
+ if hasattr(tf, 'enable_v2_behavior'):
tf.enable_v2_behavior()
tf.test.main()
+
+
+if __name__ == '__main__':
+ app.run(main)
diff --git a/integrations/tensorflow/e2e/mandelbrot_test.py b/integrations/tensorflow/e2e/mandelbrot_test.py
index b5a3929..51c1509 100644
--- a/integrations/tensorflow/e2e/mandelbrot_test.py
+++ b/integrations/tensorflow/e2e/mandelbrot_test.py
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from absl import app
from pyiree.tf.support import tf_test_utils
import tensorflow.compat.v2 as tf
@@ -90,9 +91,12 @@
return tf.reshape(in_the_set, shape=[view_pixels, view_pixels])
-@tf_test_utils.compile_module(MandelbrotModule)
class MandelbrotTest(tf_test_utils.TracedModuleTestCase):
+ def __init__(self, methodName="runTest"):
+ super(MandelbrotTest, self).__init__(methodName)
+ self._modules = tf_test_utils.compile_tf_module(MandelbrotModule)
+
def test_mandelbrot(self):
def mandelbrot(module):
@@ -101,10 +105,15 @@
# This is a much more detailed view, so more iterations are needed.
module.calculate(-0.7436447860, 0.1318252536, 0.0000029336, 400, 3000)
- self.compare_backends(mandelbrot)
+ self.compare_backends(mandelbrot, self._modules)
-if __name__ == "__main__":
- if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+ del argv # Unused
+ if hasattr(tf, 'enable_v2_behavior'):
tf.enable_v2_behavior()
tf.test.main()
+
+
+if __name__ == '__main__':
+ app.run(main)
diff --git a/integrations/tensorflow/e2e/math_test.py b/integrations/tensorflow/e2e/math_test.py
index 48bcbb6..a96193a 100644
--- a/integrations/tensorflow/e2e/math_test.py
+++ b/integrations/tensorflow/e2e/math_test.py
@@ -14,6 +14,7 @@
# limitations under the License.
"""Tests for ops in the tf.math module."""
+from absl import app
import numpy as np
from pyiree.tf.support import tf_test_utils
import tensorflow.compat.v2 as tf
@@ -42,46 +43,54 @@
return tf.math.mod(x, 2.0)
-@tf_test_utils.compile_module(MathModule)
class MathTest(tf_test_utils.TracedModuleTestCase):
+ def __init__(self, methodName="runTest"):
+ super(MathTest, self).__init__(methodName)
+ self._modules = tf_test_utils.compile_tf_module(MathModule)
+
def test_abs(self):
def abs(module):
module.abs(np.array([-0.5, 0.0, 0.5, 1.0], dtype=np.float32))
- self.compare_backends(abs)
+ self.compare_backends(abs, self._modules)
def test_ceil(self):
def ceil(module):
module.ceil(np.array([0.0, 1.2, 1.5, 3.75], dtype=np.float32))
- self.compare_backends(ceil)
+ self.compare_backends(ceil, self._modules)
def test_cos(self):
def cos(module):
module.cos(np.array([-0.5, 0.0, 0.5, 1.0], dtype=np.float32))
- self.compare_backends(cos)
+ self.compare_backends(cos, self._modules)
def test_log(self):
def log(module):
module.log(np.array([0.1, 0.2, 0.5, 1.0], dtype=np.float32))
- self.compare_backends(log)
+ self.compare_backends(log, self._modules)
def test_mod(self):
def mod(module):
module.mod(np.array([0.0, 1.2, 1.5, 3.75], dtype=np.float32))
- self.compare_backends(mod)
+ self.compare_backends(mod, self._modules)
-if __name__ == "__main__":
- if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+ del argv # Unused
+ if hasattr(tf, 'enable_v2_behavior'):
tf.enable_v2_behavior()
tf.test.main()
+
+
+if __name__ == '__main__':
+ app.run(main)
diff --git a/integrations/tensorflow/e2e/matrix_ops_dynamic_test.py b/integrations/tensorflow/e2e/matrix_ops_dynamic_test.py
index 9c87b7b..4a3daf7 100644
--- a/integrations/tensorflow/e2e/matrix_ops_dynamic_test.py
+++ b/integrations/tensorflow/e2e/matrix_ops_dynamic_test.py
@@ -14,6 +14,7 @@
# limitations under the License.
"""Test matrix ops."""
+from absl import app
from pyiree.tf.support import tf_test_utils
from pyiree.tf.support import tf_utils
import tensorflow.compat.v2 as tf
@@ -43,16 +44,19 @@
return tf.matmul(lhs, rhs)
-@tf_test_utils.compile_module(MatrixOpsDynamicModule)
class MatrixOpsDynamicTest(tf_test_utils.TracedModuleTestCase):
+ def __init__(self, methodName="runTest"):
+ super(MatrixOpsDynamicTest, self).__init__(methodName)
+ self._modules = tf_test_utils.compile_tf_module(MatrixOpsDynamicModule)
+
def test_matmul_high_rank_batch(self):
def matmul_high_rank_batch(module):
module.matmul_high_rank_batch(
tf_utils.uniform([1, 7, 4, 2]), tf_utils.uniform([7, 1, 2, 4]))
- self.compare_backends(matmul_high_rank_batch)
+ self.compare_backends(matmul_high_rank_batch, self._modules)
def test_matmul_dynamic_matching_batch(self):
@@ -60,7 +64,7 @@
module.matmul_dynamic(
tf_utils.uniform([2, 2, 3]), tf_utils.uniform([2, 3, 4]))
- self.compare_backends(matmul_dynamic_matching_batch)
+ self.compare_backends(matmul_dynamic_matching_batch, self._modules)
def test_matmul_dynamic_broadcast_lhs(self):
@@ -68,7 +72,7 @@
module.matmul_dynamic(
tf_utils.uniform([1, 2, 3]), tf_utils.uniform([2, 3, 4]))
- self.compare_backends(matmul_dynamic_broadcast_lhs)
+ self.compare_backends(matmul_dynamic_broadcast_lhs, self._modules)
def test_matmul_dynamic_broadcast_rhs(self):
@@ -76,7 +80,7 @@
module.matmul_dynamic(
tf_utils.uniform([2, 2, 3]), tf_utils.uniform([1, 3, 4]))
- self.compare_backends(matmul_dynamic_broadcast_rhs)
+ self.compare_backends(matmul_dynamic_broadcast_rhs, self._modules)
def test_matmul_dynamic_rank_broadcasting(self):
@@ -84,10 +88,15 @@
module.matmul_dynamic_lhs_batch(
tf_utils.uniform([7, 2, 3]), tf_utils.uniform([3, 4]))
- self.compare_backends(matmul_dynamic_rank_broadcasting)
+ self.compare_backends(matmul_dynamic_rank_broadcasting, self._modules)
-if __name__ == "__main__":
- if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+ del argv # Unused
+ if hasattr(tf, 'enable_v2_behavior'):
tf.enable_v2_behavior()
tf.test.main()
+
+
+if __name__ == '__main__':
+ app.run(main)
diff --git a/integrations/tensorflow/e2e/matrix_ops_static_test.py b/integrations/tensorflow/e2e/matrix_ops_static_test.py
index 82fa482..3a7fe47 100644
--- a/integrations/tensorflow/e2e/matrix_ops_static_test.py
+++ b/integrations/tensorflow/e2e/matrix_ops_static_test.py
@@ -14,6 +14,7 @@
# limitations under the License.
"""Test matrix ops."""
+from absl import app
from pyiree.tf.support import tf_test_utils
from pyiree.tf.support import tf_utils
import tensorflow.compat.v2 as tf
@@ -55,16 +56,19 @@
return tf.matmul(lhs, rhs)
-@tf_test_utils.compile_module(MatrixOpsStaticModule)
class MatrixOpsStaticTest(tf_test_utils.TracedModuleTestCase):
+ def __init__(self, methodName="runTest"):
+ super(MatrixOpsStaticTest, self).__init__(methodName)
+ self._modules = tf_test_utils.compile_tf_module(MatrixOpsStaticModule)
+
def test_basic_matmul(self):
def basic_matmul(module):
module.basic_matmul(tf_utils.uniform([LEFT_DIM, INNER_DIM]),
tf_utils.uniform([INNER_DIM, RIGHT_DIM]))
- self.compare_backends(basic_matmul)
+ self.compare_backends(basic_matmul, self._modules)
def test_matmul_lhs_batch(self):
@@ -73,7 +77,7 @@
tf_utils.uniform([BATCH_DIM, LEFT_DIM, INNER_DIM]),
tf_utils.uniform([INNER_DIM, RIGHT_DIM]))
- self.compare_backends(matmul_lhs_batch)
+ self.compare_backends(matmul_lhs_batch, self._modules)
def test_matmul_rhs_batch(self):
@@ -82,7 +86,7 @@
tf_utils.uniform([LEFT_DIM, INNER_DIM]),
tf_utils.uniform([BATCH_DIM, INNER_DIM, RIGHT_DIM]))
- self.compare_backends(matmul_rhs_batch)
+ self.compare_backends(matmul_rhs_batch, self._modules)
def test_matmul_broadcast_singleton_dimension(self):
@@ -91,10 +95,15 @@
tf_utils.uniform([1, LEFT_DIM, INNER_DIM]),
tf_utils.uniform([BATCH_DIM, INNER_DIM, RIGHT_DIM]))
- self.compare_backends(matmul_broadcast_singleton_dimension)
+ self.compare_backends(matmul_broadcast_singleton_dimension, self._modules)
-if __name__ == "__main__":
- if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+ del argv # Unused
+ if hasattr(tf, 'enable_v2_behavior'):
tf.enable_v2_behavior()
tf.test.main()
+
+
+if __name__ == '__main__':
+ app.run(main)
diff --git a/integrations/tensorflow/e2e/mobile_bert_squad_test.py b/integrations/tensorflow/e2e/mobile_bert_squad_test.py
index e8d3997..4f8daf5 100644
--- a/integrations/tensorflow/e2e/mobile_bert_squad_test.py
+++ b/integrations/tensorflow/e2e/mobile_bert_squad_test.py
@@ -81,10 +81,14 @@
return self.inference_func(**inputs)
-@tf_test_utils.compile_module(MobileBertSquad, exported_names=["predict"])
class MobileBertSquadTest(tf_test_utils.TracedModuleTestCase):
"""Tests of MobileBertSquad."""
+ def __init__(self, methodName="runTest"):
+ super(MobileBertSquadTest, self).__init__(methodName)
+ self._modules = tf_test_utils.compile_tf_module(MobileBertSquad,
+ exported_names=["predict"])
+
def test_predict(self):
def predict(module):
@@ -94,11 +98,15 @@
module.predict(input_ids, input_mask, segment_ids, atol=1e0)
- self.compare_backends(predict)
+ self.compare_backends(predict, self._modules)
-if __name__ == "__main__":
- if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+ del argv # Unused
+ if hasattr(tf, 'enable_v2_behavior'):
tf.enable_v2_behavior()
-
tf.test.main()
+
+
+if __name__ == '__main__':
+ app.run(main)
diff --git a/integrations/tensorflow/e2e/range_test.py b/integrations/tensorflow/e2e/range_test.py
index f1e093c..f4e3e97 100644
--- a/integrations/tensorflow/e2e/range_test.py
+++ b/integrations/tensorflow/e2e/range_test.py
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from absl import app
import numpy as np
from pyiree.tf.support import tf_test_utils
import tensorflow.compat.v2 as tf
@@ -31,9 +32,12 @@
return tf.range(start, stop, delta)
-@tf_test_utils.compile_module(RangeModule)
class RangeTest(tf_test_utils.TracedModuleTestCase):
+ def __init__(self, methodName="runTest"):
+ super(RangeTest, self).__init__(methodName)
+ self._modules = tf_test_utils.compile_tf_module(RangeModule)
+
def test_range(self):
def range(module):
@@ -42,10 +46,15 @@
delta = np.array(3, dtype=np.float32)
result = module.range(start, stop, delta)
- self.compare_backends(range)
+ self.compare_backends(range, self._modules)
-if __name__ == "__main__":
- if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+ del argv # Unused
+ if hasattr(tf, 'enable_v2_behavior'):
tf.enable_v2_behavior()
tf.test.main()
+
+
+if __name__ == '__main__':
+ app.run(main)
diff --git a/integrations/tensorflow/e2e/resource_ops_test.py b/integrations/tensorflow/e2e/resource_ops_test.py
index dd5ad6d..d387e86 100644
--- a/integrations/tensorflow/e2e/resource_ops_test.py
+++ b/integrations/tensorflow/e2e/resource_ops_test.py
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from absl import app
import numpy as np
from pyiree.tf.support import tf_test_utils
import tensorflow.compat.v2 as tf
@@ -28,18 +29,26 @@
return self.counter.assign_add(value)
-@tf_test_utils.compile_module(ResourcesOpsModule)
class ResourcesOpsTest(tf_test_utils.TracedModuleTestCase):
+ def __init__(self, methodName="runTest"):
+ super(ResourcesOpsTest, self).__init__(methodName)
+ self._modules = tf_test_utils.compile_tf_module(ResourcesOpsModule)
+
def test_add_assign(self):
def add_assign(module):
module.add_assign(np.array(9., dtype=np.float32))
- self.compare_backends(add_assign)
+ self.compare_backends(add_assign, self._modules)
-if __name__ == "__main__":
- if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+ del argv # Unused
+ if hasattr(tf, 'enable_v2_behavior'):
tf.enable_v2_behavior()
tf.test.main()
+
+
+if __name__ == '__main__':
+ app.run(main)
diff --git a/integrations/tensorflow/e2e/ring_buffer_test.py b/integrations/tensorflow/e2e/ring_buffer_test.py
index 8e437ea..c56bf31 100644
--- a/integrations/tensorflow/e2e/ring_buffer_test.py
+++ b/integrations/tensorflow/e2e/ring_buffer_test.py
@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from absl import app
import numpy as np
from pyiree.tf.support import tf_test_utils
import tensorflow.compat.v2 as tf
@@ -31,16 +32,20 @@
# buffer has size [buffer_size, dims]
# only the first dimension is used for updating buffer in a ring manner
- self._buffer = tf.Variable(
- tf.zeros((self._buffer_size,) + dims, dtype=dtype),
- trainable=False,
- name="RingBuffer")
+ self._buffer = tf.Variable(tf.zeros((self._buffer_size,) + dims,
+ dtype=dtype),
+ trainable=False,
+ name="RingBuffer")
# Size of the data available for reading
- self._data_size = tf.Variable(
- 0, trainable=False, dtype=tf.int32, name="FramerBuffer/Size")
+ self._data_size = tf.Variable(0,
+ trainable=False,
+ dtype=tf.int32,
+ name="FramerBuffer/Size")
# The index pointing to the head of the data available for reading
- self._read_head = tf.Variable(
- 0, trainable=False, dtype=tf.int32, name="FramerBuffer/Head")
+ self._read_head = tf.Variable(0,
+ trainable=False,
+ dtype=tf.int32,
+ name="FramerBuffer/Head")
@property
def dtype(self):
@@ -82,8 +87,8 @@
start = tf.math.floormod(
self._read_head.read_value() + self._data_size.read_value(),
self._buffer_size)
- indices = tf.math.floormod(
- tf.range(start, limit=start + elements_size), self._buffer_size)
+ indices = tf.math.floormod(tf.range(start, limit=start + elements_size),
+ self._buffer_size)
tf.compat.v1.scatter_update(self._buffer, indices, elements)
@@ -118,8 +123,8 @@
Tensor of elements with shape [length, dims...].
"""
start = self._read_head + offset
- indices = tf.math.floormod(
- tf.range(start, limit=start + length), self._buffer_size)
+ indices = tf.math.floormod(tf.range(start, limit=start + length),
+ self._buffer_size)
result = tf.gather(self._buffer, indices)
if consume:
self.consume(length, offset)
@@ -148,8 +153,9 @@
def build(self, input_shape):
super(StatefulRingBuffer, self).build(input_shape)
buffer_size = self.state_shape[1]
- self.rb = RingBuffer(
- buffer_size=buffer_size, dims=(self.state_shape[2],), dtype=tf.float32)
+ self.rb = RingBuffer(buffer_size=buffer_size,
+ dims=(self.state_shape[2],),
+ dtype=tf.float32)
def call(self, inputs):
self.rb.write(inputs)
@@ -177,10 +183,13 @@
return self.rb(x)
-@tf_test_utils.compile_module(
- StatefulRingBufferModule, exported_names=["predict"])
class StatefulRingBufferTest(tf_test_utils.TracedModuleTestCase):
+ def __init__(self, methodName="runTest"):
+ super(StatefulRingBufferTest, self).__init__(methodName)
+ self._modules = tf_test_utils.compile_tf_module(StatefulRingBufferModule,
+ exported_names=["predict"])
+
def test_stateful_ringbuffer(self):
def stateful_ringbuffer(module):
@@ -198,10 +207,15 @@
module.predict(input3)
# output = np.array([[3.0, 4.0]], dtype=np.float32)
- self.compare_backends(stateful_ringbuffer)
+ self.compare_backends(stateful_ringbuffer, self._modules)
-if __name__ == "__main__":
- if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+ del argv # Unused
+ if hasattr(tf, 'enable_v2_behavior'):
tf.enable_v2_behavior()
tf.test.main()
+
+
+if __name__ == '__main__':
+ app.run(main)
diff --git a/integrations/tensorflow/e2e/scatter_update_test.py b/integrations/tensorflow/e2e/scatter_update_test.py
index 8b43e3a..8f34002 100644
--- a/integrations/tensorflow/e2e/scatter_update_test.py
+++ b/integrations/tensorflow/e2e/scatter_update_test.py
@@ -11,11 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+"""Test scatter update behavior for tensorflow."""
+from absl import app
import numpy as np
from pyiree.tf.support import tf_test_utils
import tensorflow.compat.v2 as tf
-"""Test scatter update behavior for tensorflow."""
class ScatterUpdateModule(tf.Module):
@@ -48,9 +49,12 @@
return tf.tensor_scatter_nd_update(tensor, indices, updates)
-@tf_test_utils.compile_module(ScatterUpdateModule)
class ScatterUpdateTest(tf_test_utils.TracedModuleTestCase):
+ def __init__(self, methodName="runTest"):
+ super(ScatterUpdateTest, self).__init__(methodName)
+ self._modules = tf_test_utils.compile_tf_module(ScatterUpdateModule)
+
def test_scatter_update_1D(self):
def scatter_update_1D(module):
@@ -59,7 +63,7 @@
updates = np.array([9, 10, 11], dtype=np.int32)
module.scatter_update_1D(tensor, indices, updates)
- self.compare_backends(scatter_update_1D)
+ self.compare_backends(scatter_update_1D, self._modules)
def test_scatter_update_2D(self):
@@ -69,7 +73,7 @@
updates = np.array([2, 5, 8], dtype=np.int32)
module.scatter_update_2D(tensor, indices, updates)
- self.compare_backends(scatter_update_2D)
+ self.compare_backends(scatter_update_2D, self._modules)
def test_scatter_update_2D_slice(self):
@@ -79,10 +83,15 @@
updates = np.array([[2, 3, 4]], dtype=np.int32)
module.scatter_update_2D_slice(tensor, indices, updates)
- self.compare_backends(scatter_update_2D_slice)
+ self.compare_backends(scatter_update_2D_slice, self._modules)
-if __name__ == "__main__":
- if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+ del argv # Unused
+ if hasattr(tf, 'enable_v2_behavior'):
tf.enable_v2_behavior()
tf.test.main()
+
+
+if __name__ == '__main__':
+ app.run(main)
diff --git a/integrations/tensorflow/e2e/simple_arithmetic_test.py b/integrations/tensorflow/e2e/simple_arithmetic_test.py
index aaec578..cc9ca09 100644
--- a/integrations/tensorflow/e2e/simple_arithmetic_test.py
+++ b/integrations/tensorflow/e2e/simple_arithmetic_test.py
@@ -14,6 +14,7 @@
# limitations under the License.
"""Several baseline e2e simple arithmetic tests."""
+from absl import app
import numpy as np
from pyiree.tf.support import tf_test_utils
from pyiree.tf.support import tf_utils
@@ -37,9 +38,12 @@
return tf.matmul(a, b)
-@tf_test_utils.compile_module(SimpleArithmeticModule)
class SimpleArithmeticTest(tf_test_utils.TracedModuleTestCase):
+ def __init__(self, methodName="runTest"):
+ super(SimpleArithmeticTest, self).__init__(methodName)
+ self._modules = tf_test_utils.compile_tf_module(SimpleArithmeticModule)
+
def test_simple_mul(self):
def simple_mul(module):
@@ -48,7 +52,7 @@
c = module.simple_mul(a, b)
module.simple_mul(a, c)
- self.compare_backends(simple_mul)
+ self.compare_backends(simple_mul, self._modules)
def test_simple_matmul(self):
@@ -58,10 +62,15 @@
b = tf_utils.uniform((3072, 256)) * 1e-3
module.simple_matmul(a, b)
- self.compare_backends(simple_matmul)
+ self.compare_backends(simple_matmul, self._modules)
-if __name__ == "__main__":
- if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+ del argv # Unused
+ if hasattr(tf, 'enable_v2_behavior'):
tf.enable_v2_behavior()
tf.test.main()
+
+
+if __name__ == '__main__':
+ app.run(main)
diff --git a/integrations/tensorflow/e2e/simple_stateful_test.py b/integrations/tensorflow/e2e/simple_stateful_test.py
index 1120a4f..e4140d4 100644
--- a/integrations/tensorflow/e2e/simple_stateful_test.py
+++ b/integrations/tensorflow/e2e/simple_stateful_test.py
@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from absl import app
import numpy as np
from pyiree.tf.support import tf_test_utils
import tensorflow.compat.v2 as tf
@@ -33,19 +34,27 @@
return self.counter
-@tf_test_utils.compile_module(SimpleStatefulModule)
class StatefulTest(tf_test_utils.TracedModuleTestCase):
+ def __init__(self, methodName="runTest"):
+ super(StatefulTest, self).__init__(methodName)
+ self._modules = tf_test_utils.compile_tf_module(SimpleStatefulModule)
+
def test_stateful(self):
def get_state(module):
module.inc_by(np.array(1., dtype=np.float32))
module.get_state()
- self.compare_backends(get_state)
+ self.compare_backends(get_state, self._modules)
-if __name__ == "__main__":
- if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+ del argv # Unused
+ if hasattr(tf, 'enable_v2_behavior'):
tf.enable_v2_behavior()
tf.test.main()
+
+
+if __name__ == '__main__':
+ app.run(main)
diff --git a/integrations/tensorflow/e2e/sliding_window_test.py b/integrations/tensorflow/e2e/sliding_window_test.py
index 513aa97..413ffb2 100644
--- a/integrations/tensorflow/e2e/sliding_window_test.py
+++ b/integrations/tensorflow/e2e/sliding_window_test.py
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from absl import app
import numpy as np
from pyiree.tf.support import tf_test_utils
import tensorflow.compat.v2 as tf
@@ -75,9 +76,13 @@
return self.sw(x)
-@tf_test_utils.compile_module(SlidingWindowModule, exported_names=["predict"])
class SlidingWindowTest(tf_test_utils.TracedModuleTestCase):
+ def __init__(self, methodName="runTest"):
+ super(SlidingWindowTest, self).__init__(methodName)
+ self._modules = tf_test_utils.compile_tf_module(SlidingWindowModule,
+ exported_names=["predict"])
+
def test_sliding_window(self):
def sliding_window(module):
@@ -89,10 +94,15 @@
result2 = module.predict(input2)
# output2 = np.array([[0.0, 0.0], [1.0, 2.0], [3.0, 4.0]], dtype=np.float32)
- self.compare_backends(sliding_window)
+ self.compare_backends(sliding_window, self._modules)
-if __name__ == "__main__":
- if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+ del argv # Unused
+ if hasattr(tf, 'enable_v2_behavior'):
tf.enable_v2_behavior()
tf.test.main()
+
+
+if __name__ == '__main__':
+ app.run(main)
diff --git a/integrations/tensorflow/e2e/strings_test.py b/integrations/tensorflow/e2e/strings_test.py
index 206b33c..70f6468 100644
--- a/integrations/tensorflow/e2e/strings_test.py
+++ b/integrations/tensorflow/e2e/strings_test.py
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from absl import app
import numpy as np
from pyiree.tf.support import tf_test_utils
import string
@@ -40,9 +41,12 @@
return tf.strings.reduce_join(wps, 1)
-@tf_test_utils.compile_module(StringsModule)
class StringsTest(tf_test_utils.TracedModuleTestCase):
+ def __init__(self, methodName="runTest"):
+ super(StringsTest, self).__init__(methodName)
+ self._modules = tf_test_utils.compile_tf_module(StringsModule)
+
def test_print_ids(self):
def print_ids(module):
@@ -51,7 +55,7 @@
[13, 24, 16, 28, 94, 15, 24, 27, 94, 28, 29, 10, 34]])
module.print_ids(input_ids)
- self.compare_backends(print_ids)
+ self.compare_backends(print_ids, self._modules)
def test_strings_to_ids(self):
@@ -61,10 +65,15 @@
[13, 24, 16, 28, 94, 15, 24, 27, 94, 28, 29, 10, 34]])
module.strings_to_ids(input_ids)
- self.compare_backends(strings_to_ids)
+ self.compare_backends(strings_to_ids, self._modules)
-if __name__ == "__main__":
- if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+ del argv # Unused
+ if hasattr(tf, 'enable_v2_behavior'):
tf.enable_v2_behavior()
tf.test.main()
+
+
+if __name__ == '__main__':
+ app.run(main)
diff --git a/integrations/tensorflow/e2e/tensorlist_test.py b/integrations/tensorflow/e2e/tensorlist_test.py
index 440bd43..62babc5 100644
--- a/integrations/tensorflow/e2e/tensorlist_test.py
+++ b/integrations/tensorflow/e2e/tensorlist_test.py
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from absl import app
import numpy as np
from pyiree.tf.support import tf_test_utils
from pyiree.tf.support import tf_utils
@@ -50,8 +51,9 @@
@tf.function(
input_signature=[tf.TensorSpec([STATIC_SIZE, STATIC_SIZE], tf.float32)])
def slice_first_element_with_from_tensor_high_rank(self, t):
- ta = tf.TensorArray(
- dtype=tf.float32, size=STATIC_SIZE, element_shape=[STATIC_SIZE])
+ ta = tf.TensorArray(dtype=tf.float32,
+ size=STATIC_SIZE,
+ element_shape=[STATIC_SIZE])
ta = ta.unstack(t)
return ta.read(0)
@@ -66,23 +68,26 @@
return ta.stack()
-@tf_test_utils.compile_module(TensorListModule)
class TensorListTest(tf_test_utils.TracedModuleTestCase):
+ def __init__(self, methodName="runTest"):
+ super(TensorListTest, self).__init__(methodName)
+ self._modules = tf_test_utils.compile_tf_module(TensorListModule)
+
def test_identity_through_tensorlist(self):
def identity_through_tensorlist(module):
module.identity_through_tensorlist(np.array(42., dtype=np.float32))
- self.compare_backends(identity_through_tensorlist)
+ self.compare_backends(identity_through_tensorlist, self._modules)
def test_add_through_tensorlist(self):
def add_through_tensorlist(module):
- module.add_through_tensorlist(
- np.array(42., dtype=np.float32), np.array(43., dtype=np.float32))
+ module.add_through_tensorlist(np.array(42., dtype=np.float32),
+ np.array(43., dtype=np.float32))
- self.compare_backends(add_through_tensorlist)
+ self.compare_backends(add_through_tensorlist, self._modules)
def test_slice_first_element_with_from_tensor(self):
@@ -90,7 +95,7 @@
module.slice_first_element_with_from_tensor(
np.arange(STATIC_SIZE, dtype=np.float32))
- self.compare_backends(slice_first_element_with_from_tensor)
+ self.compare_backends(slice_first_element_with_from_tensor, self._modules)
def test_slice_first_element_with_from_tensor_high_rank(self):
@@ -98,18 +103,24 @@
module.slice_first_element_with_from_tensor_high_rank(
tf_utils.ndarange([STATIC_SIZE, STATIC_SIZE]))
- self.compare_backends(slice_first_element_with_from_tensor_high_rank)
+ self.compare_backends(slice_first_element_with_from_tensor_high_rank,
+ self._modules)
def test_concat_with_tensorlist_stack(self):
def concat_with_tensorlist_stack(module):
- module.concat_with_tensorlist_stack(
- np.array(42., dtype=np.float32), np.array(43., dtype=np.float32))
+ module.concat_with_tensorlist_stack(np.array(42., dtype=np.float32),
+ np.array(43., dtype=np.float32))
- self.compare_backends(concat_with_tensorlist_stack)
+ self.compare_backends(concat_with_tensorlist_stack, self._modules)
-if __name__ == "__main__":
- if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+ del argv # Unused
+ if hasattr(tf, 'enable_v2_behavior'):
tf.enable_v2_behavior()
tf.test.main()
+
+
+if __name__ == '__main__':
+ app.run(main)