Implement test utils for doing comparisons and benchmarks between TF and IREE.

* Also re-organizes some of the python namespace for better ergonomics.

PiperOrigin-RevId: 279344889
diff --git a/bindings/python/pyiree/BUILD b/bindings/python/pyiree/BUILD
index a85ac9e..c4880ba 100644
--- a/bindings/python/pyiree/BUILD
+++ b/bindings/python/pyiree/BUILD
@@ -62,9 +62,11 @@
     srcs_version = "PY3",
     deps = [
         ":binding",
+        ":compiler",
         "//bindings/python:pathsetup",  # build_cleaner: keep
     ] + select({
         "//iree:enable_tensorflow": [
+            "//bindings/python/pyiree/tf_interop:test_utils",
             "//bindings/python/pyiree/tf_interop:tf_test_driver",
         ],
         "//conditions:default": [
@@ -72,6 +74,16 @@
     }),
 )
 
+py_library(
+    name = "compiler",
+    srcs = ["compiler.py"],
+    srcs_version = "PY3",
+    deps = [
+        ":binding",
+        "//bindings/python:pathsetup",  # build_cleaner: keep
+    ],
+)
+
 cc_library(
     name = "base",
     srcs = [
diff --git a/bindings/python/pyiree/__init__.py b/bindings/python/pyiree/__init__.py
index 90f7e50..46739ee 100644
--- a/bindings/python/pyiree/__init__.py
+++ b/bindings/python/pyiree/__init__.py
@@ -18,19 +18,18 @@
 # pylint: disable=g-import-not-at-top
 # pylint: disable=g-bad-import-order
 
-# Always make the low-level native bindings accessible.
+# Top-level modules that are imported verbatim.
 from . import binding
+from . import compiler
+from .binding import tracing
 
 # Alias public compiler symbols.
+# TODO(laurenzo): Remove these aliases.
 from .binding.compiler import CompilerContext
 from .binding.compiler import CompilerModule
 
-# Alias tracing symbols.
-from .binding import tracing
-
-# Alias symbols from the native tf_interop module.
-if hasattr(binding, "tf_interop"):
-  from .binding.tf_interop import load_saved_model as tf_load_saved_model
+# Alias specific native functions.
+from .binding.vm import create_module_from_blob
 
 ### Load non-native py_library deps here ###
 ### Order matters because these typically have a back-reference on this
@@ -40,5 +39,6 @@
 # in).
 try:
   from .tf_interop import tf_test_driver
+  from .tf_interop import test_utils as tf_test_utils
 except ImportError:
   pass
diff --git a/bindings/python/pyiree/compiler.py b/bindings/python/pyiree/compiler.py
new file mode 100644
index 0000000..f7ce73f
--- /dev/null
+++ b/bindings/python/pyiree/compiler.py
@@ -0,0 +1,98 @@
+# Lint as: python3
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""High level compiler API.
+
+This imports parts of the native bindings as appropriate.
+"""
+
+from typing import Collection, Optional, Sequence
+
+from . import binding as _binding
+
+# Native aliases.
+Context = _binding.compiler.CompilerContext
+Module = _binding.compiler.CompilerModule
+
+# Conditionally import TensorFlow interop aliases.
+HAS_TENSORFLOW = hasattr(_binding, "tf_interop")
+
+if HAS_TENSORFLOW:
+  # Pass pipeline that should run to lower a TF saved_model to a form suitable
+  # for input to the IREE compiler.
+  TF_IMPORT_PASS_PIPELINE = (
+      "tf-executor-graph-pruning",
+      "tf-standard-pipeline",
+      "canonicalize",
+      "xla-legalize-tf",
+  )
+
+  def tf_load_saved_model(
+      saved_model_dir: str,
+      compiler_context: Optional[Context] = None,
+      exported_names: Collection[str] = (),
+      pass_pipeline: Sequence[str] = TF_IMPORT_PASS_PIPELINE) -> Module:
+    """Loads a TensorFlow saved model from its persistent representation.
+
+    See also tf_compile_saved_model() for a one-shot API to load and compile.
+
+    Args:
+      saved_model_dir: Directory of the saved model.
+      compiler_context: The pyiree.compiler.Context() backing the module.
+      exported_names: Optional tuple of strings representing the exported names
+        to keep.
+      pass_pipeline: Passes to run on the imported module prior to returning.
+        Defaults to TF_IMPORT_PASS_PIPELINE.
+
+    Returns:
+      An MLIR Module suitable for compilation by the IREE compiler.
+      This can be further compiled to an IREE blob by calling
+      .compile_to_sequencer_blob.
+    """
+    if not compiler_context:
+      compiler_context = Context()
+    input_module = _binding.tf_interop.load_saved_model(
+        compiler_context, saved_model_dir, exported_names=exported_names)
+    if pass_pipeline:
+      input_module.run_pass_pipeline(pass_pipeline)
+    return input_module
+
+  def tf_compile_saved_model(
+      saved_model_dir: str,
+      compiler_context: Optional[Context] = None,
+      exported_names: Collection[str] = (),
+      pass_pipeline: Sequence[str] = TF_IMPORT_PASS_PIPELINE,
+      print_mlir: bool = False,
+      target_backends: Collection[str] = ()
+  ) -> _binding.OpaqueBlob:
+    """Loads and compiles a TensorFlow saved model in one shot.
+
+    Args:
+      saved_model_dir: Directory of the saved model.
+      compiler_context: The pyiree.compiler.Context() backing the module.
+      exported_names: Optional tuple of strings representing the exported names
+        to keep.
+      pass_pipeline: Passes to run on the imported module prior to returning.
+        Defaults to TF_IMPORT_PASS_PIPELINE.
+      print_mlir: Whether to print intermediate MLIR after each pass.
+      target_backends: The specific target backends to compile for (defaults to
+        all compiled in targets).
+
+    Returns:
+      An OpaqueBlob representing the compiled module.
+    """
+    input_module = tf_load_saved_model(saved_model_dir, compiler_context,
+                                       exported_names, pass_pipeline)
+    return input_module.compile_to_sequencer_blob(
+        print_mlir=print_mlir, target_backends=target_backends)
diff --git a/bindings/python/pyiree/compiler_test.py b/bindings/python/pyiree/compiler_test.py
index 7c8f22d..db989dc 100644
--- a/bindings/python/pyiree/compiler_test.py
+++ b/bindings/python/pyiree/compiler_test.py
@@ -12,9 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
+# Lint as: python3
 
 from absl.testing import absltest
 import pyiree
@@ -23,12 +21,12 @@
 class CompilerTest(absltest.TestCase):
 
   def testParseError(self):
-    ctx = pyiree.CompilerContext()
+    ctx = pyiree.compiler.Context()
     with self.assertRaisesRegex(ValueError, "custom op 'FOOBAR' is unknown"):
       ctx.parse_asm("""FOOBAR: I SHOULD NOT PARSE""")
 
   def testParseAndCompileToSequencer(self):
-    ctx = pyiree.CompilerContext()
+    ctx = pyiree.compiler.Context()
     input_module = ctx.parse_asm("""
       func @simple_mul(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32>
             attributes { iree.module.export } {
diff --git a/bindings/python/pyiree/tf_interop/BUILD b/bindings/python/pyiree/tf_interop/BUILD
index b2180f2..6094507 100644
--- a/bindings/python/pyiree/tf_interop/BUILD
+++ b/bindings/python/pyiree/tf_interop/BUILD
@@ -122,7 +122,17 @@
 py_library(
     name = "tf_test_driver",
     srcs = ["tf_test_driver.py"],
-    deps = INTREE_TENSORFLOW_PY_DEPS,
+    deps = INTREE_TENSORFLOW_PY_DEPS + [
+        "//bindings/python/pyiree:binding",
+    ],
+)
+
+py_library(
+    name = "test_utils",
+    srcs = ["test_utils.py"],
+    deps = INTREE_TENSORFLOW_PY_DEPS + [
+        "//bindings/python/pyiree:binding",
+    ],
 )
 
 py_test(
diff --git a/bindings/python/pyiree/tf_interop/saved_model_test.py b/bindings/python/pyiree/tf_interop/saved_model_test.py
index 89a9267..e0c4293 100644
--- a/bindings/python/pyiree/tf_interop/saved_model_test.py
+++ b/bindings/python/pyiree/tf_interop/saved_model_test.py
@@ -71,30 +71,7 @@
       tf.saved_model.save(my_module, sm_dir, options=options)
 
       # Load it up.
-      ctx = pyiree.CompilerContext()
-      input_module = pyiree.tf_load_saved_model(ctx, sm_dir)
-      input_asm = input_module.to_asm()
-      print("LOADED ASM:\n", input_asm)
-      # Should have out exported name and have executor islands.
-      self.assertRegex(input_asm,
-                       r"""tf_saved_model.exported_names = \["add"\]""")
-      self.assertRegex(input_asm, r"""tf_executor\.island""")
-
-      # Run the necessary lowering passes. Makes sure that these are linked in.
-      input_module.run_pass_pipeline([
-          "tf-executor-graph-pruning",
-          "tf-standard-pipeline",
-          "canonicalize",
-      ])
-      lowered_asm = input_module.to_asm()
-      print("LOWERED ASM:\n", lowered_asm)
-      # Should have collapsed all executor islands.
-      self.assertNotRegex(lowered_asm, r"""tf_executor\.island""")
-
-      # And legalize to XLA.
-      input_module.run_pass_pipeline([
-          "xla-legalize-tf",
-      ])
+      input_module = pyiree.compiler.tf_load_saved_model(sm_dir)
       xla_asm = input_module.to_asm()
       print("XLA ASM:", xla_asm)
       self.assertRegex(xla_asm, "xla_hlo.tanh")
diff --git a/bindings/python/pyiree/tf_interop/test_utils.py b/bindings/python/pyiree/tf_interop/test_utils.py
new file mode 100644
index 0000000..0d8f81a
--- /dev/null
+++ b/bindings/python/pyiree/tf_interop/test_utils.py
@@ -0,0 +1,215 @@
+# Lint as: python3
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Test utilities interop with TensorFlow."""
+
+import os
+import tempfile
+import timeit
+
+from .. import binding
+from .. import compiler
+import numpy as np
+import tensorflow.compat.v2 as tf
+
+
+def save_and_compile_tf_module(tf_module):
+  with tempfile.TemporaryDirectory() as sm_path:
+    options = tf.saved_model.SaveOptions(save_debug_info=True)
+    tf.saved_model.save(tf_module, sm_path, options=options)
+    return compiler.tf_compile_saved_model(sm_path)
+
+
+def dump_iree_module(m):
+  print("Loaded module:", m.name)
+  i = 0
+  while True:
+    f = m.lookup_function_by_ordinal(i)
+    if not f:
+      break
+    print("  Export:", f.name, "-> args(", f.signature.argument_count,
+          "), results(", f.signature.result_count, ")")
+    i += 1
+
+
+def get_default_test_backends():
+  backends_env = os.environ.get("IREE_TEST_BACKENDS")
+  if backends_env:
+    return backends_env.split(",")
+  else:
+    return ("tf", "iree.interpreter")
+
+
+class _TfBackend(object):
+  """Backend for running directly on the TF module."""
+
+  def __init__(self, test_case, backend_name, fn_name):
+    self.backend_name = backend_name
+    self.module_f = getattr(test_case.tf_module, fn_name)
+
+  def __call__(self, *args):
+    return self.module_f(*args)
+
+  def postprocess(self, results):
+    # Handle single result (technically ambiguous with return of a tuple).
+    if not isinstance(results, tuple):
+      results = (results,)
+    # TODO(laurenzo): Handle structure mapping, etc.
+    return [r.numpy() for r in results]
+
+
+class _IreeBackend(object):
+  """Backend for running on an IREE driver."""
+
+  def __init__(self, test_case, backend_name, fn_name):
+    self.backend_name = backend_name
+    driver_name = backend_name.split(".")[-1]
+    self.policy = binding.rt.Policy()
+    instance = binding.rt.Instance(driver_name=driver_name)
+    self.context = binding.rt.Context(instance=instance, policy=self.policy)
+    self.context.register_module(test_case.iree_vm_module)
+    self.f = self.context.resolve_function("module." + fn_name)
+
+  def __call__(self, *args):
+    args = [self.context.wrap_for_input(arg) for arg in args]
+    # Invoke the function and wait for completion.
+    inv = self.context.invoke(self.f, self.policy, args)
+    inv.await_ready()
+    # Get results as a numpy array.
+    results = [np.array(r.map(), copy=False) for r in inv.results]
+    return results
+
+  def postprocess(self, results):
+    return results
+
+
+_ALL_BACKENDS = {
+    "tf": _TfBackend,
+    "iree.interpreter": _IreeBackend,
+    "iree.vulkan": _IreeBackend,
+}
+
+
+def _wrap_per_backend_fn(saved_model_test_case, fn_name, iterations=100):
+  """Generates a wrapper function for a backend fn name."""
+
+  def invoke_fn(*args):
+    """Lambda that invokes the function on all backends."""
+
+    backend_names = saved_model_test_case.BACKENDS
+    if not backend_names:
+      backend_names = get_default_test_backends()
+
+    backends = [
+        _ALL_BACKENDS[b](saved_model_test_case, b, fn_name)
+        for b in backend_names
+    ]
+    test_id = saved_model_test_case.id().split(".")[-1]
+
+    per_backend_results = []
+    binding.tracing.enable_thread()
+    for backend in backends:
+      # pylint: disable=cell-var-from-loop
+      print(":INVOKE %s:%s on %s" % (test_id, fn_name, backend.backend_name))
+      event = binding.tracing.ScopedEvent(
+          "%s_%s#%s" % (test_id, fn_name, backend.backend_name))
+
+      def run_iteration():
+        with event:
+          return backend(*args)
+
+      # Run one for correctness.
+      results = backend.postprocess(run_iteration())
+      per_backend_results.append((backend.backend_name, results))
+      # Then time it.
+      backend_time_ms = timeit.timeit(run_iteration, number=iterations) * 1000
+      iteration_time_ms = backend_time_ms / iterations
+      print(":BENCHMARK %s:%s on %s: time=%rms" %
+            (test_id, fn_name, backend.backend_name, iteration_time_ms))
+      # pylint: enable=cell-var-from-loop
+
+    # Verify results.
+    ref_backend_name, ref_results = per_backend_results[0]
+    print(":REF RESULTS %s:%s %s:" % (test_id, fn_name, ref_backend_name),
+          ref_results)
+    for backend_name, results in per_backend_results[1:]:
+      print(":COMPARE %s:%s %s vs %s" %
+            (test_id, fn_name, ref_backend_name, backend_name))
+      print("  :", results)
+      for ref_result, result in zip(ref_results, results):
+        saved_model_test_case.assertAllClose(
+            ref_result,
+            result,
+            msg="Result mismatch %s vs %s" % (ref_backend_name, backend_name))
+
+    return ref_results
+
+  return invoke_fn
+
+
+def per_backend_test(*fn_names):
+  """Wraps a SavedModelTestCase test method to run per backend tests.
+
+  Args:
+    *fn_names: Names of functions to run tests against. These will be converted
+      to python functions that invoke all of the backends and passed to the test
+      case method.
+
+  Returns:
+    A decorated function.
+  """
+
+  def decorator(f):
+
+    def replacement(self):
+      fns = [_wrap_per_backend_fn(self, fn_name) for fn_name in fn_names]
+      f(self, *fns)
+
+    replacement.__name__ = f.__name__
+    return replacement
+
+  return decorator
+
+
+class SavedModelTestCase(tf.test.TestCase):
+  """Tests against a SavedModel.
+
+  Use this by subclassing and then defining a TF_MODULE_CONSTRUCTOR member.
+  """
+
+  TF_MODULE_CONSTRUCTOR = None
+  TRACE_FILE_NAME = None
+  BACKENDS = None
+
+  @classmethod
+  def tearDownClass(cls):
+    trace_file_name = cls.TRACE_FILE_NAME
+    if not trace_file_name:
+      trace_file_name = cls.__name__ + ".wtf-trace"
+    trace_file = os.path.join(tempfile.gettempdir(), trace_file_name)
+    print("Flushing trace file to:", trace_file)
+    binding.tracing.flush(trace_file)
+    print("Flush complete")
+    super().tearDownClass()
+
+  @classmethod
+  def setUpClass(cls):
+    super().setUpClass()
+    if cls.TF_MODULE_CONSTRUCTOR is None:
+      raise ValueError("Expected a class level TF_MODULE_CONSTRUCTOR")
+    # Compile the module. We do this once.
+    cls.tf_module = cls.TF_MODULE_CONSTRUCTOR()  # pylint: disable=not-callable
+    cls.iree_blob = save_and_compile_tf_module(cls.tf_module)
+    cls.iree_vm_module = binding.vm.create_module_from_blob(cls.iree_blob)
+    dump_iree_module(cls.iree_vm_module)
diff --git a/integrations/tensorflow/e2e/BUILD b/integrations/tensorflow/e2e/BUILD
index b9765e0..45a278f 100644
--- a/integrations/tensorflow/e2e/BUILD
+++ b/integrations/tensorflow/e2e/BUILD
@@ -16,6 +16,7 @@
     "//iree:build_defs.bzl",
     "INTREE_TENSORFLOW_PY_DEPS",
     "NUMPY_DEPS",
+    "PLATFORM_VULKAN_DEPS",
 )
 
 package(
@@ -27,7 +28,8 @@
     name = "simple_arithmetic_test",
     srcs = ["simple_arithmetic_test.py"],
     python_version = "PY3",
-    deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
+    deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + PLATFORM_VULKAN_DEPS + [
         "//bindings/python/pyiree",
+        "//iree/hal/vulkan:vulkan_driver_module",
     ],
 )
diff --git a/integrations/tensorflow/e2e/README.md b/integrations/tensorflow/e2e/README.md
new file mode 100644
index 0000000..cea7df1
--- /dev/null
+++ b/integrations/tensorflow/e2e/README.md
@@ -0,0 +1,46 @@
+# TensorFlow e2e tests
+
+This is a collection of e2e tests that, in various fashion saves a TensorFlow
+model, compiles it with IREE and runs/evaluates it on all backends.
+
+## Pre-Requisites
+
+You will need a TensorFlow 2.0+ nightly installed in your python environment:
+the python binary in `$PYTHON_BIN` should be able to `import tensorflow` and
+that TensorFlow should be version 2.0+. This can be checked with
+`tensorflow.version`.
+
+See [Install TensorFlow with pip](https://www.tensorflow.org/install/pip) for
+instructions.
+
+## Vulkan setup
+
+By default, tests run on TensorFlow and the IREE CPU interpreter, as it never
+needs additional environment setup. If you have your environment setup to use
+IREE with Vulkan (see [the doc](../../../docs/vulkan_and_spirv.md)), then you
+can enable the backends by setting the environment variable
+`IREE_TEST_BACKENDS=tf,iree.interpreter,iree.vulkan`.
+
+## Running tests
+
+```shell
+# Run all tests with defaults and output on failure.
+bazel test ... --test_output=errors
+
+# Run an individual test interactively.
+bazel test simple_arithmetic_test --test_output=streamed
+
+# Run tests with an altered list of backends.
+bazel test ... --test_env=IREE_TEST_BACKENDS=tf,iree.interpreter,iree.vulkan \
+    --test_output=errors
+```
+
+## Test harnesses
+
+### Simple function tests
+
+See `simple_arithmetic_test.py` for some examples of single function tests.
+These are done by extending a tf_test_utils.SavedModelTestCase and then
+annotating individual test methods with
+`@tf_test_utils.per_backend_test("function_name")` to get a function that will
+run and compare on all backends.
diff --git a/integrations/tensorflow/e2e/simple_arithmetic_test.py b/integrations/tensorflow/e2e/simple_arithmetic_test.py
index ae23ece..8759ccc 100644
--- a/integrations/tensorflow/e2e/simple_arithmetic_test.py
+++ b/integrations/tensorflow/e2e/simple_arithmetic_test.py
@@ -14,16 +14,8 @@
 # limitations under the License.
 """Several baseline e2e simple arithmetic tests."""
 
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import os
-import tempfile
-import timeit
-
 import numpy as np
-import pyiree
+from pyiree import tf_test_utils
 import tensorflow.compat.v2 as tf
 
 
@@ -44,238 +36,26 @@
     return tf.matmul(a, b)
 
 
-def save_and_load_tf_module(tf_module):
-  with tempfile.TemporaryDirectory() as sm_path:
-    options = tf.saved_model.SaveOptions(save_debug_info=True)
-    tf.saved_model.save(tf_module, sm_path, options=options)
-    ctx = pyiree.CompilerContext()
-    input_module = pyiree.tf_load_saved_model(ctx, sm_path)
-  return input_module
+class SimpleArithmeticTest(tf_test_utils.SavedModelTestCase):
 
+  TF_MODULE_CONSTRUCTOR = SimpleArithmeticModule
 
-def dump_iree_module(m):
-  print("Loaded module:", m.name)
-  i = 0
-  while True:
-    f = m.lookup_function_by_ordinal(i)
-    if not f:
-      break
-    print("  Export:", f.name, "-> args(", f.signature.argument_count,
-          "), results(", f.signature.result_count, ")")
-    i += 1
-
-
-class SimpleArithmeticTest(tf.test.TestCase):
-
-  @classmethod
-  def tearDownClass(cls):
-    super().tearDownClass()
-    trace_file = os.path.join(tempfile.gettempdir(),
-                              "simple_arithmetic_test.wtf-trace")
-    print("Flushing trace file to:", trace_file)
-    pyiree.tracing.flush(trace_file)
-    print("Flush complete")
-
-  @classmethod
-  def setUpClass(cls):
-    super().setUpClass()
-    # Compile the module. We do this once.
-    cls.tf_module = SimpleArithmeticModule()
-    cls.mlir_input_module = save_and_load_tf_module(cls.tf_module)
-    print("LOADED ASM:",
-          cls.mlir_input_module.to_asm(debug_info=True, pretty=True))
-
-    # Canonicalize the TF import.
-    cls.mlir_input_module.run_pass_pipeline([
-        "tf-executor-graph-pruning",
-        "tf-standard-pipeline",
-        "canonicalize",
-    ])
-    print("CANONICAL TF ASM:",
-          cls.mlir_input_module.to_asm(debug_info=True, pretty=True))
-
-    # Legalize to XLA (high-level).
-    cls.mlir_input_module.run_pass_pipeline([
-        "xla-legalize-tf",
-    ])
-    print("XLA ASM:",
-          cls.mlir_input_module.to_asm(debug_info=True, pretty=True))
-
-    # Compile the module with IREE.
-    cls.iree_blob = cls.mlir_input_module.compile_to_sequencer_blob(
-        print_mlir=True)
-    cls.iree_vm_module = pyiree.binding.vm.create_module_from_blob(
-        cls.iree_blob)
-    dump_iree_module(cls.iree_vm_module)
-
-  def test_simple_matmul(self):
-    pyiree.tracing.enable_thread()
-    # Initialize the runtime and register the module.
-    # Use the CPU interpreter driver (which has the most implementation done):
-    driver_name = "interpreter"
-
-    # Live on the edge and give the vulkan driver a try:
-    # driver_name = "vulkan"
-
-    policy = pyiree.binding.rt.Policy()
-    instance = pyiree.binding.rt.Instance(driver_name=driver_name)
-    context = pyiree.binding.rt.Context(instance=instance, policy=policy)
-    context.register_module(self.iree_vm_module)
-
-    f = context.resolve_function("module.simple_matmul")
-    tf_f = self.tf_module.simple_matmul
+  @tf_test_utils.per_backend_test("simple_matmul")
+  def test_simple_matmul(self, simple_matmul):
     np.random.seed(12345)
     # Note: scaling by a small value to increase numerical stability.
     a = np.random.random((128, 3072)).astype(np.float32) * 1e-3
     b = np.random.random((3072, 256)).astype(np.float32) * 1e-3
+    simple_matmul(a, b)
 
-    iree_event = pyiree.tracing.ScopedEvent(
-        "SimpleArithmeticTest#simple_matmul")
-
-    def invoke_iree():
-      with iree_event:
-        arg0 = context.wrap_for_input(a)
-        arg1 = context.wrap_for_input(b)
-
-        # Invoke the function and wait for completion.
-        inv = context.invoke(f, policy, [arg0, arg1])
-        inv.await_ready()
-
-        # Get the result as a numpy array and print.
-        results = inv.results
-        result = results[0].map()
-        result_ary = np.array(result, copy=False)
-        return result_ary
-
-    def invoke_tf():
-      arg0 = a
-      arg1 = b
-      result = tf_f(arg0, arg1)
-      return result.numpy()
-
-    # Check that results are equal.
-    self.assertAllClose(invoke_iree(), invoke_tf())
-    # Quick benchmark.
-    iterations = 100
-    print("+++BM simple_matmul:")
-    iree_time = timeit.timeit(invoke_iree, number=iterations)
-    print("IREE -> TIME/ITERATION =", (iree_time / iterations) * 1000, "ms")
-    tf_time = timeit.timeit(invoke_tf, number=iterations)
-    print("TF   -> TIME/ITERATION =", (tf_time / iterations) * 1000, "ms")
-    tf_vs_iree_factor = tf_time / iree_time
-    print("IREE VS TF SPEEDUP FACTOR =", tf_vs_iree_factor)
-
-  def test_simple_scalar_mul(self):
-    pyiree.tracing.enable_thread()
-    # Initialize the runtime and register the module.
-    # Use the CPU interpreter driver (which has the most implementation done):
-    driver_name = "interpreter"
-
-    # Live on the edge and give the vulkan driver a try:
-    # driver_name = "vulkan"
-
-    policy = pyiree.binding.rt.Policy()
-    instance = pyiree.binding.rt.Instance(driver_name=driver_name)
-    context = pyiree.binding.rt.Context(instance=instance, policy=policy)
-    context.register_module(self.iree_vm_module)
-
-    f = context.resolve_function("module.simple_mul")
-    tf_f = self.tf_module.simple_mul
+  @tf_test_utils.per_backend_test("simple_mul")
+  def test_simple_scalar_mul(self, simple_mul):
     a = np.array([1., 2., 3., 4.], dtype=np.float32)
     b = np.array([400., 5., 6., 7.], dtype=np.float32)
-
-    iree_event = pyiree.tracing.ScopedEvent("SimpleArithmeticTest#simple_mul")
-
-    def invoke_iree():
-      with iree_event:
-        arg0 = context.wrap_for_input(a)
-        arg1 = context.wrap_for_input(b)
-
-        # Invoke the function and wait for completion.
-        inv = context.invoke(f, policy, [arg0, arg1])
-        inv.await_ready()
-
-        # Get the result as a numpy array and print.
-        results = inv.results
-        result = results[0].map()
-        result_ary = np.array(result, copy=False)
-        return result_ary
-
-    def invoke_tf():
-      arg0 = a
-      arg1 = b
-      result = tf_f(arg0, arg1)
-      return result.numpy()
-
-    # Check that results are equal.
-    self.assertAllEqual(invoke_iree(), invoke_tf())
-    # Quick benchmark.
-    iterations = 1000
-    print("+++BM simple_mul:")
-    iree_time = timeit.timeit(invoke_iree, number=iterations)
-    print("IREE -> TIME/ITERATION =", (iree_time / iterations) * 1000, "ms")
-    tf_time = timeit.timeit(invoke_tf, number=iterations)
-    print("TF   -> TIME/ITERATION =", (tf_time / iterations) * 1000, "ms")
-    tf_vs_iree_factor = tf_time / iree_time
-    print("IREE VS TF SPEEDUP FACTOR =", tf_vs_iree_factor)
-
-  def test_simple_scalar_mul_streamed(self):
-    pyiree.tracing.enable_thread()
-    # Initialize the runtime and register the module.
-    # Use the CPU interpreter driver (which has the most implementation done):
-    driver_name = "interpreter"
-
-    # Live on the edge and give the vulkan driver a try:
-    # driver_name = "vulkan"
-
-    policy = pyiree.binding.rt.Policy()
-    instance = pyiree.binding.rt.Instance(driver_name=driver_name)
-    context = pyiree.binding.rt.Context(instance=instance, policy=policy)
-    context.register_module(self.iree_vm_module)
-
-    f = context.resolve_function("module.simple_mul")
-    tf_f = self.tf_module.simple_mul
-    a = np.array([1., 2., 3., 4.], dtype=np.float32)
-    b = np.array([400., 5., 6., 7.], dtype=np.float32)
-
-    iree_dispatch_event = pyiree.tracing.ScopedEvent(
-        "SimpleArithmeticTest#simple_mul_dispatch")
-    iree_await_event = pyiree.tracing.ScopedEvent(
-        "SimpleArithmeticTest#simple_mul_await")
-
-    invocations = []
-
-    def invoke_iree():
-      with iree_dispatch_event:
-        arg0 = context.wrap_for_input(a)
-        arg1 = context.wrap_for_input(b)
-
-        # Invoke the function and wait for completion.
-        inv = context.invoke(f, policy, [arg0, arg1])
-        invocations.append(inv)
-
-    def await_all():
-      with iree_await_event:
-        invocations[-1].await_ready()
-
-    def invoke_tf():
-      arg0 = a
-      arg1 = b
-      result = tf_f(arg0, arg1)
-      return result.numpy()
-
-    # Quick benchmark.
-    iterations = 1000
-    print("+++BM simple_mul_streamed:")
-    iree_time = timeit.timeit(invoke_iree, number=iterations)
-    iree_time += timeit.timeit(await_all, number=1)
-    print("IREE -> TIME/ITERATION =", (iree_time / iterations) * 1000, "ms")
-    tf_time = timeit.timeit(invoke_tf, number=iterations)
-    print("TF   -> TIME/ITERATION =", (tf_time / iterations) * 1000, "ms")
-    tf_vs_iree_factor = tf_time / iree_time
-    print("IREE VS TF SPEEDUP FACTOR =", tf_vs_iree_factor)
+    simple_mul(a, b)
 
 
 if __name__ == "__main__":
-  tf.enable_v2_behavior()
+  if hasattr(tf, "enable_v2_behavior"):
+    tf.enable_v2_behavior()
   tf.test.main()