Refactor trace_utils and module_utils out of tf_test_utils/tf_utils (#3916)

These files were becoming long catch-alls so I split them up:

- `ModuleCall`, `Trace` and `TracedModule` are moved from `tf_test_utils` into `trace_utils`.
  - `compare_traces` is removed from `Trace` and in favor of accessing it from `trace_utils.compare_traces`.
  - `check_same` is moved into `tf_utils`.
- `_FunctionWrapper` and `CompiledModule` (sub)classes are moved into `module_utils` along with `BackendInfo`.
- The tests are split up to match these changes.
diff --git a/colab/edge_detection.ipynb b/colab/edge_detection.ipynb
index 18528e4..4501f53 100644
--- a/colab/edge_detection.ipynb
+++ b/colab/edge_detection.ipynb
@@ -90,7 +90,7 @@
         "import numpy as np\n",
         "import tensorflow as tf\n",
         "from pyiree.tf import compiler as ireec\n",
-        "from pyiree.tf.support import tf_utils\n",
+        "from pyiree.tf.support import module_utils\n",
         "from pyiree import rt as ireert"
       ]
     },
@@ -271,7 +271,7 @@
         "\n",
         "backend_choice = \"iree_vmla (CPU)\" #@param [ \"iree_vmla (CPU)\", \"iree_llvmjit (CPU)\", \"iree_vulkan (GPU/SwiftShader)\" ]\n",
         "backend_choice = backend_choice.split(\" \")[0]\n",
-        "backend = tf_utils.BackendInfo(backend_choice)"
+        "backend = module_utils.BackendInfo(backend_choice)"
       ]
     },
     {
@@ -435,7 +435,7 @@
         "\n",
         "backend_choice = \"iree_vmla (CPU)\" #@param [ \"iree_vmla (CPU)\", \"iree_llvmjit (CPU)\", \"iree_vulkan (GPU/SwiftShader)\" ]\n",
         "backend_choice = backend_choice.split(\" \")[0]\n",
-        "backend = tf_utils.BackendInfo(backend_choice)"
+        "backend = module_utils.BackendInfo(backend_choice)"
       ]
     },
     {
diff --git a/colab/mnist_tensorflow.ipynb b/colab/mnist_tensorflow.ipynb
index 733f65d..4a1f1aa 100644
--- a/colab/mnist_tensorflow.ipynb
+++ b/colab/mnist_tensorflow.ipynb
@@ -86,7 +86,7 @@
         "\n",
         "from pyiree import rt as ireert\n",
         "from pyiree.tf import compiler as ireec\n",
-        "from pyiree.tf.support import tf_utils\n",
+        "from pyiree.tf.support import module_utils\n",
         "\n",
         "from matplotlib import pyplot as plt\n",
         "import numpy as np\n",
@@ -335,7 +335,7 @@
         "\n",
         "backend_choice = \"iree_vmla (CPU)\" #@param [ \"iree_vmla (CPU)\", \"iree_llvmjit (CPU)\", \"iree_vulkan (GPU/SwiftShader)\" ]\n",
         "backend_choice = backend_choice.split(\" \")[0]\n",
-        "backend = tf_utils.BackendInfo(backend_choice)"
+        "backend = module_utils.BackendInfo(backend_choice)"
       ],
       "execution_count": 8,
       "outputs": []
@@ -353,7 +353,7 @@
       "source": [
         "#@title Compile the mhlo MLIR to an IREE backend and prepare a context to execute it\n",
         "\n",
-        "iree_module = tf_utils.IreeCompiledModule.create_from_instance(\n",
+        "iree_module = module_utils.IreeCompiledModule.create_from_instance(\n",
         "    inference_module, backend, exported_names, ARTIFACTS_DIR)\n",
         "\n",
         "print(\"* Module compiled! See intermediate .mlir files in\", ARTIFACTS_DIR, \"*\")"
@@ -462,4 +462,4 @@
       ]
     }
   ]
-}
\ No newline at end of file
+}
diff --git a/colab/resnet.ipynb b/colab/resnet.ipynb
index 8595da9..af7e8e6 100644
--- a/colab/resnet.ipynb
+++ b/colab/resnet.ipynb
@@ -82,7 +82,7 @@
         "\n",
         "from pyiree import rt as ireert\n",
         "from pyiree.tf import compiler as ireec\n",
-        "from pyiree.tf.support import tf_utils\n",
+        "from pyiree.tf.support import module_utils\n",
         "\n",
         "import tensorflow as tf\n",
         "from matplotlib import pyplot as plt\n",
@@ -139,7 +139,7 @@
         "\n",
         "backend_choice = \"iree_vmla (CPU)\" #@param [ \"iree_vmla (CPU)\", \"iree_llvmjit (CPU)\", \"iree_vulkan (GPU/SwiftShader)\" ]\n",
         "backend_choice = backend_choice.split(\" \")[0]\n",
-        "backend = tf_utils.BackendInfo(backend_choice)"
+        "backend = module_utils.BackendInfo(backend_choice)"
       ],
       "execution_count": 3,
       "outputs": []
@@ -314,4 +314,4 @@
       ]
     }
   ]
-}
\ No newline at end of file
+}
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/BUILD b/integrations/tensorflow/bindings/python/pyiree/tf/support/BUILD
index 9b1c869..09fe138 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/BUILD
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/BUILD
@@ -29,9 +29,11 @@
     name = "support",
     srcs = [
         "__init__.py",
+        "module_utils.py",
         "tf_test_driver.py",
         "tf_test_utils.py",
         "tf_utils.py",
+        "trace_utils.py",
     ],
     deps = INTREE_TENSORFLOW_PY_DEPS + [
         "//integrations/tensorflow/bindings/python:pathsetup",  # build_cleaner: keep
@@ -41,6 +43,22 @@
 )
 
 iree_py_test(
+    name = "module_utils_test",
+    srcs = [
+        "module_utils.py",
+        "module_utils_test.py",
+    ],
+    python_version = "PY3",
+    tags = [
+        "driver=llvm",
+        "driver=vmla",
+    ],
+    deps = INTREE_TENSORFLOW_PY_DEPS + [
+        "//integrations/tensorflow/bindings/python/pyiree/tf/support",
+    ],
+)
+
+iree_py_test(
     name = "tf_test_utils_test",
     srcs = [
         "tf_test_utils.py",
@@ -59,6 +77,18 @@
         "tf_utils_test.py",
     ],
     python_version = "PY3",
+    deps = INTREE_TENSORFLOW_PY_DEPS + [
+        "//integrations/tensorflow/bindings/python/pyiree/tf/support",
+    ],
+)
+
+iree_py_test(
+    name = "trace_utils_test",
+    srcs = [
+        "trace_utils.py",
+        "trace_utils_test.py",
+    ],
+    python_version = "PY3",
     tags = [
         "driver=vmla",
     ],
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/module_utils.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/module_utils.py
new file mode 100644
index 0000000..4ec9d66
--- /dev/null
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/module_utils.py
@@ -0,0 +1,959 @@
+# Lint as: python3
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Utilities for compiling 'tf.Module's"""
+
+import collections
+import os
+import tempfile
+from typing import Any, Callable, Dict, Sequence, Set, Tuple, Type, Union
+
+from absl import logging
+import numpy as np
+from pyiree import rt
+from pyiree.tf import compiler
+from pyiree.tf.support import tf_utils
+import tensorflow.compat.v2 as tf
+
+
+def _setup_mlir_crash_reproducer(
+    function: Any,  # pytype doesn't support arbitrary Callable[*args, **kwargs]
+    artifacts_dir: str,
+    backend_id: str,
+) -> Any:  # Callable[Any, Any]
+  """Wraps `function` so that it a MLIR crash reproducer is saved if it crashes.
+
+  Writes to `artifacts_dir/reproducer__{backend}.mlir` in the case of a crash.
+
+  Args:
+    function: The callable to decorate.
+    artifacts_dir: The directory to write the reproducer to.
+    backend_id: The unique backend name to use when writting the reproducer.
+
+  Returns:
+    A function with the same API as the passed function.
+  """
+
+  def decorator(*args, **kwargs):
+    # Set up a crash reproducer for debugging.
+    if artifacts_dir is not None:
+      compiler.Context.default_crash_reproducer_path = os.path.join(
+          artifacts_dir, f"reproducer__{backend_id}.mlir")
+    try:
+      results = function(*args, **kwargs)
+    except Exception:  # pylint: disable=broad-except
+      # Disable the crash reproducer (to avoid inadvertently overwriting it).
+      if artifacts_dir is not None:
+        compiler.Context.default_crash_reproducer_path = None
+      raise
+    return results
+
+  return decorator
+
+
+def _incrementally_lower_compiler_module(
+    compiler_module: compiler.Module,
+    backend_info: "BackendInfo",
+    artifacts_dir: str,
+) -> Tuple[compiler.binding.OpaqueBlob, Union[str, None]]:
+  """Lowers a MLIR compiler module incrementally and saves its outputs.
+
+  If artifacts_dir is provided then the following artifacts will be saved:
+    tf_input.mlir:
+      MLIR for the module in TF's input dialect.
+    iree_input.mlir:
+      The MLIR above translated to IREE via compiler.TF_IMPORT_PASS_PIPELINE.
+    backend_id/compiled.vmfb:
+      A VM FlatBuffer compiled to the target backend from the IREE MLIR above.
+
+  Args:
+    compiler_module: A compiler.Module to lower.
+    backend_info: BackendInfo with the details for lowering compiler_module to
+      IREE.
+    artifacts_dir: An optional string pointing to where compilation artifacts
+      should be saved. No compilation artifacts will be saved if this is not
+      provided.
+  """
+  if artifacts_dir is not None:
+    os.makedirs(artifacts_dir, exist_ok=True)
+    tf_mlir_path = os.path.join(artifacts_dir, "tf_input.mlir")
+    logging.info("Saving raw TF input MLIR to: %s", tf_mlir_path)
+    with open(tf_mlir_path, "w") as f:
+      f.write(compiler_module.to_asm())
+
+  # Manually run the passes that tf_module_to_compiler_module usually would.
+  compiler_module.run_pass_pipeline(compiler.TF_IMPORT_PASS_PIPELINE)
+
+  if artifacts_dir is not None:
+    iree_mlir_path = os.path.join(artifacts_dir, "iree_input.mlir")
+    logging.info("Saving IREE input MLIR to: %s", iree_mlir_path)
+    with open(iree_mlir_path, "w") as f:
+      f.write(compiler_module.to_asm())
+
+  compiled_module = compiler_module.compile(
+      target_backends=backend_info.compiler_targets)
+
+  compiled_path = None
+  if artifacts_dir is not None:
+    backend_dir = os.path.join(artifacts_dir, backend_info.backend_id)
+    os.makedirs(backend_dir, exist_ok=True)
+    compiled_path = os.path.join(backend_dir, "compiled.vmfb")
+    logging.info("Saving compiled IREE module to: %s", compiled_path)
+    with open(compiled_path, "wb") as f:
+      f.write(compiled_module)
+  return compiled_module, compiled_path
+
+
+def _incrementally_compile_tf_module(
+    module: Type[tf.Module],
+    backend_info: "BackendInfo",
+    exported_names: Sequence[str] = (),
+    artifacts_dir: str = None,
+) -> Tuple[compiler.binding.OpaqueBlob, Union[str, None]]:
+  """Compile a TensorFlow tf.Module and optionally save compilation artifacts.
+
+  The module blob this creates is not callable. See IreeCompiledModule for an
+  API that returns a module that can be called without any further steps.
+
+  See _incrementally_lower_compiler_module's docstring for details about which
+  artifacts will be saved.
+
+  Args:
+    module: A tf.Module.
+    backend_info: BackendInfo with the details for compiling this module.
+    exported_names: Optional sequence representing the exported names to keep.
+    artifacts_dir: An optional string pointing to where compilation artifacts
+      should be saved. No compilation artifacts will be saved if this is not
+      provided.
+
+  Returns:
+    A compiled IREE module blob and the path to the compiled VM FlatBuffer if
+    artifacts_dir is provided.
+  """
+
+  def _compile_module(module, backend_info, exported_names, artifacts_dir):
+    compiler_module = compiler.tf_module_to_compiler_module(module,
+                                                            exported_names,
+                                                            pass_pipeline=())
+    return _incrementally_lower_compiler_module(compiler_module, backend_info,
+                                                artifacts_dir)
+
+  _compile_module = _setup_mlir_crash_reproducer(_compile_module, artifacts_dir,
+                                                 backend_info.backend_id)
+  return _compile_module(module, backend_info, exported_names, artifacts_dir)
+
+
+def _incrementally_compile_tf_signature_def_saved_model(
+    saved_model_dir: str, saved_model_tags: Set[str],
+    backend_info: "BackendInfo", exported_name: str, artifacts_dir: str):
+  """Compile a SignatureDef SavedModel and optionally save compilation artifacts.
+
+  The module blob this creates is not callable. See IreeCompiledModule for an
+  API that returns a module that can be called without any further steps.
+
+  See _incrementally_lower_compiler_module's docstring for details about which
+  artifacts will be saved.
+
+  Args:
+    saved_model_dir: Directory of the saved model.
+    saved_model_tags: Optional set of tags to use when loading the model.
+    backend_info: BackendInfo with the details for compiling the saved model.
+    exported_name: A str representing the signature on the saved model to
+      compile.
+    artifacts_dir: An optional string pointing to where compilation artifacts
+      should be saved. No compilation artifacts will be saved if this is not
+      provided.
+
+  Returns:
+    A compiled IREE module blob and the path to the compiled VM FlatBuffer if
+    artifacts_dir is provided.
+  """
+
+  def _compile_module(saved_model_dir, saved_model_tags, backend_info,
+                      exported_name, artifacts_dir):
+    # Convert the tf_module into raw TF input MLIR.
+    compiler_module = compiler.tf_signature_def_saved_model_to_compiler_module(
+        saved_model_dir, saved_model_tags, [exported_name], pass_pipeline=())
+    return _incrementally_lower_compiler_module(compiler_module, backend_info,
+                                                artifacts_dir)
+
+  _compile_module = _setup_mlir_crash_reproducer(_compile_module, artifacts_dir,
+                                                 backend_info.backend_id)
+  return _compile_module(saved_model_dir, saved_model_tags, backend_info,
+                         exported_name, artifacts_dir)
+
+
+class _FunctionWrapper(object):
+
+  def __call__(self, *args, **kwargs):
+    raise NotImplementedError()
+
+  def get_serialized_values(self) -> Tuple[Tuple[str], Tuple[str]]:
+    """Dummy function to match _IreeFunctionWrapper's API."""
+    return ("",), ("",)
+
+
+class CompiledModule(object):
+  """Base class for the TF and IREE compiled modules."""
+
+  def __init__(
+      self,
+      module_name: str,
+      backend_info: "BackendInfo",
+      compiled_paths: Union[Dict[str, str], None],
+  ):
+    """Shared base constructor – not useful on its own.
+
+    Args:
+      module_name: A name for this compiled module. In most cases this will be
+        the name of the tf.Module subclass or instance that is compiled.
+      backend_info: BackendInfo with the details about compiling this module.
+      compiled_paths: A dictionary mapping compiled method names to file paths
+        corresponding to their serialized representations.
+    """
+    self.module_name = module_name
+    self.backend_info = backend_info
+    self.compiled_paths = compiled_paths
+
+  def reinitialize(self):
+    """Reinitializes all stateful variables."""
+    raise NotImplementedError()
+
+  @classmethod
+  def create_from_class(cls,
+                        module_class: Type[tf.Module],
+                        backend_info: "BackendInfo",
+                        exported_names: Sequence[str] = (),
+                        artifacts_dir: str = None):
+    """Compile a tf.Module subclass to the target backend in backend_info.
+
+    Args:
+      module_class: The tf.Module subclass to compile.
+      backend_info: BackendInfo with the details for compiling this module.
+      exported_names: Optional sequence representing the exported names to keep.
+      artifacts_dir: An optional string pointing to where compilation artifacts
+        should be saved. No compilation artifacts will be saved if this is not
+        provided.
+    """
+    raise NotImplementedError()
+
+  @classmethod
+  def create_from_instance(cls,
+                           module_instance: tf.Module,
+                           backend_info: "BackendInfo",
+                           exported_names: Sequence[str] = (),
+                           artifacts_dir: str = None):
+    """Compile a tf.Module instance to the target backend in backend_info.
+
+    This is only implemented for IreeCompiledModule.
+
+    Args:
+      module_instance: The tf.Module instance to compile.
+      backend_info: BackendInfo with the details for compiling module to IREE.
+      exported_names: Optional sequence representing the exported names to keep.
+      artifacts_dir: An optional string pointing to where compilation artifacts
+        should be saved. No compilation artifacts will be saved if this is not
+        provided.
+    """
+    raise NotImplementedError()
+
+  @classmethod
+  def create_from_signature_def_saved_model(cls,
+                                            saved_model_dir: str,
+                                            saved_model_tags: Set[str],
+                                            module_name: str,
+                                            backend_info: "BackendInfo",
+                                            exported_name: str,
+                                            input_names: Sequence[str],
+                                            output_names: Sequence[str],
+                                            artifacts_dir: str = None):
+    """Compile a SignatureDef SavedModel to the target backend in backend_info.
+
+    Args:
+      saved_model_dir: Directory of the saved model.
+      saved_model_tags: Optional set of tags to use when loading the model.
+      module_name: A name for this compiled module.
+      backend_info: BackendInfo with the details for compiling the saved model.
+      exported_name: A str representing the signature on the saved model to
+        compile.
+      input_names: A sequence of kwargs to feed to the saved model.
+      output_names: A sequence of named outputs to extract from the saved model.
+      artifacts_dir: An optional string pointing to where compilation artifacts
+        should be saved. No compilation artifacts will be saved if this is not
+        provided.
+    """
+    raise NotImplementedError()
+
+  def __getattr__(self, attr: str) -> _FunctionWrapper:
+    raise NotImplementedError()
+
+  def iree_serializable(self):
+    return False
+
+  def tflite_serializable(self):
+    return False
+
+
+class _IreeFunctionWrapper(_FunctionWrapper):
+  """Wraps an IREE function, making it callable."""
+
+  def __init__(self, context: rt.SystemContext, f: rt.system_api.BoundFunction):
+    self._context = context
+    self._f = f
+
+  def __call__(self, *args, **kwargs):
+    return self._f(*args, **kwargs)
+
+  def get_serialized_values(self) -> Tuple[Tuple[str], Tuple[str]]:
+    """Get cxx serialized inputs and outputs for this function."""
+    return self._f.get_serialized_values()
+
+
+class IreeCompiledModule(CompiledModule):
+  """Iree compiled module."""
+
+  def __init__(
+      self,
+      module_name: str,
+      backend_info: "BackendInfo",
+      compiled_paths: Dict[str, str],
+      vm_module: rt.VmModule,
+      config: rt.Config,
+  ):
+    """Base constructor – Use one of the named constructors instead.
+
+    Args:
+      module_name: A name for this compiled module. In most cases this will be
+        the name of the tf.Module subclass or instance that is compiled.
+      backend_info: BackendInfo with the details about compiling this module.
+      compiled_paths: A dictionary mapping compiled method names to file paths
+        corresponding to their serialized representations.
+      vm_module: A rt.VmModule containing compilation info to wrap.
+      config: A rt.Config containing compilation info to wrap.
+    """
+    super().__init__(module_name, backend_info, compiled_paths)
+    self._vm_module = vm_module
+    self._config = config
+    self.reinitialize()
+
+  @classmethod
+  def create_from_class(cls,
+                        module_class: Type[tf.Module],
+                        backend_info: "BackendInfo",
+                        exported_names: Sequence[str] = (),
+                        artifacts_dir: str = None):
+    """Compile a tf.Module subclass to the target backend in backend_info.
+
+    Args:
+      module_class: The tf.Module subclass to compile.
+      backend_info: BackendInfo with the details for compiling module to IREE.
+      exported_names: Optional sequence representing the exported names to keep.
+      artifacts_dir: An optional string pointing to where compilation artifacts
+        should be saved. No compilation artifacts will be saved if this is not
+        provided.
+    """
+    tf_utils.set_random_seed()
+    module_instance = module_class()
+    return cls.create_from_instance(module_instance, backend_info,
+                                    exported_names, artifacts_dir)
+
+  @classmethod
+  def create_from_instance(cls,
+                           module_instance: tf.Module,
+                           backend_info: "BackendInfo",
+                           exported_names: Sequence[str] = (),
+                           artifacts_dir: str = None):
+    """Compile a tf.Module instance to the target backend in backend_info.
+
+    Args:
+      module_instance: The tf.Module instance to compile.
+      backend_info: BackendInfo with the details for compiling module to IREE.
+      exported_names: Optional sequence representing the exported names to keep.
+      artifacts_dir: An optional string pointing to where compilation artifacts
+        should be saved. No compilation artifacts will be saved if this is not
+        provided.
+    """
+    module_blob, compiled_path = _incrementally_compile_tf_module(
+        module=module_instance,
+        backend_info=backend_info,
+        exported_names=exported_names,
+        artifacts_dir=artifacts_dir)
+    vm_module = rt.VmModule.from_flatbuffer(module_blob)
+    config = rt.Config(driver_name=backend_info.driver)
+
+    compiled_paths = None
+    if compiled_path is not None:
+      # IREE bundles every compiled method into the same compiled module.
+      compiled_paths = collections.defaultdict(lambda: compiled_path)
+
+    module_name = type(module_instance).__name__
+
+    return cls(module_name, backend_info, compiled_paths, vm_module, config)
+
+  @classmethod
+  def create_from_signature_def_saved_model(cls,
+                                            saved_model_dir: str,
+                                            saved_model_tags: Set[str],
+                                            module_name: str,
+                                            backend_info: "BackendInfo",
+                                            exported_name: str,
+                                            input_names: Sequence[str],
+                                            output_names: Sequence[str],
+                                            artifacts_dir: str = None):
+    """Compile a SignatureDef SavedModel to the target backend in backend_info.
+
+    Args:
+      saved_model_dir: Directory of the saved model.
+      saved_model_tags: Optional set of tags to use when loading the model.
+      module_name: A name for this compiled module.
+      backend_info: BackendInfo with the details for compiling the saved model.
+      exported_name: A str representing the signature on the saved model to
+        compile.
+      input_names: A sequence of kwargs to feed to the saved model.
+      output_names: A sequence of named outputs to extract from the saved model.
+      artifacts_dir: An optional string pointing to where compilation artifacts
+        should be saved. No compilation artifacts will be saved if this is not
+        provided.
+    """
+    del input_names  # Unused.
+    del output_names  # Unused.
+    module_blob, compiled_path = _incrementally_compile_tf_signature_def_saved_model(
+        saved_model_dir, saved_model_tags, backend_info, exported_name,
+        artifacts_dir)
+    vm_module = rt.VmModule.from_flatbuffer(module_blob)
+    config = rt.Config(driver_name=backend_info.driver)
+
+    compiled_paths = None
+    if compiled_path is not None:
+      # IREE bundles every compiled method into the same compiled module :)
+      compiled_paths = collections.defaultdict(lambda: compiled_path)
+
+    return cls(module_name, backend_info, compiled_paths, vm_module, config)
+
+  def reinitialize(self):
+    """Reinitializes all stateful variables."""
+    # set_random_seed is not needed here because the model_class.__init__ is not
+    # called.
+    self._context = rt.SystemContext(modules=[self._vm_module],
+                                     config=self._config)
+
+  def __getattr__(self, attr: str) -> _IreeFunctionWrapper:
+    # Try to resolve it as a function.
+    m = self._context.modules[self._vm_module.name]
+    f = m[attr]
+    return _IreeFunctionWrapper(self._context, f)
+
+  def iree_serializable(self) -> bool:
+    return self.compiled_paths is not None
+
+
+class _TfFunctionWrapper(_FunctionWrapper):
+  """Wraps a TF function, normalizing it to numpy."""
+
+  def __init__(self, f: Callable[..., Any]):
+    self._f = f
+
+  def __call__(self, *args, **kwargs):
+    # TensorFlow will auto-convert all inbound args.
+    results = self._f(*args, **kwargs)
+    return tf_utils.convert_to_numpy(results)
+
+
+def _convert_inputs_to_tensors(function):
+
+  def decorator(*args, **kwargs):
+    args = [tf.convert_to_tensor(arg) for arg in args]
+    kwargs = {k: tf.convert_to_tensor(v) for k, v in kwargs.items()}
+    return function(*args, **kwargs)
+
+  return decorator
+
+
+class SignatureDefSavedModelWrapper(object):
+  """Wraps a SavedModel to imitate a tf.Module with a method 'exported_name'."""
+
+  def __init__(self, saved_model_dir: str, saved_model_tags: Set[str],
+               exported_name: str):
+    self.saved_model = tf.saved_model.load(saved_model_dir,
+                                           tags=saved_model_tags)
+    inference_func = self.saved_model.signatures[exported_name]
+    inference_func = _convert_inputs_to_tensors(inference_func)
+    self.__setattr__(exported_name, inference_func)
+
+
+class TfCompiledModule(CompiledModule):
+  """TensorFlow 'compiled' module.
+
+  This facade exists to provide a complimentary API to IreeCompiledModule and
+  normalize TensorFlow's output to Numpy.
+  """
+
+  def __init__(
+      self,
+      module_name: str,
+      backend_info: "BackendInfo",
+      constructor: Callable[[], tf.Module],
+      exported_names: Sequence[str],
+  ):
+    """Base constructor – Use one of the named constructors instead.
+
+    Args:
+      module_name: A name for this compiled module. In most cases this will be
+        the name of the tf.Module subclass or instance that is compiled.
+      backend_info: BackendInfo with the details about compiling this module.
+      constructor: A callable (class or function) which returns the tf.Module
+        subclass instance to wrap.
+      exported_names: an optional iterable of strings representing which of the
+        tf.Module subclass instance's functions should be callable. If
+        exported_names is empty then all functions will be callable.
+    """
+    super().__init__(module_name, backend_info, compiled_paths=None)
+    self._constructor = constructor
+    self._exported_names = exported_names
+    self.reinitialize()
+
+  @classmethod
+  def create_from_class(cls,
+                        module_class: Type[tf.Module],
+                        backend_info: "BackendInfo",
+                        exported_names: Sequence[str] = (),
+                        artifacts_dir: str = None):
+    """Compile a tf.Module subclass to the target backend in backend_info.
+
+    Args:
+      module_class: The tf.Module subclass to compile.
+      backend_info: BackendInfo with the details for compiling this module.
+      exported_names: Optional sequence representing the exported names to keep.
+      artifacts_dir: An optional string pointing to where compilation artifacts
+        should be saved. No compilation artifacts will be saved if this is not
+        provided.
+    """
+    module_name = module_class.__name__
+    constructor = module_class
+    return cls(module_name, backend_info, constructor, exported_names)
+
+  @classmethod
+  def create_from_signature_def_saved_model(cls,
+                                            saved_model_dir: str,
+                                            saved_model_tags: Set[str],
+                                            module_name: str,
+                                            backend_info: "BackendInfo",
+                                            exported_name: str,
+                                            input_names: Sequence[str],
+                                            output_names: Sequence[str],
+                                            artifacts_dir: str = None):
+    """Compile a SignatureDef SavedModel to the target backend in backend_info.
+
+    Args:
+      saved_model_dir: Directory of the saved model.
+      saved_model_tags: Optional set of tags to use when loading the model.
+      module_name: A name for this compiled module.
+      backend_info: BackendInfo with the details for compiling the saved model.
+      exported_name: A str representing the signature on the saved model to
+        compile.
+      input_names: A sequence of kwargs to feed to the saved model.
+      output_names: A sequence of named outputs to extract from the saved model.
+      artifacts_dir: An optional string pointing to where compilation artifacts
+        should be saved. No compilation artifacts will be saved if this is not
+        provided.
+    """
+    constructor = lambda: SignatureDefSavedModelWrapper(
+        saved_model_dir, saved_model_tags, exported_name)
+    return cls(module_name, backend_info, constructor, [exported_name])
+
+  def reinitialize(self):
+    """Reinitializes all stateful variables."""
+    tf_utils.set_random_seed()
+    self._tf_module = self._constructor()
+
+  def __getattr__(self, attr: str) -> _TfFunctionWrapper:
+    # Try to resolve it as a function.
+    exported = not self._exported_names or attr in self._exported_names
+    if not hasattr(self._tf_module, attr) or not exported:
+      raise AttributeError(f"The TensorFlow module does not have attr '{attr}'")
+    f = getattr(self._tf_module, attr)
+    if not f or not hasattr(f, "__call__"):
+      raise AttributeError(
+          f"The TensorFlow module does not have a callable attr '{attr}'")
+    return _TfFunctionWrapper(f)
+
+
+def _get_non_inhereted_function_names(cls):
+  """Gets all methods that cls has that its parents don't have."""
+  names = set(dir(cls))
+  for parent in cls.__bases__:
+    names -= set(dir(parent))
+  return list(names)
+
+
+def _get_concrete_functions(module_class: Type[tf.Module],
+                            exported_names: Sequence[str] = ()):
+  """Get concrete functions from non-inherited methods or exported_names."""
+  if not len(exported_names):
+    # Get all method names on 'module_class' that aren't on 'tf.Module'.
+    exported_names = _get_non_inhereted_function_names(module_class)
+  instance = module_class()
+  functions = []
+  for name in exported_names:
+    functions.append(getattr(instance, name).get_concrete_function())
+  return functions, exported_names, instance
+
+
+def tf_module_to_tflite_module_bytes(
+    module_class: Type[tf.Module], exported_names: Sequence[str] = ()
+) -> Dict[str, bytes]:
+  """Compiles a tf.Module's methods with TFLite.
+
+  Args:
+    module_class: A tf.Module subclass to compile with TFLite.
+    exported_names: an optional iterable of strings representing which of the
+      module_class's functions should be compiled. If exported_names is empty
+      then all functions will be compiled.
+
+  Returns:
+    A dict mapping method names to compiled TFLite module bytes.
+  """
+  tflite_modules = []
+  methods, method_names, instance = _get_concrete_functions(
+      module_class, exported_names)
+  failed_methods = []
+  for method, method_name in zip(methods, method_names):
+    logging.info("Attempting to convert '%s' to tflite...", method_name)
+    try:
+      converter = tf.lite.TFLiteConverter.from_concrete_functions([method])
+      logging.info("...converted '%s' to tflite.", method_name)
+      tflite_modules.append(converter.convert())
+    except Exception as e:
+      logging.error("Failed to convert '%s' to tflite.", method_name)
+      logging.error("TFLite excpetion: %s", e)
+      failed_methods.append(method_name)
+
+  if failed_methods:
+    raise RuntimeError(
+        f"Failed to convert the following methods to tflite: {failed_methods}")
+
+  # Keep variables alive until TFLite has done the conversion; ConcreteFunctions
+  # themselves only keep weak references to variables.
+  del instance
+  return dict(zip(method_names, tflite_modules))
+
+
+def tf_signature_def_saved_model_to_tflite_module_bytes(
+    saved_model_dir: str,
+    saved_model_tags: Set[str],
+    exported_name: str,
+    input_names: Sequence[str],
+    output_names: Sequence[str],
+) -> Dict[str, bytes]:
+  """Compiles a SignatureDef SavedModel signature with TFLite.
+
+  Args:
+    saved_model_dir: Directory of the saved model.
+    saved_model_tags: Optional set of tags to use when loading the model.
+    exported_name: A str representing the signature on the saved model to
+      compile.
+    input_names: A sequence of kwargs to feed to the saved model.
+    output_names: A sequence of named outputs to extract from the saved model.
+
+  Returns:
+    A dict mapping the signature name to the compiled TFLite module bytes.
+  """
+  converter = tf.compat.v1.lite.TFLiteConverter.from_saved_model(
+      saved_model_dir,
+      tag_set=saved_model_tags,
+      signature_key=exported_name,
+      input_arrays=input_names,
+      output_arrays=output_names)
+  tflite_module = converter.convert()
+  return dict([[exported_name, tflite_module]])
+
+
+def tflite_module_bytes_to_tflite_interpreters(
+    tflite_module_bytes: Dict[str, bytes],
+    artifacts_dir: str = None
+) -> Tuple[Dict[str, tf.lite.Interpreter], Union[Dict[str, str], None]]:
+  """Compile a dict of TFLite compiled bytes to  TFLite interpreters.
+
+  Args:
+    tflite_module_bytes: A dict mapping method names to compiled TFLite byte
+      strings.
+    artifacts_dir: an optional path to save compilation artifacts to.
+
+  Returns:
+    A dictionary mapping method names to TFLite interpreters and a dictionary
+    mapping method names to compiled tflite graph paths (or None if
+    artifacts_dir is None).
+  """
+  interpreters = dict()
+  compiled_paths = None
+  if artifacts_dir is not None:
+    compiled_paths = dict()
+
+  def _interpret_bytes(method_name: str, tflite_module: bytes, base_dir: str):
+    """Save compiled TFLite module bytes and convert into an interpreter."""
+    tflite_dir = os.path.join(base_dir, "tflite")
+    os.makedirs(tflite_dir, exist_ok=True)
+    tflite_path = os.path.join(tflite_dir, f"{method_name}.tflite")
+    with open(tflite_path, "wb") as f:
+      f.write(tflite_module)
+
+    interpreters[method_name] = tf.lite.Interpreter(tflite_path)
+    if artifacts_dir is not None:
+      compiled_paths[method_name] = tflite_path
+
+  # Load each of the converted methods above into tf.lite.Interpreters.
+  for method_name, tflite_module in tflite_module_bytes.items():
+    if artifacts_dir is None:
+      with tempfile.TemporaryDirectory() as base_dir:
+        _interpret_bytes(method_name, tflite_module, base_dir)
+    else:
+      _interpret_bytes(method_name, tflite_module, artifacts_dir)
+
+  return interpreters, compiled_paths
+
+
+class _TfLiteFunctionWrapper(_FunctionWrapper):
+  """Wraps a TFLite interpreter and makes it behave like a python function."""
+
+  def __init__(self, interpreter: tf.lite.Interpreter,
+               output_names: Sequence[str]):
+    self._interpreter = interpreter
+    self._output_names = output_names
+
+  def __call__(self, *args,
+               **kwargs) -> Union[Dict[str, Any], Tuple[Any], np.ndarray]:
+    if len(args) and len(kwargs):
+      raise ValueError("Passing both args and kwargs is not supported by "
+                       "_TfLiteFunctionWrapper")
+
+    # Set up and run the function.
+    self._interpreter.allocate_tensors()
+
+    if len(args):
+      # Specifically to get TFLite to work with keras models that take a list of
+      # inputs instead of a sequence of args as their inputs, because it decides
+      # to change the input signature but it still technically works if you
+      # ignore that it does that.
+      if len(args) == 1 and isinstance(args[0], list):
+        args = args[0]
+
+      for arg, detail in zip(args, self._interpreter.get_input_details()):
+        self._interpreter.set_tensor(detail["index"], arg)
+    else:
+      for detail in self._interpreter.get_input_details():
+        self._interpreter.set_tensor(detail["index"], kwargs[detail["name"]])
+
+    self._interpreter.invoke()
+
+    # Extract the outputs from the TFLite interpreter.
+    outputs = []
+    for detail in self._interpreter.get_output_details():
+      value = tf_utils.normalize_numpy(
+          self._interpreter.get_tensor(detail["index"]))
+      if self._output_names is not None:
+        name = detail["name"]
+        if name not in self._output_names:
+          raise ValueError(f"Expected '{name}' to be in {self._output_names}")
+        outputs.append([detail["name"], value])
+      else:
+        outputs.append(value)
+
+    # Process them to match the output of the tf.Module.
+    if self._output_names is not None:
+      return dict(outputs)
+    else:
+      if len(outputs) == 1:
+        return outputs[0]
+      return tuple(outputs)
+
+
+class TfLiteCompiledModule(CompiledModule):
+  """Compiles a tf.Module with TFLite and allows it to be called."""
+
+  def __init__(
+      self,
+      module_name: str,
+      backend_info: "BackendInfo",
+      compiled_paths: Dict[str, str],
+      interpreters: Dict[str, tf.lite.Interpreter],
+      output_names: Sequence[str] = None,
+  ):
+    """Base constructor – Use one of the named constructors instead.
+
+    Args:
+      module_name: A name for this compiled module. In most cases this will be
+        the name of the tf.Module subclass or instance that is compiled.
+      backend_info: BackendInfo with the details about compiling this module.
+      compiled_paths: A dictionary mapping compiled method names to file paths
+        corresponding to their serialized representations.
+      interpreters: A dict of tf.lite.Interpreters to make callable.
+    """
+    super().__init__(module_name, backend_info, compiled_paths)
+    self._interpreters = interpreters
+    self._output_names = output_names
+
+  @classmethod
+  def create_from_class(cls,
+                        module_class: Type[tf.Module],
+                        backend_info: "BackendInfo",
+                        exported_names: Sequence[str] = (),
+                        artifacts_dir: str = None):
+    """Compile a tf.Module subclass to the target backend in backend_info.
+
+    Args:
+      module_class: The tf.Module subclass to compile.
+      backend_info: BackendInfo with the details for compiling this module.
+      exported_names: Optional sequence representing the exported names to keep.
+      artifacts_dir: An optional string pointing to where compilation artifacts
+        should be saved. No compilation artifacts will be saved if this is not
+        provided.
+    """
+    tf_utils.set_random_seed()
+    tflite_module_bytes = tf_module_to_tflite_module_bytes(
+        module_class, exported_names)
+    interpreters, compiled_paths = tflite_module_bytes_to_tflite_interpreters(
+        tflite_module_bytes, artifacts_dir)
+    module_name = module_class.__name__
+    return cls(module_name, backend_info, compiled_paths, interpreters)
+
+  @classmethod
+  def create_from_signature_def_saved_model(cls,
+                                            saved_model_dir: str,
+                                            saved_model_tags: Set[str],
+                                            module_name: str,
+                                            backend_info: "BackendInfo",
+                                            exported_name: str,
+                                            input_names: Sequence[str],
+                                            output_names: Sequence[str],
+                                            artifacts_dir: str = None):
+    """Compile a SignatureDef SavedModel to the target backend in backend_info.
+
+    Args:
+      saved_model_dir: Directory of the saved model.
+      saved_model_tags: Optional set of tags to use when loading the model.
+      module_name: A name for this compiled module.
+      backend_info: BackendInfo with the details for compiling the saved model.
+      exported_name: A str representing the signature on the saved model to
+        compile.
+      input_names: A sequence of kwargs to feed to the saved model.
+      output_names: A sequence of named outputs to extract from the saved model.
+      artifacts_dir: An optional string pointing to where compilation artifacts
+        should be saved. No compilation artifacts will be saved if this is not
+        provided.
+    """
+    tflite_module_bytes = tf_signature_def_saved_model_to_tflite_module_bytes(
+        saved_model_dir, saved_model_tags, exported_name, input_names,
+        output_names)
+    interpreters, compiled_paths = tflite_module_bytes_to_tflite_interpreters(
+        tflite_module_bytes, artifacts_dir)
+    return cls(module_name, backend_info, compiled_paths, interpreters,
+               output_names)
+
+  def reinitialize(self):
+    """Reinitializes all stateful variables."""
+    # This is a noop because TFLite (mostly) doesn't support stateful modules.
+    pass
+
+  def __getattr__(self, attr: str) -> _TfLiteFunctionWrapper:
+    # Try to resolve it as an interpreter.
+    if not attr in self._interpreters:
+      raise AttributeError(
+          f"The TFLite module does not have an interpreter for '{attr}'")
+    return _TfLiteFunctionWrapper(self._interpreters[attr], self._output_names)
+
+  def tflite_serializable(self) -> bool:
+    return self.compiled_paths is not None
+
+
+class BackendInfo:
+  """Contains information for compiling the specified backend."""
+
+  _name_to_info = {
+      "tf": {
+          "compiled_module_class": TfCompiledModule,
+          "driver": None,
+          "compiler_targets": None,
+      },
+      "tflite": {
+          "compiled_module_class": TfLiteCompiledModule,
+          "driver": None,
+          "compiler_targets": None,
+      },
+      "iree_vmla": {
+          "compiled_module_class": IreeCompiledModule,
+          "driver": "vmla",
+          "compiler_targets": ["vmla"]
+      },
+      "iree_vulkan": {
+          "compiled_module_class": IreeCompiledModule,
+          "driver": "vulkan",
+          "compiler_targets": ["vulkan-*"]
+      },
+  }
+
+  def __init__(self, backend_name: str, backend_id: str = None):
+    """Creates a BackendInfo with the compilation details for backend_name.
+
+    Args:
+      backend_name: a str specifying which backend to use. Should be one of
+        'tf', 'tflite', 'iree_vmla', 'iree_vulkan'.
+      backend_id: an optional str specifying what name to use when saving
+        compiled artifacts. Must satisfy `backend_id.startswith(backend_name)`.
+
+    Raises:
+      KeyError: if backend_name is not one of ['tf', 'tflite', 'iree_vmla',
+        'iree_vulkan'].
+      ValueError: if backend_id doesn't start with backend_name.
+    """
+    if backend_name not in self._name_to_info:
+      raise KeyError(
+          "Expected backend_name to be one of "
+          f"{list(self._name_to_info.keys())} but got '{backend_name}'.")
+    if backend_id is not None and not backend_id.startswith(backend_name):
+      raise ValueError(f"Expected backend_id to start with '{backend_name}' "
+                       f"but got '{backend_id}'.")
+
+    self.backend_name = backend_name
+    self.backend_id = backend_name if backend_id is None else backend_id
+
+    info = self._name_to_info[backend_name]
+    self._compiled_module_class = info["compiled_module_class"]
+    self.driver = info["driver"]
+    self.compiler_targets = info["compiler_targets"]
+
+  def compile_from_class(self,
+                         module_class: Type[tf.Module],
+                         exported_names: Sequence[str] = (),
+                         artifacts_dir: str = None) -> CompiledModule:
+    """Creates a 'CompiledModule' for this backend."""
+    return self._compiled_module_class.create_from_class(
+        module_class, self, exported_names, artifacts_dir)
+
+  def compile_signature_def_saved_model(
+      self,
+      saved_model_dir: str,
+      saved_model_tags: Set[str],
+      module_name: str,
+      exported_name: str,
+      input_names: Sequence[str],
+      output_names: Sequence[str],
+      artifacts_dir: str = None) -> CompiledModule:
+    return self._compiled_module_class.create_from_signature_def_saved_model(
+        saved_model_dir, saved_model_tags, module_name, self, exported_name,
+        input_names, output_names, artifacts_dir)
+
+  @classmethod
+  def get_all_backends(cls) -> Sequence["BackendInfo"]:
+    """Returns a list of all BackendInfo configurations."""
+    return [BackendInfo(backend_name) for backend_name in cls._name_to_info]
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/module_utils_test.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/module_utils_test.py
new file mode 100644
index 0000000..8baa6bc
--- /dev/null
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/module_utils_test.py
@@ -0,0 +1,114 @@
+# Lint as: python3
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tests for pyiree.tf.support.module_utils."""
+
+import os
+import tempfile
+
+from absl import logging
+from absl.testing import parameterized
+from pyiree.tf.support import module_utils
+import tensorflow as tf
+
+
+class ConstantModule(tf.Module):
+
+  @tf.function(input_signature=[])
+  def meaning(self):
+    return tf.constant([42.])
+
+
+class StatefulCountingModule(tf.Module):
+
+  def __init__(self):
+    self.count = tf.Variable([0.])
+
+  @tf.function(input_signature=[])
+  def increment(self):
+    self.count.assign_add(tf.constant([1.]))
+
+  @tf.function(input_signature=[])
+  def get_count(self):
+    return self.count
+
+
+class RandomInitModule(tf.Module):
+
+  def __init__(self):
+    self.value = tf.Variable(tf.random.uniform([1]))
+
+  @tf.function(input_signature=[])
+  def get(self):
+    return self.value
+
+
+class UtilsTests(tf.test.TestCase, parameterized.TestCase):
+
+  def test_artifact_saving(self):
+    backend_info = module_utils.BackendInfo('iree_vmla')
+    with tempfile.TemporaryDirectory() as artifacts_dir:
+      tf_module = ConstantModule()
+      iree_module_utils, compiled_path = (
+          module_utils._incrementally_compile_tf_module(
+              tf_module, backend_info=backend_info,
+              artifacts_dir=artifacts_dir))
+
+      artifacts_to_check = [
+          'tf_input.mlir',
+          'iree_input.mlir',
+          compiled_path,
+      ]
+      for artifact in artifacts_to_check:
+        artifact_path = os.path.join(artifacts_dir, artifact)
+        logging.info('Checking path: %s', artifact_path)
+        self.assertTrue(os.path.exists(artifact_path))
+
+  @parameterized.named_parameters([
+      ('tensorflow', 'tf'),
+      ('vmla', 'iree_vmla'),
+  ])
+  def test_unaltered_state(self, backend_name):
+    backend_info = module_utils.BackendInfo(backend_name)
+    module = backend_info.compile_from_class(StatefulCountingModule)
+
+    # Test that incrementing works properly.
+    self.assertEqual([0.], module.get_count())
+    module.increment()
+    self.assertEqual([1.], module.get_count())
+
+    module.reinitialize()
+    # Test reinitialization.
+    self.assertEqual([0.], module.get_count())
+
+  @parameterized.named_parameters([
+      ('tensorflow', 'tf'),
+      ('vmla', 'iree_vmla'),
+  ])
+  def test_random_initialization(self, backend_name):
+    backend_info = module_utils.BackendInfo(backend_name)
+
+    # Test compilation is the same.
+    module_1 = backend_info.compile_from_class(RandomInitModule)
+    module_2 = backend_info.compile_from_class(RandomInitModule)
+    self.assertAllEqual(module_1.get(), module_2.get())
+
+    # Test reinitialization is the same.
+    old_value = module_1.get()
+    module_1.reinitialize()
+    self.assertAllEqual(old_value, module_1.get())
+
+
+if __name__ == '__main__':
+  tf.test.main()
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
index e18b5ce..e2a67b9 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
@@ -24,20 +24,17 @@
 
 import collections
 import copy
-import glob
-import inspect
 import itertools
 import os
-import pickle
 import re
-import sys
 import tempfile
 from typing import Any, Callable, Dict, List, Sequence, Set, Tuple, Type, Union
 
 from absl import flags
 from absl import logging
-import numpy as np
+from pyiree.tf.support import module_utils
 from pyiree.tf.support import tf_utils
+from pyiree.tf.support import trace_utils
 import tensorflow.compat.v2 as tf
 
 flags.DEFINE_string("reference_backend", "tf",
@@ -56,8 +53,8 @@
 flags.DEFINE_bool(
     "get_saved_model", False,
     "Creates and stores a SavedModel for the tf.Module class to be tested.")
+
 FLAGS = flags.FLAGS
-NUMPY_LINEWIDTH = 120
 DEFAULT_INPUT_GENERATOR = tf_utils.uniform
 
 
@@ -96,7 +93,7 @@
   return backend_names, backend_ids
 
 
-def get_target_backends() -> Sequence[tf_utils.BackendInfo]:
+def get_target_backends() -> Sequence[module_utils.BackendInfo]:
   """Gets the BackendInfo instances to compare with the reference backend.
 
   By default all backends in BackendInfo will be used. Specific backends to
@@ -109,496 +106,15 @@
     logging.info("Using backends from command line: %s", FLAGS.target_backends)
     backend_names, backend_ids = _parse_target_backends()
     backends = [
-        tf_utils.BackendInfo(backend_name, backend_id)
+        module_utils.BackendInfo(backend_name, backend_id)
         for backend_name, backend_id in zip(backend_names, backend_ids)
     ]
   else:
     # If no backends are specified, use them all.
-    backends = tf_utils.BackendInfo.get_all_backends()
+    backends = module_utils.BackendInfo.get_all_backends()
   return backends
 
 
-def _indent(input_str: str, indentation: int = 2) -> str:
-  """Indents a string by the specified number of spaces, defaulting to 2."""
-  spaces = " " * indentation
-  lines = input_str.split("\n")
-  # Prepend spaces to each non-empty line.
-  lines = [f"{spaces}{line}" if len(line) else line for line in lines]
-  return "\n".join(lines)
-
-
-def _zfill_width(length: int) -> Union[int, None]:
-  return int(np.ceil(np.log10(length))) if length else None
-
-
-class ModuleCall:
-
-  def __init__(self,
-               method: str,
-               inputs: Tuple[Any],
-               outputs: Tuple[Any],
-               serialized_inputs: Tuple[str],
-               serialized_outputs: Tuple[str],
-               rtol: float = 1e-6,
-               atol: float = 1e-6):
-    """Records the details of a call to a CompiledModule."""
-    self.method = method
-
-    # Deepcopy to safegard against mutation.
-    self.inputs = copy.deepcopy(inputs)
-    if outputs is not None:
-      outputs = copy.deepcopy(outputs)
-    else:
-      outputs = tuple()
-    self.outputs = outputs if isinstance(outputs, tuple) else (outputs,)
-
-    self.serialized_inputs = serialized_inputs
-    self.serialized_outputs = serialized_outputs
-
-    self.rtol = rtol
-    self.atol = atol
-
-  def get_tolerances(self) -> Tuple[float, float]:
-    """Gets the floating point tolerances associated with this call."""
-    return self.rtol, self.atol
-
-  def _get_shape_and_dtype(self, value: Any) -> str:
-    if isinstance(value, np.ndarray):
-      return tf_utils.get_shape_and_dtype(value, allow_non_mlir_dtype=True)
-    else:
-      return str(type(value))
-
-  def __str__(self):
-    prior_printoptions = np.get_printoptions()
-    np.set_printoptions(linewidth=NUMPY_LINEWIDTH)
-
-    header = f"Method: {self.method}"
-    inputs = "\n".join(_indent(str(value)) for value in self.inputs)
-    input_shapes = ", ".join(
-        self._get_shape_and_dtype(value) for value in self.inputs)
-
-    outputs = "\n".join(_indent(str(value)) for value in self.outputs)
-    output_shapes = ", ".join(
-        self._get_shape_and_dtype(value) for value in self.outputs)
-
-    tolerances = _indent(f"rtol={self.rtol}, atol={self.atol}")
-    body = (f"Inputs: {input_shapes}\n{inputs}\n"
-            f"Outputs: {output_shapes}\n{outputs}"
-            f"\nTolerances:\n{tolerances}")
-    result = f"{header}\n{_indent(body)}"
-
-    np.set_printoptions(**prior_printoptions)
-    return result
-
-  def serialize(self, call_dir: str) -> None:
-    """Stores a serialized copy of this call.
-
-    Can be loaded via ModuleCall.load(call_dir)
-
-    Args:
-      call_dir: str, the path to the directory to serialize this call to.
-    """
-    os.makedirs(call_dir, exist_ok=True)
-
-    metadata = {
-        "method": self.method,
-        "serialized_inputs": self.serialized_inputs,
-        "serialized_outputs": self.serialized_outputs,
-        "rtol": self.rtol,
-        "atol": self.atol
-    }
-    with open(os.path.join(call_dir, "metadata.pkl"), "wb") as f:
-      pickle.dump(metadata, f)
-
-    width = _zfill_width(len(self.inputs))
-    for i, value in enumerate(self.inputs):
-      path = os.path.join(call_dir, f"input_{str(i).zfill(width)}.pkl")
-      with open(path, "wb") as f:
-        pickle.dump(value, f)
-
-    width = _zfill_width(len(self.outputs))
-    for i, value in enumerate(self.outputs):
-      path = os.path.join(call_dir, f"output_{str(i).zfill(width)}.pkl")
-      with open(path, "wb") as f:
-        pickle.dump(value, f)
-
-  @staticmethod
-  def load(call_dir: str) -> "ModuleCall":
-    """Loads and returns a trace serialized with ModuleCall.serialize."""
-    with open(os.path.join(call_dir, "metadata.pkl"), "rb") as f:
-      kwargs = pickle.load(f)
-
-    for result_type in ["input", "output"]:
-      key = f"{result_type}s"  # inputs or outputs
-      kwargs[key] = []
-
-      files = glob.glob(os.path.join(call_dir, f"{result_type}_*.pkl"))
-      for filename in sorted(files):
-        with open(filename, "rb") as f:
-          kwargs[key].append(pickle.load(f))
-
-      # Convert to tuple to match python's return type for multiple results.
-      kwargs[key] = tuple(kwargs[key])
-
-    return ModuleCall(**kwargs)
-
-
-class Trace:
-  """Stores the inputs and outputs of a series of calls to a module."""
-
-  def __init__(self,
-               module: Union[tf_utils.CompiledModule, None],
-               function: Union[Callable[["TracedModule"], None], None],
-               _load_dict: Dict[str, Any] = None):
-    """Extracts metadata from module and function and initializes.
-
-    Example usage:
-      def forward_pass(...):
-        ...
-      module = IreeCompiledModule(...)
-      trace = Trace(module, forward_pass)
-      forward_pass(TracedModule(module, trace))
-
-    Args:
-      module: the module who's outputs this trace will record.
-      function: the function that module will be traced on.
-      _load_dict: used internally
-    """
-    if _load_dict is None:
-      # Extract metadata from module and function.
-      self.module_name = module.module_name
-      self.compiled_paths = module.compiled_paths
-      self.backend_name = module.backend_info.backend_name
-      self.backend_id = module.backend_info.backend_id
-      self.backend_driver = module.backend_info.driver
-      self.iree_serializable = module.iree_serializable()
-      self.tflite_serializable = module.tflite_serializable()
-      self.function_name = function.__name__
-      self.function_sourcefile = inspect.getsourcefile(function)
-      source, start_line = inspect.getsourcelines(function)
-      self.function_line_numbers = (start_line, start_line + len(source))
-      self.function_source = "".join(source)
-
-      self.calls = []
-    else:
-      self.module_name = _load_dict["module_name"]
-      self.compiled_paths = _load_dict["compiled_paths"]
-      self.backend_name = _load_dict["backend_name"]
-      self.backend_id = _load_dict["backend_id"]
-      self.backend_driver = _load_dict["backend_driver"]
-      self.iree_serializable = _load_dict["iree_serializable"]
-      self.tflite_serializable = _load_dict["tflite_serializable"]
-      self.function_name = _load_dict["function_name"]
-      self.function_sourcefile = _load_dict["function_sourcefile"]
-      self.function_line_numbers = _load_dict["function_line_numbers"]
-      self.function_source = _load_dict["function_source"]
-      self.calls = _load_dict["calls"]
-
-  def __str__(self):
-    header = (f"Trace of {self.module_name} compiled to '{self.backend_id}' "
-              f"on function '{self.function_name}':")
-    # Give each call a number so it's easier to compare between multiple traces.
-    calls = [f"{i + 1}. {str(call)}" for i, call in enumerate(self.calls)]
-    calls = _indent("\n".join(calls))
-    return f"{header}\n{calls}"
-
-  def __iter__(self):
-    for call in self.calls:
-      yield call
-
-  @staticmethod
-  def compare_traces(ref_trace: "Trace",
-                     tar_trace: "Trace") -> Tuple[bool, Sequence[str]]:
-    traces_match = True
-    error_messages = []
-
-    # Check that all method invocations match.
-    ref_methods = [(call.method, call.rtol, call.atol) for call in ref_trace]
-    tar_methods = [(call.method, call.rtol, call.atol) for call in tar_trace]
-    if ref_methods != tar_methods:
-      # Raise a ValueError instead of returning False since this is an
-      # unexpected error.
-      raise ValueError(
-          "The reference and target traces have different call structures:\n"
-          f"Reference: {ref_methods}\nTarget:    {tar_methods}")
-
-    for ref_call, tar_call in zip(ref_trace, tar_trace):
-      logging.info("Comparing calls to '%s'", ref_call.method)
-      rtol, atol = ref_call.get_tolerances()
-
-      inputs_match, error_message = Trace._check_same(ref_call.inputs,
-                                                      tar_call.inputs, rtol,
-                                                      atol)
-      if not inputs_match:
-        error_messages.append(error_message)
-        logging.error("Inputs did not match.")
-      outputs_match, error_message = Trace._check_same(ref_call.outputs,
-                                                       tar_call.outputs, rtol,
-                                                       atol)
-      if not outputs_match:
-        error_messages.append(error_message)
-        logging.error("Outputs did not match.")
-      calls_match = inputs_match and outputs_match
-
-      if not calls_match:
-        logging.error("Comparision between '%s' and '%s' failed on method '%s'",
-                      ref_trace.backend_id, tar_trace.backend_id,
-                      ref_call.method)
-        logging.error("Reference call '%s':\n%s", ref_trace.backend_id,
-                      ref_call)
-        logging.error("Target call '%s':\n%s", tar_trace.backend_id, tar_call)
-
-      traces_match = traces_match and calls_match
-    return traces_match, error_messages
-
-  @staticmethod
-  def _check_same(ref: Any, tar: Any, rtol: float,
-                  atol: float) -> Tuple[bool, Union[str, None]]:
-    """Checks that ref and tar have identical datastructures and values."""
-    # Check for matching types.
-    if not isinstance(tar, type(ref)):
-      error = ("Expected ref and tar to have the same type but got "
-               f"'{type(ref)}' and '{type(tar)}'")
-      logging.error(error)
-      return False, error
-
-    if ref is None:
-      # Nothing to compare (e.g. the called method had no outputs).
-      return True, None
-
-    # Recursive check for dicts.
-    if isinstance(ref, dict):
-      if ref.keys() != tar.keys():
-        error = ("Expected ref and tar to have the same keys, but got "
-                 f"'{ref.keys()}' and '{tar.keys()}'")
-        logging.error(error)
-        return False, error
-      # Check that all of the dictionaries' values are the same.
-      for key in ref:
-        same, error = Trace._check_same(ref[key], tar[key], rtol, atol)
-        if not same:
-          return same, error
-
-    # Recursive check for iterables.
-    elif isinstance(ref, list) or isinstance(ref, tuple):
-      if len(ref) != len(tar):
-        error = ("Expected ref and tar to have the same length, but got "
-                 f"{len(ref)} and {len(tar)}")
-        logging.error(error)
-        return False, error
-      # Check that all of the iterables' values are the same.
-      for i in range(len(ref)):
-        same, error = Trace._check_same(ref[i], tar[i], rtol, atol)
-        if not same:
-          return same, error
-
-    # Base check for numpy arrays.
-    elif isinstance(ref, np.ndarray):
-      if ref.dtype != tar.dtype:
-        error = ("Expected ref and tar to have the same dtype, but got "
-                 f"'{ref.dtype}' and '{tar.dtype}'")
-        logging.error(error)
-        return False, error
-      if ref.size == tar.size == 0:
-        return True, None
-
-      if np.issubdtype(ref.dtype, np.floating):
-        same = np.allclose(ref, tar, rtol=rtol, atol=atol, equal_nan=True)
-        abs_diff = np.max(np.abs(ref - tar))
-        rel_diff = np.max(np.abs(ref - tar) / np.max(np.abs(tar)))
-        diff_string = (f"Max abs diff: {abs_diff:.2e}, atol: {atol:.2e}, "
-                       f"max relative diff: {rel_diff:.2e}, rtol: {rtol:.2e}")
-        if not same:
-          error = ("Floating point difference between ref and tar was too "
-                   f"large. {diff_string}")
-          logging.error(error)
-        else:
-          error = None
-          logging.info(
-              "Floating point difference between ref and tar was within "
-              "tolerance. %s", diff_string)
-        return same, error
-      elif np.issubdtype(ref.dtype, np.integer):
-        same = np.array_equal(ref, tar)
-        if not same:
-          abs_diff = np.max(np.abs(ref - tar))
-          error = ("Expected array equality between ref and tar, but got "
-                   f"a max elementwise difference of {abs_diff}")
-          logging.error(error)
-        else:
-          error = None
-        return same, error
-      else:
-        return np.array_equal(ref, tar), None
-
-    # Base check for native number types.
-    elif isinstance(ref, (int, float)):
-      return ref == tar, None
-
-    # If outputs end up here then an extra branch for that type should be added.
-    else:
-      raise TypeError(f"Encountered results with unexpected type {type(ref)}")
-    return True, None
-
-  def save_plaintext(self, trace_dir: str, summarize: bool = True) -> None:
-    """Saves a human-readable string representation of this trace to disk.
-
-    Args:
-      trace_dir: str, path to the directory to save the trace in.
-      summarize: a bool controlling whether numpy should summarize the inputs
-        and outputs if they're large. Setting this to False is very slow for
-        large outputs.
-    """
-    prior_printoptions = np.get_printoptions()
-    np.set_printoptions(
-        linewidth=NUMPY_LINEWIDTH,
-        threshold=None if summarize else sys.maxsize,
-        edgeitems=10)  # Can show more items since they won't clutter the logs.
-
-    path = os.path.join(trace_dir, "log.txt")
-    with open(path, "w") as f:
-      f.write(str(self))
-      f.write("\n")
-
-    np.set_printoptions(**prior_printoptions)
-
-  def serialize(self, trace_dir: str) -> None:
-    """Stores a serialized copy of this trace in trace_dir.
-
-    It can be loaded via `Trace.load(trace_dir)`.
-
-    Args:
-      trace_dir: str, path to the directory to serialize the trace to.
-    """
-
-    compiled_paths = None
-    if self.compiled_paths is not None:
-      # Convert to a dict to avoid the issues with serializing defaultdicts.
-      compiled_paths = dict(self.compiled_paths)
-
-    # Python serialization.
-    metadata = {
-        "module_name": self.module_name,
-        "compiled_paths": compiled_paths,
-        "backend_name": self.backend_name,
-        "backend_id": self.backend_id,
-        "backend_driver": self.backend_driver,
-        "iree_serializable": self.iree_serializable,
-        "tflite_serializable": self.tflite_serializable,
-        "function_name": self.function_name,
-        "function_sourcefile": self.function_sourcefile,
-        "function_line_numbers": self.function_line_numbers,
-        "function_source": self.function_source
-    }
-    with open(os.path.join(trace_dir, "metadata.pkl"), "wb") as f:
-      pickle.dump(metadata, f)
-
-    width = _zfill_width(len(self.calls))
-    for i, call in enumerate(self.calls):
-      call_dir = os.path.join(trace_dir, f"call_{str(i).zfill(width)}")
-      call.serialize(call_dir)
-
-    # C++ benchmark serialization.
-    if self.iree_serializable or self.tflite_serializable:
-      entry_function = self.calls[0].method
-      compiled_path = self.compiled_paths[entry_function]
-
-      if self.iree_serializable:
-        serialized_inputs = ", ".join(self.calls[0].serialized_inputs)
-        flagfile = [
-            f"--module_file={compiled_path}",
-            f"--driver={self.backend_driver}",
-            f"--function_inputs={serialized_inputs}",
-            f"--entry_function={entry_function}",
-        ]
-        with open(os.path.join(trace_dir, "flagfile"), "w") as f:
-          f.writelines(line + "\n" for line in flagfile)
-      else:
-        with open(os.path.join(trace_dir, "graph_path"), "w") as f:
-          f.writelines(compiled_path + "\n")
-
-  @staticmethod
-  def load(trace_dir: str) -> "Trace":
-    """Loads and returns a trace serialized with Trace.serialize.
-
-    Args:
-      trace_dir: str, path to the directory of the serialized trace.
-
-    Returns:
-      A Trace deserialized from trace_dir.
-    """
-    with open(os.path.join(trace_dir, "metadata.pkl"), "rb") as f:
-      load_dict = pickle.load(f)
-    call_dirs = sorted(glob.glob(os.path.join(trace_dir, "call_*")))
-    calls = [ModuleCall.load(call_dir) for call_dir in call_dirs]
-    load_dict["calls"] = calls
-    return Trace(module=None, function=None, _load_dict=load_dict)
-
-
-def _get_trace_dir(artifacts_dir: str, trace: Trace) -> str:
-  trace_dir = os.path.join(artifacts_dir, trace.backend_id, "traces",
-                           trace.function_name)
-  os.makedirs(trace_dir, exist_ok=True)
-  return trace_dir
-
-
-class TracedModule:
-
-  def __init__(self, module: tf_utils.CompiledModule, trace: Trace):
-    """Wraps a CompiledModule so that all inputs and outputs are traced.
-
-    The TracedModule returned will have an API almost identical to that of the
-    passed CompiledModule. The only changes is that if the keywords `rtol` or
-    `atol` are passed to one of the CompiledModule's methods, then they will be
-    used to set the tolerance for comparing that call to the same call in
-    another trace. So for example, calling `traced_module.add(a, b rtol=1e-8)`
-    would be the same as calling `module.add(a, b)`.
-
-    Args:
-      module: the CompiledModule to trace.
-      trace: the Trace to record calls to this module with.
-    """
-    self._module = module
-    self._trace = trace
-
-  def _trace_call(self, method: tf_utils._FunctionWrapper, method_name: str):
-    """Decorates a CompiledModule method to capture its inputs and outputs."""
-
-    def call(*args, **kwargs):
-      # Pop manually specified tolerances from the kwargs (if any).
-      tolerances = {}
-      tolerances["rtol"] = kwargs.pop("rtol", None)
-      tolerances["atol"] = kwargs.pop("atol", None)
-      # Only pass these to ModuleCall if they were specified by the user.
-      tolerances = {k: v for k, v in tolerances.items() if v is not None}
-
-      # Ensure the inputs are numpy inputs.
-      args = tf_utils.convert_to_numpy(args)
-      kwargs = tf_utils.convert_to_numpy(kwargs)
-
-      # Run the method and record the details of the call.
-      outputs = method(*args, **kwargs)
-      serialized_inputs, serialized_outputs = method.get_serialized_values()
-      self._trace.calls.append(
-          ModuleCall(method_name, args, outputs, serialized_inputs,
-                     serialized_outputs, **tolerances))
-      return outputs
-
-    return call
-
-  def __getattr__(self, attr):
-    # Try to resolve it as an attr on self._module.
-    if not hasattr(self._module, attr):
-      raise AttributeError(f"The compiled module does not have attr '{attr}'")
-    module_attr = getattr(self._module, attr)
-    if not hasattr(module_attr, "__call__"):
-      # e.g. traced_module.backend
-      return module_attr
-    else:
-      # e.g. traced_module.simple_mul(a, b)
-      return self._trace_call(module_attr, method_name=attr)
-
-
 Modules = collections.namedtuple("Modules",
                                  ["ref_module", "tar_modules", "artifacts_dir"])
 
@@ -635,8 +151,8 @@
   artifacts_dir = _setup_artifacts_dir(module_class.__name__)
 
   # Get the backend information for this test.
-  ref_backend_info = tf_utils.BackendInfo(FLAGS.reference_backend,
-                                          f"{FLAGS.reference_backend}_ref")
+  ref_backend_info = module_utils.BackendInfo(FLAGS.reference_backend,
+                                              f"{FLAGS.reference_backend}_ref")
   tar_backend_infos = get_target_backends()
 
   compile_backend = lambda backend_info: backend_info.compile_from_class(
@@ -678,8 +194,8 @@
   artifacts_dir = _setup_artifacts_dir(module_name)
 
   # Get the backend information for this test.
-  ref_backend_info = tf_utils.BackendInfo(FLAGS.reference_backend,
-                                          f"{FLAGS.reference_backend}_ref")
+  ref_backend_info = module_utils.BackendInfo(FLAGS.reference_backend,
+                                              f"{FLAGS.reference_backend}_ref")
   tar_backend_infos = get_target_backends()
 
   compile_backend = (
@@ -1049,7 +565,9 @@
                          f"unit_test '{unit_test.__name__}'.")
       setattr(cls, unit_test.__name__, unit_test)
 
-  def compare_backends(self, trace_function: Callable[[TracedModule], None],
+  def compare_backends(self,
+                       trace_function: Callable[[trace_utils.TracedModule],
+                                                None],
                        modules: Modules) -> None:
     """Run the reference and target backends on trace_function and compare them.
 
@@ -1060,19 +578,20 @@
       trace_function: a function accepting a TracedModule as its argument.
     """
     # Create Traces for each backend.
-    ref_trace = Trace(modules.ref_module, trace_function)
+    ref_trace = trace_utils.Trace(modules.ref_module, trace_function)
     tar_traces = [
-        Trace(module, trace_function) for module in modules.tar_modules
+        trace_utils.Trace(module, trace_function)
+        for module in modules.tar_modules
     ]
 
     # Run the traces through trace_function with their associated modules.
     tf_utils.set_random_seed()
-    trace_function(TracedModule(modules.ref_module, ref_trace))
+    trace_function(trace_utils.TracedModule(modules.ref_module, ref_trace))
     if FLAGS.log_all_traces:
       logging.info(ref_trace)
     for module, trace in zip(modules.tar_modules, tar_traces):
       tf_utils.set_random_seed()
-      trace_function(TracedModule(module, trace))
+      trace_function(trace_utils.TracedModule(module, trace))
       if FLAGS.log_all_traces:
         logging.info(trace)
 
@@ -1082,17 +601,18 @@
     for i, tar_trace in enumerate(tar_traces):
       logging.info("Comparing the reference backend '%s' with '%s'",
                    ref_trace.backend_id, tar_trace.backend_id)
-      traces_match, errors = Trace.compare_traces(ref_trace, tar_trace)
+      traces_match, errors = trace_utils.compare_traces(ref_trace, tar_trace)
       if not traces_match:
         failed_backend_indices.append(i)
         error_messages.extend(errors)
 
     # Save the results to disk before validating.
-    ref_trace_dir = _get_trace_dir(modules.artifacts_dir, ref_trace)
+    ref_trace_dir = trace_utils.get_trace_dir(modules.artifacts_dir, ref_trace)
     ref_trace.save_plaintext(ref_trace_dir, FLAGS.summarize)
     ref_trace.serialize(ref_trace_dir)
     for tar_trace in tar_traces:
-      tar_trace_dir = _get_trace_dir(modules.artifacts_dir, tar_trace)
+      tar_trace_dir = trace_utils.get_trace_dir(modules.artifacts_dir,
+                                                tar_trace)
       tar_trace.save_plaintext(tar_trace_dir, FLAGS.summarize)
       tar_trace.serialize(tar_trace_dir)
 
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils_test.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils_test.py
index 6069688..84c1c36 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils_test.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils_test.py
@@ -14,45 +14,13 @@
 # limitations under the License.
 """Tests for pyiree.tf.support.tf_test_utils."""
 
-import os
-import tempfile
-
-from absl.testing import parameterized
 import numpy as np
 from pyiree.tf.support import tf_test_utils
 from pyiree.tf.support import tf_utils
 import tensorflow as tf
 
 
-class StatefulCountingModule(tf.Module):
-
-  def __init__(self):
-    self.count = tf.Variable([0.])
-
-  @tf.function(input_signature=[])
-  def increment(self):
-    self.count.assign_add(tf.constant([1.]))
-
-  @tf.function(input_signature=[])
-  def get_count(self):
-    return self.count
-
-  @tf.function(input_signature=[tf.TensorSpec([1])])
-  def increment_by(self, value):
-    self.count.assign_add(value)
-
-  @tf.function(input_signature=[tf.TensorSpec([1]), tf.TensorSpec([1])])
-  def increment_by_max(self, a, b):
-    result = tf.maximum(a, b)
-    self.count.assign_add(result)
-    return result
-
-  @tf.function(input_signature=[])
-  def decrement(self):
-    self.count.assign_sub(tf.constant([1.]))
-
-
-class TfFunctionUnittestModule(tf_test_utils.TestModule):
+class TfFunctionUnitTestModule(tf_test_utils.TestModule):
 
   @tf_test_utils.tf_function_unit_test(input_signature=[])
   def no_args(self):
@@ -100,158 +68,7 @@
     return tf.matmul(a, b)
 
 
-class TestUtilsTests(tf.test.TestCase, parameterized.TestCase):
-
-  @parameterized.named_parameters([
-      {
-          'testcase_name': 'all the same',
-          'array_c': np.array([0, 1, 2]),
-          'array_d': np.array(['0', '1', '2']),
-          'array_e': np.array([0.0, 0.1, 0.2]),
-          'tar_same': True,
-      },
-      {
-          'testcase_name': 'wrong int',
-          'array_c': np.array([1, 1, 2]),
-          'array_d': np.array(['0', '1', '2']),
-          'array_e': np.array([0.0, 0.1, 0.2]),
-          'tar_same': False,
-      },
-      {
-          'testcase_name': 'wrong string',
-          'array_c': np.array([0, 1, 2]),
-          'array_d': np.array(['a', '1', '2']),
-          'array_e': np.array([0.0, 0.1, 0.2]),
-          'tar_same': False,
-      },
-      {
-          'testcase_name': 'wrong float',
-          'array_c': np.array([0, 1, 2]),
-          'array_d': np.array(['0', '1', '2']),
-          'array_e': np.array([1.0, 0.1, 0.2]),
-          'tar_same': False,
-      },
-  ])
-  def test_recursive_check_same(self, array_c, array_d, array_e, tar_same):
-
-    # yapf: disable
-    ref = {
-        'a': 1,
-        'b': [
-            {'c': np.array([0, 1, 2])},
-            {'d': np.array(['0', '1', '2'])},
-            {'e': np.array([0.0, 0.1, 0.2])}
-        ],
-    }
-    tar = {
-        'a': 1,
-        'b': [
-            {'c': array_c},
-            {'d': array_d},
-            {'e': array_e}
-        ],
-    }
-    # yapf: enable
-    same, _ = tf_test_utils.Trace._check_same(ref, tar, rtol=1e-6, atol=1e-6)
-    self.assertEqual(tar_same, same)
-
-  def test_trace_inputs_and_outputs(self):
-
-    def trace_function(module):
-      # No inputs or outputs
-      module.increment()
-      # Only inputs
-      module.increment_by(np.array([81.], dtype=np.float32))
-      # Only outputs
-      module.get_count()
-
-    module = tf_utils.TfCompiledModule.create_from_class(
-        StatefulCountingModule, tf_utils.BackendInfo('tf'))
-    trace = tf_test_utils.Trace(module, trace_function)
-    trace_function(tf_test_utils.TracedModule(module, trace))
-
-    self.assertIsInstance(trace.calls[0].inputs, tuple)
-    self.assertEmpty(trace.calls[0].inputs)
-    self.assertIsInstance(trace.calls[0].outputs, tuple)
-    self.assertEmpty(trace.calls[0].outputs)
-
-    self.assertAllClose(trace.calls[1].inputs[0], [81.])
-    self.assertAllClose(trace.calls[2].outputs[0], [82.])
-
-  def test_nonmatching_methods(self):
-
-    def tf_function(module):
-      module.increment()
-      module.increment()
-
-    def vmla_function(module):
-      module.increment()
-      module.decrement()
-
-    tf_module = tf_utils.TfCompiledModule.create_from_class(
-        StatefulCountingModule, tf_utils.BackendInfo('tf'))
-    tf_trace = tf_test_utils.Trace(tf_module, tf_function)
-    tf_function(tf_test_utils.TracedModule(tf_module, tf_trace))
-
-    vmla_module = tf_utils.IreeCompiledModule.create_from_class(
-        StatefulCountingModule, tf_utils.BackendInfo('iree_vmla'))
-    vmla_trace = tf_test_utils.Trace(vmla_module, vmla_function)
-    vmla_function(tf_test_utils.TracedModule(vmla_module, vmla_trace))
-
-    with self.assertRaises(ValueError):
-      tf_test_utils.Trace.compare_traces(tf_trace, vmla_trace)
-
-  def test_nonmatching_inputs(self):
-
-    def tf_function(module):
-      module.increment_by(np.array([42.], dtype=np.float32))
-
-    def vmla_function(module):
-      module.increment_by(np.array([22.], dtype=np.float32))
-
-    tf_module = tf_utils.TfCompiledModule.create_from_class(
-        StatefulCountingModule, tf_utils.BackendInfo('tf'))
-    tf_trace = tf_test_utils.Trace(tf_module, tf_function)
-    tf_function(tf_test_utils.TracedModule(tf_module, tf_trace))
-
-    vmla_module = tf_utils.IreeCompiledModule.create_from_class(
-        StatefulCountingModule, tf_utils.BackendInfo('iree_vmla'))
-    vmla_trace = tf_test_utils.Trace(vmla_module, vmla_function)
-    vmla_function(tf_test_utils.TracedModule(vmla_module, vmla_trace))
-
-    same, error_messages = tf_test_utils.Trace.compare_traces(
-        tf_trace, vmla_trace)
-    self.assertFalse(same)
-
-  def test_trace_serialize_and_load(self):
-
-    def trace_function(module):
-      module.increment()
-      module.increment_by(np.array([81.], dtype=np.float32))
-      module.increment_by_max(np.array([81], dtype=np.float32),
-                              np.array([92], dtype=np.float32))
-      module.get_count()
-
-    module = tf_utils.IreeCompiledModule.create_from_class(
-        StatefulCountingModule, tf_utils.BackendInfo('iree_vmla'))
-    trace = tf_test_utils.Trace(module, trace_function)
-    trace_function(tf_test_utils.TracedModule(module, trace))
-
-    with tempfile.TemporaryDirectory() as artifacts_dir:
-      trace_function_dir = tf_test_utils._get_trace_dir(artifacts_dir, trace)
-      trace.serialize(trace_function_dir)
-      self.assertTrue(
-          os.path.exists(os.path.join(trace_function_dir, 'metadata.pkl')))
-      loaded_trace = tf_test_utils.Trace.load(trace_function_dir)
-
-      # Check all calls match.
-      self.assertTrue(tf_test_utils.Trace.compare_traces(trace, loaded_trace))
-
-      # Check all other metadata match.
-      self.assertAllEqual(trace.__dict__.keys(), loaded_trace.__dict__.keys())
-      for key in trace.__dict__.keys():
-        if key != 'calls':
-          self.assertEqual(trace.__dict__[key], loaded_trace.__dict__[key])
+class TestUtilsTests(tf.test.TestCase):
 
   def test_tf_function_unittet(self):
 
@@ -260,9 +77,9 @@
       def __init__(self, *args, **kwargs):
         super().__init__(*args, **kwargs)
         self._modules = tf_test_utils.compile_tf_module(
-            TfFunctionUnittestModule)
+            TfFunctionUnitTestModule)
 
-    TfFunctionUnittestTest.generate_unit_tests(TfFunctionUnittestModule)
+    TfFunctionUnittestTest.generate_unit_tests(TfFunctionUnitTestModule)
     test_case = TfFunctionUnittestTest()
     self.assertTrue(hasattr(test_case, 'test_no_args'))
     self.assertTrue(hasattr(test_case, 'test_default_uniform_inputs'))
@@ -270,7 +87,7 @@
     self.assertTrue(hasattr(test_case, 'test_custom_input_args'))
     self.assertTrue(hasattr(test_case, 'test_high_tolerance'))
 
-    # Will throw an error if 'atol' and 'rtol' are not set.
+    # Will throw an error if 'atol' is not set.
     test_case = TfFunctionUnittestTest()
     test_case.test_high_tolerance()
 
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py
index aa94744..1af3cca 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py
@@ -14,21 +14,18 @@
 # limitations under the License.
 """Utilities interop with TensorFlow."""
 
-# pylint: disable=protected-access
-
-import collections
 import os
 import random
 import re
-import tempfile
-from typing import Any, Callable, Dict, Sequence, Set, Tuple, Type, Union
+from typing import Any, Callable, Sequence, Set, Tuple, Union
 
 from absl import logging
 import numpy as np
-from pyiree import rt
-from pyiree.tf import compiler
 import tensorflow.compat.v2 as tf
 
+InputGeneratorType = Callable[[Sequence[int], Union[tf.DType, np.dtype]],
+                              np.ndarray]
+
 
 def set_random_seed(seed: int = 0) -> None:
   """Set random seed for tf, np and random."""
@@ -37,10 +34,6 @@
   np.random.seed(seed)
 
 
-InputGeneratorType = Callable[[Sequence[int], Union[tf.DType, np.dtype]],
-                              np.ndarray]
-
-
 def uniform(shape: Sequence[int],
             dtype: Union[tf.DType, np.dtype] = np.float32,
             low: float = -1.0,
@@ -94,8 +87,32 @@
   return apply_function(spec, generate)
 
 
+def normalize_numpy(result: np.ndarray):
+  """Normalizes TF and TFLite's outputs to match IREE's"""
+  if np.isscalar(result):
+    result = np.array(result)
+  if result.dtype == np.bool:
+    # IREE interprets bools as int8s, so we modify this for comparison.
+    result = result.astype(dtype=np.int8)
+  return result
+
+
+def convert_to_numpy(values: Any) -> Any:
+  """Converts any tf.Tensor in values to numpy."""
+
+  def _convert_to_numpy(tensor: Any) -> Any:
+    if not isinstance(tensor, tf.Tensor):
+      return tensor
+    return normalize_numpy(tensor.numpy())
+
+  return apply_function(values, _convert_to_numpy)
+
+
 def to_mlir_type(dtype: np.dtype) -> str:
   """Returns a string that denotes the type 'dtype' in MLIR style."""
+  if not isinstance(dtype, np.dtype):
+    # Handle np.int8 _not_ being a dtype.
+    dtype = np.dtype(dtype)
   bits = dtype.itemsize * 8
   if np.issubdtype(dtype, np.integer):
     return f"i{bits}"
@@ -192,953 +209,90 @@
   return tf.TensorSpec([None] * len(spec.shape), spec.dtype)
 
 
-def _setup_mlir_crash_reproducer(
-    function: Any,  # pytype doesn't support arbitrary Callable[*args, **kwargs]
-    artifacts_dir: str,
-    backend_id: str,
-) -> Any:  # Callable[Any, Any]
-  """Wraps `function` so that it a MLIR crash reproducer is saved if it crashes.
+def check_same(ref: Any, tar: Any, rtol: float,
+               atol: float) -> Tuple[bool, Union[str, None]]:
+  """Checks that ref and tar have identical datastructures and values."""
+  # Check for matching types.
+  if not isinstance(tar, type(ref)):
+    error = ("Expected ref and tar to have the same type but got "
+             f"'{type(ref)}' and '{type(tar)}'")
+    logging.error(error)
+    return False, error
 
-  Writes to `artifacts_dir/reproducer__{backend}.mlir` in the case of a crash.
+  if ref is None:
+    # Nothing to compare (e.g. the called method had no outputs).
+    return True, None
 
-  Args:
-    function: The callable to decorate.
-    artifacts_dir: The directory to write the reproducer to.
-    backend_id: The unique backend name to use when writting the reproducer.
+  # Recursive check for dicts.
+  if isinstance(ref, dict):
+    if ref.keys() != tar.keys():
+      error = ("Expected ref and tar to have the same keys, but got "
+               f"'{ref.keys()}' and '{tar.keys()}'")
+      logging.error(error)
+      return False, error
+    # Check that all of the dictionaries' values are the same.
+    for key in ref:
+      same, error = check_same(ref[key], tar[key], rtol, atol)
+      if not same:
+        return same, error
 
-  Returns:
-    A function with the same API as the passed function.
-  """
+  # Recursive check for iterables.
+  elif isinstance(ref, list) or isinstance(ref, tuple):
+    if len(ref) != len(tar):
+      error = ("Expected ref and tar to have the same length, but got "
+               f"{len(ref)} and {len(tar)}")
+      logging.error(error)
+      return False, error
+    # Check that all of the iterables' values are the same.
+    for i in range(len(ref)):
+      same, error = check_same(ref[i], tar[i], rtol, atol)
+      if not same:
+        return same, error
 
-  def decorator(*args, **kwargs):
-    # Set up a crash reproducer for debugging.
-    if artifacts_dir is not None:
-      compiler.Context.default_crash_reproducer_path = os.path.join(
-          artifacts_dir, f"reproducer__{backend_id}.mlir")
-    try:
-      results = function(*args, **kwargs)
-    except Exception:  # pylint: disable=broad-except
-      # Disable the crash reproducer (to avoid inadvertently overwriting it).
-      if artifacts_dir is not None:
-        compiler.Context.default_crash_reproducer_path = None
-      raise
-    return results
+  # Base check for numpy arrays.
+  elif isinstance(ref, np.ndarray):
+    if ref.dtype != tar.dtype:
+      error = ("Expected ref and tar to have the same dtype, but got "
+               f"'{ref.dtype}' and '{tar.dtype}'")
+      logging.error(error)
+      return False, error
+    if ref.size == tar.size == 0:
+      return True, None
 
-  return decorator
-
-
-def _incrementally_lower_compiler_module(
-    compiler_module: compiler.Module,
-    backend_info: "BackendInfo",
-    artifacts_dir: str,
-) -> Tuple[compiler.binding.OpaqueBlob, Union[str, None]]:
-  """Lowers a MLIR compiler module incrementally and saves its outputs.
-
-  If artifacts_dir is provided then the following artifacts will be saved:
-    tf_input.mlir:
-      MLIR for the module in TF's input dialect.
-    iree_input.mlir:
-      The MLIR above translated to IREE via compiler.TF_IMPORT_PASS_PIPELINE.
-    backend_id/compiled.vmfb:
-      A VM FlatBuffer compiled to the target backend from the IREE MLIR above.
-
-  Args:
-    compiler_module: A compiler.Module to lower.
-    backend_info: BackendInfo with the details for lowering compiler_module to
-      IREE.
-    artifacts_dir: An optional string pointing to where compilation artifacts
-      should be saved. No compilation artifacts will be saved if this is not
-      provided.
-  """
-  if artifacts_dir is not None:
-    os.makedirs(artifacts_dir, exist_ok=True)
-    tf_mlir_path = os.path.join(artifacts_dir, "tf_input.mlir")
-    logging.info("Saving raw TF input MLIR to: %s", tf_mlir_path)
-    with open(tf_mlir_path, "w") as f:
-      f.write(compiler_module.to_asm())
-
-  # Manually run the passes that tf_module_to_compiler_module usually would.
-  compiler_module.run_pass_pipeline(compiler.TF_IMPORT_PASS_PIPELINE)
-
-  if artifacts_dir is not None:
-    iree_mlir_path = os.path.join(artifacts_dir, "iree_input.mlir")
-    logging.info("Saving IREE input MLIR to: %s", iree_mlir_path)
-    with open(iree_mlir_path, "w") as f:
-      f.write(compiler_module.to_asm())
-
-  compiled_module = compiler_module.compile(
-      target_backends=backend_info.compiler_targets)
-
-  compiled_path = None
-  if artifacts_dir is not None:
-    backend_dir = os.path.join(artifacts_dir, backend_info.backend_id)
-    os.makedirs(backend_dir, exist_ok=True)
-    compiled_path = os.path.join(backend_dir, "compiled.vmfb")
-    logging.info("Saving compiled IREE module to: %s", compiled_path)
-    with open(compiled_path, "wb") as f:
-      f.write(compiled_module)
-  return compiled_module, compiled_path
-
-
-def _incrementally_compile_tf_module(
-    module: Type[tf.Module],
-    backend_info: "BackendInfo",
-    exported_names: Sequence[str] = (),
-    artifacts_dir: str = None,
-) -> Tuple[compiler.binding.OpaqueBlob, Union[str, None]]:
-  """Compile a TensorFlow tf.Module and optionally save compilation artifacts.
-
-  The module blob this creates is not callable. See IreeCompiledModule for an
-  API that returns a module that can be called without any further steps.
-
-  See _incrementally_lower_compiler_module's docstring for details about which
-  artifacts will be saved.
-
-  Args:
-    module: A tf.Module.
-    backend_info: BackendInfo with the details for compiling this module.
-    exported_names: Optional sequence representing the exported names to keep.
-    artifacts_dir: An optional string pointing to where compilation artifacts
-      should be saved. No compilation artifacts will be saved if this is not
-      provided.
-
-  Returns:
-    A compiled IREE module blob and the path to the compiled VM FlatBuffer if
-    artifacts_dir is provided.
-  """
-
-  def _compile_module(module, backend_info, exported_names, artifacts_dir):
-    compiler_module = compiler.tf_module_to_compiler_module(module,
-                                                            exported_names,
-                                                            pass_pipeline=())
-    return _incrementally_lower_compiler_module(compiler_module, backend_info,
-                                                artifacts_dir)
-
-  _compile_module = _setup_mlir_crash_reproducer(_compile_module, artifacts_dir,
-                                                 backend_info.backend_id)
-  return _compile_module(module, backend_info, exported_names, artifacts_dir)
-
-
-def _incrementally_compile_tf_signature_def_saved_model(
-    saved_model_dir: str, saved_model_tags: Set[str],
-    backend_info: "BackendInfo", exported_name: str, artifacts_dir: str):
-  """Compile a SignatureDef SavedModel and optionally save compilation artifacts.
-
-  The module blob this creates is not callable. See IreeCompiledModule for an
-  API that returns a module that can be called without any further steps.
-
-  See _incrementally_lower_compiler_module's docstring for details about which
-  artifacts will be saved.
-
-  Args:
-    saved_model_dir: Directory of the saved model.
-    saved_model_tags: Optional set of tags to use when loading the model.
-    backend_info: BackendInfo with the details for compiling the saved model.
-    exported_name: A str representing the signature on the saved model to
-      compile.
-    artifacts_dir: An optional string pointing to where compilation artifacts
-      should be saved. No compilation artifacts will be saved if this is not
-      provided.
-
-  Returns:
-    A compiled IREE module blob and the path to the compiled VM FlatBuffer if
-    artifacts_dir is provided.
-  """
-
-  def _compile_module(saved_model_dir, saved_model_tags, backend_info,
-                      exported_name, artifacts_dir):
-    # Convert the tf_module into raw TF input MLIR.
-    compiler_module = compiler.tf_signature_def_saved_model_to_compiler_module(
-        saved_model_dir, saved_model_tags, [exported_name], pass_pipeline=())
-    return _incrementally_lower_compiler_module(compiler_module, backend_info,
-                                                artifacts_dir)
-
-  _compile_module = _setup_mlir_crash_reproducer(_compile_module, artifacts_dir,
-                                                 backend_info.backend_id)
-  return _compile_module(saved_model_dir, saved_model_tags, backend_info,
-                         exported_name, artifacts_dir)
-
-
-class _FunctionWrapper(object):
-
-  def __call__(self, *args, **kwargs):
-    raise NotImplementedError()
-
-  def get_serialized_values(self) -> Tuple[Tuple[str], Tuple[str]]:
-    """Dummy function to match _IreeFunctionWrapper's API."""
-    return ("",), ("",)
-
-
-class CompiledModule(object):
-  """Base class for the TF and IREE compiled modules."""
-
-  def __init__(
-      self,
-      module_name: str,
-      backend_info: "BackendInfo",
-      compiled_paths: Union[Dict[str, str], None],
-  ):
-    """Shared base constructor – not useful on its own.
-
-    Args:
-      module_name: A name for this compiled module. In most cases this will be
-        the name of the tf.Module subclass or instance that is compiled.
-      backend_info: BackendInfo with the details about compiling this module.
-      compiled_paths: A dictionary mapping compiled method names to file paths
-        corresponding to their serialized representations.
-    """
-    self.module_name = module_name
-    self.backend_info = backend_info
-    self.compiled_paths = compiled_paths
-
-  def reinitialize(self):
-    """Reinitializes all stateful variables."""
-    raise NotImplementedError()
-
-  @classmethod
-  def create_from_class(cls,
-                        module_class: Type[tf.Module],
-                        backend_info: "BackendInfo",
-                        exported_names: Sequence[str] = (),
-                        artifacts_dir: str = None):
-    """Compile a tf.Module subclass to the target backend in backend_info.
-
-    Args:
-      module_class: The tf.Module subclass to compile.
-      backend_info: BackendInfo with the details for compiling this module.
-      exported_names: Optional sequence representing the exported names to keep.
-      artifacts_dir: An optional string pointing to where compilation artifacts
-        should be saved. No compilation artifacts will be saved if this is not
-        provided.
-    """
-    raise NotImplementedError()
-
-  @classmethod
-  def create_from_instance(cls,
-                           module_instance: tf.Module,
-                           backend_info: "BackendInfo",
-                           exported_names: Sequence[str] = (),
-                           artifacts_dir: str = None):
-    """Compile a tf.Module instance to the target backend in backend_info.
-
-    This is only implemented for IreeCompiledModule.
-
-    Args:
-      module_instance: The tf.Module instance to compile.
-      backend_info: BackendInfo with the details for compiling module to IREE.
-      exported_names: Optional sequence representing the exported names to keep.
-      artifacts_dir: An optional string pointing to where compilation artifacts
-        should be saved. No compilation artifacts will be saved if this is not
-        provided.
-    """
-    raise NotImplementedError()
-
-  @classmethod
-  def create_from_signature_def_saved_model(cls,
-                                            saved_model_dir: str,
-                                            saved_model_tags: Set[str],
-                                            module_name: str,
-                                            backend_info: "BackendInfo",
-                                            exported_name: str,
-                                            input_names: Sequence[str],
-                                            output_names: Sequence[str],
-                                            artifacts_dir: str = None):
-    """Compile a SignatureDef SavedModel to the target backend in backend_info.
-
-    Args:
-      saved_model_dir: Directory of the saved model.
-      saved_model_tags: Optional set of tags to use when loading the model.
-      module_name: A name for this compiled module.
-      backend_info: BackendInfo with the details for compiling the saved model.
-      exported_name: A str representing the signature on the saved model to
-        compile.
-      input_names: A sequence of kwargs to feed to the saved model.
-      output_names: A sequence of named outputs to extract from the saved model.
-      artifacts_dir: An optional string pointing to where compilation artifacts
-        should be saved. No compilation artifacts will be saved if this is not
-        provided.
-    """
-    raise NotImplementedError()
-
-  def __getattr__(self, attr: str) -> _FunctionWrapper:
-    raise NotImplementedError()
-
-  def iree_serializable(self):
-    return False
-
-  def tflite_serializable(self):
-    return False
-
-
-class _IreeFunctionWrapper(_FunctionWrapper):
-  """Wraps an IREE function, making it callable."""
-
-  def __init__(self, context: rt.SystemContext, f: rt.system_api.BoundFunction):
-    self._context = context
-    self._f = f
-
-  def __call__(self, *args, **kwargs):
-    return self._f(*args, **kwargs)
-
-  def get_serialized_values(self) -> Tuple[Tuple[str], Tuple[str]]:
-    """Get cxx serialized inputs and outputs for this function."""
-    return self._f.get_serialized_values()
-
-
-class IreeCompiledModule(CompiledModule):
-  """Iree compiled module."""
-
-  def __init__(
-      self,
-      module_name: str,
-      backend_info: "BackendInfo",
-      compiled_paths: Dict[str, str],
-      vm_module: rt.VmModule,
-      config: rt.Config,
-  ):
-    """Base constructor – Use one of the named constructors instead.
-
-    Args:
-      module_name: A name for this compiled module. In most cases this will be
-        the name of the tf.Module subclass or instance that is compiled.
-      backend_info: BackendInfo with the details about compiling this module.
-      compiled_paths: A dictionary mapping compiled method names to file paths
-        corresponding to their serialized representations.
-      vm_module: A rt.VmModule containing compilation info to wrap.
-      config: A rt.Config containing compilation info to wrap.
-    """
-    super().__init__(module_name, backend_info, compiled_paths)
-    self._vm_module = vm_module
-    self._config = config
-    self.reinitialize()
-
-  @classmethod
-  def create_from_class(cls,
-                        module_class: Type[tf.Module],
-                        backend_info: "BackendInfo",
-                        exported_names: Sequence[str] = (),
-                        artifacts_dir: str = None):
-    """Compile a tf.Module subclass to the target backend in backend_info.
-
-    Args:
-      module_class: The tf.Module subclass to compile.
-      backend_info: BackendInfo with the details for compiling module to IREE.
-      exported_names: Optional sequence representing the exported names to keep.
-      artifacts_dir: An optional string pointing to where compilation artifacts
-        should be saved. No compilation artifacts will be saved if this is not
-        provided.
-    """
-    set_random_seed()
-    module_instance = module_class()
-    return cls.create_from_instance(module_instance, backend_info,
-                                    exported_names, artifacts_dir)
-
-  @classmethod
-  def create_from_instance(cls,
-                           module_instance: tf.Module,
-                           backend_info: "BackendInfo",
-                           exported_names: Sequence[str] = (),
-                           artifacts_dir: str = None):
-    """Compile a tf.Module instance to the target backend in backend_info.
-
-    Args:
-      module_instance: The tf.Module instance to compile.
-      backend_info: BackendInfo with the details for compiling module to IREE.
-      exported_names: Optional sequence representing the exported names to keep.
-      artifacts_dir: An optional string pointing to where compilation artifacts
-        should be saved. No compilation artifacts will be saved if this is not
-        provided.
-    """
-    module_blob, compiled_path = _incrementally_compile_tf_module(
-        module=module_instance,
-        backend_info=backend_info,
-        exported_names=exported_names,
-        artifacts_dir=artifacts_dir)
-    vm_module = rt.VmModule.from_flatbuffer(module_blob)
-    config = rt.Config(driver_name=backend_info.driver)
-
-    compiled_paths = None
-    if compiled_path is not None:
-      # IREE bundles every compiled method into the same compiled module.
-      compiled_paths = collections.defaultdict(lambda: compiled_path)
-
-    module_name = type(module_instance).__name__
-
-    return cls(module_name, backend_info, compiled_paths, vm_module, config)
-
-  @classmethod
-  def create_from_signature_def_saved_model(cls,
-                                            saved_model_dir: str,
-                                            saved_model_tags: Set[str],
-                                            module_name: str,
-                                            backend_info: "BackendInfo",
-                                            exported_name: str,
-                                            input_names: Sequence[str],
-                                            output_names: Sequence[str],
-                                            artifacts_dir: str = None):
-    """Compile a SignatureDef SavedModel to the target backend in backend_info.
-
-    Args:
-      saved_model_dir: Directory of the saved model.
-      saved_model_tags: Optional set of tags to use when loading the model.
-      module_name: A name for this compiled module.
-      backend_info: BackendInfo with the details for compiling the saved model.
-      exported_name: A str representing the signature on the saved model to
-        compile.
-      input_names: A sequence of kwargs to feed to the saved model.
-      output_names: A sequence of named outputs to extract from the saved model.
-      artifacts_dir: An optional string pointing to where compilation artifacts
-        should be saved. No compilation artifacts will be saved if this is not
-        provided.
-    """
-    del input_names  # Unused.
-    del output_names  # Unused.
-    module_blob, compiled_path = _incrementally_compile_tf_signature_def_saved_model(
-        saved_model_dir, saved_model_tags, backend_info, exported_name,
-        artifacts_dir)
-    vm_module = rt.VmModule.from_flatbuffer(module_blob)
-    config = rt.Config(driver_name=backend_info.driver)
-
-    compiled_paths = None
-    if compiled_path is not None:
-      # IREE bundles every compiled method into the same compiled module :)
-      compiled_paths = collections.defaultdict(lambda: compiled_path)
-
-    return cls(module_name, backend_info, compiled_paths, vm_module, config)
-
-  def reinitialize(self):
-    """Reinitializes all stateful variables."""
-    # set_random_seed is not needed here because the model_class.__init__ is not
-    # called.
-    self._context = rt.SystemContext(modules=[self._vm_module],
-                                     config=self._config)
-
-  def __getattr__(self, attr: str) -> _IreeFunctionWrapper:
-    # Try to resolve it as a function.
-    m = self._context.modules[self._vm_module.name]
-    f = m[attr]
-    return _IreeFunctionWrapper(self._context, f)
-
-  def iree_serializable(self) -> bool:
-    return self.compiled_paths is not None
-
-
-def _normalize_numpy(result: np.ndarray):
-  """Normalizes TF and TFLite's outputs to match IREE's"""
-  if np.isscalar(result):
-    result = np.array(result)
-  if result.dtype == np.bool:
-    # IREE interprets bools as int8s, so we modify this for comparison.
-    result = result.astype(dtype=np.int8)
-  return result
-
-
-def _convert_to_numpy(tensor: Any) -> Any:
-  if not isinstance(tensor, tf.Tensor):
-    return tensor
-  return _normalize_numpy(tensor.numpy())
-
-
-def convert_to_numpy(values: Any) -> Any:
-  """Converts any tf.Tensor in values to numpy."""
-  return apply_function(values, _convert_to_numpy)
-
-
-class _TfFunctionWrapper(_FunctionWrapper):
-  """Wraps a TF function, normalizing it to numpy."""
-
-  def __init__(self, f: Callable[..., Any]):
-    self._f = f
-
-  def __call__(self, *args, **kwargs):
-    # TensorFlow will auto-convert all inbound args.
-    results = self._f(*args, **kwargs)
-    return convert_to_numpy(results)
-
-
-def _convert_inputs_to_tensors(function):
-
-  def decorator(*args, **kwargs):
-    args = [tf.convert_to_tensor(arg) for arg in args]
-    kwargs = {k: tf.convert_to_tensor(v) for k, v in kwargs.items()}
-    return function(*args, **kwargs)
-
-  return decorator
-
-
-class SignatureDefSavedModelWrapper(object):
-  """Wraps a SavedModel to imitate a tf.Module with a method 'exported_name'."""
-
-  def __init__(self, saved_model_dir: str, saved_model_tags: Set[str],
-               exported_name: str):
-    self.saved_model = tf.saved_model.load(saved_model_dir,
-                                           tags=saved_model_tags)
-    inference_func = self.saved_model.signatures[exported_name]
-    inference_func = _convert_inputs_to_tensors(inference_func)
-    self.__setattr__(exported_name, inference_func)
-
-
-class TfCompiledModule(CompiledModule):
-  """TensorFlow 'compiled' module.
-
-  This facade exists to provide a complimentary API to IreeCompiledModule and
-  normalize TensorFlow's output to Numpy.
-  """
-
-  def __init__(
-      self,
-      module_name: str,
-      backend_info: "BackendInfo",
-      constructor: Callable[[], tf.Module],
-      exported_names: Sequence[str],
-  ):
-    """Base constructor – Use one of the named constructors instead.
-
-    Args:
-      module_name: A name for this compiled module. In most cases this will be
-        the name of the tf.Module subclass or instance that is compiled.
-      backend_info: BackendInfo with the details about compiling this module.
-      constructor: A callable (class or function) which returns the tf.Module
-        subclass instance to wrap.
-      exported_names: an optional iterable of strings representing which of the
-        tf.Module subclass instance's functions should be callable. If
-        exported_names is empty then all functions will be callable.
-    """
-    super().__init__(module_name, backend_info, compiled_paths=None)
-    self._constructor = constructor
-    self._exported_names = exported_names
-    self.reinitialize()
-
-  @classmethod
-  def create_from_class(cls,
-                        module_class: Type[tf.Module],
-                        backend_info: "BackendInfo",
-                        exported_names: Sequence[str] = (),
-                        artifacts_dir: str = None):
-    """Compile a tf.Module subclass to the target backend in backend_info.
-
-    Args:
-      module_class: The tf.Module subclass to compile.
-      backend_info: BackendInfo with the details for compiling this module.
-      exported_names: Optional sequence representing the exported names to keep.
-      artifacts_dir: An optional string pointing to where compilation artifacts
-        should be saved. No compilation artifacts will be saved if this is not
-        provided.
-    """
-    module_name = module_class.__name__
-    constructor = module_class
-    return cls(module_name, backend_info, constructor, exported_names)
-
-  @classmethod
-  def create_from_signature_def_saved_model(cls,
-                                            saved_model_dir: str,
-                                            saved_model_tags: Set[str],
-                                            module_name: str,
-                                            backend_info: "BackendInfo",
-                                            exported_name: str,
-                                            input_names: Sequence[str],
-                                            output_names: Sequence[str],
-                                            artifacts_dir: str = None):
-    """Compile a SignatureDef SavedModel to the target backend in backend_info.
-
-    Args:
-      saved_model_dir: Directory of the saved model.
-      saved_model_tags: Optional set of tags to use when loading the model.
-      module_name: A name for this compiled module.
-      backend_info: BackendInfo with the details for compiling the saved model.
-      exported_name: A str representing the signature on the saved model to
-        compile.
-      input_names: A sequence of kwargs to feed to the saved model.
-      output_names: A sequence of named outputs to extract from the saved model.
-      artifacts_dir: An optional string pointing to where compilation artifacts
-        should be saved. No compilation artifacts will be saved if this is not
-        provided.
-    """
-    constructor = lambda: SignatureDefSavedModelWrapper(
-        saved_model_dir, saved_model_tags, exported_name)
-    return cls(module_name, backend_info, constructor, [exported_name])
-
-  def reinitialize(self):
-    """Reinitializes all stateful variables."""
-    set_random_seed()
-    self._tf_module = self._constructor()
-
-  def __getattr__(self, attr: str) -> _TfFunctionWrapper:
-    # Try to resolve it as a function.
-    exported = not self._exported_names or attr in self._exported_names
-    if not hasattr(self._tf_module, attr) or not exported:
-      raise AttributeError(f"The TensorFlow module does not have attr '{attr}'")
-    f = getattr(self._tf_module, attr)
-    if not f or not hasattr(f, "__call__"):
-      raise AttributeError(
-          f"The TensorFlow module does not have a callable attr '{attr}'")
-    return _TfFunctionWrapper(f)
-
-
-def _get_non_inhereted_function_names(cls):
-  """Gets all methods that cls has that its parents don't have."""
-  names = set(dir(cls))
-  for parent in cls.__bases__:
-    names -= set(dir(parent))
-  return list(names)
-
-
-def _get_concrete_functions(module_class: Type[tf.Module],
-                            exported_names: Sequence[str] = ()):
-  """Get concrete functions from non-inherited methods or exported_names."""
-  if not len(exported_names):
-    # Get all method names on 'module_class' that aren't on 'tf.Module'.
-    exported_names = _get_non_inhereted_function_names(module_class)
-  instance = module_class()
-  functions = []
-  for name in exported_names:
-    functions.append(getattr(instance, name).get_concrete_function())
-  return functions, exported_names, instance
-
-
-def tf_module_to_tflite_module_bytes(
-    module_class: Type[tf.Module], exported_names: Sequence[str] = ()
-) -> Dict[str, bytes]:
-  """Compiles a tf.Module's methods with TFLite.
-
-  Args:
-    module_class: A tf.Module subclass to compile with TFLite.
-    exported_names: an optional iterable of strings representing which of the
-      module_class's functions should be compiled. If exported_names is empty
-      then all functions will be compiled.
-
-  Returns:
-    A dict mapping method names to compiled TFLite module bytes.
-  """
-  tflite_modules = []
-  methods, method_names, instance = _get_concrete_functions(
-      module_class, exported_names)
-  failed_methods = []
-  for method, method_name in zip(methods, method_names):
-    logging.info("Attempting to convert '%s' to tflite...", method_name)
-    try:
-      converter = tf.lite.TFLiteConverter.from_concrete_functions([method])
-      logging.info("...converted '%s' to tflite.", method_name)
-      tflite_modules.append(converter.convert())
-    except Exception as e:
-      logging.error("Failed to convert '%s' to tflite.", method_name)
-      logging.error("TFLite excpetion: %s", e)
-      failed_methods.append(method_name)
-
-  if failed_methods:
-    raise RuntimeError(
-        f"Failed to convert the following methods to tflite: {failed_methods}")
-
-  # Keep variables alive until TFLite has done the conversion; ConcreteFunctions
-  # themselves only keep weak references to variables.
-  del instance
-  return dict(zip(method_names, tflite_modules))
-
-
-def tf_signature_def_saved_model_to_tflite_module_bytes(
-    saved_model_dir: str,
-    saved_model_tags: Set[str],
-    exported_name: str,
-    input_names: Sequence[str],
-    output_names: Sequence[str],
-) -> Dict[str, bytes]:
-  """Compiles a SignatureDef SavedModel signature with TFLite.
-
-  Args:
-    saved_model_dir: Directory of the saved model.
-    saved_model_tags: Optional set of tags to use when loading the model.
-    exported_name: A str representing the signature on the saved model to
-      compile.
-    input_names: A sequence of kwargs to feed to the saved model.
-    output_names: A sequence of named outputs to extract from the saved model.
-
-  Returns:
-    A dict mapping the signature name to the compiled TFLite module bytes.
-  """
-  converter = tf.compat.v1.lite.TFLiteConverter.from_saved_model(
-      saved_model_dir,
-      tag_set=saved_model_tags,
-      signature_key=exported_name,
-      input_arrays=input_names,
-      output_arrays=output_names)
-  tflite_module = converter.convert()
-  return dict([[exported_name, tflite_module]])
-
-
-def tflite_module_bytes_to_tflite_interpreters(
-    tflite_module_bytes: Dict[str, bytes],
-    artifacts_dir: str = None
-) -> Tuple[Dict[str, tf.lite.Interpreter], Union[Dict[str, str], None]]:
-  """Compile a dict of TFLite compiled bytes to  TFLite interpreters.
-
-  Args:
-    tflite_module_bytes: A dict mapping method names to compiled TFLite byte
-      strings.
-    artifacts_dir: an optional path to save compilation artifacts to.
-
-  Returns:
-    A dictionary mapping method names to TFLite interpreters and a dictionary
-    mapping method names to compiled tflite graph paths (or None if
-    artifacts_dir is None).
-  """
-  interpreters = dict()
-  compiled_paths = None
-  if artifacts_dir is not None:
-    compiled_paths = dict()
-
-  def _interpret_bytes(method_name: str, tflite_module: bytes, base_dir: str):
-    """Save compiled TFLite module bytes and convert into an interpreter."""
-    tflite_dir = os.path.join(base_dir, "tflite")
-    os.makedirs(tflite_dir, exist_ok=True)
-    tflite_path = os.path.join(tflite_dir, f"{method_name}.tflite")
-    with open(tflite_path, "wb") as f:
-      f.write(tflite_module)
-
-    interpreters[method_name] = tf.lite.Interpreter(tflite_path)
-    if artifacts_dir is not None:
-      compiled_paths[method_name] = tflite_path
-
-  # Load each of the converted methods above into tf.lite.Interpreters.
-  for method_name, tflite_module in tflite_module_bytes.items():
-    if artifacts_dir is None:
-      with tempfile.TemporaryDirectory() as base_dir:
-        _interpret_bytes(method_name, tflite_module, base_dir)
-    else:
-      _interpret_bytes(method_name, tflite_module, artifacts_dir)
-
-  return interpreters, compiled_paths
-
-
-class _TfLiteFunctionWrapper(_FunctionWrapper):
-  """Wraps a TFLite interpreter and makes it behave like a python function."""
-
-  def __init__(self, interpreter: tf.lite.Interpreter,
-               output_names: Sequence[str]):
-    self._interpreter = interpreter
-    self._output_names = output_names
-
-  def __call__(self, *args,
-               **kwargs) -> Union[Dict[str, Any], Tuple[Any], np.ndarray]:
-    if len(args) and len(kwargs):
-      raise ValueError("Passing both args and kwargs is not supported by "
-                       "_TfLiteFunctionWrapper")
-
-    # Set up and run the function.
-    self._interpreter.allocate_tensors()
-
-    if len(args):
-      # Specifically to get TFLite to work with keras models that take a list of
-      # inputs instead of a sequence of args as their inputs, because it decides
-      # to change the input signature but it still technically works if you
-      # ignore that it does that.
-      if len(args) == 1 and isinstance(args[0], list):
-        args = args[0]
-
-      for arg, detail in zip(args, self._interpreter.get_input_details()):
-        self._interpreter.set_tensor(detail["index"], arg)
-    else:
-      for detail in self._interpreter.get_input_details():
-        self._interpreter.set_tensor(detail["index"], kwargs[detail["name"]])
-
-    self._interpreter.invoke()
-
-    # Extract the outputs from the TFLite interpreter.
-    outputs = []
-    for detail in self._interpreter.get_output_details():
-      value = _normalize_numpy(self._interpreter.get_tensor(detail["index"]))
-      if self._output_names is not None:
-        name = detail["name"]
-        if name not in self._output_names:
-          raise ValueError(f"Expected '{name}' to be in {self._output_names}")
-        outputs.append([detail["name"], value])
+    if np.issubdtype(ref.dtype, np.floating):
+      same = np.allclose(ref, tar, rtol=rtol, atol=atol, equal_nan=True)
+      abs_diff = np.max(np.abs(ref - tar))
+      rel_diff = np.max(np.abs(ref - tar) / np.max(np.abs(tar)))
+      diff_string = (f"Max abs diff: {abs_diff:.2e}, atol: {atol:.2e}, "
+                     f"max relative diff: {rel_diff:.2e}, rtol: {rtol:.2e}")
+      if not same:
+        error = ("Floating point difference between ref and tar was too "
+                 f"large. {diff_string}")
+        logging.error(error)
       else:
-        outputs.append(value)
-
-    # Process them to match the output of the tf.Module.
-    if self._output_names is not None:
-      return dict(outputs)
+        error = None
+        logging.info(
+            "Floating point difference between ref and tar was within "
+            "tolerance. %s", diff_string)
+      return same, error
+    elif np.issubdtype(ref.dtype, np.integer):
+      same = np.array_equal(ref, tar)
+      if not same:
+        abs_diff = np.max(np.abs(ref - tar))
+        error = ("Expected array equality between ref and tar, but got "
+                 f"a max elementwise difference of {abs_diff}")
+        logging.error(error)
+      else:
+        error = None
+      return same, error
     else:
-      if len(outputs) == 1:
-        return outputs[0]
-      return tuple(outputs)
+      return np.array_equal(ref, tar), None
 
+  # Base check for native number types.
+  elif isinstance(ref, (int, float)):
+    return ref == tar, None
 
-class TfLiteCompiledModule(CompiledModule):
-  """Compiles a tf.Module with TFLite and allows it to be called."""
-
-  def __init__(
-      self,
-      module_name: str,
-      backend_info: "BackendInfo",
-      compiled_paths: Dict[str, str],
-      interpreters: Dict[str, tf.lite.Interpreter],
-      output_names: Sequence[str] = None,
-  ):
-    """Base constructor – Use one of the named constructors instead.
-
-    Args:
-      module_name: A name for this compiled module. In most cases this will be
-        the name of the tf.Module subclass or instance that is compiled.
-      backend_info: BackendInfo with the details about compiling this module.
-      compiled_paths: A dictionary mapping compiled method names to file paths
-        corresponding to their serialized representations.
-      interpreters: A dict of tf.lite.Interpreters to make callable.
-    """
-    super().__init__(module_name, backend_info, compiled_paths)
-    self._interpreters = interpreters
-    self._output_names = output_names
-
-  @classmethod
-  def create_from_class(cls,
-                        module_class: Type[tf.Module],
-                        backend_info: "BackendInfo",
-                        exported_names: Sequence[str] = (),
-                        artifacts_dir: str = None):
-    """Compile a tf.Module subclass to the target backend in backend_info.
-
-    Args:
-      module_class: The tf.Module subclass to compile.
-      backend_info: BackendInfo with the details for compiling this module.
-      exported_names: Optional sequence representing the exported names to keep.
-      artifacts_dir: An optional string pointing to where compilation artifacts
-        should be saved. No compilation artifacts will be saved if this is not
-        provided.
-    """
-    set_random_seed()
-    tflite_module_bytes = tf_module_to_tflite_module_bytes(
-        module_class, exported_names)
-    interpreters, compiled_paths = tflite_module_bytes_to_tflite_interpreters(
-        tflite_module_bytes, artifacts_dir)
-    module_name = module_class.__name__
-    return cls(module_name, backend_info, compiled_paths, interpreters)
-
-  @classmethod
-  def create_from_signature_def_saved_model(cls,
-                                            saved_model_dir: str,
-                                            saved_model_tags: Set[str],
-                                            module_name: str,
-                                            backend_info: "BackendInfo",
-                                            exported_name: str,
-                                            input_names: Sequence[str],
-                                            output_names: Sequence[str],
-                                            artifacts_dir: str = None):
-    """Compile a SignatureDef SavedModel to the target backend in backend_info.
-
-    Args:
-      saved_model_dir: Directory of the saved model.
-      saved_model_tags: Optional set of tags to use when loading the model.
-      module_name: A name for this compiled module.
-      backend_info: BackendInfo with the details for compiling the saved model.
-      exported_name: A str representing the signature on the saved model to
-        compile.
-      input_names: A sequence of kwargs to feed to the saved model.
-      output_names: A sequence of named outputs to extract from the saved model.
-      artifacts_dir: An optional string pointing to where compilation artifacts
-        should be saved. No compilation artifacts will be saved if this is not
-        provided.
-    """
-    tflite_module_bytes = tf_signature_def_saved_model_to_tflite_module_bytes(
-        saved_model_dir, saved_model_tags, exported_name, input_names,
-        output_names)
-    interpreters, compiled_paths = tflite_module_bytes_to_tflite_interpreters(
-        tflite_module_bytes, artifacts_dir)
-    return cls(module_name, backend_info, compiled_paths, interpreters,
-               output_names)
-
-  def reinitialize(self):
-    """Reinitializes all stateful variables."""
-    # This is a noop because TFLite (mostly) doesn't support stateful modules.
-    pass
-
-  def __getattr__(self, attr: str) -> _TfLiteFunctionWrapper:
-    # Try to resolve it as an interpreter.
-    if not attr in self._interpreters:
-      raise AttributeError(
-          f"The TFLite module does not have an interpreter for '{attr}'")
-    return _TfLiteFunctionWrapper(self._interpreters[attr], self._output_names)
-
-  def tflite_serializable(self) -> bool:
-    return self.compiled_paths is not None
-
-
-class BackendInfo:
-  """Contains information for compiling the specified backend."""
-
-  _name_to_info = {
-      "tf": {
-          "compiled_module_class": TfCompiledModule,
-          "driver": None,
-          "compiler_targets": None,
-      },
-      "tflite": {
-          "compiled_module_class": TfLiteCompiledModule,
-          "driver": None,
-          "compiler_targets": None,
-      },
-      "iree_vmla": {
-          "compiled_module_class": IreeCompiledModule,
-          "driver": "vmla",
-          "compiler_targets": ["vmla"]
-      },
-      "iree_vulkan": {
-          "compiled_module_class": IreeCompiledModule,
-          "driver": "vulkan",
-          "compiler_targets": ["vulkan-*"]
-      },
-  }
-
-  def __init__(self, backend_name: str, backend_id: str = None):
-    """Creates a BackendInfo with the compilation details for backend_name.
-
-    Args:
-      backend_name: a str specifying which backend to use. Should be one of
-        'tf', 'iree_vmla', 'iree_vulkan'.
-      backend_id: an optional str specifying what name to use when saving
-        compiled artifacts. Must satisfy `backend_id.startswith(backend_name)`.
-
-    Raises:
-      KeyError: if backend_name is not one of
-        ['tf', 'iree_vmla', 'iree_vulkan'].
-      ValueError: if backend_id doesn't start with backend_name.
-    """
-    if backend_name not in self._name_to_info:
-      raise KeyError(
-          "Expected backend_name to be one of "
-          f"{list(self._name_to_info.keys())} but got '{backend_name}'.")
-    if backend_id is not None and not backend_id.startswith(backend_name):
-      raise ValueError(f"Expected backend_id to start with '{backend_name}' "
-                       f"but got '{backend_id}'.")
-
-    self.backend_name = backend_name
-    self.backend_id = backend_name if backend_id is None else backend_id
-
-    info = self._name_to_info[backend_name]
-    self._compiled_module_class = info["compiled_module_class"]
-    self.driver = info["driver"]
-    self.compiler_targets = info["compiler_targets"]
-
-  def compile_from_class(self,
-                         module_class: Type[tf.Module],
-                         exported_names: Sequence[str] = (),
-                         artifacts_dir: str = None) -> CompiledModule:
-    """Creates a 'CompiledModule' for this backend."""
-    return self._compiled_module_class.create_from_class(
-        module_class, self, exported_names, artifacts_dir)
-
-  def compile_signature_def_saved_model(
-      self,
-      saved_model_dir: str,
-      saved_model_tags: Set[str],
-      module_name: str,
-      exported_name: str,
-      input_names: Sequence[str],
-      output_names: Sequence[str],
-      artifacts_dir: str = None) -> CompiledModule:
-    return self._compiled_module_class.create_from_signature_def_saved_model(
-        saved_model_dir, saved_model_tags, module_name, self, exported_name,
-        input_names, output_names, artifacts_dir)
-
-  @classmethod
-  def get_all_backends(cls) -> Sequence["BackendInfo"]:
-    """Returns a list of all BackendInfo configurations."""
-    return [BackendInfo(backend_name) for backend_name in cls._name_to_info]
+  # If outputs end up here then an extra branch for that type should be added.
+  else:
+    raise TypeError(f"Encountered results with unexpected type {type(ref)}")
+  return True, None
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils_test.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils_test.py
index 9934460..deef8d9 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils_test.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils_test.py
@@ -14,125 +14,87 @@
 # limitations under the License.
 """Tests for pyiree.tf.support.tf_utils."""
 
-import os
-import tempfile
-
-from absl import logging
 from absl.testing import parameterized
 import numpy as np
 from pyiree.tf.support import tf_utils
 import tensorflow as tf
 
 
-class ConstantModule(tf.Module):
-
-  @tf.function(input_signature=[])
-  def meaning(self):
-    return tf.constant([42.])
-
-
-class StatefulCountingModule(tf.Module):
-
-  def __init__(self):
-    self.count = tf.Variable([0.])
-
-  @tf.function(input_signature=[])
-  def increment(self):
-    self.count.assign_add(tf.constant([1.]))
-
-  @tf.function(input_signature=[])
-  def get_count(self):
-    return self.count
-
-
-class RandomInitModule(tf.Module):
-
-  def __init__(self):
-    self.value = tf.Variable(tf.random.uniform([1]))
-
-  @tf.function(input_signature=[])
-  def get(self):
-    return self.value
-
-
 class UtilsTests(tf.test.TestCase, parameterized.TestCase):
 
-  def test_artifact_saving(self):
-    backend_info = tf_utils.BackendInfo('iree_vmla')
-    with tempfile.TemporaryDirectory() as artifacts_dir:
-      tf_module = ConstantModule()
-      iree_compiled_module, compiled_path = (
-          tf_utils._incrementally_compile_tf_module(
-              tf_module, backend_info=backend_info,
-              artifacts_dir=artifacts_dir))
+  @parameterized.named_parameters([('int8_to_i8', np.int8, 'i8'),
+                                   ('int32_to_i32', np.int32, 'i32'),
+                                   ('float32_to_f32', np.float32, 'f32'),
+                                   ('float64_to_f64', np.float64, 'f64')])
+  def test_to_mlir_type(self, numpy_type, mlir_type):
+    self.assertEqual(tf_utils.to_mlir_type(numpy_type), mlir_type)
 
-      artifacts_to_check = [
-          'tf_input.mlir',
-          'iree_input.mlir',
-          compiled_path,
-      ]
-      for artifact in artifacts_to_check:
-        artifact_path = os.path.join(artifacts_dir, artifact)
-        logging.info('Checking path: %s', artifact_path)
-        self.assertTrue(os.path.exists(artifact_path))
+  @parameterized.named_parameters([
+      ('single_i32', [np.array([1, 2], dtype=np.int32)], '2xi32=1 2'),
+      ('single_f32', [np.array([1, 2], dtype=np.float32)], '2xf32=1.0 2.0'),
+  ])
+  def test_save_input_values(self, inputs, inputs_str):
+    self.assertEqual(tf_utils.save_input_values(inputs), inputs_str)
+
+  def test_apply_function(self):
+    inputs = [1, [2, 3], (4, 5), {'6': 6, '78': [7, 8]}]
+    expected = [0, [1, 2], (3, 4), {'6': 5, '78': [6, 7]}]
+    result = tf_utils.apply_function(inputs, lambda x: x - 1)
+    self.assertEqual(result, expected)
+    self.assertNotEqual(inputs, expected)
 
   @parameterized.named_parameters([
       {
-          'testcase_name': 'tensorflow',
-          'backend_name': 'tf',
+          'testcase_name': 'all the same',
+          'array_c': np.array([0, 1, 2]),
+          'array_d': np.array(['0', '1', '2']),
+          'array_e': np.array([0.0, 0.1, 0.2]),
+          'tar_same': True,
       },
       {
-          'testcase_name': 'vmla',
-          'backend_name': 'iree_vmla',
+          'testcase_name': 'wrong int',
+          'array_c': np.array([1, 1, 2]),
+          'array_d': np.array(['0', '1', '2']),
+          'array_e': np.array([0.0, 0.1, 0.2]),
+          'tar_same': False,
+      },
+      {
+          'testcase_name': 'wrong string',
+          'array_c': np.array([0, 1, 2]),
+          'array_d': np.array(['a', '1', '2']),
+          'array_e': np.array([0.0, 0.1, 0.2]),
+          'tar_same': False,
+      },
+      {
+          'testcase_name': 'wrong float',
+          'array_c': np.array([0, 1, 2]),
+          'array_d': np.array(['0', '1', '2']),
+          'array_e': np.array([1.0, 0.1, 0.2]),
+          'tar_same': False,
       },
   ])
-  def test_unaltered_state(self, backend_name):
-    backend_info = tf_utils.BackendInfo(backend_name)
-    module = backend_info.compile_from_class(StatefulCountingModule)
+  def test_recursive_check_same(self, array_c, array_d, array_e, tar_same):
 
-    # Test that incrementing works properly.
-    self.assertEqual([0.], module.get_count())
-    module.increment()
-    self.assertEqual([1.], module.get_count())
-
-    module.reinitialize()
-    # Test reinitialization.
-    self.assertEqual([0.], module.get_count())
-
-  def test_to_mlir_type(self):
-    self.assertEqual('i8', tf_utils.to_mlir_type(np.dtype('int8')))
-    self.assertEqual('i32', tf_utils.to_mlir_type(np.dtype('int32')))
-    self.assertEqual('f32', tf_utils.to_mlir_type(np.dtype('float32')))
-    self.assertEqual('f64', tf_utils.to_mlir_type(np.dtype('float64')))
-
-  def test_save_input_values(self):
-    inputs = [np.array([1, 2], dtype=np.int32)]
-    self.assertEqual('2xi32=1 2', tf_utils.save_input_values(inputs))
-    inputs = [np.array([1, 2], dtype=np.float32)]
-    self.assertEqual('2xf32=1.0 2.0', tf_utils.save_input_values(inputs))
-
-  @parameterized.named_parameters([
-      {
-          'testcase_name': 'tensorflow',
-          'backend_name': 'tf',
-      },
-      {
-          'testcase_name': 'vmla',
-          'backend_name': 'iree_vmla',
-      },
-  ])
-  def test_random_initialization(self, backend_name):
-    backend_info = tf_utils.BackendInfo(backend_name)
-
-    # Test compilation is the same.
-    module_1 = backend_info.compile_from_class(RandomInitModule)
-    module_2 = backend_info.compile_from_class(RandomInitModule)
-    self.assertAllEqual(module_1.get(), module_2.get())
-
-    # Test reinitialization is the same.
-    old_value = module_1.get()
-    module_1.reinitialize()
-    self.assertAllEqual(old_value, module_1.get())
+    # yapf: disable
+    ref = {
+        'a': 1,
+        'b': [
+            {'c': np.array([0, 1, 2])},
+            {'d': np.array(['0', '1', '2'])},
+            {'e': np.array([0.0, 0.1, 0.2])}
+        ],
+    }
+    tar = {
+        'a': 1,
+        'b': [
+            {'c': array_c},
+            {'d': array_d},
+            {'e': array_e}
+        ],
+    }
+    # yapf: enable
+    same, _ = tf_utils.check_same(ref, tar, rtol=1e-6, atol=1e-6)
+    self.assertEqual(tar_same, same)
 
 
 if __name__ == '__main__':
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/trace_utils.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/trace_utils.py
new file mode 100644
index 0000000..1a0c789
--- /dev/null
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/trace_utils.py
@@ -0,0 +1,421 @@
+# Lint as: python3
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Utilities for tracing tf.function inputs and outputs."""
+
+# This file uses the following abbreviations:
+#   ref: reference – for the reference CompiledModule
+#   tar: target - for one of the target CompiledModules
+
+import copy
+import glob
+import inspect
+import os
+import pickle
+import sys
+import textwrap
+from typing import Any, Callable, Dict, Sequence, Tuple, Union
+
+from absl import logging
+import numpy as np
+from pyiree.tf.support import module_utils
+from pyiree.tf.support import tf_utils
+import tensorflow.compat.v2 as tf
+
+NUMPY_LINEWIDTH = 120
+INDENT = " " * 2
+
+
+def _zfill_width(length: int) -> Union[int, None]:
+  return int(np.ceil(np.log10(length))) if length else None
+
+
+def get_trace_dir(artifacts_dir: str, trace: "Trace") -> str:
+  trace_dir = os.path.join(artifacts_dir, trace.backend_id, "traces",
+                           trace.function_name)
+  os.makedirs(trace_dir, exist_ok=True)
+  return trace_dir
+
+
+class ModuleCall:
+
+  def __init__(self,
+               method: str,
+               inputs: Tuple[Any],
+               outputs: Tuple[Any],
+               serialized_inputs: Tuple[str],
+               serialized_outputs: Tuple[str],
+               rtol: float = 1e-6,
+               atol: float = 1e-6):
+    """Records the details of a call to a CompiledModule."""
+    self.method = method
+
+    # Deepcopy to safegard against mutation.
+    self.inputs = copy.deepcopy(inputs)
+    if outputs is not None:
+      outputs = copy.deepcopy(outputs)
+    else:
+      outputs = tuple()
+    self.outputs = outputs if isinstance(outputs, tuple) else (outputs,)
+
+    self.serialized_inputs = serialized_inputs
+    self.serialized_outputs = serialized_outputs
+
+    self.rtol = rtol
+    self.atol = atol
+
+  def get_tolerances(self) -> Tuple[float, float]:
+    """Gets the floating point tolerances associated with this call."""
+    return self.rtol, self.atol
+
+  def _get_shape_and_dtype(self, value: Any) -> str:
+    if isinstance(value, np.ndarray):
+      return tf_utils.get_shape_and_dtype(value, allow_non_mlir_dtype=True)
+    else:
+      return str(type(value))
+
+  def __str__(self):
+    prior_printoptions = np.get_printoptions()
+    np.set_printoptions(linewidth=NUMPY_LINEWIDTH)
+
+    header = f"Method: {self.method}"
+    inputs = "\n".join(
+        [textwrap.indent(str(value), INDENT) for value in self.inputs])
+    input_shapes = ", ".join(
+        self._get_shape_and_dtype(value) for value in self.inputs)
+
+    outputs = "\n".join(
+        [textwrap.indent(str(value), INDENT) for value in self.outputs])
+    output_shapes = ", ".join(
+        self._get_shape_and_dtype(value) for value in self.outputs)
+
+    tolerances = textwrap.indent(f"rtol={self.rtol}, atol={self.atol}", INDENT)
+    body = (f"Inputs: {input_shapes}\n{inputs}\n"
+            f"Outputs: {output_shapes}\n{outputs}"
+            f"\nTolerances:\n{tolerances}")
+    result = f"{header}\n{textwrap.indent(body, INDENT)}"
+
+    np.set_printoptions(**prior_printoptions)
+    return result
+
+  def serialize(self, call_dir: str) -> None:
+    """Stores a serialized copy of this call.
+
+    Can be loaded via ModuleCall.load(call_dir)
+
+    Args:
+      call_dir: str, the path to the directory to serialize this call to.
+    """
+    os.makedirs(call_dir, exist_ok=True)
+
+    metadata = {
+        "method": self.method,
+        "serialized_inputs": self.serialized_inputs,
+        "serialized_outputs": self.serialized_outputs,
+        "rtol": self.rtol,
+        "atol": self.atol
+    }
+    with open(os.path.join(call_dir, "metadata.pkl"), "wb") as f:
+      pickle.dump(metadata, f)
+
+    width = _zfill_width(len(self.inputs))
+    for i, value in enumerate(self.inputs):
+      path = os.path.join(call_dir, f"input_{str(i).zfill(width)}.pkl")
+      with open(path, "wb") as f:
+        pickle.dump(value, f)
+
+    width = _zfill_width(len(self.outputs))
+    for i, value in enumerate(self.outputs):
+      path = os.path.join(call_dir, f"output_{str(i).zfill(width)}.pkl")
+      with open(path, "wb") as f:
+        pickle.dump(value, f)
+
+  @staticmethod
+  def load(call_dir: str) -> "ModuleCall":
+    """Loads and returns a trace serialized with ModuleCall.serialize."""
+    with open(os.path.join(call_dir, "metadata.pkl"), "rb") as f:
+      kwargs = pickle.load(f)
+
+    for result_type in ["input", "output"]:
+      key = f"{result_type}s"  # inputs or outputs
+      kwargs[key] = []
+
+      files = glob.glob(os.path.join(call_dir, f"{result_type}_*.pkl"))
+      for filename in sorted(files):
+        with open(filename, "rb") as f:
+          kwargs[key].append(pickle.load(f))
+
+      # Convert to tuple to match python's return type for multiple results.
+      kwargs[key] = tuple(kwargs[key])
+
+    return ModuleCall(**kwargs)
+
+
+class Trace:
+  """Stores the inputs and outputs of a series of calls to a module."""
+
+  def __init__(self,
+               module: Union[module_utils.CompiledModule, None],
+               function: Union[Callable[["TracedModule"], None], None],
+               _load_dict: Dict[str, Any] = None):
+    """Extracts metadata from module and function and initializes.
+
+    Example usage:
+      def forward_pass(...):
+        ...
+      module = IreeCompiledModule(...)
+      trace = Trace(module, forward_pass)
+      forward_pass(TracedModule(module, trace))
+
+    Args:
+      module: the module who's outputs this trace will record.
+      function: the function that module will be traced on.
+      _load_dict: used internally
+    """
+    if _load_dict is None:
+      # Extract metadata from module and function.
+      self.module_name = module.module_name
+      self.compiled_paths = module.compiled_paths
+      self.backend_name = module.backend_info.backend_name
+      self.backend_id = module.backend_info.backend_id
+      self.backend_driver = module.backend_info.driver
+      self.iree_serializable = module.iree_serializable()
+      self.tflite_serializable = module.tflite_serializable()
+      self.function_name = function.__name__
+      self.function_sourcefile = inspect.getsourcefile(function)
+      source, start_line = inspect.getsourcelines(function)
+      self.function_line_numbers = (start_line, start_line + len(source))
+      self.function_source = "".join(source)
+
+      self.calls = []
+    else:
+      self.module_name = _load_dict["module_name"]
+      self.compiled_paths = _load_dict["compiled_paths"]
+      self.backend_name = _load_dict["backend_name"]
+      self.backend_id = _load_dict["backend_id"]
+      self.backend_driver = _load_dict["backend_driver"]
+      self.iree_serializable = _load_dict["iree_serializable"]
+      self.tflite_serializable = _load_dict["tflite_serializable"]
+      self.function_name = _load_dict["function_name"]
+      self.function_sourcefile = _load_dict["function_sourcefile"]
+      self.function_line_numbers = _load_dict["function_line_numbers"]
+      self.function_source = _load_dict["function_source"]
+      self.calls = _load_dict["calls"]
+
+  def __str__(self):
+    header = (f"Trace of {self.module_name} compiled to '{self.backend_id}' "
+              f"on function '{self.function_name}':")
+    # Give each call a number so it's easier to compare between multiple traces.
+    calls = [f"{i + 1}. {str(call)}" for i, call in enumerate(self.calls)]
+    calls = textwrap.indent("\n".join(calls), prefix=INDENT)
+    return f"{header}\n{calls}"
+
+  def __iter__(self):
+    for call in self.calls:
+      yield call
+
+  def save_plaintext(self, trace_dir: str, summarize: bool = True) -> None:
+    """Saves a human-readable string representation of this trace to disk.
+
+    Args:
+      trace_dir: str, path to the directory to save the trace in.
+      summarize: a bool controlling whether numpy should summarize the inputs
+        and outputs if they're large. Setting this to False is very slow for
+        large outputs.
+    """
+    prior_printoptions = np.get_printoptions()
+    np.set_printoptions(
+        linewidth=NUMPY_LINEWIDTH,
+        threshold=None if summarize else sys.maxsize,
+        edgeitems=10)  # Can show more items since they won't clutter the logs.
+
+    path = os.path.join(trace_dir, "log.txt")
+    with open(path, "w") as f:
+      f.write(str(self))
+      f.write("\n")
+
+    np.set_printoptions(**prior_printoptions)
+
+  def serialize(self, trace_dir: str) -> None:
+    """Stores a serialized copy of this trace in trace_dir.
+
+    It can be loaded via `Trace.load(trace_dir)`.
+
+    Args:
+      trace_dir: str, path to the directory to serialize the trace to.
+    """
+
+    compiled_paths = None
+    if self.compiled_paths is not None:
+      # Convert to a dict to avoid the issues with serializing defaultdicts.
+      compiled_paths = dict(self.compiled_paths)
+
+    # Python serialization.
+    metadata = {
+        "module_name": self.module_name,
+        "compiled_paths": compiled_paths,
+        "backend_name": self.backend_name,
+        "backend_id": self.backend_id,
+        "backend_driver": self.backend_driver,
+        "iree_serializable": self.iree_serializable,
+        "tflite_serializable": self.tflite_serializable,
+        "function_name": self.function_name,
+        "function_sourcefile": self.function_sourcefile,
+        "function_line_numbers": self.function_line_numbers,
+        "function_source": self.function_source
+    }
+    with open(os.path.join(trace_dir, "metadata.pkl"), "wb") as f:
+      pickle.dump(metadata, f)
+
+    width = _zfill_width(len(self.calls))
+    for i, call in enumerate(self.calls):
+      call_dir = os.path.join(trace_dir, f"call_{str(i).zfill(width)}")
+      call.serialize(call_dir)
+
+    # C++ benchmark serialization.
+    if self.iree_serializable or self.tflite_serializable:
+      entry_function = self.calls[0].method
+      compiled_path = self.compiled_paths[entry_function]
+
+      if self.iree_serializable:
+        serialized_inputs = ", ".join(self.calls[0].serialized_inputs)
+        flagfile = [
+            f"--module_file={compiled_path}",
+            f"--driver={self.backend_driver}",
+            f"--function_inputs={serialized_inputs}",
+            f"--entry_function={entry_function}",
+        ]
+        with open(os.path.join(trace_dir, "flagfile"), "w") as f:
+          f.writelines(line + "\n" for line in flagfile)
+      else:
+        with open(os.path.join(trace_dir, "graph_path"), "w") as f:
+          f.writelines(compiled_path + "\n")
+
+  @staticmethod
+  def load(trace_dir: str) -> "Trace":
+    """Loads and returns a trace serialized with Trace.serialize.
+
+    Args:
+      trace_dir: str, path to the directory of the serialized trace.
+
+    Returns:
+      A Trace deserialized from trace_dir.
+    """
+    with open(os.path.join(trace_dir, "metadata.pkl"), "rb") as f:
+      load_dict = pickle.load(f)
+    call_dirs = sorted(glob.glob(os.path.join(trace_dir, "call_*")))
+    calls = [ModuleCall.load(call_dir) for call_dir in call_dirs]
+    load_dict["calls"] = calls
+    return Trace(module=None, function=None, _load_dict=load_dict)
+
+
+class TracedModule:
+
+  def __init__(self, module: module_utils.CompiledModule, trace: Trace):
+    """Wraps a CompiledModule so that all inputs and outputs are traced.
+
+    The TracedModule returned will have an API almost identical to that of the
+    passed CompiledModule. The only changes is that if the keywords `rtol` or
+    `atol` are passed to one of the CompiledModule's methods, then they will be
+    used to set the tolerance for comparing that call to the same call in
+    another trace. So for example, calling `traced_module.add(a, b rtol=1e-8)`
+    would be the same as calling `module.add(a, b)`.
+
+    Args:
+      module: the CompiledModule to trace.
+      trace: the Trace to record calls to this module with.
+    """
+    self._module = module
+    self._trace = trace
+
+  def _trace_call(self, method: module_utils._FunctionWrapper,
+                  method_name: str):
+    """Decorates a CompiledModule method to capture its inputs and outputs."""
+
+    def call(*args, **kwargs):
+      # Pop manually specified tolerances from the kwargs (if any).
+      tolerances = {}
+      tolerances["rtol"] = kwargs.pop("rtol", None)
+      tolerances["atol"] = kwargs.pop("atol", None)
+      # Only pass these to ModuleCall if they were specified by the user.
+      tolerances = {k: v for k, v in tolerances.items() if v is not None}
+
+      # Ensure the inputs are numpy inputs.
+      args = tf_utils.convert_to_numpy(args)
+      kwargs = tf_utils.convert_to_numpy(kwargs)
+
+      # Run the method and record the details of the call.
+      outputs = method(*args, **kwargs)
+      serialized_inputs, serialized_outputs = method.get_serialized_values()
+      self._trace.calls.append(
+          ModuleCall(method_name, args, outputs, serialized_inputs,
+                     serialized_outputs, **tolerances))
+      return outputs
+
+    return call
+
+  def __getattr__(self, attr):
+    # Try to resolve it as an attr on self._module.
+    if not hasattr(self._module, attr):
+      raise AttributeError(f"The compiled module does not have attr '{attr}'")
+    module_attr = getattr(self._module, attr)
+    if not hasattr(module_attr, "__call__"):
+      # e.g. traced_module.backend
+      return module_attr
+    else:
+      # e.g. traced_module.simple_mul(a, b)
+      return self._trace_call(module_attr, method_name=attr)
+
+
+def compare_traces(ref_trace: Trace,
+                   tar_trace: Trace) -> Tuple[bool, Sequence[str]]:
+  traces_match = True
+  error_messages = []
+
+  # Check that all method invocations match.
+  ref_methods = [(call.method, call.rtol, call.atol) for call in ref_trace]
+  tar_methods = [(call.method, call.rtol, call.atol) for call in tar_trace]
+  if ref_methods != tar_methods:
+    # Raise a ValueError instead of returning False since this is an
+    # unexpected error.
+    raise ValueError(
+        "The reference and target traces have different call structures:\n"
+        f"Reference: {ref_methods}\nTarget:    {tar_methods}")
+
+  for ref_call, tar_call in zip(ref_trace, tar_trace):
+    logging.info("Comparing calls to '%s'", ref_call.method)
+    rtol, atol = ref_call.get_tolerances()
+
+    inputs_match, error_message = tf_utils.check_same(ref_call.inputs,
+                                                      tar_call.inputs, rtol,
+                                                      atol)
+    if not inputs_match:
+      error_messages.append(error_message)
+      logging.error("Inputs did not match.")
+    outputs_match, error_message = tf_utils.check_same(ref_call.outputs,
+                                                       tar_call.outputs, rtol,
+                                                       atol)
+    if not outputs_match:
+      error_messages.append(error_message)
+      logging.error("Outputs did not match.")
+    calls_match = inputs_match and outputs_match
+
+    if not calls_match:
+      logging.error("Comparision between '%s' and '%s' failed on method '%s'",
+                    ref_trace.backend_id, tar_trace.backend_id, ref_call.method)
+      logging.error("Reference call '%s':\n%s", ref_trace.backend_id, ref_call)
+      logging.error("Target call '%s':\n%s", tar_trace.backend_id, tar_call)
+
+    traces_match = traces_match and calls_match
+  return traces_match, error_messages
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/trace_utils_test.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/trace_utils_test.py
new file mode 100644
index 0000000..58315c8
--- /dev/null
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/trace_utils_test.py
@@ -0,0 +1,156 @@
+# Lint as: python3
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tests for pyiree.tf.support.trace_utils."""
+
+import os
+import tempfile
+
+from absl.testing import parameterized
+import numpy as np
+from pyiree.tf.support import module_utils
+from pyiree.tf.support import trace_utils
+import tensorflow as tf
+
+
+class StatefulCountingModule(tf.Module):
+
+  def __init__(self):
+    self.count = tf.Variable([0.])
+
+  @tf.function(input_signature=[])
+  def increment(self):
+    self.count.assign_add(tf.constant([1.]))
+
+  @tf.function(input_signature=[])
+  def get_count(self):
+    return self.count
+
+  @tf.function(input_signature=[tf.TensorSpec([1])])
+  def increment_by(self, value):
+    self.count.assign_add(value)
+
+  @tf.function(input_signature=[tf.TensorSpec([1]), tf.TensorSpec([1])])
+  def increment_by_max(self, a, b):
+    result = tf.maximum(a, b)
+    self.count.assign_add(result)
+    return result
+
+  @tf.function(input_signature=[])
+  def decrement(self):
+    self.count.assign_sub(tf.constant([1.]))
+
+
+class TestUtilsTests(tf.test.TestCase, parameterized.TestCase):
+
+  def test_trace_inputs_and_outputs(self):
+
+    def trace_function(module):
+      # No inputs or outputs
+      module.increment()
+      # Only inputs
+      module.increment_by(np.array([81.], dtype=np.float32))
+      # Only outputs
+      module.get_count()
+
+    module = module_utils.TfCompiledModule.create_from_class(
+        StatefulCountingModule, module_utils.BackendInfo('tf'))
+    trace = trace_utils.Trace(module, trace_function)
+    trace_function(trace_utils.TracedModule(module, trace))
+
+    self.assertIsInstance(trace.calls[0].inputs, tuple)
+    self.assertEmpty(trace.calls[0].inputs)
+    self.assertIsInstance(trace.calls[0].outputs, tuple)
+    self.assertEmpty(trace.calls[0].outputs)
+
+    self.assertAllClose(trace.calls[1].inputs[0], [81.])
+    self.assertAllClose(trace.calls[2].outputs[0], [82.])
+
+  def test_nonmatching_methods(self):
+
+    def tf_function(module):
+      module.increment()
+      module.increment()
+
+    def vmla_function(module):
+      module.increment()
+      module.decrement()
+
+    tf_module = module_utils.TfCompiledModule.create_from_class(
+        StatefulCountingModule, module_utils.BackendInfo('tf'))
+    tf_trace = trace_utils.Trace(tf_module, tf_function)
+    tf_function(trace_utils.TracedModule(tf_module, tf_trace))
+
+    vmla_module = module_utils.IreeCompiledModule.create_from_class(
+        StatefulCountingModule, module_utils.BackendInfo('iree_vmla'))
+    vmla_trace = trace_utils.Trace(vmla_module, vmla_function)
+    vmla_function(trace_utils.TracedModule(vmla_module, vmla_trace))
+
+    with self.assertRaises(ValueError):
+      trace_utils.compare_traces(tf_trace, vmla_trace)
+
+  def test_nonmatching_inputs(self):
+
+    def tf_function(module):
+      module.increment_by(np.array([42.], dtype=np.float32))
+
+    def vmla_function(module):
+      module.increment_by(np.array([22.], dtype=np.float32))
+
+    tf_module = module_utils.TfCompiledModule.create_from_class(
+        StatefulCountingModule, module_utils.BackendInfo('tf'))
+    tf_trace = trace_utils.Trace(tf_module, tf_function)
+    tf_function(trace_utils.TracedModule(tf_module, tf_trace))
+
+    vmla_module = module_utils.IreeCompiledModule.create_from_class(
+        StatefulCountingModule, module_utils.BackendInfo('iree_vmla'))
+    vmla_trace = trace_utils.Trace(vmla_module, vmla_function)
+    vmla_function(trace_utils.TracedModule(vmla_module, vmla_trace))
+
+    same, error_messages = trace_utils.compare_traces(tf_trace, vmla_trace)
+    self.assertFalse(same)
+
+  def test_trace_serialize_and_load(self):
+
+    def trace_function(module):
+      module.increment()
+      module.increment_by(np.array([81.], dtype=np.float32))
+      module.increment_by_max(np.array([81], dtype=np.float32),
+                              np.array([92], dtype=np.float32))
+      module.get_count()
+
+    module = module_utils.IreeCompiledModule.create_from_class(
+        StatefulCountingModule, module_utils.BackendInfo('iree_vmla'))
+    trace = trace_utils.Trace(module, trace_function)
+    trace_function(trace_utils.TracedModule(module, trace))
+
+    with tempfile.TemporaryDirectory() as artifacts_dir:
+      trace_function_dir = trace_utils.get_trace_dir(artifacts_dir, trace)
+      trace.serialize(trace_function_dir)
+      self.assertTrue(
+          os.path.exists(os.path.join(trace_function_dir, 'metadata.pkl')))
+      loaded_trace = trace_utils.Trace.load(trace_function_dir)
+
+      # Check all calls match.
+      self.assertTrue(trace_utils.compare_traces(trace, loaded_trace))
+
+      # Check all other metadata match.
+      self.assertAllEqual(trace.__dict__.keys(), loaded_trace.__dict__.keys())
+      for key in trace.__dict__.keys():
+        if key != 'calls':
+          self.assertEqual(trace.__dict__[key], loaded_trace.__dict__[key])
+
+
+if __name__ == '__main__':
+  tf.test.main()
diff --git a/integrations/tensorflow/e2e/README.md b/integrations/tensorflow/e2e/README.md
index e806496..ccc2f63 100644
--- a/integrations/tensorflow/e2e/README.md
+++ b/integrations/tensorflow/e2e/README.md
@@ -32,16 +32,16 @@
 specified directory. These artifacts include MLIR across various lowerings and
 the compiled VM FlatBuffer. A basic example of creating and calling an
 `IreeCompiledModule` can be found in
-[`tf_utils_test.py`](https://github.com/google/iree/blob/main/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils_test.py)
+[`module_utils_test.py`](https://github.com/google/iree/blob/main/integrations/tensorflow/bindings/python/pyiree/tf/support/module_utils_test.py)
 
 When using Keras models or tf.Modules with functions that IREE can't compile,
 `exported_names` should be specified. For example:
 
 ```python
-from pyiree.tf.support import tf_utils
-vmla_module = tf_utils.IreeCompiledModule(
+from pyiree.tf.support import module_utils
+vmla_module = module_utils.IreeCompiledModule(
     module_class=KerasTFModuleClass,
-    backend_info=tf_utils.BackendInfo('iree_vmla'),
+    backend_info=module_utils.BackendInfo('iree_vmla'),
     exported_names=['predict'])
 vmla_module.predict(...)
 ```