Merge main -> google

* f8b466e2 Bumping tracy version and fixing API usage. (#3328)
* c2bd3b9d Refactor tf bindings compilation API (#3326)
* 9c698256 [vulkan] Create a GUI application to run an IREE module (#3274)
* 61cb3a6e Remove e2e test case decorator (#3304)
* e5953855 Canonicalizes mhlo.dot_general to a rank-3 mhlo.dot_general or mhlo.dot (#3225)
* f7cf2195 Merge google -> main (#3321)
* f25eef3c Update android benchmarking titles to match.. (#3322)

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/iree/pull/3334 from hanhanW:main-to-google 62c8861c803837bfb3b67ac102b6afe4e4c299b0
PiperOrigin-RevId: 335072918

diff --git a/SUBMODULE_VERSIONS b/SUBMODULE_VERSIONS
index 006a608..f18bcd0 100644
--- a/SUBMODULE_VERSIONS
+++ b/SUBMODULE_VERSIONS
@@ -15,6 +15,6 @@
 f8bf11a0253a32375c32cad92c841237b96696c0 third_party/spirv_headers
 57eb48aed36160c4876bc8310d9ca84d42ee9e2a third_party/swiftshader
 1454ee0907ee3e71030d7123595fa5487f23a56d third_party/tensorflow
-864d86e8b6d21449474db5e9313dbff90aa9c24f third_party/tracy
+a9a09ab0940408898fccfdcfe2bb8dc19b50f13c third_party/tracy
 9bd3f561bcee3f01d22912de10bb07ce4e23d378 third_party/vulkan_headers
 909f36b714c9239ee0b112a321220213a474ba53 third_party/vulkan_memory_allocator
diff --git a/docs/developing_iree/e2e_benchmarking.md b/docs/developing_iree/e2e_benchmarking.md
index 2187536..a712d2f 100644
--- a/docs/developing_iree/e2e_benchmarking.md
+++ b/docs/developing_iree/e2e_benchmarking.md
@@ -247,7 +247,7 @@
 changes. The flagfile can still take care of specifying the input data, driver
 and entry function however.
 
-## 5. Benchmark the model on Android with TFLite
+## 5. Benchmarking TFLite on Android
 
 ### 5.1 Prepare the benchmarking tools
 
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/compiler/__init__.py b/integrations/tensorflow/bindings/python/pyiree/tf/compiler/__init__.py
index 982242e..774412d 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/compiler/__init__.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/compiler/__init__.py
@@ -28,14 +28,16 @@
     "OutputFormat",
     # TensorFlow
     "TF_IMPORT_PASS_PIPELINE",
-    "tf_load_saved_model",
-    "tf_load_signature_def_saved_model",
-    "tf_compile_saved_model",
+    "tf_saved_model_to_compiler_module",
+    "tf_signature_def_saved_model_to_compiler_module",
     "tf_module_to_compiler_module",
+    "compile_tf_saved_model",
+    "compile_tf_signature_def_saved_model",
+    "compile_tf_module",
 ]
 
 import tempfile
-from typing import Collection, Optional, Sequence
+from typing import Collection, Optional, Sequence, Set
 
 from . import binding as binding
 import tensorflow as tf
@@ -95,74 +97,14 @@
 )
 
 
-def tf_load_saved_model(saved_model_dir: str,
-                        exported_names: Collection[str] = (),
-                        pass_pipeline: Sequence[str] = TF_IMPORT_PASS_PIPELINE,
-                        compiler_context: Optional[Context] = None) -> Module:
-  """Loads a TensorFlow saved model from its persistent representation.
-
-  See also tf_compile_saved_model() for a one-shot API to load and compile.
-
-  Args:
-    saved_model_dir: Directory of the saved model.
-    exported_names: Optional tuple of strings representing the exported names to
-      keep.
-    pass_pipeline: Passes to run on the imported module prior to returning.
-      Defaults to TF_IMPORT_PASS_PIPELINE.
-    compiler_context: The pyiree.compiler.Context() backing the module.
-
-  Returns:
-    An MLIR Module suitable for compilation by the IREE compiler.
-    This can be further compiled to an IREE blob by calling
-    .compile_to_sequencer_blob.
-  """
-  if not compiler_context:
-    compiler_context = Context()
-  input_module = binding.load_saved_model(
-      compiler_context, saved_model_dir, exported_names=exported_names)
-  if pass_pipeline:
-    input_module.run_pass_pipeline(pass_pipeline)
-  return input_module
-
-
-def tf_load_signature_def_saved_model(
-    saved_model_dir: str,
-    tags: Collection[str] = set(),
-    exported_names: Collection[str] = [],
-    pass_pipeline: Sequence[str] = TF_IMPORT_PASS_PIPELINE,
-    compiler_context: Optional[Context] = None) -> Module:
-  """Loads a TensorFlow SignatureDef saved model from persistent representation.
-
-  Args:
-    saved_model_dir: Directory of the saved model.
-    tags: Optional tuple of tags to use when loading the model.
-    exported_names: Optional tuple of strings representing the exported names to
-      keep.
-    pass_pipeline: Passes to run on the imported module prior to returning.
-      Defaults to TF_IMPORT_PASS_PIPELINE.
-    compiler_context: The pyiree.compiler.Context() backing the module.
-
-  Returns:
-    An MLIR Module suitable for compilation by the IREE compiler.
-    This can be further compiled to an IREE blob by calling
-    .compile_to_sequencer_blob.
-  """
-  if not compiler_context:
-    compiler_context = Context()
-  input_module = binding.load_signature_def_saved_model(
-      compiler_context, saved_model_dir, tags, exported_names=exported_names)
-  if pass_pipeline:
-    input_module.run_pass_pipeline(pass_pipeline)
-  return input_module
-
-
-def tf_compile_saved_model(
+def tf_saved_model_to_compiler_module(
     saved_model_dir: str,
     exported_names: Collection[str] = (),
     pass_pipeline: Sequence[str] = TF_IMPORT_PASS_PIPELINE,
-    target_backends: Collection[str] = (),
-    compiler_context: Optional[Context] = None) -> binding.OpaqueBlob:
-  """Loads and compiles a TensorFlow saved model in one shot.
+    compiler_context: Optional[Context] = None) -> Module:
+  """Converts a TensorFlow SavedModel into a MLIR module.
+
+  See also compile_tf_saved_model() for a one-shot API to load and compile.
 
   Args:
     saved_model_dir: Directory of the saved model.
@@ -170,31 +112,64 @@
       keep.
     pass_pipeline: Passes to run on the imported module prior to returning.
       Defaults to TF_IMPORT_PASS_PIPELINE.
+    compiler_context: The pyiree.compiler.Context() backing the module.
+
+  Returns:
+    An MLIR Module suitable for compilation by the IREE compiler.
+    This can be further compiled to an IREE blob by calling
+    .compile_to_sequencer_blob.
+  """
+  if not compiler_context:
+    compiler_context = Context()
+  compiler_module = binding.load_saved_model(compiler_context,
+                                             saved_model_dir,
+                                             exported_names=exported_names)
+  if pass_pipeline:
+    compiler_module.run_pass_pipeline(pass_pipeline)
+  return compiler_module
+
+
+def compile_tf_saved_model(
+    saved_model_dir: str,
+    exported_names: Collection[str] = (),
+    target_backends: Collection[str] = (),
+    pass_pipeline: Sequence[str] = TF_IMPORT_PASS_PIPELINE,
+    compiler_context: Optional[Context] = None) -> binding.OpaqueBlob:
+  """Compiles a TensorFlow SavedModel to IREE in one shot.
+
+  Args:
+    saved_model_dir: Directory of the saved model.
+    exported_names: Optional tuple of strings representing the exported names to
+      keep.
     target_backends: The specific target backends to compile for (defaults to
       all compiled in targets).
+    pass_pipeline: Passes to run on the imported module prior to returning.
+      Defaults to TF_IMPORT_PASS_PIPELINE.
     compiler_context: The pyiree.compiler.Context() backing the module.
 
   Returns:
     An OpaqueBlob representing the compiled module.
   """
-  input_module = tf_load_saved_model(saved_model_dir, exported_names,
-                                     pass_pipeline, compiler_context)
-  return input_module.compile(target_backends=target_backends)
+  compiler_module = tf_saved_model_to_compiler_module(saved_model_dir,
+                                                      exported_names,
+                                                      pass_pipeline,
+                                                      compiler_context)
+  return compiler_module.compile(target_backends=target_backends)
 
 
-def tf_module_to_compiler_module(
-    module: tf.Module,
-    exported_names: Collection[str] = (),
-    sm_path: str = None,
+def tf_signature_def_saved_model_to_compiler_module(
+    saved_model_dir: str,
+    saved_model_tags: Set[str] = set(),
+    exported_names: Collection[str] = [],
     pass_pipeline: Sequence[str] = TF_IMPORT_PASS_PIPELINE,
     compiler_context: Optional[Context] = None) -> Module:
-  """Converts a tf.Module into a MLIR module.
+  """Converts a TensorFlow SignatureDef SavedModel into a MLIR module.
 
   Args:
-    module: The tf.Module instance to convert to MLIR
+    saved_model_dir: Directory of the saved model.
+    saved_model_tags: Optional set of tags to use when loading the model.
     exported_names: Optional tuple of strings representing the exported names to
       keep.
-    sm_path: the path to save the tf.Module to, if any. Defaults to None.
     pass_pipeline: Passes to run on the imported module prior to returning.
       Defaults to TF_IMPORT_PASS_PIPELINE.
     compiler_context: The pyiree.compiler.Context() backing the module.
@@ -204,16 +179,110 @@
     This can be further compiled to an IREE blob by calling
     .compile_to_sequencer_blob.
   """
-
-  def _convert(sm_path):
-    options = tf.saved_model.SaveOptions(save_debug_info=True)
-    tf.saved_model.save(module, sm_path, options=options)
-    return tf_load_saved_model(sm_path, exported_names, pass_pipeline,
-                               compiler_context)
-
-  if sm_path is None:
-    with tempfile.TemporaryDirectory() as sm_path:
-      compiler_module = _convert(sm_path)
-  else:
-    compiler_module = _convert(sm_path)
+  if not compiler_context:
+    compiler_context = Context()
+  compiler_module = binding.load_signature_def_saved_model(
+      compiler_context,
+      saved_model_dir,
+      saved_model_tags,
+      exported_names=exported_names)
+  if pass_pipeline:
+    compiler_module.run_pass_pipeline(pass_pipeline)
   return compiler_module
+
+
+def compile_tf_signature_def_saved_model(
+    saved_model_dir: str,
+    saved_model_tags: Set[str] = set(),
+    exported_names: Collection[str] = (),
+    target_backends: Collection[str] = (),
+    pass_pipeline: Sequence[str] = TF_IMPORT_PASS_PIPELINE,
+    compiler_context: Optional[Context] = None) -> binding.OpaqueBlob:
+  """Compiles a TensorFlow SignatureDef SavedModel to IREE in one shot.
+
+  Args:
+    saved_model_dir: Directory of the saved model.
+    saved_model_tags: Optional set of tags to use when loading the model.
+    exported_names: Optional tuple of strings representing the exported names to
+      keep.
+    target_backends: The specific target backends to compile for (defaults to
+      all compiled in targets).
+    pass_pipeline: Passes to run on the imported module prior to returning.
+      Defaults to TF_IMPORT_PASS_PIPELINE.
+    compiler_context: The pyiree.compiler.Context() backing the module.
+
+  Returns:
+    An OpaqueBlob representing the compiled module.
+  """
+  compiler_module = tf_signature_def_saved_model_to_compiler_module(
+      saved_model_dir, saved_model_tags, exported_names, pass_pipeline,
+      compiler_context)
+  return compiler_module.compile(target_backends=target_backends)
+
+
+def tf_module_to_compiler_module(
+    module: tf.Module,
+    exported_names: Collection[str] = (),
+    pass_pipeline: Sequence[str] = TF_IMPORT_PASS_PIPELINE,
+    compiler_context: Optional[Context] = None,
+    saved_model_dir: str = None) -> Module:
+  """Converts a tf.Module instance into a MLIR module.
+
+  Args:
+    module: The tf.Module instance to convert to MLIR
+    exported_names: Optional tuple of strings representing the exported names to
+      keep.
+    pass_pipeline: Passes to run on the imported module prior to returning.
+      Defaults to TF_IMPORT_PASS_PIPELINE.
+    compiler_context: The pyiree.compiler.Context() backing the module.
+    saved_model_dir: Optional path to save the tf.Module to. The module will not
+      be saved on disk if this is not provided.
+
+  Returns:
+    An MLIR Module suitable for compilation by the IREE compiler.
+    This can be further compiled to an IREE blob by calling
+    .compile_to_sequencer_blob.
+  """
+
+  def _convert(saved_model_dir):
+    options = tf.saved_model.SaveOptions(save_debug_info=True)
+    tf.saved_model.save(module, saved_model_dir, options=options)
+    return tf_saved_model_to_compiler_module(saved_model_dir, exported_names,
+                                             pass_pipeline, compiler_context)
+
+  if saved_model_dir is None:
+    with tempfile.TemporaryDirectory() as saved_model_dir:
+      compiler_module = _convert(saved_model_dir)
+  else:
+    compiler_module = _convert(saved_model_dir)
+  return compiler_module
+
+
+def compile_tf_module(module: tf.Module,
+                      exported_names: Collection[str] = (),
+                      target_backends: Collection[str] = (),
+                      pass_pipeline: Sequence[str] = TF_IMPORT_PASS_PIPELINE,
+                      compiler_context: Optional[Context] = None,
+                      saved_model_dir: str = None):
+  """Compiles a tf.Module to IREE in one shot.
+
+  Args:
+    module: The tf.Module instance to convert to MLIR
+    exported_names: Optional tuple of strings representing the exported names to
+      keep.
+    target_backends: The specific target backends to compile for (defaults to
+      all compiled in targets).
+    pass_pipeline: Passes to run on the imported module prior to returning.
+      Defaults to TF_IMPORT_PASS_PIPELINE.
+    compiler_context: The pyiree.compiler.Context() backing the module.
+    saved_model_dir: Optional path to save the tf.Module to. The module will not
+      be saved on disk if this is not provided.
+
+  Returns:
+    An OpaqueBlob representing the compiled module.
+  """
+  compiler_module = tf_module_to_compiler_module(module, exported_names,
+                                                 pass_pipeline,
+                                                 compiler_context,
+                                                 saved_model_dir)
+  return compiler_module.compile(target_backends=target_backends)
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/compiler/saved_model_test.py b/integrations/tensorflow/bindings/python/pyiree/tf/compiler/saved_model_test.py
index a7aba66..10e55c3 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/compiler/saved_model_test.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/compiler/saved_model_test.py
@@ -62,7 +62,7 @@
       tf.saved_model.save(my_module, sm_dir, options=options)
 
       # Load it up.
-      input_module = compiler.tf_load_saved_model(sm_dir)
+      input_module = compiler.tf_saved_model_to_compiler_module(sm_dir)
       xla_asm = input_module.to_asm()
       print("XLA ASM:", xla_asm)
       self.assertRegex(xla_asm, "mhlo.tanh")
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/compiler/signature_def_saved_model_test.py b/integrations/tensorflow/bindings/python/pyiree/tf/compiler/signature_def_saved_model_test.py
index 8a2e1cb..9764b13 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/compiler/signature_def_saved_model_test.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/compiler/signature_def_saved_model_test.py
@@ -55,8 +55,8 @@
               sess, ["bar"], {"baz": sig}, strip_default_attrs=True)
           builder.save()
 
-      module = compiler.tf_load_signature_def_saved_model(
-          sm_dir, tags=set(["bar"]), exported_names=["baz"])
+      module = compiler.tf_signature_def_saved_model_to_compiler_module(
+          sm_dir, saved_model_tags=set(["bar"]), exported_names=["baz"])
 
       module_asm = module.to_asm(large_element_limit=100)
       self.assertRegexpMatches(module_asm, "flow.variable @[^ ]* dense<10>")
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
index 87a4c8a..fa3df03 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
@@ -22,6 +22,7 @@
 #   ref: reference – for the reference CompiledModule
 #   tar: target - for one of the target CompiledModules
 
+import collections
 import copy
 import glob
 import inspect
@@ -465,7 +466,7 @@
             f"--input_file={compiled_path}",
             f"--driver={self.backend_driver}",
             f"--inputs={serialized_inputs}",
-            f"--entry_function={entry_function}"
+            f"--entry_function={entry_function}",
         ]
         with open(os.path.join(trace_dir, "flagfile"), "w") as f:
           f.writelines(line + "\n" for line in flagfile)
@@ -551,7 +552,11 @@
       return self._trace_call(module_attr, method_name=attr)
 
 
-def compile_module(
+Modules = collections.namedtuple('Modules',
+                                 ['ref_module', 'tar_modules', 'artifacts_dir'])
+
+
+def compile_tf_module(
     module_class: Type[tf.Module], exported_names: Sequence[str] = ()
 ) -> Callable[[Any], Any]:
   """CompiledModuleTestCase decorator that compiles a tf.Module.
@@ -565,80 +570,38 @@
     exported_names: optional iterable of strings representing which of
       module_class's functions to compile. If exported_names is empty all
       functions will be compiled.
-
-  Returns:
-    Class decorator function.
   """
 
-  def decorator(cls):
-    """Decorator Function."""
-    if not issubclass(cls, TracedModuleTestCase):
-      logging.exception(
-          "The 'compile_module' decorator must be applied to a "
-          "TracedModuleTestCase derived class, which %s is not.", cls)
-    cls._module_class = module_class
-    cls._exported_names = exported_names
-    return cls
+  # Setup the directory for saving compilation artifacts and traces.
+  artifacts_dir = _setup_artifacts_dir(module_class.__name__)
 
-  return decorator
+  # Get the backend information for this test.
+  ref_backend_info = tf_utils.BackendInfo(FLAGS.reference_backend,
+                                          f"{FLAGS.reference_backend}_ref")
+  tar_backend_infos = get_target_backends()
 
+  compile_backend = lambda backend_info: backend_info.compile(
+      module_class, exported_names, artifacts_dir)
 
-# Will be initialized by TracedModuleTestCase.setUpClass
-# Global variables are used because storing the compiler context on the cls
-# causes cleaning up refcounts to fail, and tf.test.TestCase wipes the variables
-# on the class instance (self.*) before each unittest.
-# TODO(#2900): Move these back to class variables when we figure out issues with
-# refcounting.
-_global_ref_module = None
-_global_tar_modules = None
+  ref_module = compile_backend(ref_backend_info)
+  tar_modules = [
+      compile_backend(backend_info) for backend_info in tar_backend_infos
+  ]
+  return Modules(ref_module, tar_modules, artifacts_dir)
 
 
 class TracedModuleTestCase(tf.test.TestCase):
   """Compiles a tf.Module to multiple backends to test their correctness."""
-  # Will be initialized by the @compile_module decorator.
-  _module_class = None
-  _exported_names = ()
-
-  @classmethod
-  def _compile(cls, backend_info: tf_utils.BackendInfo):
-    return backend_info.compile(cls._module_class, cls._exported_names,
-                                cls._artifacts_dir)
-
-  @classmethod
-  def setUpClass(cls) -> None:
-    # Ran before any of the unit tests.
-    super().setUpClass()
-    if cls._module_class is None:
-      raise AttributeError(
-          "setUpClass was called but no module was specified. Specify a module "
-          "to compile via the @tf_test_utils.compile_module decorator.")
-
-    # Setup the directory for saving compilation artifacts and traces.
-    cls._artifacts_dir = _setup_artifacts_dir(cls._module_class.__name__)
-
-    # Get the backend information for this test.
-    ref_backend_info = tf_utils.BackendInfo(FLAGS.reference_backend,
-                                            f"{FLAGS.reference_backend}_ref")
-    tar_backend_infos = get_target_backends()
-
-    global _global_ref_module
-    global _global_tar_modules
-    _global_ref_module = cls._compile(ref_backend_info)
-    _global_tar_modules = [
-        cls._compile(backend_info) for backend_info in tar_backend_infos
-    ]
 
   def setUp(self) -> None:
     # Runs before each unit test.
     super().setUp()
-    global _global_ref_module
-    global _global_tar_modules
-    _global_ref_module.reinitialize()
-    for module in _global_tar_modules:
+    self._modules.ref_module.reinitialize()
+    for module in self._modules.tar_modules:
       module.reinitialize()
 
-  def compare_backends(self, trace_function: Callable[[TracedModule],
-                                                      None]) -> None:
+  def compare_backends(self, trace_function: Callable[[TracedModule], None],
+                       modules: Modules) -> None:
     """Run the reference and target backends on trace_function and compare them.
 
     Random seeds for tensorflow, numpy and python are set before each invocation
@@ -648,17 +611,17 @@
       trace_function: a function accepting a TracedModule as its argument.
     """
     # Create Traces for each backend.
-    ref_trace = Trace(_global_ref_module, trace_function)
+    ref_trace = Trace(modules.ref_module, trace_function)
     tar_traces = [
-        Trace(module, trace_function) for module in _global_tar_modules
+        Trace(module, trace_function) for module in modules.tar_modules
     ]
 
     # Run the traces through trace_function with their associated modules.
     tf_utils.set_random_seed()
-    trace_function(TracedModule(_global_ref_module, ref_trace))
+    trace_function(TracedModule(modules.ref_module, ref_trace))
     if FLAGS.log_all_traces:
       logging.info(ref_trace)
-    for module, trace in zip(_global_tar_modules, tar_traces):
+    for module, trace in zip(modules.tar_modules, tar_traces):
       tf_utils.set_random_seed()
       trace_function(TracedModule(module, trace))
       if FLAGS.log_all_traces:
@@ -674,11 +637,11 @@
         failed_backend_indices.append(i)
 
     # Save the results to disk before validating.
-    ref_trace_dir = _get_trace_dir(self._artifacts_dir, ref_trace)
+    ref_trace_dir = _get_trace_dir(modules.artifacts_dir, ref_trace)
     ref_trace.save_plaintext(ref_trace_dir, FLAGS.summarize)
     ref_trace.serialize(ref_trace_dir)
     for tar_trace in tar_traces:
-      tar_trace_dir = _get_trace_dir(self._artifacts_dir, tar_trace)
+      tar_trace_dir = _get_trace_dir(modules.artifacts_dir, tar_trace)
       tar_trace.save_plaintext(tar_trace_dir, FLAGS.summarize)
       tar_trace.serialize(tar_trace_dir)
 
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py
index 01c6e0f..fc8bab6 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py
@@ -88,40 +88,47 @@
   return result
 
 
-def backends_to_str(backend_infos: Sequence["BackendInfo"]) -> str:
-  """Creates a normalized string representing the provided backends."""
-  normalized_names = []
-  for backend_info in backend_infos:
-    # Remove unusual characters and ensure names don't end or start in "_".
-    name = re.sub("[^0-9a-zA-Z_]+", "_", backend_info.name)
-    normalized_names.append(name.strip("_"))
-  return "__".join(normalized_names)
+def _setup_mlir_crash_reproducer(
+    function: Callable[[Any], Any],
+    artifacts_dir: str,
+    backend_name: str,
+) -> Callable[[Any], Any]:
+  """Wraps `function` so that it a MLIR crash reproducer is saved if it crashes.
+
+  Writes to `artifacts_dir/reproducer__{backend}.mlir` in the case of a crash.
+
+  Args:
+    function: The callable to decorate.
+    artifacts_dir: The directory to write the reproducer to.
+    backend_name: The name of the backend `function` compiles to.
+
+  Returns:
+    A function with the same API as the passed function.
+  """
+
+  def decorator(*args, **kwargs):
+    # Set up a crash reproducer for debugging.
+    if artifacts_dir is not None:
+      compiler.Context.default_crash_reproducer_path = os.path.join(
+          artifacts_dir, f"reproducer__{backend_name}.mlir")
+    try:
+      results = function(*args, **kwargs)
+    except Exception:  # pylint: disable=broad-except
+      # Disable the crash reproducer (to avoid inadvertently overwriting it).
+      if artifacts_dir is not None:
+        compiler.Context.default_crash_reproducer_path = None
+      raise
+    return results
+
+  return decorator
 
 
-def _get_backends_path(artifact_name: str,
-                       backend_infos: Sequence["BackendInfo"],
-                       artifacts_dir: str) -> str:
-  """Gets the path to save artifact_name under for the specified backend(s)."""
-  backends_string = backends_to_str(backend_infos)
-  # Put the artifact in a directory if there's only one backend.
-  if len(backend_infos) == 1:
-    backend_dir = os.path.join(artifacts_dir, backends_string)
-    os.makedirs(backend_dir, exist_ok=True)
-    return os.path.join(artifacts_dir, backends_string, artifact_name)
-  else:
-    return os.path.join(artifacts_dir, f"{artifact_name}__{backends_string}")
-
-
-def compile_tf_module(
-    tf_module: Type[tf.Module],
-    backend_infos: Sequence["BackendInfo"] = (),
-    exported_names: Sequence[str] = (),
-    artifacts_dir: str = None
+def _incrementally_lower_compiler_module(
+    compiler_module: compiler.Module,
+    backend_info: "BackendInfo",
+    artifacts_dir: str,
 ) -> Tuple[compiler.binding.OpaqueBlob, Union[str, None]]:
-  """Compiles a TensorFlow tf.Module and optionally saves compilation artifacts.
-
-  The artifact this creates is not callable. See IreeCompiledModule for an API
-  that returns a module that can be called without any further steps.
+  """Lowers a MLIR compiler module incrementally and saves its outputs.
 
   If artifacts_dir is provided then the following artifacts will be saved:
     tf_input.mlir:
@@ -129,75 +136,84 @@
     iree_input.mlir:
       The MLIR above translated to IREE via compiler.TF_IMPORT_PASS_PIPELINE.
     backend_name/compiled.vmfb:
-      A VM FlatBuffer compiled to the target backends from the IREE MLIR above.
-
-  If multiple backends are specified, then instead of saving compiled 'vmfb'
-  under 'backend_name/', it will be saved as follows:
-    - 'compiled__{backends}.vmfb'
-  where 'backends' is a '__' delimited list (e.g. iree_vmla__iree_llvmjit).
+      A VM FlatBuffer compiled to the target backend from the IREE MLIR above.
 
   Args:
-    tf_module: A tf.Module.
-    backend_infos: Iterable of BackendInfo names to compile for.
+    compiler_module: A compiler.Module to lower.
+    backend_info: BackendInfo with the details for lowering compiler_module to
+      IREE.
+    artifacts_dir: An optional string pointing to where compilation artifacts
+      should be saved. No compilation artifacts will be saved if this is not
+      provided.
+  """
+  if artifacts_dir is not None:
+    os.makedirs(artifacts_dir, exist_ok=True)
+    tf_mlir_path = os.path.join(artifacts_dir, "tf_input.mlir")
+    logging.info("Saving raw TF input MLIR to: %s", tf_mlir_path)
+    with open(tf_mlir_path, "w") as f:
+      f.write(compiler_module.to_asm())
+
+  # Manually run the passes that tf_module_to_compiler_module usually would.
+  compiler_module.run_pass_pipeline(compiler.TF_IMPORT_PASS_PIPELINE)
+
+  if artifacts_dir is not None:
+    iree_mlir_path = os.path.join(artifacts_dir, "iree_input.mlir")
+    logging.info("Saving IREE input MLIR to: %s", iree_mlir_path)
+    with open(iree_mlir_path, "w") as f:
+      f.write(compiler_module.to_asm())
+
+  compiled_module = compiler_module.compile(
+      target_backends=backend_info.compiler_targets)
+
+  compiled_path = None
+  if artifacts_dir is not None:
+    backend_dir = os.path.join(artifacts_dir, backend_info.name)
+    os.makedirs(backend_dir, exist_ok=True)
+    compiled_path = os.path.join(backend_dir, "compiled.vmfb")
+    logging.info("Saving compiled IREE module to: %s", compiled_path)
+    with open(compiled_path, "wb") as f:
+      f.write(compiled_module)
+  return compiled_module, compiled_path
+
+
+def _incrementally_compile_tf_module(
+    module: Type[tf.Module],
+    backend_info: "BackendInfo",
+    exported_names: Sequence[str] = (),
+    artifacts_dir: str = None
+) -> Tuple[compiler.binding.OpaqueBlob, Union[str, None]]:
+  """Compiles a TensorFlow tf.Module and optionally saves compilation artifacts.
+
+  The module blob this creates is not callable. See IreeCompiledModule for an
+  API that returns a module that can be called without any further steps.
+
+  See _incrementally_lower_compiler_module's docstring for details about which
+  artifacts will be saved.
+
+  Args:
+    module: A tf.Module.
+    backend_info: BackendInfo with the details for compiling module to IREE.
     exported_names: Iterable of dotted function names to consider for
       compilation.
     artifacts_dir: An optional string pointing to where compilation artifacts
-      should be saved.
+      should be saved. No compilation artifacts will be saved if this is not
+      provided.
 
   Returns:
     A compiled IREE module blob and the path to the compiled VM FlatBuffer if
     artifacts_dir is provided.
   """
 
-  if artifacts_dir is not None:
-    # Set up a crash reproducer for debugging.
-    backends_string = backends_to_str(backend_infos)
-    compiler.Context.default_crash_reproducer_path = os.path.join(
-        artifacts_dir, f"reproducer__{backends_string}.mlir")
-
-  try:
-    # Convert the tf_module into raw TF input MLIR.
-    compiler_module = compiler.tf_module_to_compiler_module(tf_module,
+  def _compile_module(module, exported_names, backend_info, artifacts_dir):
+    compiler_module = compiler.tf_module_to_compiler_module(module,
                                                             exported_names,
                                                             pass_pipeline=())
+    return _incrementally_lower_compiler_module(compiler_module, backend_info,
+                                                artifacts_dir)
 
-    if artifacts_dir is not None:
-      tf_mlir_path = os.path.join(artifacts_dir, "tf_input.mlir")
-      logging.info("Saving raw TF input MLIR to: %s", tf_mlir_path)
-      with open(tf_mlir_path, "w") as f:
-        f.write(compiler_module.to_asm())
-
-    # Now run the passes manually that tf_module_to_compiler_module would
-    # usually do.
-    compiler_module.run_pass_pipeline(compiler.TF_IMPORT_PASS_PIPELINE)
-
-    if artifacts_dir is not None:
-      iree_mlir_path = os.path.join(artifacts_dir, "iree_input.mlir")
-      logging.info("Saving IREE input MLIR to: %s", iree_mlir_path)
-      with open(iree_mlir_path, "w") as f:
-        f.write(compiler_module.to_asm())
-
-    target_backends = []
-    for backend_info in backend_infos:
-      target_backends.extend(backend_info.compiler_targets)
-    compiled_module = compiler_module.compile(target_backends=target_backends)
-
-    compiled_path = None
-    if artifacts_dir is not None:
-      compiled_path = _get_backends_path("compiled", backend_infos,
-                                         artifacts_dir)
-      compiled_path = f"{compiled_path}.vmfb"
-      logging.info("Saving compiled IREE module to: %s", compiled_path)
-      with open(compiled_path, "wb") as f:
-        f.write(compiled_module)
-
-  except Exception:  # pylint: disable=broad-except
-    if artifacts_dir is not None:
-      # Disable the crash reproducer (to avoid inadvertently overwriting it).
-      compiler.Context.default_crash_reproducer_path = None
-    raise
-
-  return compiled_module, compiled_path
+  _compile_module = _setup_mlir_crash_reproducer(_compile_module, artifacts_dir,
+                                                 backend_info.name)
+  return _compile_module(module, exported_names, backend_info, artifacts_dir)
 
 
 class CompiledModule(object):
@@ -280,9 +296,9 @@
     super().__init__(module_class, backend_info, exported_names, artifacts_dir)
 
     set_random_seed()
-    self._module_blob, compiled_path = compile_tf_module(
-        tf_module=module_class(),
-        backend_infos=[backend_info],
+    self._module_blob, compiled_path = _incrementally_compile_tf_module(
+        module=module_class(),
+        backend_info=backend_info,
         exported_names=exported_names,
         artifacts_dir=artifacts_dir)
     self._module = rt.VmModule.from_flatbuffer(self._module_blob)
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils_test.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils_test.py
index 43ec6cf..1a441d0 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils_test.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils_test.py
@@ -44,6 +44,7 @@
   def get_count(self):
     return self.count
 
+
 class RandomInitModule(tf.Module):
 
   def __init__(self):
@@ -56,25 +57,15 @@
 
 class UtilsTests(tf.test.TestCase, parameterized.TestCase):
 
-  @parameterized.named_parameters([
-      {
-          'testcase_name': 'single_backend',
-          'backend_infos': [tf_utils.BackendInfo('iree_vmla')],
-      },
-      {
-          'testcase_name':
-              'multiple_backends',
-          'backend_infos': [
-              tf_utils.BackendInfo('iree_vmla'),
-              tf_utils.BackendInfo('iree_llvmjit')
-          ],
-      },
-  ])
-  def test_artifact_saving(self, backend_infos):
+  def test_artifact_saving(self):
+    backend_info = tf_utils.BackendInfo('iree_vmla')
     with tempfile.TemporaryDirectory() as artifacts_dir:
       tf_module = ConstantModule()
-      iree_compiled_module, compiled_path = tf_utils.compile_tf_module(
-          tf_module, backend_infos=backend_infos, artifacts_dir=artifacts_dir)
+      iree_compiled_module, compiled_path = (
+          tf_utils._incrementally_compile_tf_module(
+              tf_module,
+              backend_info=backend_info,
+              artifacts_dir=artifacts_dir))
 
       artifacts_to_check = [
           'tf_input.mlir',
@@ -121,7 +112,6 @@
     inputs = [np.array([1, 2], dtype=np.float32)]
     self.assertEqual('2xf32=1.0 2.0', tf_utils.save_input_values(inputs))
 
-
   @parameterized.named_parameters([
       {
           'testcase_name': 'tensorflow',
diff --git a/integrations/tensorflow/e2e/README.md b/integrations/tensorflow/e2e/README.md
index ff95379..7adb281 100644
--- a/integrations/tensorflow/e2e/README.md
+++ b/integrations/tensorflow/e2e/README.md
@@ -80,11 +80,15 @@
 backend as a source of truth. For example:
 
 ```python
-# Compile a `tf.Module` named `SimpleArithmeticModule` into a `CompiledModule`.
-@tf_test_utils.compile_module(SimpleArithmeticModule)
 # Inherit from `TracedModuleTestCase`.
 class SimpleArithmeticTest(tf_test_utils.TracedModuleTestCase):
 
+  def __init__(self, methodName="runTest"):
+    super(SimpleArithmeticTest, self).__init__(methodName)
+    # Compile a `tf.Module` named `SimpleArithmeticModule` into
+    # `CompiledModule`s for each reference and target backend.
+    self._modules = tf_test_utils.compile_tf_module(SimpleArithmeticModule)
+
   # Unit test.
   def test_simple_mul(self):
 
@@ -103,7 +107,7 @@
 
     # Calls `simple_mul` once for each backend, recording the inputs and outputs
     # to `module` and then comparing them.
-    self.compare_backends(simple_mul)
+    self.compare_backends(simple_mul, self._modules)
 ```
 
 ## Test Suites
diff --git a/integrations/tensorflow/e2e/batch_norm_test.py b/integrations/tensorflow/e2e/batch_norm_test.py
index e16436d..4e3c0cd 100644
--- a/integrations/tensorflow/e2e/batch_norm_test.py
+++ b/integrations/tensorflow/e2e/batch_norm_test.py
@@ -14,6 +14,7 @@
 # limitations under the License.
 """Batch norm tests."""
 
+from absl import app
 import numpy as np
 from pyiree.tf.support import tf_test_utils
 from pyiree.tf.support import tf_utils
@@ -39,9 +40,12 @@
         variance_epsilon=1e-4)
 
 
-@tf_test_utils.compile_module(BatchNormModule)
 class BatchNormTest(tf_test_utils.TracedModuleTestCase):
 
+  def __init__(self, methodName="runTest"):
+    super(BatchNormTest, self).__init__(methodName)
+    self._modules = tf_test_utils.compile_tf_module(BatchNormModule)
+
   def test_batch_norm_inference(self):
 
     def batch_norm_inference(module):
@@ -53,10 +57,15 @@
       scale = tf_utils.uniform((16,)) * 1e-3
       module.batch_norm_inference(x, mean, variance, offset, scale)
 
-    self.compare_backends(batch_norm_inference)
+    self.compare_backends(batch_norm_inference, self._modules)
 
 
-if __name__ == "__main__":
-  if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+  del argv  # Unused
+  if hasattr(tf, 'enable_v2_behavior'):
     tf.enable_v2_behavior()
   tf.test.main()
+
+
+if __name__ == '__main__':
+  app.run(main)
diff --git a/integrations/tensorflow/e2e/bool_test.py b/integrations/tensorflow/e2e/bool_test.py
index 161f6ad..ef9dd10 100644
--- a/integrations/tensorflow/e2e/bool_test.py
+++ b/integrations/tensorflow/e2e/bool_test.py
@@ -14,6 +14,7 @@
 # limitations under the License.
 """Tests for ops in the tf.math module."""
 
+from absl import app
 import numpy as np
 from pyiree.tf.support import tf_test_utils
 import tensorflow.compat.v2 as tf
@@ -37,22 +38,25 @@
     return tf.math.logical_and(x, y)
 
 
-@tf_test_utils.compile_module(MathModule)
 class BooleanTest(tf_test_utils.TracedModuleTestCase):
 
+  def __init__(self, methodName="runTest"):
+    super(BooleanTest, self).__init__(methodName)
+    self._modules = tf_test_utils.compile_tf_module(MathModule)
+
   def test_constant(self):
 
     def constant(module):
       module.constant()
 
-    self.compare_backends(constant)
+    self.compare_backends(constant, self._modules)
 
   def test_greater_than(self):
 
     def greater_than(module):
       module.greater_than(np.array([0.0, 1.2, 1.5, 3.75], dtype=np.float32))
 
-    self.compare_backends(greater_than)
+    self.compare_backends(greater_than, self._modules)
 
   def test_logical_and(self):
 
@@ -61,10 +65,15 @@
           np.array([True, True, False, False], dtype=np.bool),
           np.array([True, False, False, True], dtype=np.bool))
 
-    self.compare_backends(logical_and)
+    self.compare_backends(logical_and, self._modules)
 
 
-if __name__ == "__main__":
-  if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+  del argv  # Unused
+  if hasattr(tf, 'enable_v2_behavior'):
     tf.enable_v2_behavior()
   tf.test.main()
+
+
+if __name__ == '__main__':
+  app.run(main)
diff --git a/integrations/tensorflow/e2e/broadcast_to_test.py b/integrations/tensorflow/e2e/broadcast_to_test.py
index 6d57d6f..dc53f34 100644
--- a/integrations/tensorflow/e2e/broadcast_to_test.py
+++ b/integrations/tensorflow/e2e/broadcast_to_test.py
@@ -12,6 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from absl import app
 import numpy as np
 from pyiree.tf.support import tf_test_utils
 import tensorflow.compat.v2 as tf
@@ -30,9 +31,12 @@
     return tf.broadcast_to(x, shape)
 
 
-@tf_test_utils.compile_module(BroadcastToModule)
 class BroadcastToTest(tf_test_utils.TracedModuleTestCase):
 
+  def __init__(self, methodName="runTest"):
+    super(BroadcastToTest, self).__init__(methodName)
+    self._modules = tf_test_utils.compile_tf_module(BroadcastToModule)
+
   def test_scalar_broadcast_to(self):
 
     def scalar_broadcast_to(module):
@@ -40,10 +44,15 @@
       shape = np.array([3, 3], dtype=np.int32)
       result = module.scalar_broadcast_to(x, shape)
 
-    self.compare_backends(scalar_broadcast_to)
+    self.compare_backends(scalar_broadcast_to, self._modules)
 
 
-if __name__ == "__main__":
-  if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+  del argv  # Unused
+  if hasattr(tf, 'enable_v2_behavior'):
     tf.enable_v2_behavior()
   tf.test.main()
+
+
+if __name__ == '__main__':
+  app.run(main)
diff --git a/integrations/tensorflow/e2e/broadcasting_test.py b/integrations/tensorflow/e2e/broadcasting_test.py
index c4f8e38..f72f3b3 100644
--- a/integrations/tensorflow/e2e/broadcasting_test.py
+++ b/integrations/tensorflow/e2e/broadcasting_test.py
@@ -14,6 +14,7 @@
 # limitations under the License.
 """Test broadcasting support."""
 
+from absl import app
 import numpy as np
 from pyiree.tf.support import tf_test_utils
 from pyiree.tf.support import tf_utils
@@ -30,9 +31,12 @@
     return lhs + rhs
 
 
-@tf_test_utils.compile_module(BroadcastingModule)
 class BroadcastingTest(tf_test_utils.TracedModuleTestCase):
 
+  def __init__(self, methodName="runTest"):
+    super(BroadcastingTest, self).__init__(methodName)
+    self._modules = tf_test_utils.compile_tf_module(BroadcastingModule)
+
   def test_add_same_shape(self):
 
     def add_same_shape(module):
@@ -40,7 +44,7 @@
       rhs = tf_utils.uniform([4])
       module.add(lhs, rhs)
 
-    self.compare_backends(add_same_shape)
+    self.compare_backends(add_same_shape, self._modules)
 
   def test_add_broadcast_lhs(self):
 
@@ -49,7 +53,7 @@
       rhs = tf_utils.uniform([4])
       module.add(lhs, rhs)
 
-    self.compare_backends(add_broadcast_lhs)
+    self.compare_backends(add_broadcast_lhs, self._modules)
 
   def test_add_broadcast_rhs(self):
 
@@ -58,10 +62,15 @@
       rhs = tf_utils.uniform([1])
       module.add(lhs, rhs)
 
-    self.compare_backends(add_broadcast_rhs)
+    self.compare_backends(add_broadcast_rhs, self._modules)
 
 
-if __name__ == "__main__":
-  if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+  del argv  # Unused
+  if hasattr(tf, 'enable_v2_behavior'):
     tf.enable_v2_behavior()
   tf.test.main()
+
+
+if __name__ == '__main__':
+  app.run(main)
diff --git a/integrations/tensorflow/e2e/complex_test.py b/integrations/tensorflow/e2e/complex_test.py
index 102396b..4832146 100644
--- a/integrations/tensorflow/e2e/complex_test.py
+++ b/integrations/tensorflow/e2e/complex_test.py
@@ -12,6 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from absl import app
 import numpy as np
 from pyiree.tf.support import tf_test_utils
 import tensorflow.compat.v2 as tf
@@ -32,9 +33,12 @@
     return tf.math.real(exp)
 
 
-@tf_test_utils.compile_module(ComplexModule)
 class ComplexTest(tf_test_utils.TracedModuleTestCase):
 
+  def __init__(self, methodName="runTest"):
+    super(ComplexTest, self).__init__(methodName)
+    self._modules = tf_test_utils.compile_tf_module(ComplexModule)
+
   def test_complex(self):
 
     def complex_exp(module):
@@ -42,10 +46,15 @@
       imag = np.array([-1., 0.4], dtype=np.float32)
       module.complex_exp(real, imag)
 
-    self.compare_backends(complex_exp)
+    self.compare_backends(complex_exp, self._modules)
 
 
-if __name__ == "__main__":
-  if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+  del argv  # Unused
+  if hasattr(tf, 'enable_v2_behavior'):
     tf.enable_v2_behavior()
   tf.test.main()
+
+
+if __name__ == '__main__':
+  app.run(main)
diff --git a/integrations/tensorflow/e2e/concat_test.py b/integrations/tensorflow/e2e/concat_test.py
index 187c6cb..dc0fab8 100644
--- a/integrations/tensorflow/e2e/concat_test.py
+++ b/integrations/tensorflow/e2e/concat_test.py
@@ -14,6 +14,7 @@
 # limitations under the License.
 """Test concat op."""
 
+from absl import app
 import numpy as np
 from pyiree.tf.support import tf_test_utils
 from pyiree.tf.support import tf_utils
@@ -51,9 +52,12 @@
     return tf.concat([a, b], axis=2)
 
 
-@tf_test_utils.compile_module(ConcatOpsModule)
 class ConcatOpsTest(tf_test_utils.TracedModuleTestCase):
 
+  def __init__(self, methodName="runTest"):
+    super(ConcatOpsTest, self).__init__(methodName)
+    self._modules = tf_test_utils.compile_tf_module(ConcatOpsModule)
+
   def test_concat_zero_dim(self):
 
     def concat_zero_dim(module):
@@ -61,7 +65,7 @@
       b = tf_utils.uniform([1, 5, 1])
       module.concat_zero_dim(a, b)
 
-    self.compare_backends(concat_zero_dim)
+    self.compare_backends(concat_zero_dim, self._modules)
 
   def test_concat0axis(self):
 
@@ -70,7 +74,7 @@
       b = tf_utils.uniform([1, 5, 1])
       module.concat0axis(a, b)
 
-    self.compare_backends(concat0axis)
+    self.compare_backends(concat0axis, self._modules)
 
   def test_concat1axis(self):
 
@@ -79,7 +83,7 @@
       b = tf_utils.uniform([1, 5, 1])
       module.concat1axis(a, b)
 
-    self.compare_backends(concat1axis)
+    self.compare_backends(concat1axis, self._modules)
 
   def test_concat2axis(self):
 
@@ -88,10 +92,15 @@
       b = tf_utils.uniform([1, 5, 1])
       module.concat2axis(a, b)
 
-    self.compare_backends(concat2axis)
+    self.compare_backends(concat2axis, self._modules)
 
 
-if __name__ == "__main__":
-  if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+  del argv  # Unused
+  if hasattr(tf, 'enable_v2_behavior'):
     tf.enable_v2_behavior()
   tf.test.main()
+
+
+if __name__ == '__main__':
+  app.run(main)
diff --git a/integrations/tensorflow/e2e/control_flow_test.py b/integrations/tensorflow/e2e/control_flow_test.py
index bc0a328..4ba691a 100644
--- a/integrations/tensorflow/e2e/control_flow_test.py
+++ b/integrations/tensorflow/e2e/control_flow_test.py
@@ -12,6 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from absl import app
 import numpy as np
 from pyiree.tf.support import tf_test_utils
 import tensorflow.compat.v2 as tf
@@ -34,16 +35,19 @@
     return i
 
 
-@tf_test_utils.compile_module(ControlFlowModule)
 class ControlFlowTest(tf_test_utils.TracedModuleTestCase):
 
+  def __init__(self, methodName="runTest"):
+    super(ControlFlowTest, self).__init__(methodName)
+    self._modules = tf_test_utils.compile_tf_module(ControlFlowModule)
+
   def test_short_sequence(self):
 
     def short_sequence(module):
       input_array = np.array(9., dtype=np.float32)
       module.collatz(input_array)
 
-    self.compare_backends(short_sequence)
+    self.compare_backends(short_sequence, self._modules)
 
   def test_long_sequence(self):
 
@@ -51,10 +55,15 @@
       input_array = np.array(178., dtype=np.float32)
       module.collatz(input_array)
 
-    self.compare_backends(long_sequence)
+    self.compare_backends(long_sequence, self._modules)
 
 
-if __name__ == "__main__":
-  if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+  del argv  # Unused
+  if hasattr(tf, 'enable_v2_behavior'):
     tf.enable_v2_behavior()
   tf.test.main()
+
+
+if __name__ == '__main__':
+  app.run(main)
diff --git a/integrations/tensorflow/e2e/conv_test.py b/integrations/tensorflow/e2e/conv_test.py
index 9346a52..cff4496 100644
--- a/integrations/tensorflow/e2e/conv_test.py
+++ b/integrations/tensorflow/e2e/conv_test.py
@@ -13,6 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from absl import app
 import numpy as np
 from pyiree.tf.support import tf_test_utils
 from pyiree.tf.support import tf_utils
@@ -99,9 +100,12 @@
     return tf.nn.conv2d(img, kernel, [1, 1, 1, 1], "VALID", name="result")
 
 
-@tf_test_utils.compile_module(Conv2dModule)
 class ConvTest(tf_test_utils.TracedModuleTestCase):
 
+  def __init__(self, methodName="runTest"):
+    super(ConvTest, self).__init__(methodName)
+    self._modules = tf_test_utils.compile_tf_module(Conv2dModule)
+
   def test_id_batch_size_1(self):
 
     def id_batch_size_1(module):
@@ -109,7 +113,7 @@
       k = np.ones([1, 1, 1, 1], dtype=np.float32)
       module.conv2d_1451x1111_valid(i, k)
 
-    self.compare_backends(id_batch_size_1)
+    self.compare_backends(id_batch_size_1, self._modules)
 
   def test_id_batch_size_2(self):
 
@@ -118,7 +122,7 @@
       k = np.ones([1, 1, 1, 1], dtype=np.float32)
       module.conv2d_2451x1111_valid(i, k)
 
-    self.compare_backends(id_batch_size_2)
+    self.compare_backends(id_batch_size_2, self._modules)
 
   def test_asymmetric_kernel(self):
 
@@ -128,7 +132,7 @@
                    dtype=np.float32).reshape(2, 3, 1, 1)
       module.conv2d_1451x2311_valid(i, k)
 
-    self.compare_backends(asymmetric_kernel)
+    self.compare_backends(asymmetric_kernel, self._modules)
 
   def test_padding(self):
 
@@ -138,7 +142,7 @@
                    dtype=np.float32).reshape(2, 3, 1, 1)
       module.conv2d_1451x2311_same(i, k)
 
-    self.compare_backends(padding)
+    self.compare_backends(padding, self._modules)
 
   def test_batched_padding(self):
 
@@ -148,7 +152,7 @@
                    dtype=np.float32).reshape(2, 3, 1, 1)
       module.conv2d_2451x2311_same(i, k)
 
-    self.compare_backends(batched_padding)
+    self.compare_backends(batched_padding, self._modules)
 
   def test_feature_reduce(self):
 
@@ -157,7 +161,7 @@
       k = np.ones([3, 2, 2, 1], dtype=np.float32)
       module.conv2d_1452x3221_same(i, k)
 
-    self.compare_backends(feature_reduce)
+    self.compare_backends(feature_reduce, self._modules)
 
   def test_feature_inflate(self):
 
@@ -166,7 +170,7 @@
       k = tf_utils.ndarange([1, 1, 1, 2])
       module.conv2d_1451x1112_same(i, k)
 
-    self.compare_backends(feature_inflate)
+    self.compare_backends(feature_inflate, self._modules)
 
   def test_feature_mix(self):
 
@@ -175,7 +179,7 @@
       k = tf_utils.ndarange([1, 1, 2, 2])
       module.conv2d_1452x1122_same(i, k)
 
-    self.compare_backends(feature_mix)
+    self.compare_backends(feature_mix, self._modules)
 
   def test_feature_padded(self):
 
@@ -184,7 +188,7 @@
       k = tf_utils.ndarange([2, 2, 2, 3])
       module.conv2d_1452x2223_same(i, k)
 
-    self.compare_backends(feature_padded)
+    self.compare_backends(feature_padded, self._modules)
 
   def test_feature_unpadded(self):
 
@@ -193,7 +197,7 @@
       k = tf_utils.ndarange([2, 2, 2, 3])
       module.conv2d_1452x2223_valid(i, k)
 
-    self.compare_backends(feature_unpadded)
+    self.compare_backends(feature_unpadded, self._modules)
 
   def test_batched_feature_unpadded(self):
 
@@ -202,10 +206,15 @@
       k = tf_utils.ndarange([2, 2, 2, 3])
       module.conv2d_2452x2223_valid(i, k)
 
-    self.compare_backends(batched_feature_unpadded)
+    self.compare_backends(batched_feature_unpadded, self._modules)
 
 
-if __name__ == "__main__":
-  if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+  del argv  # Unused
+  if hasattr(tf, 'enable_v2_behavior'):
     tf.enable_v2_behavior()
   tf.test.main()
+
+
+if __name__ == '__main__':
+  app.run(main)
diff --git a/integrations/tensorflow/e2e/depth_conv_test.py b/integrations/tensorflow/e2e/depth_conv_test.py
index e3ba303..c928d5e 100644
--- a/integrations/tensorflow/e2e/depth_conv_test.py
+++ b/integrations/tensorflow/e2e/depth_conv_test.py
@@ -13,6 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from absl import app
 import numpy as np
 from pyiree.tf.support import tf_test_utils
 from pyiree.tf.support import tf_utils
@@ -27,45 +28,58 @@
       tf.TensorSpec([2, 2, 2, 3], tf.float32),
   ])
   def conv2d_2452x2423_valid(self, img, kernel):
-    return tf.nn.depthwise_conv2d(
-        img, kernel, [1, 1, 1, 1], "VALID", name="result")
+    return tf.nn.depthwise_conv2d(img,
+                                  kernel, [1, 1, 1, 1],
+                                  "VALID",
+                                  name="result")
 
   @tf.function(input_signature=[
       tf.TensorSpec([2, 4, 5, 2], tf.float32),
       tf.TensorSpec([2, 4, 2, 3], tf.float32),
   ])
   def conv2d_2452x2423_same(self, img, kernel):
-    return tf.nn.depthwise_conv2d(
-        img, kernel, [1, 1, 1, 1], "SAME", name="result")
+    return tf.nn.depthwise_conv2d(img,
+                                  kernel, [1, 1, 1, 1],
+                                  "SAME",
+                                  name="result")
 
   @tf.function(input_signature=[
       tf.TensorSpec([2, 4, 5, 2], tf.float32),
       tf.TensorSpec([2, 4, 2, 3], tf.float32),
   ])
   def conv2d_2452x2423_valid_stride_2(self, img, kernel):
-    return tf.nn.depthwise_conv2d(
-        img, kernel, [1, 2, 2, 1], "VALID", name="result")
+    return tf.nn.depthwise_conv2d(img,
+                                  kernel, [1, 2, 2, 1],
+                                  "VALID",
+                                  name="result")
 
   @tf.function(input_signature=[
       tf.TensorSpec([2, 4, 5, 2], tf.float32),
       tf.TensorSpec([2, 4, 2, 3], tf.float32),
   ])
   def conv2d_2452x2423_same_stride_2(self, img, kernel):
-    return tf.nn.depthwise_conv2d(
-        img, kernel, [1, 2, 2, 1], "SAME", name="result")
+    return tf.nn.depthwise_conv2d(img,
+                                  kernel, [1, 2, 2, 1],
+                                  "SAME",
+                                  name="result")
 
   @tf.function(input_signature=[
       tf.TensorSpec([2, 4, 5, 4], tf.float32),
       tf.TensorSpec([2, 4, 4, 1], tf.float32),
   ])
   def conv2d_2453x2441_same_stride_1(self, img, kernel):
-    return tf.nn.depthwise_conv2d(
-        img, kernel, [1, 1, 1, 1], "SAME", name="result")
+    return tf.nn.depthwise_conv2d(img,
+                                  kernel, [1, 1, 1, 1],
+                                  "SAME",
+                                  name="result")
 
 
-@tf_test_utils.compile_module(DepthConv2dModule)
 class ConvTest(tf_test_utils.TracedModuleTestCase):
 
+  def __init__(self, methodName="runTest"):
+    super(ConvTest, self).__init__(methodName)
+    self._modules = tf_test_utils.compile_tf_module(DepthConv2dModule)
+
   def test_batched_feature_unpadded(self):
 
     def batched_feature_unpadded(module):
@@ -73,7 +87,7 @@
       k = tf_utils.ndarange([2, 2, 2, 3])
       module.conv2d_2452x2423_valid(i, k)
 
-    self.compare_backends(batched_feature_unpadded)
+    self.compare_backends(batched_feature_unpadded, self._modules)
 
   def test_batched_feature_unpadded_same(self):
 
@@ -82,7 +96,7 @@
       k = tf_utils.ndarange([2, 4, 2, 3])
       module.conv2d_2452x2423_same(i, k)
 
-    self.compare_backends(batched_feature_unpadded_same)
+    self.compare_backends(batched_feature_unpadded_same, self._modules)
 
   def test_batched_feature_unpadded_same_stride_2(self):
 
@@ -91,7 +105,8 @@
       k = tf_utils.ndarange([2, 4, 2, 3])
       module.conv2d_2452x2423_valid_stride_2(i, k)
 
-    self.compare_backends(batched_feature_unpadded_same_stride_2)
+    self.compare_backends(batched_feature_unpadded_same_stride_2,
+                          self._modules)
 
   def test_batched_feature_padded_same_stride_2(self):
 
@@ -100,7 +115,7 @@
       k = tf_utils.ndarange([2, 4, 2, 3])
       module.conv2d_2452x2423_same_stride_2(i, k)
 
-    self.compare_backends(batched_feature_padded_same_stride_2)
+    self.compare_backends(batched_feature_padded_same_stride_2, self._modules)
 
   def test_batched_feature_padded_same_stride_1_output_1(self):
 
@@ -109,10 +124,16 @@
       k = tf_utils.ndarange([2, 4, 4, 1])
       module.conv2d_2453x2441_same_stride_1(i, k)
 
-    self.compare_backends(batched_feature_padded_same_stride_1_output_1)
+    self.compare_backends(batched_feature_padded_same_stride_1_output_1,
+                          self._modules)
 
 
-if __name__ == "__main__":
-  if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+  del argv  # Unused
+  if hasattr(tf, 'enable_v2_behavior'):
     tf.enable_v2_behavior()
   tf.test.main()
+
+
+if __name__ == '__main__':
+  app.run(main)
diff --git a/integrations/tensorflow/e2e/dynamic_mlp_relu_test.py b/integrations/tensorflow/e2e/dynamic_mlp_relu_test.py
index 3742c6e..1e5ba08 100644
--- a/integrations/tensorflow/e2e/dynamic_mlp_relu_test.py
+++ b/integrations/tensorflow/e2e/dynamic_mlp_relu_test.py
@@ -17,6 +17,7 @@
 # This uses a relu instead, allowing it to get to the remaining issue
 # (unimplemented dynamic dot_general).
 
+from absl import app
 import numpy as np
 from pyiree.tf.support import tf_test_utils
 from pyiree.tf.support import tf_utils
@@ -42,8 +43,8 @@
     self.input_dim = input_dim
     self.classes = classes
     self.h1_weights = tf.Variable(tf.random.normal([input_dim, hidden_1_dim]))
-    self.h2_weights = tf.Variable(
-        tf.random.normal([hidden_1_dim, hidden_2_dim]))
+    self.h2_weights = tf.Variable(tf.random.normal([hidden_1_dim,
+                                                    hidden_2_dim]))
     self.out_weights = tf.Variable(tf.random.normal([hidden_2_dim, classes]))
     self.h1_bias = tf.Variable(tf.random.normal([hidden_1_dim]))
     self.h2_bias = tf.Variable(tf.random.normal([hidden_2_dim]))
@@ -51,8 +52,7 @@
 
     # Compile with dynamic batch dim.
     self.predict = tf.function(
-        input_signature=[tf.TensorSpec([None, self.input_dim])])(
-            self.predict)
+        input_signature=[tf.TensorSpec([None, self.input_dim])])(self.predict)
 
   def mlp(self, x):
     layer_1 = tf.nn.relu(tf.add(tf.matmul(x, self.h1_weights), self.h1_bias))
@@ -65,19 +65,28 @@
     return tf.nn.softmax(self.mlp(x))
 
 
-@tf_test_utils.compile_module(DynamicMlpReluModule, exported_names=["predict"])
 class DynamicMlpReluTest(tf_test_utils.TracedModuleTestCase):
 
+  def __init__(self, methodName="runTest"):
+    super(DynamicMlpReluTest, self).__init__(methodName)
+    self._modules = tf_test_utils.compile_tf_module(DynamicMlpReluModule,
+                                                    exported_names=["predict"])
+
   def test_dynamic_batch(self):
 
     def dynamic_batch(module):
       x = tf_utils.uniform([3, 28 * 28]) * 1e-3
       module.predict(x)
 
-    self.compare_backends(dynamic_batch)
+    self.compare_backends(dynamic_batch, self._modules)
 
 
-if __name__ == "__main__":
-  if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+  del argv  # Unused
+  if hasattr(tf, 'enable_v2_behavior'):
     tf.enable_v2_behavior()
   tf.test.main()
+
+
+if __name__ == '__main__':
+  app.run(main)
diff --git a/integrations/tensorflow/e2e/dynamic_mlp_test.py b/integrations/tensorflow/e2e/dynamic_mlp_test.py
index ff905a5..a0ecdc5 100644
--- a/integrations/tensorflow/e2e/dynamic_mlp_test.py
+++ b/integrations/tensorflow/e2e/dynamic_mlp_test.py
@@ -13,6 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from absl import app
 import numpy as np
 from pyiree.tf.support import tf_test_utils
 from pyiree.tf.support import tf_utils
@@ -38,8 +39,8 @@
     self.input_dim = input_dim
     self.classes = classes
     self.h1_weights = tf.Variable(tf.random.normal([input_dim, hidden_1_dim]))
-    self.h2_weights = tf.Variable(
-        tf.random.normal([hidden_1_dim, hidden_2_dim]))
+    self.h2_weights = tf.Variable(tf.random.normal([hidden_1_dim,
+                                                    hidden_2_dim]))
     self.out_weights = tf.Variable(tf.random.normal([hidden_2_dim, classes]))
     self.h1_bias = tf.Variable(tf.random.normal([hidden_1_dim]))
     self.h2_bias = tf.Variable(tf.random.normal([hidden_2_dim]))
@@ -47,8 +48,7 @@
 
     # Compile with dynamic batch dim.
     self.predict = tf.function(
-        input_signature=[tf.TensorSpec([None, self.input_dim])])(
-            self.predict)
+        input_signature=[tf.TensorSpec([None, self.input_dim])])(self.predict)
 
   def mlp(self, x):
     layer_1 = tf.sigmoid(tf.add(tf.matmul(x, self.h1_weights), self.h1_bias))
@@ -61,19 +61,28 @@
     return tf.nn.softmax(self.mlp(x))
 
 
-@tf_test_utils.compile_module(DynamicMlpModule, exported_names=["predict"])
 class DynamicMlpTest(tf_test_utils.TracedModuleTestCase):
 
+  def __init__(self, methodName="runTest"):
+    super(DynamicMlpTest, self).__init__(methodName)
+    self._modules = tf_test_utils.compile_tf_module(DynamicMlpModule,
+                                                    exported_names=["predict"])
+
   def test_dynamic_batch(self):
 
     def dynamic_batch(module):
       x = tf_utils.uniform([3, 28 * 28]) * 1e-3
       module.predict(x)
 
-    self.compare_backends(dynamic_batch)
+    self.compare_backends(dynamic_batch, self._modules)
 
 
-if __name__ == "__main__":
-  if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+  del argv  # Unused
+  if hasattr(tf, 'enable_v2_behavior'):
     tf.enable_v2_behavior()
   tf.test.main()
+
+
+if __name__ == '__main__':
+  app.run(main)
diff --git a/integrations/tensorflow/e2e/fill_test.py b/integrations/tensorflow/e2e/fill_test.py
index 050eb47..adaeac9 100644
--- a/integrations/tensorflow/e2e/fill_test.py
+++ b/integrations/tensorflow/e2e/fill_test.py
@@ -12,6 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from absl import app
 import numpy as np
 from pyiree.tf.support import tf_test_utils
 import tensorflow.compat.v2 as tf
@@ -30,9 +31,12 @@
     return tf.fill(dims, value)
 
 
-@tf_test_utils.compile_module(FillModule)
 class FillTest(tf_test_utils.TracedModuleTestCase):
 
+  def __init__(self, methodName="runTest"):
+    super(FillTest, self).__init__(methodName)
+    self._modules = tf_test_utils.compile_tf_module(FillModule)
+
   def test_fill(self):
 
     def fill(module):
@@ -40,10 +44,15 @@
       value = np.array(9., dtype=np.float32)
       module.fill(dims, value)
 
-    self.compare_backends(fill)
+    self.compare_backends(fill, self._modules)
 
 
-if __name__ == "__main__":
-  if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+  del argv  # Unused
+  if hasattr(tf, 'enable_v2_behavior'):
     tf.enable_v2_behavior()
   tf.test.main()
+
+
+if __name__ == '__main__':
+  app.run(main)
diff --git a/integrations/tensorflow/e2e/finite_test.py b/integrations/tensorflow/e2e/finite_test.py
index ff62f3a..1f53406 100644
--- a/integrations/tensorflow/e2e/finite_test.py
+++ b/integrations/tensorflow/e2e/finite_test.py
@@ -12,6 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from absl import app
 import numpy as np
 from pyiree.tf.support import tf_test_utils
 import tensorflow.compat.v2 as tf
@@ -24,18 +25,26 @@
     return tf.math.is_finite(x)
 
 
-@tf_test_utils.compile_module(FiniteModule)
 class FiniteTest(tf_test_utils.TracedModuleTestCase):
 
+  def __init__(self, methodName="runTest"):
+    super(FiniteTest, self).__init__(methodName)
+    self._modules = tf_test_utils.compile_tf_module(FiniteModule)
+
   def test_finite(self):
 
     def finite(module):
       module.finite(np.array([0.0, 1.2, -5.0, np.inf], dtype=np.float32))
 
-    self.compare_backends(finite)
+    self.compare_backends(finite, self._modules)
 
 
-if __name__ == "__main__":
-  if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+  del argv  # Unused
+  if hasattr(tf, 'enable_v2_behavior'):
     tf.enable_v2_behavior()
   tf.test.main()
+
+
+if __name__ == '__main__':
+  app.run(main)
diff --git a/integrations/tensorflow/e2e/gather_test.py b/integrations/tensorflow/e2e/gather_test.py
index fc0ec13..da71d4b 100644
--- a/integrations/tensorflow/e2e/gather_test.py
+++ b/integrations/tensorflow/e2e/gather_test.py
@@ -12,6 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from absl import app
 import numpy as np
 from pyiree.tf.support import tf_test_utils
 from pyiree.tf.support import tf_utils
@@ -65,9 +66,12 @@
     return tf.gather(params, indices, axis=1, batch_dims=1)
 
 
-@tf_test_utils.compile_module(GatherModule)
 class GatherTest(tf_test_utils.TracedModuleTestCase):
 
+  def __init__(self, methodName="runTest"):
+    super(GatherTest, self).__init__(methodName)
+    self._modules = tf_test_utils.compile_tf_module(GatherModule)
+
   def test_gather_axis0_scalar(self):
 
     def gather_axis0_scalar(module):
@@ -75,7 +79,7 @@
       params = tf_utils.ndarange([4, 8])
       module.gather_axis0_scalar(params, indices)
 
-    self.compare_backends(gather_axis0_scalar)
+    self.compare_backends(gather_axis0_scalar, self._modules)
 
   def test_gather_axis0_batch0(self):
 
@@ -84,7 +88,7 @@
       params = tf_utils.ndarange([4, 8])
       module.gather_axis0_batch0(params, indices)
 
-    self.compare_backends(gather_axis0_batch0)
+    self.compare_backends(gather_axis0_batch0, self._modules)
 
   def test_gather_axis1_batch0(self):
 
@@ -93,7 +97,7 @@
       params = tf_utils.ndarange([4, 7, 8])
       module.gather_axis1_batch0(params, indices)
 
-    self.compare_backends(gather_axis1_batch0)
+    self.compare_backends(gather_axis1_batch0, self._modules)
 
   def test_gather_axis2_batch1(self):
 
@@ -102,7 +106,7 @@
       params = tf_utils.ndarange([4, 7, 8, 2])
       module.gather_axis2_batch1(params, indices)
 
-    self.compare_backends(gather_axis2_batch1)
+    self.compare_backends(gather_axis2_batch1, self._modules)
 
   def test_gather_axis1_batch1(self):
 
@@ -111,7 +115,7 @@
       params = tf_utils.ndarange([4, 7, 8, 2])
       module.gather_axis1_batch1(params, indices)
 
-    self.compare_backends(gather_axis1_batch1)
+    self.compare_backends(gather_axis1_batch1, self._modules)
 
   def test_gather_axis2_batch2(self):
 
@@ -120,11 +124,16 @@
       values = np.array([[0, 1, 2, 3], [9, 8, 7, 0]], dtype=np.int32)
       module.gather_axis2_batch2(values, indices)
 
-    self.compare_backends(gather_axis2_batch2)
+    self.compare_backends(gather_axis2_batch2, self._modules)
 
 
 
-if __name__ == "__main__":
-  if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+  del argv  # Unused
+  if hasattr(tf, 'enable_v2_behavior'):
     tf.enable_v2_behavior()
   tf.test.main()
+
+
+if __name__ == '__main__':
+  app.run(main)
diff --git a/integrations/tensorflow/e2e/keras/lstm_static_test.py b/integrations/tensorflow/e2e/keras/lstm_static_test.py
index db8f86e..5cb1827 100644
--- a/integrations/tensorflow/e2e/keras/lstm_static_test.py
+++ b/integrations/tensorflow/e2e/keras/lstm_static_test.py
@@ -16,6 +16,7 @@
 # This test is the same as keras_lstm_test, but all shapes are static.
 # This stresses the TensorList lowering more specifically.
 
+from absl import app
 import numpy as np
 from pyiree.tf.support import tf_test_utils
 from pyiree.tf.support import tf_utils
@@ -42,19 +43,28 @@
             self.m.call)
 
 
-@tf_test_utils.compile_module(LstmStaticModule, exported_names=["predict"])
 class LstmStaticTest(tf_test_utils.TracedModuleTestCase):
 
+  def __init__(self, methodName="runTest"):
+    super(LstmStaticTest, self).__init__(methodName)
+    self._modules = tf_test_utils.compile_tf_module(LstmStaticModule,
+                                                    exported_names=["predict"])
+
   def test_lstm(self):
 
     def predict(module):
       inputs = tf_utils.ndarange(INPUT_SHAPE)
       module.predict(inputs, rtol=1e-5, atol=1e-5)
 
-    self.compare_backends(predict)
+    self.compare_backends(predict, self._modules)
 
 
-if __name__ == "__main__":
-  if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+  del argv  # Unused
+  if hasattr(tf, 'enable_v2_behavior'):
     tf.enable_v2_behavior()
   tf.test.main()
+
+
+if __name__ == '__main__':
+  app.run(main)
diff --git a/integrations/tensorflow/e2e/keras/lstm_test.py b/integrations/tensorflow/e2e/keras/lstm_test.py
index 112e32a..747cc6f 100644
--- a/integrations/tensorflow/e2e/keras/lstm_test.py
+++ b/integrations/tensorflow/e2e/keras/lstm_test.py
@@ -13,6 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from absl import app
 import numpy as np
 from pyiree.tf.support import tf_test_utils
 from pyiree.tf.support import tf_utils
@@ -31,28 +32,36 @@
     super(LstmModule, self).__init__()
     tf_utils.set_random_seed()
     inputs = tf.keras.layers.Input(batch_size=None, shape=DYNAMIC_SHAPE[1:])
-    outputs = tf.keras.layers.LSTM(
-        units=NUM_UNITS, return_sequences=True)(
-            inputs)
+    outputs = tf.keras.layers.LSTM(units=NUM_UNITS,
+                                   return_sequences=True)(inputs)
     self.m = tf.keras.Model(inputs, outputs)
     self.predict = tf.function(
-        input_signature=[tf.TensorSpec(DYNAMIC_SHAPE, tf.float32)])(
-            self.m.call)
+        input_signature=[tf.TensorSpec(DYNAMIC_SHAPE, tf.float32)])(self.m.call)
 
 
-@tf_test_utils.compile_module(LstmModule, exported_names=["predict"])
 class LstmTest(tf_test_utils.TracedModuleTestCase):
 
+  def __init__(self, methodName="runTest"):
+    super(LstmTest, self).__init__(methodName)
+    self._modules = tf_test_utils.compile_tf_module(LstmModule,
+                                                    exported_names=["predict"])
+
   def test_lstm(self):
 
     def predict(module):
       inputs = tf_utils.ndarange(INPUT_SHAPE)
       module.predict(inputs)
 
-    self.compare_backends(predict)
+    self.compare_backends(predict, self._modules)
 
 
-if __name__ == "__main__":
-  if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+  del argv  # Unused
+  if hasattr(tf, 'enable_v2_behavior'):
     tf.enable_v2_behavior()
+
   tf.test.main()
+
+
+if __name__ == '__main__':
+  app.run(main)
diff --git a/integrations/tensorflow/e2e/keras/train/model_train_test.py b/integrations/tensorflow/e2e/keras/train/model_train_test.py
index 75286ad..49a3949 100644
--- a/integrations/tensorflow/e2e/keras/train/model_train_test.py
+++ b/integrations/tensorflow/e2e/keras/train/model_train_test.py
@@ -75,8 +75,6 @@
     return loss_value
 
 
-@tf_test_utils.compile_module(
-    ModelTrain.CreateModule, exported_names=["train_step"])
 class ModelTrainTest(tf_test_utils.TracedModuleTestCase):
 
   def generate_regression_data(self, size=8):
@@ -104,11 +102,12 @@
       # Run one iteration of training step.
       module.train_step(inputs, targets)
 
-    self.compare_backends(train_step)
+    self.compare_backends(train_step, self._modules)
 
 
 if __name__ == "__main__":
   if hasattr(tf, "enable_v2_behavior"):
     tf.enable_v2_behavior()
-
+  tf_test_utils.compile_tf_module(ModelTrain.CreateModule,
+                                  exported_names=["train_step"])
   tf.test.main()
diff --git a/integrations/tensorflow/e2e/keras/vision_model_test.py b/integrations/tensorflow/e2e/keras/vision_model_test.py
index 10663ec..4e1ca64 100644
--- a/integrations/tensorflow/e2e/keras/vision_model_test.py
+++ b/integrations/tensorflow/e2e/keras/vision_model_test.py
@@ -113,8 +113,9 @@
   # an external tf.keras URL.
   weights = 'imagenet' if FLAGS.data == 'imagenet' else None
 
-  model = APP_MODELS[FLAGS.model](
-      weights=weights, include_top=FLAGS.include_top, input_shape=input_shape)
+  model = APP_MODELS[FLAGS.model](weights=weights,
+                                  include_top=FLAGS.include_top,
+                                  input_shape=input_shape)
 
   if FLAGS.data == 'cifar10' and FLAGS.url:
     model = load_cifar10_weights(model)
@@ -131,19 +132,22 @@
     # TODO(b/142948097): Add support for dynamic shapes in SPIR-V lowering.
     # Replace input_shape with m.input_shape to make the batch size dynamic.
     self.predict = tf.function(
-        input_signature=[tf.TensorSpec(get_input_shape())])(
-            self.m.predict)
+        input_signature=[tf.TensorSpec(get_input_shape())])(self.m.predict)
 
 
-@tf_test_utils.compile_module(VisionModule, exported_names=['predict'])
 class AppTest(tf_test_utils.TracedModuleTestCase):
 
+  def __init__(self, methodName="runTest"):
+    super(AppTest, self).__init__(methodName)
+    self._modules = tf_test_utils.compile_tf_module(VisionModule,
+                                                    exported_names=['predict'])
+
   def test_application(self):
 
     def predict(module):
       module.predict(tf_utils.uniform(get_input_shape()))
 
-    self.compare_backends(predict)
+    self.compare_backends(predict, self._modules)
 
 
 def main(argv):
diff --git a/integrations/tensorflow/e2e/linspace_test.py b/integrations/tensorflow/e2e/linspace_test.py
index b535021..4acc170 100644
--- a/integrations/tensorflow/e2e/linspace_test.py
+++ b/integrations/tensorflow/e2e/linspace_test.py
@@ -12,12 +12,13 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from absl import app
 import numpy as np
 from pyiree.tf.support import tf_test_utils
 import tensorflow.compat.v2 as tf
 
 
-class LinSpaceModule(tf.Module):
+class LinspaceModule(tf.Module):
 
   def __init__(self):
     pass
@@ -33,9 +34,12 @@
     return tf.linspace(start, stop, num)
 
 
-@tf_test_utils.compile_module(LinSpaceModule)
 class LinspaceTest(tf_test_utils.TracedModuleTestCase):
 
+  def __init__(self, methodName="runTest"):
+    super(LinspaceTest, self).__init__(methodName)
+    self._modules = tf_test_utils.compile_tf_module(LinspaceModule)
+
   def test_linspace(self):
 
     def linspace(module):
@@ -43,10 +47,15 @@
       stop = np.array(12., dtype=np.float32)
       module.linspace(start, stop)
 
-    self.compare_backends(linspace)
+    self.compare_backends(linspace, self._modules)
 
 
-if __name__ == "__main__":
-  if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+  del argv  # Unused
+  if hasattr(tf, 'enable_v2_behavior'):
     tf.enable_v2_behavior()
   tf.test.main()
+
+
+if __name__ == '__main__':
+  app.run(main)
diff --git a/integrations/tensorflow/e2e/logical_ops_test.py b/integrations/tensorflow/e2e/logical_ops_test.py
index ea3fd7a..ee83a95 100644
--- a/integrations/tensorflow/e2e/logical_ops_test.py
+++ b/integrations/tensorflow/e2e/logical_ops_test.py
@@ -13,6 +13,7 @@
 # limitations under the License.
 """Tests for ops in the tf.math module that specifically handle logical ops."""
 
+from absl import app
 import numpy as np
 from pyiree.tf.support import tf_test_utils
 import tensorflow.compat.v2 as tf
@@ -46,9 +47,12 @@
     return tf.math.logical_not(x)
 
 
-@tf_test_utils.compile_module(LogicalOpsModule)
 class LogicalOpsTest(tf_test_utils.TracedModuleTestCase):
 
+  def __init__(self, methodName="runTest"):
+    super(LogicalOpsTest, self).__init__(methodName)
+    self._modules = tf_test_utils.compile_tf_module(LogicalOpsModule)
+
   def test_logical_and(self):
 
     def logical_and(module):
@@ -56,7 +60,7 @@
           np.array([1, 1, 0, 0], dtype=np.bool),
           np.array([0, 1, 1, 0], dtype=np.bool))
 
-    self.compare_backends(logical_and)
+    self.compare_backends(logical_and, self._modules)
 
   def test_logical_or(self):
 
@@ -65,7 +69,7 @@
           np.array([1, 1, 0, 0], dtype=np.bool),
           np.array([0, 1, 1, 0], dtype=np.bool))
 
-    self.compare_backends(logical_or)
+    self.compare_backends(logical_or, self._modules)
 
   def test_logical_xor(self):
 
@@ -74,17 +78,22 @@
           np.array([1, 1, 0, 0], dtype=np.bool),
           np.array([0, 1, 1, 0], dtype=np.bool))
 
-    self.compare_backends(logical_xor)
+    self.compare_backends(logical_xor, self._modules)
 
   def test_logical_not(self):
 
     def logical_not(module):
       module.logical_not(np.array([0, 1, 1, 0], dtype=np.bool))
 
-    self.compare_backends(logical_not)
+    self.compare_backends(logical_not, self._modules)
 
 
-if __name__ == "__main__":
-  if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+  del argv  # Unused
+  if hasattr(tf, 'enable_v2_behavior'):
     tf.enable_v2_behavior()
   tf.test.main()
+
+
+if __name__ == '__main__':
+  app.run(main)
diff --git a/integrations/tensorflow/e2e/mandelbrot_test.py b/integrations/tensorflow/e2e/mandelbrot_test.py
index b5a3929..51c1509 100644
--- a/integrations/tensorflow/e2e/mandelbrot_test.py
+++ b/integrations/tensorflow/e2e/mandelbrot_test.py
@@ -12,6 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from absl import app
 from pyiree.tf.support import tf_test_utils
 import tensorflow.compat.v2 as tf
 
@@ -90,9 +91,12 @@
     return tf.reshape(in_the_set, shape=[view_pixels, view_pixels])
 
 
-@tf_test_utils.compile_module(MandelbrotModule)
 class MandelbrotTest(tf_test_utils.TracedModuleTestCase):
 
+  def __init__(self, methodName="runTest"):
+    super(MandelbrotTest, self).__init__(methodName)
+    self._modules = tf_test_utils.compile_tf_module(MandelbrotModule)
+
   def test_mandelbrot(self):
 
     def mandelbrot(module):
@@ -101,10 +105,15 @@
       # This is a much more detailed view, so more iterations are needed.
       module.calculate(-0.7436447860, 0.1318252536, 0.0000029336, 400, 3000)
 
-    self.compare_backends(mandelbrot)
+    self.compare_backends(mandelbrot, self._modules)
 
 
-if __name__ == "__main__":
-  if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+  del argv  # Unused
+  if hasattr(tf, 'enable_v2_behavior'):
     tf.enable_v2_behavior()
   tf.test.main()
+
+
+if __name__ == '__main__':
+  app.run(main)
diff --git a/integrations/tensorflow/e2e/math_test.py b/integrations/tensorflow/e2e/math_test.py
index 48bcbb6..a96193a 100644
--- a/integrations/tensorflow/e2e/math_test.py
+++ b/integrations/tensorflow/e2e/math_test.py
@@ -14,6 +14,7 @@
 # limitations under the License.
 """Tests for ops in the tf.math module."""
 
+from absl import app
 import numpy as np
 from pyiree.tf.support import tf_test_utils
 import tensorflow.compat.v2 as tf
@@ -42,46 +43,54 @@
     return tf.math.mod(x, 2.0)
 
 
-@tf_test_utils.compile_module(MathModule)
 class MathTest(tf_test_utils.TracedModuleTestCase):
 
+  def __init__(self, methodName="runTest"):
+    super(MathTest, self).__init__(methodName)
+    self._modules = tf_test_utils.compile_tf_module(MathModule)
+
   def test_abs(self):
 
     def abs(module):
       module.abs(np.array([-0.5, 0.0, 0.5, 1.0], dtype=np.float32))
 
-    self.compare_backends(abs)
+    self.compare_backends(abs, self._modules)
 
   def test_ceil(self):
 
     def ceil(module):
       module.ceil(np.array([0.0, 1.2, 1.5, 3.75], dtype=np.float32))
 
-    self.compare_backends(ceil)
+    self.compare_backends(ceil, self._modules)
 
   def test_cos(self):
 
     def cos(module):
       module.cos(np.array([-0.5, 0.0, 0.5, 1.0], dtype=np.float32))
 
-    self.compare_backends(cos)
+    self.compare_backends(cos, self._modules)
 
   def test_log(self):
 
     def log(module):
       module.log(np.array([0.1, 0.2, 0.5, 1.0], dtype=np.float32))
 
-    self.compare_backends(log)
+    self.compare_backends(log, self._modules)
 
   def test_mod(self):
 
     def mod(module):
       module.mod(np.array([0.0, 1.2, 1.5, 3.75], dtype=np.float32))
 
-    self.compare_backends(mod)
+    self.compare_backends(mod, self._modules)
 
 
-if __name__ == "__main__":
-  if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+  del argv  # Unused
+  if hasattr(tf, 'enable_v2_behavior'):
     tf.enable_v2_behavior()
   tf.test.main()
+
+
+if __name__ == '__main__':
+  app.run(main)
diff --git a/integrations/tensorflow/e2e/matrix_ops_dynamic_test.py b/integrations/tensorflow/e2e/matrix_ops_dynamic_test.py
index 9c87b7b..4a3daf7 100644
--- a/integrations/tensorflow/e2e/matrix_ops_dynamic_test.py
+++ b/integrations/tensorflow/e2e/matrix_ops_dynamic_test.py
@@ -14,6 +14,7 @@
 # limitations under the License.
 """Test matrix ops."""
 
+from absl import app
 from pyiree.tf.support import tf_test_utils
 from pyiree.tf.support import tf_utils
 import tensorflow.compat.v2 as tf
@@ -43,16 +44,19 @@
     return tf.matmul(lhs, rhs)
 
 
-@tf_test_utils.compile_module(MatrixOpsDynamicModule)
 class MatrixOpsDynamicTest(tf_test_utils.TracedModuleTestCase):
 
+  def __init__(self, methodName="runTest"):
+    super(MatrixOpsDynamicTest, self).__init__(methodName)
+    self._modules = tf_test_utils.compile_tf_module(MatrixOpsDynamicModule)
+
   def test_matmul_high_rank_batch(self):
 
     def matmul_high_rank_batch(module):
       module.matmul_high_rank_batch(
           tf_utils.uniform([1, 7, 4, 2]), tf_utils.uniform([7, 1, 2, 4]))
 
-    self.compare_backends(matmul_high_rank_batch)
+    self.compare_backends(matmul_high_rank_batch, self._modules)
 
   def test_matmul_dynamic_matching_batch(self):
 
@@ -60,7 +64,7 @@
       module.matmul_dynamic(
           tf_utils.uniform([2, 2, 3]), tf_utils.uniform([2, 3, 4]))
 
-    self.compare_backends(matmul_dynamic_matching_batch)
+    self.compare_backends(matmul_dynamic_matching_batch, self._modules)
 
   def test_matmul_dynamic_broadcast_lhs(self):
 
@@ -68,7 +72,7 @@
       module.matmul_dynamic(
           tf_utils.uniform([1, 2, 3]), tf_utils.uniform([2, 3, 4]))
 
-    self.compare_backends(matmul_dynamic_broadcast_lhs)
+    self.compare_backends(matmul_dynamic_broadcast_lhs, self._modules)
 
   def test_matmul_dynamic_broadcast_rhs(self):
 
@@ -76,7 +80,7 @@
       module.matmul_dynamic(
           tf_utils.uniform([2, 2, 3]), tf_utils.uniform([1, 3, 4]))
 
-    self.compare_backends(matmul_dynamic_broadcast_rhs)
+    self.compare_backends(matmul_dynamic_broadcast_rhs, self._modules)
 
   def test_matmul_dynamic_rank_broadcasting(self):
 
@@ -84,10 +88,15 @@
       module.matmul_dynamic_lhs_batch(
           tf_utils.uniform([7, 2, 3]), tf_utils.uniform([3, 4]))
 
-    self.compare_backends(matmul_dynamic_rank_broadcasting)
+    self.compare_backends(matmul_dynamic_rank_broadcasting, self._modules)
 
 
-if __name__ == "__main__":
-  if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+  del argv  # Unused
+  if hasattr(tf, 'enable_v2_behavior'):
     tf.enable_v2_behavior()
   tf.test.main()
+
+
+if __name__ == '__main__':
+  app.run(main)
diff --git a/integrations/tensorflow/e2e/matrix_ops_static_test.py b/integrations/tensorflow/e2e/matrix_ops_static_test.py
index 82fa482..3a7fe47 100644
--- a/integrations/tensorflow/e2e/matrix_ops_static_test.py
+++ b/integrations/tensorflow/e2e/matrix_ops_static_test.py
@@ -14,6 +14,7 @@
 # limitations under the License.
 """Test matrix ops."""
 
+from absl import app
 from pyiree.tf.support import tf_test_utils
 from pyiree.tf.support import tf_utils
 import tensorflow.compat.v2 as tf
@@ -55,16 +56,19 @@
     return tf.matmul(lhs, rhs)
 
 
-@tf_test_utils.compile_module(MatrixOpsStaticModule)
 class MatrixOpsStaticTest(tf_test_utils.TracedModuleTestCase):
 
+  def __init__(self, methodName="runTest"):
+    super(MatrixOpsStaticTest, self).__init__(methodName)
+    self._modules = tf_test_utils.compile_tf_module(MatrixOpsStaticModule)
+
   def test_basic_matmul(self):
 
     def basic_matmul(module):
       module.basic_matmul(tf_utils.uniform([LEFT_DIM, INNER_DIM]),
                           tf_utils.uniform([INNER_DIM, RIGHT_DIM]))
 
-    self.compare_backends(basic_matmul)
+    self.compare_backends(basic_matmul, self._modules)
 
   def test_matmul_lhs_batch(self):
 
@@ -73,7 +77,7 @@
           tf_utils.uniform([BATCH_DIM, LEFT_DIM, INNER_DIM]),
           tf_utils.uniform([INNER_DIM, RIGHT_DIM]))
 
-    self.compare_backends(matmul_lhs_batch)
+    self.compare_backends(matmul_lhs_batch, self._modules)
 
   def test_matmul_rhs_batch(self):
 
@@ -82,7 +86,7 @@
           tf_utils.uniform([LEFT_DIM, INNER_DIM]),
           tf_utils.uniform([BATCH_DIM, INNER_DIM, RIGHT_DIM]))
 
-    self.compare_backends(matmul_rhs_batch)
+    self.compare_backends(matmul_rhs_batch, self._modules)
 
   def test_matmul_broadcast_singleton_dimension(self):
 
@@ -91,10 +95,15 @@
           tf_utils.uniform([1, LEFT_DIM, INNER_DIM]),
           tf_utils.uniform([BATCH_DIM, INNER_DIM, RIGHT_DIM]))
 
-    self.compare_backends(matmul_broadcast_singleton_dimension)
+    self.compare_backends(matmul_broadcast_singleton_dimension, self._modules)
 
 
-if __name__ == "__main__":
-  if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+  del argv  # Unused
+  if hasattr(tf, 'enable_v2_behavior'):
     tf.enable_v2_behavior()
   tf.test.main()
+
+
+if __name__ == '__main__':
+  app.run(main)
diff --git a/integrations/tensorflow/e2e/mobile_bert_squad_test.py b/integrations/tensorflow/e2e/mobile_bert_squad_test.py
index e8d3997..4f8daf5 100644
--- a/integrations/tensorflow/e2e/mobile_bert_squad_test.py
+++ b/integrations/tensorflow/e2e/mobile_bert_squad_test.py
@@ -81,10 +81,14 @@
     return self.inference_func(**inputs)
 
 
-@tf_test_utils.compile_module(MobileBertSquad, exported_names=["predict"])
 class MobileBertSquadTest(tf_test_utils.TracedModuleTestCase):
   """Tests of MobileBertSquad."""
 
+  def __init__(self, methodName="runTest"):
+    super(MobileBertSquadTest, self).__init__(methodName)
+    self._modules = tf_test_utils.compile_tf_module(MobileBertSquad,
+                                                    exported_names=["predict"])
+
   def test_predict(self):
 
     def predict(module):
@@ -94,11 +98,15 @@
 
       module.predict(input_ids, input_mask, segment_ids, atol=1e0)
 
-    self.compare_backends(predict)
+    self.compare_backends(predict, self._modules)
 
 
-if __name__ == "__main__":
-  if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+  del argv  # Unused
+  if hasattr(tf, 'enable_v2_behavior'):
     tf.enable_v2_behavior()
-
   tf.test.main()
+
+
+if __name__ == '__main__':
+  app.run(main)
diff --git a/integrations/tensorflow/e2e/range_test.py b/integrations/tensorflow/e2e/range_test.py
index f1e093c..f4e3e97 100644
--- a/integrations/tensorflow/e2e/range_test.py
+++ b/integrations/tensorflow/e2e/range_test.py
@@ -12,6 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from absl import app
 import numpy as np
 from pyiree.tf.support import tf_test_utils
 import tensorflow.compat.v2 as tf
@@ -31,9 +32,12 @@
     return tf.range(start, stop, delta)
 
 
-@tf_test_utils.compile_module(RangeModule)
 class RangeTest(tf_test_utils.TracedModuleTestCase):
 
+  def __init__(self, methodName="runTest"):
+    super(RangeTest, self).__init__(methodName)
+    self._modules = tf_test_utils.compile_tf_module(RangeModule)
+
   def test_range(self):
 
     def range(module):
@@ -42,10 +46,15 @@
       delta = np.array(3, dtype=np.float32)
       result = module.range(start, stop, delta)
 
-    self.compare_backends(range)
+    self.compare_backends(range, self._modules)
 
 
-if __name__ == "__main__":
-  if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+  del argv  # Unused
+  if hasattr(tf, 'enable_v2_behavior'):
     tf.enable_v2_behavior()
   tf.test.main()
+
+
+if __name__ == '__main__':
+  app.run(main)
diff --git a/integrations/tensorflow/e2e/resource_ops_test.py b/integrations/tensorflow/e2e/resource_ops_test.py
index dd5ad6d..d387e86 100644
--- a/integrations/tensorflow/e2e/resource_ops_test.py
+++ b/integrations/tensorflow/e2e/resource_ops_test.py
@@ -12,6 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from absl import app
 import numpy as np
 from pyiree.tf.support import tf_test_utils
 import tensorflow.compat.v2 as tf
@@ -28,18 +29,26 @@
     return self.counter.assign_add(value)
 
 
-@tf_test_utils.compile_module(ResourcesOpsModule)
 class ResourcesOpsTest(tf_test_utils.TracedModuleTestCase):
 
+  def __init__(self, methodName="runTest"):
+    super(ResourcesOpsTest, self).__init__(methodName)
+    self._modules = tf_test_utils.compile_tf_module(ResourcesOpsModule)
+
   def test_add_assign(self):
 
     def add_assign(module):
       module.add_assign(np.array(9., dtype=np.float32))
 
-    self.compare_backends(add_assign)
+    self.compare_backends(add_assign, self._modules)
 
 
-if __name__ == "__main__":
-  if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+  del argv  # Unused
+  if hasattr(tf, 'enable_v2_behavior'):
     tf.enable_v2_behavior()
   tf.test.main()
+
+
+if __name__ == '__main__':
+  app.run(main)
diff --git a/integrations/tensorflow/e2e/ring_buffer_test.py b/integrations/tensorflow/e2e/ring_buffer_test.py
index 8e437ea..c56bf31 100644
--- a/integrations/tensorflow/e2e/ring_buffer_test.py
+++ b/integrations/tensorflow/e2e/ring_buffer_test.py
@@ -13,6 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from absl import app
 import numpy as np
 from pyiree.tf.support import tf_test_utils
 import tensorflow.compat.v2 as tf
@@ -31,16 +32,20 @@
 
     # buffer has size [buffer_size, dims]
     # only the first dimension is used for updating buffer in a ring manner
-    self._buffer = tf.Variable(
-        tf.zeros((self._buffer_size,) + dims, dtype=dtype),
-        trainable=False,
-        name="RingBuffer")
+    self._buffer = tf.Variable(tf.zeros((self._buffer_size,) + dims,
+                                        dtype=dtype),
+                               trainable=False,
+                               name="RingBuffer")
     # Size of the data available for reading
-    self._data_size = tf.Variable(
-        0, trainable=False, dtype=tf.int32, name="FramerBuffer/Size")
+    self._data_size = tf.Variable(0,
+                                  trainable=False,
+                                  dtype=tf.int32,
+                                  name="FramerBuffer/Size")
     # The index pointing to the head of the data available for reading
-    self._read_head = tf.Variable(
-        0, trainable=False, dtype=tf.int32, name="FramerBuffer/Head")
+    self._read_head = tf.Variable(0,
+                                  trainable=False,
+                                  dtype=tf.int32,
+                                  name="FramerBuffer/Head")
 
   @property
   def dtype(self):
@@ -82,8 +87,8 @@
     start = tf.math.floormod(
         self._read_head.read_value() + self._data_size.read_value(),
         self._buffer_size)
-    indices = tf.math.floormod(
-        tf.range(start, limit=start + elements_size), self._buffer_size)
+    indices = tf.math.floormod(tf.range(start, limit=start + elements_size),
+                               self._buffer_size)
 
     tf.compat.v1.scatter_update(self._buffer, indices, elements)
 
@@ -118,8 +123,8 @@
       Tensor of elements with shape [length, dims...].
     """
     start = self._read_head + offset
-    indices = tf.math.floormod(
-        tf.range(start, limit=start + length), self._buffer_size)
+    indices = tf.math.floormod(tf.range(start, limit=start + length),
+                               self._buffer_size)
     result = tf.gather(self._buffer, indices)
     if consume:
       self.consume(length, offset)
@@ -148,8 +153,9 @@
   def build(self, input_shape):
     super(StatefulRingBuffer, self).build(input_shape)
     buffer_size = self.state_shape[1]
-    self.rb = RingBuffer(
-        buffer_size=buffer_size, dims=(self.state_shape[2],), dtype=tf.float32)
+    self.rb = RingBuffer(buffer_size=buffer_size,
+                         dims=(self.state_shape[2],),
+                         dtype=tf.float32)
 
   def call(self, inputs):
     self.rb.write(inputs)
@@ -177,10 +183,13 @@
     return self.rb(x)
 
 
-@tf_test_utils.compile_module(
-    StatefulRingBufferModule, exported_names=["predict"])
 class StatefulRingBufferTest(tf_test_utils.TracedModuleTestCase):
 
+  def __init__(self, methodName="runTest"):
+    super(StatefulRingBufferTest, self).__init__(methodName)
+    self._modules = tf_test_utils.compile_tf_module(StatefulRingBufferModule,
+                                                    exported_names=["predict"])
+
   def test_stateful_ringbuffer(self):
 
     def stateful_ringbuffer(module):
@@ -198,10 +207,15 @@
       module.predict(input3)
       # output = np.array([[3.0, 4.0]], dtype=np.float32)
 
-    self.compare_backends(stateful_ringbuffer)
+    self.compare_backends(stateful_ringbuffer, self._modules)
 
 
-if __name__ == "__main__":
-  if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+  del argv  # Unused
+  if hasattr(tf, 'enable_v2_behavior'):
     tf.enable_v2_behavior()
   tf.test.main()
+
+
+if __name__ == '__main__':
+  app.run(main)
diff --git a/integrations/tensorflow/e2e/scatter_update_test.py b/integrations/tensorflow/e2e/scatter_update_test.py
index 8b43e3a..8f34002 100644
--- a/integrations/tensorflow/e2e/scatter_update_test.py
+++ b/integrations/tensorflow/e2e/scatter_update_test.py
@@ -11,11 +11,12 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+"""Test scatter update behavior for tensorflow."""
 
+from absl import app
 import numpy as np
 from pyiree.tf.support import tf_test_utils
 import tensorflow.compat.v2 as tf
-"""Test scatter update behavior for tensorflow."""
 
 
 class ScatterUpdateModule(tf.Module):
@@ -48,9 +49,12 @@
     return tf.tensor_scatter_nd_update(tensor, indices, updates)
 
 
-@tf_test_utils.compile_module(ScatterUpdateModule)
 class ScatterUpdateTest(tf_test_utils.TracedModuleTestCase):
 
+  def __init__(self, methodName="runTest"):
+    super(ScatterUpdateTest, self).__init__(methodName)
+    self._modules = tf_test_utils.compile_tf_module(ScatterUpdateModule)
+
   def test_scatter_update_1D(self):
 
     def scatter_update_1D(module):
@@ -59,7 +63,7 @@
       updates = np.array([9, 10, 11], dtype=np.int32)
       module.scatter_update_1D(tensor, indices, updates)
 
-    self.compare_backends(scatter_update_1D)
+    self.compare_backends(scatter_update_1D, self._modules)
 
   def test_scatter_update_2D(self):
 
@@ -69,7 +73,7 @@
       updates = np.array([2, 5, 8], dtype=np.int32)
       module.scatter_update_2D(tensor, indices, updates)
 
-    self.compare_backends(scatter_update_2D)
+    self.compare_backends(scatter_update_2D, self._modules)
 
   def test_scatter_update_2D_slice(self):
 
@@ -79,10 +83,15 @@
       updates = np.array([[2, 3, 4]], dtype=np.int32)
       module.scatter_update_2D_slice(tensor, indices, updates)
 
-    self.compare_backends(scatter_update_2D_slice)
+    self.compare_backends(scatter_update_2D_slice, self._modules)
 
 
-if __name__ == "__main__":
-  if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+  del argv  # Unused
+  if hasattr(tf, 'enable_v2_behavior'):
     tf.enable_v2_behavior()
   tf.test.main()
+
+
+if __name__ == '__main__':
+  app.run(main)
diff --git a/integrations/tensorflow/e2e/simple_arithmetic_test.py b/integrations/tensorflow/e2e/simple_arithmetic_test.py
index aaec578..cc9ca09 100644
--- a/integrations/tensorflow/e2e/simple_arithmetic_test.py
+++ b/integrations/tensorflow/e2e/simple_arithmetic_test.py
@@ -14,6 +14,7 @@
 # limitations under the License.
 """Several baseline e2e simple arithmetic tests."""
 
+from absl import app
 import numpy as np
 from pyiree.tf.support import tf_test_utils
 from pyiree.tf.support import tf_utils
@@ -37,9 +38,12 @@
     return tf.matmul(a, b)
 
 
-@tf_test_utils.compile_module(SimpleArithmeticModule)
 class SimpleArithmeticTest(tf_test_utils.TracedModuleTestCase):
 
+  def __init__(self, methodName="runTest"):
+    super(SimpleArithmeticTest, self).__init__(methodName)
+    self._modules = tf_test_utils.compile_tf_module(SimpleArithmeticModule)
+
   def test_simple_mul(self):
 
     def simple_mul(module):
@@ -48,7 +52,7 @@
       c = module.simple_mul(a, b)
       module.simple_mul(a, c)
 
-    self.compare_backends(simple_mul)
+    self.compare_backends(simple_mul, self._modules)
 
   def test_simple_matmul(self):
 
@@ -58,10 +62,15 @@
       b = tf_utils.uniform((3072, 256)) * 1e-3
       module.simple_matmul(a, b)
 
-    self.compare_backends(simple_matmul)
+    self.compare_backends(simple_matmul, self._modules)
 
 
-if __name__ == "__main__":
-  if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+  del argv  # Unused
+  if hasattr(tf, 'enable_v2_behavior'):
     tf.enable_v2_behavior()
   tf.test.main()
+
+
+if __name__ == '__main__':
+  app.run(main)
diff --git a/integrations/tensorflow/e2e/simple_stateful_test.py b/integrations/tensorflow/e2e/simple_stateful_test.py
index 1120a4f..e4140d4 100644
--- a/integrations/tensorflow/e2e/simple_stateful_test.py
+++ b/integrations/tensorflow/e2e/simple_stateful_test.py
@@ -13,6 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from absl import app
 import numpy as np
 from pyiree.tf.support import tf_test_utils
 import tensorflow.compat.v2 as tf
@@ -33,19 +34,27 @@
     return self.counter
 
 
-@tf_test_utils.compile_module(SimpleStatefulModule)
 class StatefulTest(tf_test_utils.TracedModuleTestCase):
 
+  def __init__(self, methodName="runTest"):
+    super(StatefulTest, self).__init__(methodName)
+    self._modules = tf_test_utils.compile_tf_module(SimpleStatefulModule)
+
   def test_stateful(self):
 
     def get_state(module):
       module.inc_by(np.array(1., dtype=np.float32))
       module.get_state()
 
-    self.compare_backends(get_state)
+    self.compare_backends(get_state, self._modules)
 
 
-if __name__ == "__main__":
-  if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+  del argv  # Unused
+  if hasattr(tf, 'enable_v2_behavior'):
     tf.enable_v2_behavior()
   tf.test.main()
+
+
+if __name__ == '__main__':
+  app.run(main)
diff --git a/integrations/tensorflow/e2e/sliding_window_test.py b/integrations/tensorflow/e2e/sliding_window_test.py
index 513aa97..413ffb2 100644
--- a/integrations/tensorflow/e2e/sliding_window_test.py
+++ b/integrations/tensorflow/e2e/sliding_window_test.py
@@ -12,6 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from absl import app
 import numpy as np
 from pyiree.tf.support import tf_test_utils
 import tensorflow.compat.v2 as tf
@@ -75,9 +76,13 @@
     return self.sw(x)
 
 
-@tf_test_utils.compile_module(SlidingWindowModule, exported_names=["predict"])
 class SlidingWindowTest(tf_test_utils.TracedModuleTestCase):
 
+  def __init__(self, methodName="runTest"):
+    super(SlidingWindowTest, self).__init__(methodName)
+    self._modules = tf_test_utils.compile_tf_module(SlidingWindowModule,
+                                                    exported_names=["predict"])
+
   def test_sliding_window(self):
 
     def sliding_window(module):
@@ -89,10 +94,15 @@
       result2 = module.predict(input2)
       # output2 = np.array([[0.0, 0.0], [1.0, 2.0], [3.0, 4.0]], dtype=np.float32)
 
-    self.compare_backends(sliding_window)
+    self.compare_backends(sliding_window, self._modules)
 
 
-if __name__ == "__main__":
-  if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+  del argv  # Unused
+  if hasattr(tf, 'enable_v2_behavior'):
     tf.enable_v2_behavior()
   tf.test.main()
+
+
+if __name__ == '__main__':
+  app.run(main)
diff --git a/integrations/tensorflow/e2e/strings_test.py b/integrations/tensorflow/e2e/strings_test.py
index 206b33c..70f6468 100644
--- a/integrations/tensorflow/e2e/strings_test.py
+++ b/integrations/tensorflow/e2e/strings_test.py
@@ -12,6 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from absl import app
 import numpy as np
 from pyiree.tf.support import tf_test_utils
 import string
@@ -40,9 +41,12 @@
     return tf.strings.reduce_join(wps, 1)
 
 
-@tf_test_utils.compile_module(StringsModule)
 class StringsTest(tf_test_utils.TracedModuleTestCase):
 
+  def __init__(self, methodName="runTest"):
+    super(StringsTest, self).__init__(methodName)
+    self._modules = tf_test_utils.compile_tf_module(StringsModule)
+
   def test_print_ids(self):
 
     def print_ids(module):
@@ -51,7 +55,7 @@
            [13, 24, 16, 28, 94, 15, 24, 27, 94, 28, 29, 10, 34]])
       module.print_ids(input_ids)
 
-    self.compare_backends(print_ids)
+    self.compare_backends(print_ids, self._modules)
 
   def test_strings_to_ids(self):
 
@@ -61,10 +65,15 @@
            [13, 24, 16, 28, 94, 15, 24, 27, 94, 28, 29, 10, 34]])
       module.strings_to_ids(input_ids)
 
-    self.compare_backends(strings_to_ids)
+    self.compare_backends(strings_to_ids, self._modules)
 
 
-if __name__ == "__main__":
-  if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+  del argv  # Unused
+  if hasattr(tf, 'enable_v2_behavior'):
     tf.enable_v2_behavior()
   tf.test.main()
+
+
+if __name__ == '__main__':
+  app.run(main)
diff --git a/integrations/tensorflow/e2e/tensorlist_test.py b/integrations/tensorflow/e2e/tensorlist_test.py
index 440bd43..62babc5 100644
--- a/integrations/tensorflow/e2e/tensorlist_test.py
+++ b/integrations/tensorflow/e2e/tensorlist_test.py
@@ -12,6 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from absl import app
 import numpy as np
 from pyiree.tf.support import tf_test_utils
 from pyiree.tf.support import tf_utils
@@ -50,8 +51,9 @@
   @tf.function(
       input_signature=[tf.TensorSpec([STATIC_SIZE, STATIC_SIZE], tf.float32)])
   def slice_first_element_with_from_tensor_high_rank(self, t):
-    ta = tf.TensorArray(
-        dtype=tf.float32, size=STATIC_SIZE, element_shape=[STATIC_SIZE])
+    ta = tf.TensorArray(dtype=tf.float32,
+                        size=STATIC_SIZE,
+                        element_shape=[STATIC_SIZE])
     ta = ta.unstack(t)
     return ta.read(0)
 
@@ -66,23 +68,26 @@
     return ta.stack()
 
 
-@tf_test_utils.compile_module(TensorListModule)
 class TensorListTest(tf_test_utils.TracedModuleTestCase):
 
+  def __init__(self, methodName="runTest"):
+    super(TensorListTest, self).__init__(methodName)
+    self._modules = tf_test_utils.compile_tf_module(TensorListModule)
+
   def test_identity_through_tensorlist(self):
 
     def identity_through_tensorlist(module):
       module.identity_through_tensorlist(np.array(42., dtype=np.float32))
 
-    self.compare_backends(identity_through_tensorlist)
+    self.compare_backends(identity_through_tensorlist, self._modules)
 
   def test_add_through_tensorlist(self):
 
     def add_through_tensorlist(module):
-      module.add_through_tensorlist(
-          np.array(42., dtype=np.float32), np.array(43., dtype=np.float32))
+      module.add_through_tensorlist(np.array(42., dtype=np.float32),
+                                    np.array(43., dtype=np.float32))
 
-    self.compare_backends(add_through_tensorlist)
+    self.compare_backends(add_through_tensorlist, self._modules)
 
   def test_slice_first_element_with_from_tensor(self):
 
@@ -90,7 +95,7 @@
       module.slice_first_element_with_from_tensor(
           np.arange(STATIC_SIZE, dtype=np.float32))
 
-    self.compare_backends(slice_first_element_with_from_tensor)
+    self.compare_backends(slice_first_element_with_from_tensor, self._modules)
 
   def test_slice_first_element_with_from_tensor_high_rank(self):
 
@@ -98,18 +103,24 @@
       module.slice_first_element_with_from_tensor_high_rank(
           tf_utils.ndarange([STATIC_SIZE, STATIC_SIZE]))
 
-    self.compare_backends(slice_first_element_with_from_tensor_high_rank)
+    self.compare_backends(slice_first_element_with_from_tensor_high_rank,
+                          self._modules)
 
   def test_concat_with_tensorlist_stack(self):
 
     def concat_with_tensorlist_stack(module):
-      module.concat_with_tensorlist_stack(
-          np.array(42., dtype=np.float32), np.array(43., dtype=np.float32))
+      module.concat_with_tensorlist_stack(np.array(42., dtype=np.float32),
+                                          np.array(43., dtype=np.float32))
 
-    self.compare_backends(concat_with_tensorlist_stack)
+    self.compare_backends(concat_with_tensorlist_stack, self._modules)
 
 
-if __name__ == "__main__":
-  if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+  del argv  # Unused
+  if hasattr(tf, 'enable_v2_behavior'):
     tf.enable_v2_behavior()
   tf.test.main()
+
+
+if __name__ == '__main__':
+  app.run(main)
diff --git a/iree/base/tracing.cc b/iree/base/tracing.cc
index 7049d39..15301ad 100644
--- a/iree/base/tracing.cc
+++ b/iree/base/tracing.cc
@@ -68,12 +68,13 @@
       TracyLfqCommitC;
     }
 #endif  // TRACY_NO_VERIFY
-    auto name_ptr =
-        reinterpret_cast<char*>(tracy::tracy_malloc(name_length + 1));
+    auto name_ptr = reinterpret_cast<char*>(tracy::tracy_malloc(name_length));
     memcpy(name_ptr, name, name_length);
-    name_ptr[name_length] = '\0';
     TracyLfqPrepareC(tracy::QueueType::ZoneName);
-    tracy::MemWrite(&item->zoneText.text, reinterpret_cast<uint64_t>(name_ptr));
+    tracy::MemWrite(&item->zoneTextFat.text,
+                    reinterpret_cast<uint64_t>(name_ptr));
+    tracy::MemWrite(&item->zoneTextFat.size,
+                    static_cast<uint64_t>(name_length));
     TracyLfqCommitC;
   }
 
@@ -84,22 +85,9 @@
     const char* file_name, size_t file_name_length, uint32_t line,
     const char* function_name, size_t function_name_length, const char* name,
     size_t name_length) {
-  // NOTE: cloned from tracy::Profiler::AllocSourceLocation so that we can use
-  // the string lengths we already have.
-  const uint32_t src_loc_length =
-      static_cast<uint32_t>(4 + 4 + 4 + function_name_length + 1 +
-                            file_name_length + 1 + name_length);
-  auto ptr = reinterpret_cast<char*>(tracy::tracy_malloc(src_loc_length));
-  memcpy(ptr, &src_loc_length, 4);
-  memset(ptr + 4, 0, 4);
-  memcpy(ptr + 8, &line, 4);
-  memcpy(ptr + 12, function_name, function_name_length + 1);
-  memcpy(ptr + 12 + function_name_length + 1, file_name, file_name_length + 1);
-  if (name_length) {
-    memcpy(ptr + 12 + function_name_length + 1 + file_name_length + 1, name,
-           name_length);
-  }
-  uint64_t src_loc = reinterpret_cast<uint64_t>(ptr);
+  uint64_t src_loc = tracy::Profiler::AllocSourceLocation(
+      line, file_name, file_name_length, function_name, function_name_length,
+      name, name_length);
 
   const iree_zone_id_t zone_id = tracy::GetProfiler().GetNextZoneId();
 
@@ -152,3 +140,19 @@
 #ifdef __cplusplus
 }  // extern "C"
 #endif  // __cplusplus
+
+#if defined(__cplusplus) && \
+    (IREE_TRACING_FEATURES & IREE_TRACING_FEATURE_ALLOCATION_TRACKING)
+
+void* operator new(size_t count) noexcept {
+  auto ptr = malloc(count);
+  IREE_TRACE_ALLOC(ptr, count);
+  return ptr;
+}
+
+void operator delete(void* ptr) noexcept {
+  IREE_TRACE_FREE(ptr);
+  free(ptr);
+}
+
+#endif  // __cplusplus && IREE_TRACING_FEATURE_ALLOCATION_TRACKING
diff --git a/iree/base/tracing.h b/iree/base/tracing.h
index 03de9b6..d8026a9 100644
--- a/iree/base/tracing.h
+++ b/iree/base/tracing.h
@@ -355,14 +355,14 @@
 
 #define IREE_TRACE_ALLOC(ptr, size)               \
   ___tracy_emit_memory_alloc_callstack(ptr, size, \
-                                       IREE_TRACING_MAX_CALLSTACK_DEPTH)
+                                       IREE_TRACING_MAX_CALLSTACK_DEPTH, 0)
 #define IREE_TRACE_FREE(ptr) \
-  ___tracy_emit_memory_free_callstack(ptr, IREE_TRACING_MAX_CALLSTACK_DEPTH)
+  ___tracy_emit_memory_free_callstack(ptr, IREE_TRACING_MAX_CALLSTACK_DEPTH, 0)
 
 #else
 
-#define IREE_TRACE_ALLOC(ptr, size) ___tracy_emit_memory_alloc(ptr, size)
-#define IREE_TRACE_FREE(ptr) ___tracy_emit_memory_free(ptr)
+#define IREE_TRACE_ALLOC(ptr, size) ___tracy_emit_memory_alloc(ptr, size, 0)
+#define IREE_TRACE_FREE(ptr) ___tracy_emit_memory_free(ptr, 0)
 
 #endif  // IREE_TRACING_FEATURE_ALLOCATION_CALLSTACKS
 
@@ -371,24 +371,11 @@
 #define IREE_TRACE_FREE(ptr)
 #endif  // IREE_TRACING_FEATURE_ALLOCATION_TRACKING
 
-#ifdef __cplusplus
-
-#if IREE_TRACING_FEATURES & IREE_TRACING_FEATURE_ALLOCATION_TRACKING
-
-inline void* operator new(size_t count) {
-  auto ptr = malloc(count);
-  IREE_TRACE_ALLOC(ptr, count);
-  return ptr;
-}
-
-inline void operator delete(void* ptr) noexcept {
-  IREE_TRACE_FREE(ptr);
-  free(ptr);
-}
-
-#endif  // IREE_TRACING_FEATURE_ALLOCATION_TRACKING
-
-#endif  // __cplusplus
+#if defined(__cplusplus) && \
+    (IREE_TRACING_FEATURES & IREE_TRACING_FEATURE_ALLOCATION_TRACKING)
+void* operator new(size_t count) noexcept;
+void operator delete(void* ptr) noexcept;
+#endif  // __cplusplus && IREE_TRACING_FEATURE_ALLOCATION_TRACKING
 
 //===----------------------------------------------------------------------===//
 // Instrumentation C++ RAII types, wrappers, and macros
diff --git a/iree/compiler/Dialect/Flow/Transforms/HLOToHLOPreprocessing.cpp b/iree/compiler/Dialect/Flow/Transforms/HLOToHLOPreprocessing.cpp
index 8123574..e6fe664 100644
--- a/iree/compiler/Dialect/Flow/Transforms/HLOToHLOPreprocessing.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/HLOToHLOPreprocessing.cpp
@@ -12,6 +12,8 @@
 // See the License for the specific language governing permissions and
 // limitations under the License.
 
+#include <numeric>
+
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/Support/Casting.h"
 #include "mlir/Dialect/Shape/IR/Shape.h"
@@ -56,6 +58,13 @@
                       [](APInt v) -> bool { return !v.isNullValue(); });
 }
 
+static DenseIntElementsAttr make1DElementsAttr(PatternRewriter &rewriter,
+                                               ArrayRef<int64_t> integers) {
+  auto type = RankedTensorType::get({static_cast<int64_t>(integers.size())},
+                                    rewriter.getIntegerType(64));
+  return DenseIntElementsAttr::get(type, integers);
+}
+
 class DecomposeLog1PPattern : public OpRewritePattern<mhlo::Log1pOp> {
  public:
   using OpRewritePattern<mhlo::Log1pOp>::OpRewritePattern;
@@ -324,6 +333,240 @@
   }
 };
 
+// Rewrites rank-3 mhlo.dot_general so lhs contraction dimension is
+// inner most (2) and rhs contraction dimension is dim right after batch
+// dimension. The pattern inserts transposes so the dot_general always has the
+// form: {batch_dim, parallel, contraction}.{batch_dim, contraction, parallel}
+class TransposeRank3GenericDotGeneral
+    : public OpRewritePattern<mhlo::DotGeneralOp> {
+ public:
+  using OpRewritePattern<mhlo::DotGeneralOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(mhlo::DotGeneralOp op,
+                                PatternRewriter &rewriter) const override {
+    auto lhsShapeType = op.lhs().getType().dyn_cast<RankedTensorType>();
+    auto rhsShapeType = op.rhs().getType().dyn_cast<RankedTensorType>();
+    auto resultType = op.getResult().getType().dyn_cast<RankedTensorType>();
+
+    if (!lhsShapeType || !rhsShapeType || !resultType) return failure();
+    if (resultType.getRank() != 3) return failure();
+
+    if (op.dot_dimension_numbers().lhs_contracting_dimensions().size() != 1 ||
+        op.dot_dimension_numbers().rhs_contracting_dimensions().size() != 1)
+      return failure();
+
+    int64_t lhsBatchDim = (*op.dot_dimension_numbers()
+                                .lhs_batching_dimensions()
+                                .int_value_begin())
+                              .getSExtValue();
+    int64_t rhsBatchDim = (*op.dot_dimension_numbers()
+                                .rhs_batching_dimensions()
+                                .int_value_begin())
+                              .getSExtValue();
+    int64_t lhsContractionDim = (*op.dot_dimension_numbers()
+                                      .lhs_contracting_dimensions()
+                                      .int_value_begin())
+                                    .getSExtValue();
+    int64_t rhsContractionDim = (*op.dot_dimension_numbers()
+                                      .rhs_contracting_dimensions()
+                                      .int_value_begin())
+                                    .getSExtValue();
+    // Only accept rank-3 tensors with dim order when dims are :
+    // lhs : {batch_dim, contraction, parallel}
+    // rhs : {batch_dim, parallel, contraction}
+    if (lhsBatchDim != 0 || rhsBatchDim != 0) return failure();
+    // No transposes are needed.
+    if (lhsContractionDim == 2 && rhsContractionDim == 1) return failure();
+
+    Value lhs = op.lhs(), rhs = op.rhs();
+
+    // transpose {batch_dim, contraction, parallel} case.
+    if (lhsContractionDim == 1) {
+      Type transposedType = RankedTensorType::get(
+          {lhsShapeType.getDimSize(0), lhsShapeType.getDimSize(2),
+           lhsShapeType.getDimSize(1)},
+          resultType.getElementType());
+      lhs = rewriter.create<mhlo::TransposeOp>(
+          op.getLoc(), transposedType, lhs,
+          make1DElementsAttr(rewriter, {0, 2, 1}));
+    }
+
+    // transpose {batch_dim, contraction, parallel} case.
+    if (rhsContractionDim == 2) {
+      Type transposedType = RankedTensorType::get(
+          {rhsShapeType.getDimSize(0), rhsShapeType.getDimSize(2),
+           rhsShapeType.getDimSize(1)},
+          resultType.getElementType());
+      rhs = rewriter.create<mhlo::TransposeOp>(
+          op.getLoc(), transposedType, rhs,
+          make1DElementsAttr(rewriter, {0, 2, 1}));
+    }
+
+    auto dimensionNumbers = mhlo::DotDimensionNumbers::get(
+        /*lhs_batching_dimensions=*/make1DElementsAttr(rewriter, {0}),
+        /*rhs_batching_dimensions=*/make1DElementsAttr(rewriter, {0}),
+        /*lhs_contracting_dimensions=*/make1DElementsAttr(rewriter, {2}),
+        /*rhs_contracting_dimensions=*/
+        make1DElementsAttr(rewriter, {1}), rewriter.getContext());
+
+    Value result = rewriter.create<mhlo::DotGeneralOp>(
+        op.getLoc(), op.getType(), lhs, rhs, dimensionNumbers,
+        op.precision_configAttr());
+    rewriter.replaceOp(op, result);
+    return success();
+  }
+};
+
+// Rewrite mhlo.dot_general to operate on rank-3 tensors when reduction dims are
+// in consecutive order and not spliting the domain. This pattern inserts
+// reshapes to collapse consecutive reduction and parallel dims to always
+// generate a rank-3 dot_general op.
+class RankReducedDotGeneral : public OpRewritePattern<mhlo::DotGeneralOp> {
+ public:
+  using OpRewritePattern<mhlo::DotGeneralOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(mhlo::DotGeneralOp op,
+                                PatternRewriter &rewriter) const override {
+    auto lhsShapeType = op.lhs().getType().dyn_cast<ShapedType>();
+    auto rhsShapeType = op.rhs().getType().dyn_cast<ShapedType>();
+    auto resultType = op.getResult().getType().dyn_cast<ShapedType>();
+
+    if (!lhsShapeType || !rhsShapeType || !resultType) return failure();
+    if (!lhsShapeType.hasStaticShape() || !rhsShapeType.hasStaticShape())
+      return failure();
+    if (resultType.getRank() <= 3) return failure();
+
+    mhlo::DotDimensionNumbers dimNumbers = op.dot_dimension_numbers();
+    auto lhsBatchingDims = llvm::to_vector<4>(
+        llvm::map_range(dimNumbers.lhs_batching_dimensions(),
+                        [](APInt v) { return v.getSExtValue(); }));
+    auto rhsBatchingDims = llvm::to_vector<4>(
+        llvm::map_range(dimNumbers.rhs_batching_dimensions(),
+                        [](APInt v) { return v.getSExtValue(); }));
+    auto lhsContractingDims = llvm::to_vector<4>(
+        llvm::map_range(dimNumbers.lhs_contracting_dimensions(),
+                        [](APInt v) { return v.getSExtValue(); }));
+    auto rhsContractingDims = llvm::to_vector<4>(
+        llvm::map_range(dimNumbers.rhs_contracting_dimensions(),
+                        [](APInt v) { return v.getSExtValue(); }));
+
+    if (lhsBatchingDims.empty() || rhsBatchingDims.empty()) return failure();
+
+    llvm::sort(lhsBatchingDims);
+    llvm::sort(lhsContractingDims);
+    llvm::sort(rhsBatchingDims);
+    llvm::sort(rhsContractingDims);
+
+    auto isConsecutive = [](ArrayRef<int64_t> array) {
+      for (int i = 1; i < array.size(); ++i) {
+        if (array[i] - array[i - 1] != 1) return false;
+      }
+      return true;
+    };
+
+    auto isDomainSplit = [](ArrayRef<int64_t> shape,
+                            ArrayRef<int64_t> batchingDims,
+                            ArrayRef<int64_t> contractingDims) {
+      // Batching and contracting are contiguous.
+      if ((contractingDims.front() - batchingDims.back()) == 1) return false;
+      // Contracting dims are inner most.
+      if (contractingDims.back() == (shape.size() - 1)) return false;
+      return true;
+    };
+
+    if (!isConsecutive(lhsBatchingDims) || !isConsecutive(lhsContractingDims) ||
+        !isConsecutive(rhsBatchingDims) || !isConsecutive(rhsContractingDims))
+      return failure();
+
+    if (isDomainSplit(lhsShapeType.getShape(), lhsBatchingDims,
+                      lhsContractingDims) ||
+        isDomainSplit(rhsShapeType.getShape(), rhsBatchingDims,
+                      rhsContractingDims))
+      return failure();
+
+    // Collapsing shape into a rank-3 tensor, returns newCollabsedShape
+    // contraction and parallel dim indices.
+    auto computeCollapsedShape = [](ArrayRef<int64_t> shape,
+                                    ArrayRef<int64_t> batchingDims,
+                                    ArrayRef<int64_t> contractingDims) {
+      auto newRank =
+          shape.size() - batchingDims.size() - contractingDims.size() + 2;
+      auto batchingSize = std::accumulate(
+          batchingDims.begin(), batchingDims.end(), 1,
+          [shape](const int64_t accum, const int64_t index) -> int64_t {
+            return accum * shape[index];
+          });
+      auto contractingSize = std::accumulate(
+          contractingDims.begin(), contractingDims.end(), 1,
+          [shape](const int64_t accum, const int64_t index) -> int64_t {
+            return accum * shape[index];
+          });
+
+      int parallelDimIndex, contractingDimIndex, parallelDimSize = 1;
+      if (contractingDims.front() - batchingDims.back() > 1) {
+        parallelDimIndex = 1;
+        contractingDimIndex = 2;
+        for (int i = batchingDims.back() + 1; i < contractingDims.front();
+             ++i) {
+          parallelDimSize *= shape[i];
+        }
+      } else {
+        contractingDimIndex = 1;
+        parallelDimIndex = 2;
+        for (int i = contractingDims.back() + 1; i < shape.size(); ++i) {
+          parallelDimSize *= shape[i];
+        }
+      }
+      llvm::SmallVector<int64_t, 4> newShape(newRank);
+      newShape[0] = batchingSize;
+      newShape[contractingDimIndex] = contractingSize;
+      newShape[parallelDimIndex] = parallelDimSize;
+      return std::make_tuple(newShape, contractingDimIndex, parallelDimIndex);
+    };
+
+    int lhsContractingDimIndex, rhsContractingDimIndex, lhsParallelDimIndex,
+        rhsParallelDimIndex;
+    SmallVector<int64_t, 4> lhsNewShape, rhsNewShape;
+    std::tie(lhsNewShape, lhsContractingDimIndex, lhsParallelDimIndex) =
+        computeCollapsedShape(lhsShapeType.getShape(), lhsBatchingDims,
+                              lhsContractingDims);
+
+    std::tie(rhsNewShape, rhsContractingDimIndex, rhsParallelDimIndex) =
+        computeCollapsedShape(rhsShapeType.getShape(), rhsBatchingDims,
+                              rhsContractingDims);
+    SmallVector<int64_t, 4> resultNewShape = {lhsNewShape[0],
+                                              lhsNewShape[lhsParallelDimIndex],
+                                              rhsNewShape[rhsParallelDimIndex]};
+    Type dotGeneralResultType =
+        RankedTensorType::get(resultNewShape, resultType.getElementType());
+
+    auto loc = op.getLoc();
+    Value reshapedLhs = rewriter.create<mhlo::ReshapeOp>(
+        loc, RankedTensorType::get(lhsNewShape, lhsShapeType.getElementType()),
+        op.lhs());
+    Value reshapedRhs = rewriter.create<mhlo::ReshapeOp>(
+        loc, RankedTensorType::get(rhsNewShape, rhsShapeType.getElementType()),
+        op.rhs());
+    auto dimensionNumbers = mhlo::DotDimensionNumbers::get(
+        /*lhs_batching_dimensions=*/make1DElementsAttr(rewriter, {0}),
+        /*rhs_batching_dimensions=*/make1DElementsAttr(rewriter, {0}),
+        /*lhs_contracting_dimensions=*/
+        make1DElementsAttr(rewriter, {lhsContractingDimIndex}),
+        /*rhs_contracting_dimensions=*/
+        make1DElementsAttr(rewriter, {rhsContractingDimIndex}),
+        rewriter.getContext());
+    Value dotGeneralResult = rewriter.create<mhlo::DotGeneralOp>(
+        loc, dotGeneralResultType, reshapedLhs, reshapedRhs, dimensionNumbers,
+        op.precision_configAttr());
+
+    Value result =
+        rewriter.create<mhlo::ReshapeOp>(loc, resultType, dotGeneralResult);
+    rewriter.replaceOp(op, result);
+
+    return success();
+  }
+};  // namespace
+
 // clang-format off
 //
 // Reorder BroadcastInDimOp and N-ary elementwise op.
@@ -425,6 +668,11 @@
                     AdjustDepthwiseFilterShape, DecomposeLog1PPattern,
                     DecomposeExpM1Pattern>(context);
 
+    // dot_general canoncalization patterns.
+    mhlo::PopulateGeneralDotOpLoweringPatterns(&patterns, context);
+    patterns.insert<RankReducedDotGeneral, TransposeRank3GenericDotGeneral>(
+        context);
+
     // Unary elementwise op.
     patterns.insert<
         ReorderBroadcastInDimOpAndElementwiseOp<mhlo::AbsOp>,
diff --git a/iree/compiler/Dialect/Flow/Transforms/test/hlo_to_hlo_preprocessing_canoncalize_dot_general.mlir b/iree/compiler/Dialect/Flow/Transforms/test/hlo_to_hlo_preprocessing_canoncalize_dot_general.mlir
new file mode 100644
index 0000000..67b8f46
--- /dev/null
+++ b/iree/compiler/Dialect/Flow/Transforms/test/hlo_to_hlo_preprocessing_canoncalize_dot_general.mlir
@@ -0,0 +1,104 @@
+// RUN: iree-opt -split-input-file -verify-diagnostics -iree-flow-hlo-to-hlo-preprocessing %s | IreeFileCheck %s
+
+func @dot_general_to_dot(%arg0: tensor<1x32x128x4xf32>, %arg1: tensor<128x4x8x64xf32>) -> tensor<1x32x8x64xf32> {
+  %0 = "mhlo.dot_general"(%arg0, %arg1) {
+      dot_dimension_numbers = {
+        lhs_batching_dimensions = dense<> : tensor<0xi64>,
+        lhs_contracting_dimensions = dense<[2, 3]> : tensor<2xi64>,
+        rhs_batching_dimensions = dense<> : tensor<0xi64>,
+        rhs_contracting_dimensions = dense<[0, 1]> : tensor<2xi64>
+      }, name = "dot_general_to_dot", precision_config = ["DEFAULT", "DEFAULT"]
+    } : (tensor<1x32x128x4xf32>, tensor<128x4x8x64xf32>) -> tensor<1x32x8x64xf32>
+  return %0 : tensor<1x32x8x64xf32>
+}
+
+// CHECK: dot_general_to_dot(%[[ARG0:.+]]: tensor<1x32x128x4xf32>, %[[ARG1:.+]]: tensor<128x4x8x64xf32>) -> tensor<1x32x8x64xf32>
+// CHECK: %[[ARG0_RESHAPED:.+]] = "mhlo.reshape"(%[[ARG0]]) : (tensor<1x32x128x4xf32>) -> tensor<32x512xf32>
+// CHECK: %[[ARG1_RESHAPED:.+]] = "mhlo.reshape"(%[[ARG1]]) : (tensor<128x4x8x64xf32>) -> tensor<512x512xf32>
+// CHECK: %[[DOT:.+]] = "mhlo.dot"(%[[ARG0_RESHAPED]], %[[ARG1_RESHAPED]])
+// CHECK: %[[RESULT:.+]] = "mhlo.reshape"(%[[DOT]]) : (tensor<32x512xf32>) -> tensor<1x32x8x64xf32>
+// CHECK: return %[[RESULT]] : tensor<1x32x8x64xf32>
+
+// -----
+
+func @dot_general_to_dot_general_rank_reduced(%arg0: tensor<1x8x32x64xf32>, %arg1 : tensor<1x8x64x32xf32>) -> tensor<1x8x32x32xf32> {
+  %0 = "mhlo.dot_general"(%arg0, %arg1) {
+    dot_dimension_numbers = {
+      lhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>,
+      lhs_contracting_dimensions = dense<3> : tensor<1xi64>,
+      rhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>,
+      rhs_contracting_dimensions = dense<2> : tensor<1xi64>
+    }, name = "dot_general_to_dot", precision_config = ["DEFAULT", "DEFAULT"]
+  } : (tensor<1x8x32x64xf32>, tensor<1x8x64x32xf32>) -> tensor<1x8x32x32xf32>
+  return %0 : tensor<1x8x32x32xf32>
+}
+// CHECK: dot_general_to_dot_general_rank_reduced(%[[ARG0:.+]]: tensor<1x8x32x64xf32>, %[[ARG1:.+]]: tensor<1x8x64x32xf32>) -> tensor<1x8x32x32xf32>
+// CHECK: %[[ARG0_RESHAPED:.+]] = "mhlo.reshape"(%[[ARG0]]) : (tensor<1x8x32x64xf32>) -> tensor<8x32x64xf32>
+// CHECK: %[[ARG1_RESHAPED:.+]] = "mhlo.reshape"(%[[ARG1]]) : (tensor<1x8x64x32xf32>) -> tensor<8x64x32xf32>
+// CHECK: %[[DOT_RESULT:.+]] = "mhlo.dot_general"(%[[ARG0_RESHAPED]], %[[ARG1_RESHAPED]])
+// CHECK: %[[RESULT:.+]] = "mhlo.reshape"(%[[DOT_RESULT]]) : (tensor<8x32x32xf32>) -> tensor<1x8x32x32xf32>
+// CHECK: return %[[RESULT]] : tensor<1x8x32x32xf32>
+
+// -----
+
+func @dot_general_to_dot_general_rank_reduced_a_transposed(%arg0: tensor<1x8x64x32xf32>, %arg1: tensor<1x8x64x32xf32>) -> tensor<1x8x32x32xf32> {
+  %0 = "mhlo.dot_general"(%arg0, %arg1) {
+    dot_dimension_numbers = {
+      lhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>,
+      lhs_contracting_dimensions = dense<2> : tensor<1xi64>,
+      rhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>,
+      rhs_contracting_dimensions = dense<2> : tensor<1xi64>
+    }, name = "dot_general_to_dot_trans_a", precision_config = ["DEFAULT", "DEFAULT"]
+  } : (tensor<1x8x64x32xf32>, tensor<1x8x64x32xf32>) -> tensor<1x8x32x32xf32>
+  return %0 : tensor<1x8x32x32xf32>
+}
+// CHECK: dot_general_to_dot_general_rank_reduced_a_transposed(%[[ARG0:.+]]: tensor<1x8x64x32xf32>, %[[ARG1:.+]]: tensor<1x8x64x32xf32>) -> tensor<1x8x32x32xf32>
+// CHECK: %[[ARG0_RESHAPED:.+]] = "mhlo.reshape"(%[[ARG0]]) : (tensor<1x8x64x32xf32>) -> tensor<8x64x32xf32>
+// CHECK: %[[ARG1_RSSHAPED:.+]] = "mhlo.reshape"(%[[ARG1]]) : (tensor<1x8x64x32xf32>) -> tensor<8x64x32xf32>
+// CHECK: %[[ARG0_RESHAPED_TR:.+]] = "mhlo.transpose"(%[[ARG0_RESHAPED]]) {permutation = dense<[0, 2, 1]> : tensor<3xi64>} : (tensor<8x64x32xf32>) -> tensor<8x32x64xf32>
+// CHECK: %[[DOT_RESULT:.+]] = "mhlo.dot_general"(%[[ARG0_RESHAPED_TR]], %[[ARG1_RSSHAPED]])
+// CHECK: %[[RESULT:.+]] = "mhlo.reshape"(%[[DOT_RESULT]]) : (tensor<8x32x32xf32>) -> tensor<1x8x32x32xf32>
+
+// -----
+
+func @dot_general_to_dot_general_rank_reduced_b_transposed(%arg0: tensor<1x8x32x64xf32>, %arg1: tensor<1x8x32x64xf32>) -> tensor<1x8x32x32xf32> {
+  %0 = "mhlo.dot_general"(%arg0, %arg1) {
+    dot_dimension_numbers = {
+      lhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>,
+      lhs_contracting_dimensions = dense<3> : tensor<1xi64>,
+      rhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>,
+      rhs_contracting_dimensions = dense<3> : tensor<1xi64>
+    }, name = "dot_general_to_dot_trans_b", precision_config = ["DEFAULT", "DEFAULT"]
+  } : (tensor<1x8x32x64xf32>, tensor<1x8x32x64xf32>) -> tensor<1x8x32x32xf32>
+  return %0 : tensor<1x8x32x32xf32>
+}
+// CHECK: dot_general_to_dot_general_rank_reduced_b_transposed(%[[ARG0:.+]]: tensor<1x8x32x64xf32>, %[[ARG1:.+]]: tensor<1x8x32x64xf32>) -> tensor<1x8x32x32xf32>
+// CHECK: %[[ARG0_REHSPAED:.+]] = "mhlo.reshape"(%[[ARG0]]) : (tensor<1x8x32x64xf32>) -> tensor<8x32x64xf32>
+// CHECK: %[[ARG1_REHSPAED:.+]] = "mhlo.reshape"(%[[ARG1]]) : (tensor<1x8x32x64xf32>) -> tensor<8x32x64xf32>
+// CHECK: %[[ARG1_REHSPAED_TR:.+]] = "mhlo.transpose"(%[[ARG1_REHSPAED]]) {permutation = dense<[0, 2, 1]> : tensor<3xi64>} : (tensor<8x32x64xf32>) -> tensor<8x64x32xf32>
+// CHECK: %[[DOT_RESULT:.+]] = "mhlo.dot_general"(%[[ARG0_REHSPAED]], %[[ARG1_REHSPAED_TR]])
+// CHECK: %[[RESULT:.+]] = "mhlo.reshape"(%[[DOT_RESULT]]) : (tensor<8x32x32xf32>) -> tensor<1x8x32x32xf32>
+// CHECK: return %[[RESULT]] : tensor<1x8x32x32xf32>
+
+
+// -----
+
+func @dot_general_to_dot_general_rank_reduced_ab_transposed(%arg0: tensor<1x8x64x32xf32>, %arg1: tensor<1x8x32x64xf32>) -> tensor<1x8x32x32xf32> {
+  %0 = "mhlo.dot_general"(%arg0, %arg1) {
+    dot_dimension_numbers = {
+      lhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>,
+      lhs_contracting_dimensions = dense<2> : tensor<1xi64>,
+      rhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>,
+      rhs_contracting_dimensions = dense<3> : tensor<1xi64>
+    }, name = "dot_general_to_dot_trans_ab", precision_config = ["DEFAULT", "DEFAULT"]
+  } : (tensor<1x8x64x32xf32>, tensor<1x8x32x64xf32>) -> tensor<1x8x32x32xf32>
+  return %0 : tensor<1x8x32x32xf32>
+}
+// CHECK: dot_general_to_dot_general_rank_reduced_ab_transposed(%[[ARG0:.+]]: tensor<1x8x64x32xf32>, %[[ARG1:.+]]: tensor<1x8x32x64xf32>) -> tensor<1x8x32x32xf32>
+// CHECK: %[[ARG0_REHSPAED:.+]] = "mhlo.reshape"(%[[ARG0]]) : (tensor<1x8x64x32xf32>) -> tensor<8x64x32xf32>
+// CHECK: %[[ARG1_REHSPAED:.+]] = "mhlo.reshape"(%[[ARG1]]) : (tensor<1x8x32x64xf32>) -> tensor<8x32x64xf32>
+// CHECK: %[[ARG0_REHSPAED_TR:.+]] = "mhlo.transpose"(%[[ARG0_REHSPAED]]) {permutation = dense<[0, 2, 1]> : tensor<3xi64>} : (tensor<8x64x32xf32>) -> tensor<8x32x64xf32>
+// CHECK: %[[ARG1_REHSPAED_TR:.+]] = "mhlo.transpose"(%[[ARG1_REHSPAED]]) {permutation = dense<[0, 2, 1]> : tensor<3xi64>} : (tensor<8x32x64xf32>) -> tensor<8x64x32xf32>
+// CHECK: %[[DOT_RESULT:.+]] = "mhlo.dot_general"(%[[ARG0_REHSPAED_TR]], %[[ARG1_REHSPAED_TR]])
+// CHECK: %[[RESULT:.+]] = "mhlo.reshape"(%[[DOT_RESULT]]) : (tensor<8x32x32xf32>) -> tensor<1x8x32x32xf32>
+// CHECK: return %[[RESULT]] : tensor<1x8x32x32xf32>
diff --git a/iree/test/e2e/xla_ops/BUILD b/iree/test/e2e/xla_ops/BUILD
index 55f8428..2bbd6b4 100644
--- a/iree/test/e2e/xla_ops/BUILD
+++ b/iree/test/e2e/xla_ops/BUILD
@@ -52,6 +52,7 @@
         "cosine.mlir",
         "divide.mlir",
         "dot.mlir",
+        "dot_general.mlir",
         "exponential.mlir",
         "exponential_minus_one.mlir",
         "floor.mlir",
@@ -104,6 +105,7 @@
         "cosine.mlir",
         "divide.mlir",
         "dot.mlir",
+        "dot_general.mlir",
         "exponential.mlir",
         "exponential_minus_one.mlir",
         "floor.mlir",
diff --git a/iree/test/e2e/xla_ops/CMakeLists.txt b/iree/test/e2e/xla_ops/CMakeLists.txt
index ae2a8c1..4bd7219 100644
--- a/iree/test/e2e/xla_ops/CMakeLists.txt
+++ b/iree/test/e2e/xla_ops/CMakeLists.txt
@@ -45,6 +45,7 @@
     "cosine.mlir"
     "divide.mlir"
     "dot.mlir"
+    "dot_general.mlir"
     "exponential.mlir"
     "exponential_minus_one.mlir"
     "floor.mlir"
@@ -97,6 +98,7 @@
     "cosine.mlir"
     "divide.mlir"
     "dot.mlir"
+    "dot_general.mlir"
     "exponential.mlir"
     "exponential_minus_one.mlir"
     "floor.mlir"
diff --git a/iree/testing/vulkan/BUILD b/iree/testing/vulkan/BUILD
index bf3fdce..4e235aa 100644
--- a/iree/testing/vulkan/BUILD
+++ b/iree/testing/vulkan/BUILD
@@ -12,6 +12,11 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+load(
+    "//iree:build_defs.oss.bzl",
+    "PLATFORM_VULKAN_DEPS",
+)
+
 package(
     default_visibility = ["//visibility:public"],
     features = ["layering_check"],
@@ -37,3 +42,33 @@
         "@vulkan_sdk//:sdk",
     ],
 )
+
+cc_binary(
+    name = "iree-run-module-vulkan-gui",
+    srcs = ["iree-run-module-vulkan-gui-main.cc"],
+    linkopts = select({
+        "@bazel_tools//src/conditions:windows": [
+            "-SUBSYSTEM:WINDOWS",
+        ],
+        "//conditions:default": [],
+    }),
+    tags = [
+        "manual",
+        "nokokoro",
+    ],
+    deps = [
+        ":vulkan_gui_util",
+        "//iree/base:init",
+        "//iree/base:main",
+        "//iree/base:status",
+        "//iree/base:tracing",
+        "//iree/modules/hal",
+        "//iree/hal/vulkan:vulkan_driver_module",
+        "//iree/tools/utils:vm_util",
+        "//iree/vm",
+        "//iree/vm:bytecode_module",
+        "//iree/vm:ref_cc",
+        "@com_google_absl//absl/flags:flag",
+        "@sdl2//:SDL2",
+    ] + PLATFORM_VULKAN_DEPS,
+)
diff --git a/iree/testing/vulkan/CMakeLists.txt b/iree/testing/vulkan/CMakeLists.txt
index 2c995a9..6ceb5dc 100644
--- a/iree/testing/vulkan/CMakeLists.txt
+++ b/iree/testing/vulkan/CMakeLists.txt
@@ -44,3 +44,31 @@
     SDL2-static
     Vulkan::Vulkan
 )
+
+if(${CMAKE_HOST_SYSTEM_NAME} STREQUAL "Windows")
+  set(_GUI_LINKOPTS "-SUBSYSTEM:WINDOWS")
+else()
+  set(_GUI_LINKOPTS "")
+endif()
+
+iree_cc_binary(
+  NAME
+    iree-run-module-vulkan-gui
+  SRCS
+    "iree-run-module-vulkan-gui-main.cc"
+  DEPS
+    ::vulkan_gui_util
+    absl::flags
+    iree::base::file_io
+    iree::base::init
+    iree::base::main
+    iree::base::status
+    iree::base::tracing
+    iree::hal::vulkan::vulkan_driver_module
+    iree::modules::hal
+    iree::tools::utils::vm_util
+    iree::vm
+    iree::vm::bytecode_module
+  LINKOPTS
+    "${_GUI_LINKOPTS}"
+)
diff --git a/iree/testing/vulkan/iree-run-module-vulkan-gui-main.cc b/iree/testing/vulkan/iree-run-module-vulkan-gui-main.cc
new file mode 100644
index 0000000..ec1349b
--- /dev/null
+++ b/iree/testing/vulkan/iree-run-module-vulkan-gui-main.cc
@@ -0,0 +1,454 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Vulkan GUI utility functions
+// Other matters here: we need to pull in this first to make sure Vulkan API
+// prototypes are defined so that we can statically link against them.
+#include "iree/testing/vulkan/vulkan_gui_util.h"
+
+// Other dependencies (helpers, etc.)
+#include "absl/flags/flag.h"
+#include "iree/base/file_io.h"
+#include "iree/base/init.h"
+#include "iree/base/main.h"
+#include "iree/base/status.h"
+#include "iree/modules/hal/hal_module.h"
+#include "iree/tools/utils/vm_util.h"
+#include "iree/vm/api.h"
+#include "iree/vm/bytecode_module.h"
+
+ABSL_FLAG(std::string, module_file, "-",
+          "File containing the module to load that contains the entry "
+          "function. Defaults to stdin.");
+
+ABSL_FLAG(std::string, entry_function, "",
+          "Name of a function contained in the module specified by input_file "
+          "to run.");
+
+ABSL_FLAG(std::vector<std::string>, inputs, {},
+          "A comma-separated list of of input buffers of the format:"
+          "[shape]xtype=[value]\n"
+          "2x2xi32=1 2 3 4\n"
+          "Optionally, brackets may be used to separate the element values. "
+          "They are ignored by the parser.\n"
+          "2x2xi32=[[1 2][3 4]]\n"
+          "Due to the absence of repeated flags in absl, commas should not be "
+          "used to separate elements. They are reserved for separating input "
+          "values:\n"
+          "2x2xi32=[[1 2][3 4]], 1x2xf32=[[1 2]]");
+
+ABSL_FLAG(std::string, inputs_file, "",
+          "Provides a file for input shapes and optional values (see "
+          "ParseToVariantListFromFile in vm_util.h for details)");
+
+static VkAllocationCallbacks* g_Allocator = NULL;
+static VkInstance g_Instance = VK_NULL_HANDLE;
+static VkPhysicalDevice g_PhysicalDevice = VK_NULL_HANDLE;
+static VkDevice g_Device = VK_NULL_HANDLE;
+static uint32_t g_QueueFamily = (uint32_t)-1;
+static VkQueue g_Queue = VK_NULL_HANDLE;
+static VkPipelineCache g_PipelineCache = VK_NULL_HANDLE;
+static VkDescriptorPool g_DescriptorPool = VK_NULL_HANDLE;
+
+static ImGui_ImplVulkanH_Window g_MainWindowData;
+static uint32_t g_MinImageCount = 2;
+static bool g_SwapChainRebuild = false;
+static int g_SwapChainResizeWidth = 0;
+static int g_SwapChainResizeHeight = 0;
+
+namespace iree {
+namespace {
+
+void check_vk_result(VkResult err) {
+  if (err == 0) return;
+  IREE_LOG(FATAL) << "VkResult: " << err;
+}
+
+void CleanupVulkan() {
+  vkDestroyDescriptorPool(g_Device, g_DescriptorPool, g_Allocator);
+
+  vkDestroyDevice(g_Device, g_Allocator);
+  vkDestroyInstance(g_Instance, g_Allocator);
+}
+
+void CleanupVulkanWindow() {
+  ImGui_ImplVulkanH_DestroyWindow(g_Instance, g_Device, &g_MainWindowData,
+                                  g_Allocator);
+}
+
+StatusOr<std::string> GetModuleContentsFromFlags() {
+  auto module_file = absl::GetFlag(FLAGS_module_file);
+  std::string contents;
+  if (module_file == "-") {
+    contents = std::string{std::istreambuf_iterator<char>(std::cin),
+                           std::istreambuf_iterator<char>()};
+  } else {
+    IREE_ASSIGN_OR_RETURN(contents, file_io::GetFileContents(module_file));
+  }
+  return contents;
+}
+
+// Runs the current IREE bytecode module and renders its result to a window
+// using ImGui.
+Status RunModuleAndUpdateImGuiWindow(
+    iree_hal_device_t* device, iree_vm_context_t* context,
+    iree_vm_function_t function, const std::string& function_name,
+    const vm::ref<iree_vm_list_t>& inputs,
+    const std::vector<RawSignatureParser::Description>& output_descs,
+    const std::string& window_title) {
+  vm::ref<iree_vm_list_t> outputs;
+  IREE_RETURN_IF_ERROR(iree_vm_list_create(/*element_type=*/nullptr,
+                                           output_descs.size(),
+                                           iree_allocator_system(), &outputs));
+
+  IREE_LOG(INFO) << "EXEC @" << function_name;
+  IREE_RETURN_IF_ERROR(iree_vm_invoke(context, function, /*policy=*/nullptr,
+                                      inputs.get(), outputs.get(),
+                                      iree_allocator_system()))
+      << "invoking function " << function_name;
+
+  std::ostringstream oss;
+  IREE_RETURN_IF_ERROR(PrintVariantList(output_descs, outputs.get(), &oss))
+      << "printing results";
+
+  outputs.reset();
+
+  ImGui::Begin(window_title.c_str(), /*p_open=*/nullptr,
+               ImGuiWindowFlags_AlwaysAutoResize);
+
+  ImGui::Text("Entry function:");
+  ImGui::Text(function_name.c_str());
+  ImGui::Separator();
+
+  ImGui::Text("Invocation result:");
+  ImGui::Text(oss.str().c_str());
+  ImGui::Separator();
+
+  // Framerate counter.
+  ImGui::Text("Application average %.3f ms/frame (%.1f FPS)",
+              1000.0f / ImGui::GetIO().Framerate, ImGui::GetIO().Framerate);
+
+  ImGui::End();
+  return OkStatus();
+}
+}  // namespace
+}  // namespace iree
+
+int iree::IreeMain(int argc, char** argv) {
+  iree::InitializeEnvironment(&argc, &argv);
+
+  // --------------------------------------------------------------------------
+  // Create a window.
+  if (SDL_Init(SDL_INIT_VIDEO | SDL_INIT_TIMER) != 0) {
+    IREE_LOG(FATAL) << "Failed to initialize SDL";
+    return 1;
+  }
+
+  // Setup window
+  SDL_WindowFlags window_flags = (SDL_WindowFlags)(
+      SDL_WINDOW_VULKAN | SDL_WINDOW_RESIZABLE | SDL_WINDOW_ALLOW_HIGHDPI);
+  SDL_Window* window = SDL_CreateWindow(
+      "IREE Samples - Vulkan Inference GUI", SDL_WINDOWPOS_CENTERED,
+      SDL_WINDOWPOS_CENTERED, 1280, 720, window_flags);
+
+  // Setup Vulkan
+  iree_hal_vulkan_features_t iree_vulkan_features =
+      static_cast<iree_hal_vulkan_features_t>(
+          IREE_HAL_VULKAN_ENABLE_VALIDATION_LAYERS |
+          IREE_HAL_VULKAN_ENABLE_DEBUG_UTILS |
+          IREE_HAL_VULKAN_ENABLE_PUSH_DESCRIPTORS);
+  std::vector<const char*> layers = GetInstanceLayers(iree_vulkan_features);
+  std::vector<const char*> extensions =
+      GetInstanceExtensions(window, iree_vulkan_features);
+  SetupVulkan(iree_vulkan_features, layers.data(), layers.size(),
+              extensions.data(), extensions.size(), g_Allocator, &g_Instance,
+              &g_QueueFamily, &g_PhysicalDevice, &g_Queue, &g_Device,
+              &g_DescriptorPool);
+
+  // Create Window Surface
+  VkSurfaceKHR surface;
+  VkResult err;
+  if (SDL_Vulkan_CreateSurface(window, g_Instance, &surface) == 0) {
+    printf("Failed to create Vulkan surface.\n");
+    return 1;
+  }
+
+  // Create Framebuffers
+  int w, h;
+  SDL_GetWindowSize(window, &w, &h);
+  ImGui_ImplVulkanH_Window* wd = &g_MainWindowData;
+  SetupVulkanWindow(wd, g_Allocator, g_Instance, g_QueueFamily,
+                    g_PhysicalDevice, g_Device, surface, w, h, g_MinImageCount);
+
+  // Setup Dear ImGui context
+  IMGUI_CHECKVERSION();
+  ImGui::CreateContext();
+  ImGuiIO& io = ImGui::GetIO();
+  (void)io;
+
+  ImGui::StyleColorsDark();
+
+  // Setup Platform/Renderer bindings
+  ImGui_ImplSDL2_InitForVulkan(window);
+  ImGui_ImplVulkan_InitInfo init_info = {};
+  init_info.Instance = g_Instance;
+  init_info.PhysicalDevice = g_PhysicalDevice;
+  init_info.Device = g_Device;
+  init_info.QueueFamily = g_QueueFamily;
+  init_info.Queue = g_Queue;
+  init_info.PipelineCache = g_PipelineCache;
+  init_info.DescriptorPool = g_DescriptorPool;
+  init_info.Allocator = g_Allocator;
+  init_info.MinImageCount = g_MinImageCount;
+  init_info.ImageCount = wd->ImageCount;
+  init_info.CheckVkResultFn = check_vk_result;
+  ImGui_ImplVulkan_Init(&init_info, wd->RenderPass);
+
+  // Upload Fonts
+  {
+    // Use any command queue
+    VkCommandPool command_pool = wd->Frames[wd->FrameIndex].CommandPool;
+    VkCommandBuffer command_buffer = wd->Frames[wd->FrameIndex].CommandBuffer;
+
+    err = vkResetCommandPool(g_Device, command_pool, 0);
+    check_vk_result(err);
+    VkCommandBufferBeginInfo begin_info = {};
+    begin_info.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO;
+    begin_info.flags |= VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT;
+    err = vkBeginCommandBuffer(command_buffer, &begin_info);
+    check_vk_result(err);
+
+    ImGui_ImplVulkan_CreateFontsTexture(command_buffer);
+
+    VkSubmitInfo end_info = {};
+    end_info.sType = VK_STRUCTURE_TYPE_SUBMIT_INFO;
+    end_info.commandBufferCount = 1;
+    end_info.pCommandBuffers = &command_buffer;
+    err = vkEndCommandBuffer(command_buffer);
+    check_vk_result(err);
+    err = vkQueueSubmit(g_Queue, 1, &end_info, VK_NULL_HANDLE);
+    check_vk_result(err);
+
+    err = vkDeviceWaitIdle(g_Device);
+    check_vk_result(err);
+    ImGui_ImplVulkan_DestroyFontUploadObjects();
+  }
+  // --------------------------------------------------------------------------
+
+  // --------------------------------------------------------------------------
+  // Setup IREE.
+  // This call to |iree_api_init| is not technically required, but it is
+  // included for completeness.
+  IREE_CHECK_OK(iree_api_init(&argc, &argv));
+
+  // Check API version.
+  iree_api_version_t actual_version;
+  iree_status_t status =
+      iree_api_version_check(IREE_API_VERSION_LATEST, &actual_version);
+  if (iree_status_is_ok(status)) {
+    IREE_LOG(INFO) << "IREE runtime API version " << actual_version;
+  } else {
+    IREE_LOG(FATAL) << "Unsupported runtime API version " << actual_version;
+  }
+
+  // Register HAL module types.
+  IREE_CHECK_OK(iree_hal_module_register_types());
+
+  // Create a runtime Instance.
+  iree_vm_instance_t* iree_instance = nullptr;
+  IREE_CHECK_OK(
+      iree_vm_instance_create(iree_allocator_system(), &iree_instance));
+
+  // Create IREE Vulkan Driver and Device, sharing our VkInstance/VkDevice.
+  IREE_LOG(INFO) << "Creating Vulkan driver/device";
+  // Load symbols from our static `vkGetInstanceProcAddr` for IREE to use.
+  iree_hal_vulkan_syms_t* iree_vk_syms = nullptr;
+  IREE_CHECK_OK(iree_hal_vulkan_syms_create(
+      reinterpret_cast<void*>(&vkGetInstanceProcAddr), &iree_vk_syms));
+  // Create the driver sharing our VkInstance.
+  iree_hal_driver_t* iree_vk_driver = nullptr;
+  iree_hal_vulkan_driver_options_t options;
+  options.api_version = VK_API_VERSION_1_0;
+  options.features = static_cast<iree_hal_vulkan_features_t>(
+      IREE_HAL_VULKAN_ENABLE_DEBUG_UTILS |
+      IREE_HAL_VULKAN_ENABLE_PUSH_DESCRIPTORS);
+  IREE_CHECK_OK(iree_hal_vulkan_driver_create_using_instance(
+      options, iree_vk_syms, g_Instance, &iree_vk_driver));
+  // Create a device sharing our VkDevice and queue. This makes capturing with
+  // vendor tools easier because we will have sync compute residing in the
+  // rendered frame.
+  iree_hal_vulkan_queue_set_t compute_queue_set;
+  compute_queue_set.queue_family_index = g_QueueFamily;
+  compute_queue_set.queue_indices = 1 << 0;
+  iree_hal_vulkan_queue_set_t transfer_queue_set;
+  transfer_queue_set.queue_indices = 0;
+  iree_hal_device_t* iree_vk_device = nullptr;
+  IREE_CHECK_OK(iree_hal_vulkan_driver_wrap_device(
+      iree_vk_driver, g_PhysicalDevice, g_Device, compute_queue_set,
+      transfer_queue_set, &iree_vk_device));
+  // Create a HAL module using the HAL device.
+  iree_vm_module_t* hal_module = nullptr;
+  IREE_CHECK_OK(iree_hal_module_create(iree_vk_device, iree_allocator_system(),
+                                       &hal_module));
+
+  // Load bytecode module from embedded data.
+  IREE_LOG(INFO) << "Loading IREE byecode module...";
+  auto module_file_or = iree::GetModuleContentsFromFlags();
+  if (!module_file_or) {
+    IREE_LOG(FATAL) << "Error when reading module file"
+                    << module_file_or.status();
+  }
+  iree_vm_module_t* bytecode_module = nullptr;
+  IREE_CHECK_OK(iree_vm_bytecode_module_create(
+      iree_const_byte_span_t{
+          reinterpret_cast<const uint8_t*>(module_file_or->data()),
+          module_file_or->size()},
+      iree_allocator_null(), iree_allocator_system(), &bytecode_module));
+
+  // Allocate a context that will hold the module state across invocations.
+  iree_vm_context_t* iree_context = nullptr;
+  std::vector<iree_vm_module_t*> modules = {hal_module, bytecode_module};
+  IREE_CHECK_OK(iree_vm_context_create_with_modules(
+      iree_instance, modules.data(), modules.size(), iree_allocator_system(),
+      &iree_context));
+  IREE_LOG(INFO) << "Context with modules is ready for use";
+
+  // Lookup the entry point function.
+  std::string entry_function = absl::GetFlag(FLAGS_entry_function);
+  iree_vm_function_t main_function;
+  IREE_CHECK_OK(bytecode_module->lookup_function(
+      bytecode_module->self, IREE_VM_FUNCTION_LINKAGE_EXPORT,
+      iree_string_view_t{entry_function.data(), entry_function.size()},
+      &main_function));
+  iree_string_view_t main_function_name = iree_vm_function_name(&main_function);
+  IREE_LOG(INFO) << "Resolved main function named '"
+                 << std::string(main_function_name.data,
+                                main_function_name.size)
+                 << "'";
+
+  IREE_CHECK_OK(ValidateFunctionAbi(main_function));
+
+  auto main_function_input_descs = ParseInputSignature(main_function);
+  if (!main_function_input_descs.ok()) {
+    IREE_LOG(FATAL) << main_function_input_descs.status().ToString();
+  }
+  StatusOr<vm::ref<iree_vm_list_t>> main_function_inputs;
+  if (!absl::GetFlag(FLAGS_inputs_file).empty()) {
+    if (!absl::GetFlag(FLAGS_inputs).empty()) {
+      IREE_LOG(FATAL)
+          << "Expected only one of inputs and inputs_file to be set";
+    }
+    main_function_inputs = ParseToVariantListFromFile(
+        *main_function_input_descs, iree_hal_device_allocator(iree_vk_device),
+        absl::GetFlag(FLAGS_inputs_file));
+  } else {
+    main_function_inputs = ParseToVariantList(
+        *main_function_input_descs, iree_hal_device_allocator(iree_vk_device),
+        absl::GetFlag(FLAGS_inputs));
+  }
+  if (!main_function_inputs.ok()) {
+    IREE_LOG(FATAL) << main_function_inputs.status().ToString();
+  }
+
+  auto main_function_output_descs = ParseOutputSignature(main_function);
+  if (!main_function_output_descs.ok()) {
+    IREE_LOG(FATAL) << main_function_output_descs.status().ToString();
+  }
+
+  const std::string& window_title = absl::GetFlag(FLAGS_module_file);
+  // --------------------------------------------------------------------------
+
+  // --------------------------------------------------------------------------
+  // Main loop.
+  bool done = false;
+  while (!done) {
+    SDL_Event event;
+
+    while (SDL_PollEvent(&event)) {
+      if (event.type == SDL_QUIT) {
+        done = true;
+      }
+
+      ImGui_ImplSDL2_ProcessEvent(&event);
+      if (event.type == SDL_QUIT) done = true;
+      if (event.type == SDL_WINDOWEVENT &&
+          event.window.event == SDL_WINDOWEVENT_RESIZED &&
+          event.window.windowID == SDL_GetWindowID(window)) {
+        g_SwapChainResizeWidth = (int)event.window.data1;
+        g_SwapChainResizeHeight = (int)event.window.data2;
+        g_SwapChainRebuild = true;
+      }
+    }
+
+    if (g_SwapChainRebuild) {
+      g_SwapChainRebuild = false;
+      ImGui_ImplVulkan_SetMinImageCount(g_MinImageCount);
+      ImGui_ImplVulkanH_CreateWindow(g_Instance, g_PhysicalDevice, g_Device,
+                                     &g_MainWindowData, g_QueueFamily,
+                                     g_Allocator, g_SwapChainResizeWidth,
+                                     g_SwapChainResizeHeight, g_MinImageCount);
+      g_MainWindowData.FrameIndex = 0;
+    }
+
+    // Start the Dear ImGui frame
+    ImGui_ImplVulkan_NewFrame();
+    ImGui_ImplSDL2_NewFrame(window);
+    ImGui::NewFrame();
+
+    // Custom window.
+    auto status = RunModuleAndUpdateImGuiWindow(
+        iree_vk_device, iree_context, main_function, entry_function,
+        main_function_inputs.value(), main_function_output_descs.value(),
+        window_title);
+    if (!status.ok()) {
+      IREE_LOG(FATAL) << status;
+      done = true;
+      continue;
+    }
+
+    // Rendering
+    ImGui::Render();
+    RenderFrame(wd, g_Device, g_Queue);
+
+    PresentFrame(wd, g_Queue);
+  }
+  // --------------------------------------------------------------------------
+
+  // --------------------------------------------------------------------------
+  // Cleanup
+  main_function_inputs.value().reset();
+
+  iree_vm_module_release(hal_module);
+  iree_vm_module_release(bytecode_module);
+  iree_vm_context_release(iree_context);
+  iree_hal_device_release(iree_vk_device);
+  iree_hal_driver_release(iree_vk_driver);
+  iree_hal_vulkan_syms_release(iree_vk_syms);
+  iree_vm_instance_release(iree_instance);
+
+  err = vkDeviceWaitIdle(g_Device);
+  check_vk_result(err);
+  ImGui_ImplVulkan_Shutdown();
+  ImGui_ImplSDL2_Shutdown();
+  ImGui::DestroyContext();
+
+  CleanupVulkanWindow();
+  CleanupVulkan();
+
+  SDL_DestroyWindow(window);
+  SDL_Quit();
+  // --------------------------------------------------------------------------
+
+  return 0;
+}
diff --git a/third_party/tracy b/third_party/tracy
index 864d86e..a9a09ab 160000
--- a/third_party/tracy
+++ b/third_party/tracy
@@ -1 +1 @@
-Subproject commit 864d86e8b6d21449474db5e9313dbff90aa9c24f
+Subproject commit a9a09ab0940408898fccfdcfe2bb8dc19b50f13c