Add serialization for benchmarking e2e test modules with IREE (#2895)

Abseil flagfiles containing all of the data that `iree-benchmark-module` needs to run are generated for each `Trace` in our E2E tests. This allows for any module we test to be easily benchmarked on valid inputs. The process for benchmarking a vision model can thus be reduced to the following:

```shell
# Generate benchmarking artifacts for all vision models:
bazel test integrations/tensorflow/e2e/keras:vision_external_tests

# Benchmark ResNet50 with cifar10 weights on vmla:
bazel run iree/tools:iree-benchmark-module -- \
  --flagfile=/tmp/iree/modules/ResNet50/cifar10/iree_vmla/traces/predict/flagfile 

# Benchmark ResNet50 with cifar10 weights on llvmjit:
bazel run iree/tools:iree-benchmark-module -- \
  --flagfile=/tmp/iree/modules/ResNet50/cifar10/iree_llvmjit/traces/predict/flagfile 
```

Duplicate flags provided after the flagfile will take precedence. For example:

```shell
bazel run iree/tools:iree-benchmark-module -- \
  --flagfile=/tmp/iree/modules/ResNet50/cifar10/iree_llvmjit/traces/predict/flagfile  \
  --input_file=/path/to/custom/compiled.vmfb
```

Currently, this only supports benchmarking the first module call in a trace. We plan to extend this to support benchmarking all of the calls in the trace, and also plan to support verifying outputs during the warm-up phase of the benchmark.
diff --git a/bindings/python/pyiree/rt/function_abi.cc b/bindings/python/pyiree/rt/function_abi.cc
index 7326e6b..7b0d7f1 100644
--- a/bindings/python/pyiree/rt/function_abi.cc
+++ b/bindings/python/pyiree/rt/function_abi.cc
@@ -494,6 +494,45 @@
                  "Error moving buffer view");
 }
 
+std::vector<std::string> SerializeVmVariantList(VmVariantList& vm_list) {
+  size_t size = vm_list.size();
+  std::vector<std::string> results;
+  results.reserve(size);
+  for (iree_host_size_t i = 0; i < size; ++i) {
+    iree_vm_variant_t variant = iree_vm_variant_empty();
+    iree_status_t status =
+        iree_vm_list_get_variant(vm_list.raw_ptr(), i, &variant);
+    CheckApiStatus(status, "Failed to get vm variant from list");
+
+    if (iree_vm_variant_is_value(variant)) {
+      results.push_back("i32=" + std::to_string(variant.i32));
+    } else if (iree_vm_variant_is_ref(variant) &&
+               iree_hal_buffer_view_isa(&variant.ref)) {
+      auto buffer_view = iree_hal_buffer_view_deref(&variant.ref);
+
+      std::string result_str(4096, '\0');
+      iree_status_t status;
+      do {
+        iree_host_size_t actual_length = 0;
+        iree_host_size_t max_element_count =
+            std::numeric_limits<iree_host_size_t>::max();
+        status = iree_hal_buffer_view_format(buffer_view, max_element_count,
+                                             result_str.size() + 1,
+                                             &result_str[0], &actual_length);
+        result_str.resize(actual_length);
+      } while (iree_status_is_out_of_range(status));
+      CheckApiStatus(status,
+                     "Failed to create a string representation of the inputs");
+
+      results.push_back(result_str);
+    } else {
+      RaiseValueError(
+          "Expected vm_list's elements to be scalars or buffer views.");
+    }
+  }
+  return results;
+}
+
 void SetupFunctionAbiBindings(pybind11::module m) {
   py::class_<FunctionAbi, std::unique_ptr<FunctionAbi>>(m, "FunctionAbi")
       .def(py::init(&PyCreateAbi))
@@ -506,6 +545,10 @@
                               absl::MakeConstSpan(self->raw_config().inputs),
                               py_args, false /* writable */);
            })
+      .def("serialize_vm_list",
+           [](FunctionAbi* self, VmVariantList& vm_list) {
+             return SerializeVmVariantList(vm_list);
+           })
       .def("allocate_results", &PyAllocateResults, py::arg("f_results"),
            py::arg("static_alloc") = true)
       .def("raw_unpack_results", &PyRawUnpackResults);
diff --git a/bindings/python/pyiree/rt/system_api.py b/bindings/python/pyiree/rt/system_api.py
index 327883bc..17ac317 100644
--- a/bindings/python/pyiree/rt/system_api.py
+++ b/bindings/python/pyiree/rt/system_api.py
@@ -58,18 +58,16 @@
   driver_exceptions = {}
   for driver_name in driver_names:
     if driver_name not in available_driver_names:
-      print(
-          "Could not create driver %s (not registered)" % driver_name,
-          file=sys.stderr)
+      print("Could not create driver %s (not registered)" % driver_name,
+            file=sys.stderr)
       continue
     try:
       driver = _binding.HalDriver.create(driver_name)
       # TODO(laurenzo): Remove these prints to stderr (for now, more information
       # is better and there is no better way to report it yet).
     except Exception as ex:  # pylint: disable=broad-except
-      print(
-          "Could not create default driver %s: %r" % (driver_name, ex),
-          file=sys.stderr)
+      print("Could not create default driver %s: %r" % (driver_name, ex),
+            file=sys.stderr)
       driver_exceptions[driver_name] = ex
       continue
 
@@ -80,9 +78,8 @@
     try:
       device = driver.create_default_device()
     except Exception as ex:
-      print(
-          "Could not create default driver device %s: %r" % (driver_name, ex),
-          file=sys.stderr)
+      print("Could not create default driver device %s: %r" % (driver_name, ex),
+            file=sys.stderr)
       driver_exceptions[driver_name] = ex
       continue
 
@@ -134,14 +131,18 @@
     self._context = context
     self._vm_function = vm_function
     self._abi = context.create_function_abi(vm_function)
+    self._serialized_inputs = None
+    self._serialized_outputs = None
 
   def __call__(self, *args):
     # NOTE: This is just doing sync dispatch right now. In the future,
     # this should default to async and potentially have some kind of policy
     # flag that can allow it to be overridden.
     inputs = self._abi.raw_pack_inputs(args)
+    self._serialized_inputs = tuple(self._abi.serialize_vm_list(inputs))
     results = self._abi.allocate_results(inputs, static_alloc=False)
     self._context._vm_context.invoke(self._vm_function, inputs, results)
+    self._serialized_outputs = tuple(self._abi.serialize_vm_list(results))
     unpacked_results = self._abi.raw_unpack_results(results)
     # TODO(laurenzo): When switching from 'raw' to structured pack/unpack,
     # the ABI should take care of this one-arg special case.
@@ -158,6 +159,12 @@
         self._vm_function,
     )
 
+  def get_serialized_values(self):
+    if self._serialized_inputs is None:
+      raise RuntimeError("Attempted to call get_serialized_values() before "
+                         "any values were passed.")
+    return self._serialized_inputs, self._serialized_outputs
+
 
 class BoundModule:
   """Wraps a VmModule with its context and provides nice python accessors.
@@ -219,8 +226,8 @@
     else:
       init_modules = None
 
-    self._vm_context = _binding.VmContext(
-        instance=self._config.vm_instance, modules=init_modules)
+    self._vm_context = _binding.VmContext(instance=self._config.vm_instance,
+                                          modules=init_modules)
 
     if self._is_dynamic:
       self._vm_context.register_modules(self._config.default_modules)
diff --git a/bindings/python/pyiree/rt/system_api_test.py b/bindings/python/pyiree/rt/system_api_test.py
index ca47439..a27f3b0 100644
--- a/bindings/python/pyiree/rt/system_api_test.py
+++ b/bindings/python/pyiree/rt/system_api_test.py
@@ -93,6 +93,19 @@
     results = f(arg0, arg1)
     np.testing.assert_allclose(results, [4., 10., 18., 28.])
 
+  def test_serialize_values(self):
+    ctx = rt.SystemContext()
+    self.assertTrue(ctx.is_dynamic)
+    ctx.add_module(create_simple_mul_module())
+    self.assertEqual(ctx.modules.arithmetic.name, "arithmetic")
+    f = ctx.modules.arithmetic["simple_mul"]
+    arg0 = np.array([1., 2., 3., 4.], dtype=np.float32)
+    arg1 = np.array([4., 5., 6., 7.], dtype=np.float32)
+    results = f(arg0, arg1)
+    inputs, outputs = f.get_serialized_values()
+    self.assertEqual(inputs, ("4xf32=1 2 3 4", "4xf32=4 5 6 7"))
+    self.assertEqual(outputs, ("4xf32=4 10 18 28",))
+
   def test_load_module(self):
     arithmetic = rt.load_module(create_simple_mul_module())
     arg0 = np.array([1., 2., 3., 4.], dtype=np.float32)
diff --git a/bindings/python/pyiree/rt/vm_test.py b/bindings/python/pyiree/rt/vm_test.py
index da05a6c..e88b0f3 100644
--- a/bindings/python/pyiree/rt/vm_test.py
+++ b/bindings/python/pyiree/rt/vm_test.py
@@ -126,6 +126,7 @@
     logging.info("abi: %s", abi)
 
     inputs = abi.raw_pack_inputs((5, 6))
+    logging.info("serialize_inputs: %s", abi.serialize_vm_list(inputs))
     logging.info("inputs: %s", inputs)
 
     allocated_results = abi.allocate_results(inputs, static_alloc=False)
@@ -148,6 +149,7 @@
 
     arg0 = np.array([[-1., 2.], [3., -4.]], dtype=np.float32)
     inputs = abi.raw_pack_inputs((arg0,))
+    logging.info("Serialized inputs: %s", abi.serialize_vm_list(inputs))
     logging.info("inputs: %s", inputs)
 
     allocated_results = abi.allocate_results(inputs, static_alloc=False)
@@ -171,6 +173,7 @@
     arg0 = np.array([1., 2., 3., 4.], dtype=np.float32)
     arg1 = np.array([4., 5., 6., 7.], dtype=np.float32)
     inputs = abi.raw_pack_inputs((arg0, arg1))
+    logging.info("Serialized inputs: %s", abi.serialize_vm_list(inputs))
     logging.info("inputs: %s", inputs)
 
     allocated_results = abi.allocate_results(inputs, static_alloc=False)
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
index 550324c..f59ff16 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
@@ -124,6 +124,8 @@
                method: str,
                inputs: Tuple[Any],
                outputs: Tuple[Any],
+               serialized_inputs: Tuple[str],
+               serialized_outputs: Tuple[str],
                rtol: float = 1e-6,
                atol: float = 1e-6):
     """Records the details of a call to a CompiledModule."""
@@ -142,6 +144,9 @@
       outputs = tuple()
     self.outputs = outputs if isinstance(outputs, tuple) else (outputs,)
 
+    self.serialized_inputs = serialized_inputs
+    self.serialized_outputs = serialized_outputs
+
     self.rtol = rtol
     self.atol = atol
 
@@ -187,7 +192,13 @@
     """
     os.makedirs(call_dir, exist_ok=True)
 
-    metadata = {"method": self.method, "rtol": self.rtol, "atol": self.atol}
+    metadata = {
+        "method": self.method,
+        "serialized_inputs": self.serialized_inputs,
+        "serialized_outputs": self.serialized_outputs,
+        "rtol": self.rtol,
+        "atol": self.atol
+    }
     with open(os.path.join(call_dir, "metadata.pkl"), "wb") as f:
       pickle.dump(metadata, f)
 
@@ -248,7 +259,10 @@
     if _load_dict is None:
       # Extract metadata from module and function.
       self.module_name = module.module_name
-      self.backend = module.backend
+      self.compiled_path = module.compiled_path
+      self.backend_name = module.backend
+      self.supports_cxx_serialization = module.supports_cxx_serialization()
+      self.backend_driver = module.backend_driver
       self.function_name = function.__name__
       self.function_sourcefile = inspect.getsourcefile(function)
       source, start_line = inspect.getsourcelines(function)
@@ -258,7 +272,10 @@
       self.calls = []
     else:
       self.module_name = _load_dict["module_name"]
-      self.backend = _load_dict["backend"]
+      self.compiled_path = _load_dict["compiled_path"]
+      self.backend_name = _load_dict["backend_name"]
+      self.supports_cxx_serialization = _load_dict["supports_cxx_serialization"]
+      self.backend_driver = _load_dict["backend_driver"]
       self.function_name = _load_dict["function_name"]
       self.function_sourcefile = _load_dict["function_sourcefile"]
       self.function_line_numbers = _load_dict["function_line_numbers"]
@@ -266,7 +283,7 @@
       self.calls = _load_dict["calls"]
 
   def __str__(self):
-    header = (f"Trace of {self.module_name} compiled to '{self.backend}' "
+    header = (f"Trace of {self.module_name} compiled to '{self.backend_name}' "
               f"on function '{self.function_name}':")
     # Give each call a number so it's easier to compare between multiple traces.
     calls = [f"{i + 1}. {str(call)}" for i, call in enumerate(self.calls)]
@@ -307,9 +324,11 @@
 
       if not calls_match:
         logging.error("Comparision between '%s' and '%s' failed on method '%s'",
-                      ref_trace.backend, tar_trace.backend, ref_call.method)
-        logging.error("Reference call '%s':\n%s", ref_trace.backend, ref_call)
-        logging.error("Target call '%s':\n%s", tar_trace.backend, tar_call)
+                      ref_trace.backend_name, tar_trace.backend_name,
+                      ref_call.method)
+        logging.error("Reference call '%s':\n%s", ref_trace.backend_name,
+                      ref_call)
+        logging.error("Target call '%s':\n%s", tar_trace.backend_name, tar_call)
 
       traces_match = traces_match and calls_match
     return traces_match
@@ -411,9 +430,14 @@
     Args:
       trace_dir: str, path to the directory to serialize the trace to.
     """
+
+    # Python serialization.
     metadata = {
         "module_name": self.module_name,
-        "backend": self.backend,
+        "compiled_path": self.compiled_path,
+        "backend_name": self.backend_name,
+        "supports_cxx_serialization": self.supports_cxx_serialization,
+        "backend_driver": self.backend_driver,
         "function_name": self.function_name,
         "function_sourcefile": self.function_sourcefile,
         "function_line_numbers": self.function_line_numbers,
@@ -427,6 +451,19 @@
       call_dir = os.path.join(trace_dir, f"call_{str(i).zfill(width)}")
       call.serialize(call_dir)
 
+    # C++ Serialization.
+    if not self.supports_cxx_serialization:
+      flaglines = []
+      if self.compiled_path is not None:
+        flaglines.append(f"--input_file={self.compiled_path}")
+      flaglines.append(f"--driver={self.backend_driver}")
+      inputs_str = ", ".join(self.calls[0].serialized_inputs)
+      flaglines.append(f"--inputs={inputs_str}")
+      flaglines.append(f"--entry_function={self.calls[0].method}")
+
+      with open(os.path.join(trace_dir, "flagfile"), "w") as f:
+        f.writelines(line + '\n' for line in flaglines)
+
   @staticmethod
   def load(trace_dir: str) -> "Trace":
     """Loads and returns a trace serialized with Trace.serialize.
@@ -446,7 +483,7 @@
 
 
 def _get_trace_dir(artifacts_dir: str, trace: Trace) -> str:
-  trace_dir = os.path.join(artifacts_dir, trace.backend, "traces",
+  trace_dir = os.path.join(artifacts_dir, trace.backend_name, "traces",
                            trace.function_name)
   os.makedirs(trace_dir, exist_ok=True)
   return trace_dir
@@ -484,8 +521,10 @@
 
       # Run the method and record the details of the call.
       outputs = method(*args, **kwargs)
+      serialized_inputs, serialized_outputs = method.get_serialized_values()
       self._trace.calls.append(
-          ModuleCall(method_name, args, outputs, **tolerances))
+          ModuleCall(method_name, args, outputs, serialized_inputs,
+                     serialized_outputs, **tolerances))
       return outputs
 
     return call
@@ -621,7 +660,7 @@
     failed_backend_indices = []
     for i, tar_trace in enumerate(tar_traces):
       logging.info("Comparing the reference backend '%s' with '%s'",
-                   ref_trace.backend, tar_trace.backend)
+                   ref_trace.backend_name, tar_trace.backend_name)
       traces_match = Trace.compare_traces(ref_trace, tar_trace)
       if not traces_match:
         failed_backend_indices.append(i)
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils_test.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils_test.py
index 5bffd59..dd054a5 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils_test.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils_test.py
@@ -182,8 +182,8 @@
           np.array([81], dtype=np.float32), np.array([92], dtype=np.float32))
       module.get_count()
 
-    module = tf_utils.TfCompiledModule(StatefulCountingModule,
-                                       tf_utils.BackendInfo('tf'))
+    module = tf_utils.IreeCompiledModule(StatefulCountingModule,
+                                         tf_utils.BackendInfo('iree_vmla'))
     trace = tf_test_utils.Trace(module, trace_function)
     trace_function(tf_test_utils.TracedModule(module, trace))
 
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py
index 98ce726..faf751e 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py
@@ -20,7 +20,7 @@
 import random
 import re
 import tempfile
-from typing import Any, Callable, Sequence, Type
+from typing import Any, Callable, Dict, Sequence, Tuple, Type, Union
 
 from absl import flags
 from absl import logging
@@ -114,10 +114,12 @@
     return os.path.join(artifacts_dir, f"{artifact_name}__{backends_string}")
 
 
-def compile_tf_module(tf_module: Type[tf.Module],
-                      backend_infos: Sequence["BackendInfo"] = (),
-                      exported_names: Sequence[str] = (),
-                      artifacts_dir: str = None) -> compiler.binding.OpaqueBlob:
+def compile_tf_module(
+    tf_module: Type[tf.Module],
+    backend_infos: Sequence["BackendInfo"] = (),
+    exported_names: Sequence[str] = (),
+    artifacts_dir: str = None
+) -> Tuple[compiler.binding.OpaqueBlob, Union[str, None]]:
   """Compiles a TensorFlow tf.Module and optionally saves compilation artifacts.
 
   The artifact this creates is not callable. See IreeCompiledModule for an API
@@ -149,7 +151,8 @@
       should be saved.
 
   Returns:
-    A compiled IREE module blob.
+    A compiled IREE module blob and the path to the compiled VM FlatBuffer if
+    artifacts_dir is provided.
   """
 
   def _compile_from_path(sm_path: str) -> compiler.binding.OpaqueBlob:
@@ -189,6 +192,7 @@
         target_backends.extend(backend_info.compiler_targets)
       compiled_module = compiler_module.compile(target_backends=target_backends)
 
+      compiled_path = None
       if artifacts_dir is not None:
         compiled_path = _get_backends_path("compiled", backend_infos,
                                            artifacts_dir)
@@ -197,7 +201,7 @@
         with open(compiled_path, "wb") as f:
           f.write(compiled_module)
 
-      return compiled_module
+      return compiled_module, compiled_path
     except Exception:  # pylint: disable=broad-except
       if artifacts_dir is not None:
         # Disable the crash reproducer (to avoid inadvertently overwriting it).
@@ -233,12 +237,18 @@
 
     # Public attributes:
     self.backend = self._backend_info.name
+    self.backend_driver = self._backend_info.driver
     self.module_name = self._module_class.__name__
+    self.compiled_path = None
 
   def create_reinitialized(self):
     """Duplicates this module with its initial state without recompiling."""
     raise NotImplementedError()
 
+  @staticmethod
+  def supports_cxx_serialization():
+    raise NotImplementedError()
+
 
 class _IreeFunctionWrapper(object):
   """Wraps an IREE function, making it callable."""
@@ -250,6 +260,9 @@
   def __call__(self, *args):
     return self._f(*args)
 
+  def get_serialized_values(self) -> Tuple[Tuple[str], Tuple[str]]:
+    return self._f.get_serialized_values()
+
 
 class IreeCompiledModule(CompiledModule):
   """Iree compiled module."""
@@ -259,7 +272,7 @@
                backend_info: "BackendInfo",
                exported_names: Sequence[str] = (),
                artifacts_dir: str = None,
-               _create_reinitialized_args: Sequence[Any] = None):
+               _create_reinitialized_dict: Dict[str, Any] = None):
     """Compile a tf.Module to the target backend in backend_info.
 
     Args:
@@ -270,13 +283,13 @@
         module_class's functions to compile. If exported_names is empty all
         functions will be compiled.
       artifacts_dir: an optional path to save compilation artifacts to.
-      _create_reinitialized_args: used internally.
+      _create_reinitialized_dict: used internally.
     """
     super().__init__(module_class, backend_info, exported_names, artifacts_dir)
 
-    if _create_reinitialized_args is None:
+    if _create_reinitialized_dict is None:
       set_random_seed()
-      self._module_blob = compile_tf_module(
+      self._module_blob, self.compiled_path = compile_tf_module(
           tf_module=module_class(),
           backend_infos=[backend_info],
           exported_names=exported_names,
@@ -285,11 +298,14 @@
       self._config = rt.Config(driver_name=backend_info.driver)
     else:
       # Called from self.create_reinitialized()
-      self._module_blob, self._module, self._config = _create_reinitialized_args
+      self._module_blob = _create_reinitialized_dict["_module_blob"]
+      self._module = _create_reinitialized_dict["_module"]
+      self._config = _create_reinitialized_dict["_config"]
+      self.compiled_path = _create_reinitialized_dict["compiled_path"]
 
     # Holds all of the module's mutable state.
-    self._context = rt.SystemContext(
-        modules=[self._module], config=self._config)
+    self._context = rt.SystemContext(modules=[self._module],
+                                     config=self._config)
 
   def create_reinitialized(self) -> "IreeCompiledModule":
     """Duplicates this module with its initial state without recompiling."""
@@ -297,8 +313,13 @@
         self._module_class, self._backend_info, self._exported_names,
         self._artifacts_dir
     ]
-    create_reinitialized_args = [self._module_blob, self._module, self._config]
-    return IreeCompiledModule(*default_args, create_reinitialized_args)
+    create_reinitialized_dict = {
+        "_module_blob": self._module_blob,
+        "_module": self._module,
+        "_config": self._config,
+        "compiled_path": self.compiled_path
+    }
+    return IreeCompiledModule(*default_args, create_reinitialized_dict)
 
   def __getattr__(self, attr: str) -> _IreeFunctionWrapper:
     # Try to resolve it as a function.
@@ -306,6 +327,10 @@
     f = m[attr]
     return _IreeFunctionWrapper(self._context, f)
 
+  @staticmethod
+  def supports_cxx_serialization() -> bool:
+    return True
+
 
 class _TfFunctionWrapper(object):
   """Wraps a TF function, normalizing it to numpy."""
@@ -330,8 +355,13 @@
     # which is sad).
     if not isinstance(results, tuple):
       results = (results,)
-    return tf.nest.map_structure(
-        self._convert_to_numpy, *results, check_types=False)
+    return tf.nest.map_structure(self._convert_to_numpy,
+                                 *results,
+                                 check_types=False)
+
+  def get_serialized_values(self) -> Tuple[Tuple[str], Tuple[str]]:
+    """Dummy function to match _IreeFunctionWrapper's API."""
+    return (), ()
 
 
 class TfCompiledModule(CompiledModule):
@@ -377,6 +407,10 @@
           f"The TensorFlow module does not have a callable attr '{attr}'")
     return _TfFunctionWrapper(f)
 
+  @staticmethod
+  def supports_cxx_serialization() -> bool:
+    return False
+
 
 class BackendInfo:
   """Contains information for compiling the specified backend."""
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils_test.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils_test.py
index 4f4084e..28ac195 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils_test.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils_test.py
@@ -64,12 +64,9 @@
   def test_artifact_saving(self, backend_infos):
     with tempfile.TemporaryDirectory() as artifacts_dir:
       tf_module = ConstantModule()
-      iree_compiled_module = tf_utils.compile_tf_module(
+      iree_compiled_module, compiled_path = tf_utils.compile_tf_module(
           tf_module, backend_infos=backend_infos, artifacts_dir=artifacts_dir)
 
-      compiled_path = tf_utils._get_backends_path('compiled', backend_infos,
-                                                  artifacts_dir)
-      compiled_path = f'{compiled_path}.vmfb'
       artifacts_to_check = [
           'tf_input.mlir',
           'iree_input.mlir',
diff --git a/integrations/tensorflow/e2e/README.md b/integrations/tensorflow/e2e/README.md
index 343b155..2a4c9eb 100644
--- a/integrations/tensorflow/e2e/README.md
+++ b/integrations/tensorflow/e2e/README.md
@@ -187,6 +187,39 @@
 `SimpleArithmeticModule` example above, the `trace_dir` would be
 `/tmp/iree/modules/SimpleArithmeticModule/iree_vmla/traces/simple_mul/`.
 
+## Benchmarking E2E Modules
+
+Abseil flagfiles containing all of the data that `iree-benchmark-module` needs
+to run are generated for each `Trace` in our E2E tests. This allows for any
+module we test to be easily benchmarked on valid inputs. The process for
+benchmarking a vision model can thus be reduced to the following:
+
+```shell
+# Generate benchmarking artifacts for all vision models:
+bazel test integrations/tensorflow/e2e/keras:vision_external_tests
+
+# Benchmark ResNet50 with cifar10 weights on vmla:
+bazel run iree/tools:iree-benchmark-module -- \
+  --flagfile=/tmp/iree/modules/ResNet50/cifar10/iree_vmla/traces/predict/flagfile
+
+# Benchmark ResNet50 with cifar10 weights on llvmjit:
+bazel run iree/tools:iree-benchmark-module -- \
+  --flagfile=/tmp/iree/modules/ResNet50/cifar10/iree_llvmjit/traces/predict/flagfile
+```
+
+Duplicate flags provided after the flagfile will take precedence. For example:
+
+```shell
+bazel run iree/tools:iree-benchmark-module -- \
+  --flagfile=/tmp/iree/modules/ResNet50/cifar10/iree_llvmjit/traces/predict/flagfile  \
+  --input_file=/path/to/custom/compiled.vmfb
+```
+
+Currently, this only supports benchmarking the first module call in a trace.
+We plan to extend this to support benchmarking all of the calls in the trace,
+and also plan to support verifying outputs during the warm-up phase of the
+benchmark.
+
 ## Debugging Tests
 
 If the compiler fails to compile the program, then it will create a crash