Add serialization for benchmarking e2e test modules with IREE (#2895)
Abseil flagfiles containing all of the data that `iree-benchmark-module` needs to run are generated for each `Trace` in our E2E tests. This allows for any module we test to be easily benchmarked on valid inputs. The process for benchmarking a vision model can thus be reduced to the following:
```shell
# Generate benchmarking artifacts for all vision models:
bazel test integrations/tensorflow/e2e/keras:vision_external_tests
# Benchmark ResNet50 with cifar10 weights on vmla:
bazel run iree/tools:iree-benchmark-module -- \
--flagfile=/tmp/iree/modules/ResNet50/cifar10/iree_vmla/traces/predict/flagfile
# Benchmark ResNet50 with cifar10 weights on llvmjit:
bazel run iree/tools:iree-benchmark-module -- \
--flagfile=/tmp/iree/modules/ResNet50/cifar10/iree_llvmjit/traces/predict/flagfile
```
Duplicate flags provided after the flagfile will take precedence. For example:
```shell
bazel run iree/tools:iree-benchmark-module -- \
--flagfile=/tmp/iree/modules/ResNet50/cifar10/iree_llvmjit/traces/predict/flagfile \
--input_file=/path/to/custom/compiled.vmfb
```
Currently, this only supports benchmarking the first module call in a trace. We plan to extend this to support benchmarking all of the calls in the trace, and also plan to support verifying outputs during the warm-up phase of the benchmark.
diff --git a/bindings/python/pyiree/rt/function_abi.cc b/bindings/python/pyiree/rt/function_abi.cc
index 7326e6b..7b0d7f1 100644
--- a/bindings/python/pyiree/rt/function_abi.cc
+++ b/bindings/python/pyiree/rt/function_abi.cc
@@ -494,6 +494,45 @@
"Error moving buffer view");
}
+std::vector<std::string> SerializeVmVariantList(VmVariantList& vm_list) {
+ size_t size = vm_list.size();
+ std::vector<std::string> results;
+ results.reserve(size);
+ for (iree_host_size_t i = 0; i < size; ++i) {
+ iree_vm_variant_t variant = iree_vm_variant_empty();
+ iree_status_t status =
+ iree_vm_list_get_variant(vm_list.raw_ptr(), i, &variant);
+ CheckApiStatus(status, "Failed to get vm variant from list");
+
+ if (iree_vm_variant_is_value(variant)) {
+ results.push_back("i32=" + std::to_string(variant.i32));
+ } else if (iree_vm_variant_is_ref(variant) &&
+ iree_hal_buffer_view_isa(&variant.ref)) {
+ auto buffer_view = iree_hal_buffer_view_deref(&variant.ref);
+
+ std::string result_str(4096, '\0');
+ iree_status_t status;
+ do {
+ iree_host_size_t actual_length = 0;
+ iree_host_size_t max_element_count =
+ std::numeric_limits<iree_host_size_t>::max();
+ status = iree_hal_buffer_view_format(buffer_view, max_element_count,
+ result_str.size() + 1,
+ &result_str[0], &actual_length);
+ result_str.resize(actual_length);
+ } while (iree_status_is_out_of_range(status));
+ CheckApiStatus(status,
+ "Failed to create a string representation of the inputs");
+
+ results.push_back(result_str);
+ } else {
+ RaiseValueError(
+ "Expected vm_list's elements to be scalars or buffer views.");
+ }
+ }
+ return results;
+}
+
void SetupFunctionAbiBindings(pybind11::module m) {
py::class_<FunctionAbi, std::unique_ptr<FunctionAbi>>(m, "FunctionAbi")
.def(py::init(&PyCreateAbi))
@@ -506,6 +545,10 @@
absl::MakeConstSpan(self->raw_config().inputs),
py_args, false /* writable */);
})
+ .def("serialize_vm_list",
+ [](FunctionAbi* self, VmVariantList& vm_list) {
+ return SerializeVmVariantList(vm_list);
+ })
.def("allocate_results", &PyAllocateResults, py::arg("f_results"),
py::arg("static_alloc") = true)
.def("raw_unpack_results", &PyRawUnpackResults);
diff --git a/bindings/python/pyiree/rt/system_api.py b/bindings/python/pyiree/rt/system_api.py
index 327883bc..17ac317 100644
--- a/bindings/python/pyiree/rt/system_api.py
+++ b/bindings/python/pyiree/rt/system_api.py
@@ -58,18 +58,16 @@
driver_exceptions = {}
for driver_name in driver_names:
if driver_name not in available_driver_names:
- print(
- "Could not create driver %s (not registered)" % driver_name,
- file=sys.stderr)
+ print("Could not create driver %s (not registered)" % driver_name,
+ file=sys.stderr)
continue
try:
driver = _binding.HalDriver.create(driver_name)
# TODO(laurenzo): Remove these prints to stderr (for now, more information
# is better and there is no better way to report it yet).
except Exception as ex: # pylint: disable=broad-except
- print(
- "Could not create default driver %s: %r" % (driver_name, ex),
- file=sys.stderr)
+ print("Could not create default driver %s: %r" % (driver_name, ex),
+ file=sys.stderr)
driver_exceptions[driver_name] = ex
continue
@@ -80,9 +78,8 @@
try:
device = driver.create_default_device()
except Exception as ex:
- print(
- "Could not create default driver device %s: %r" % (driver_name, ex),
- file=sys.stderr)
+ print("Could not create default driver device %s: %r" % (driver_name, ex),
+ file=sys.stderr)
driver_exceptions[driver_name] = ex
continue
@@ -134,14 +131,18 @@
self._context = context
self._vm_function = vm_function
self._abi = context.create_function_abi(vm_function)
+ self._serialized_inputs = None
+ self._serialized_outputs = None
def __call__(self, *args):
# NOTE: This is just doing sync dispatch right now. In the future,
# this should default to async and potentially have some kind of policy
# flag that can allow it to be overridden.
inputs = self._abi.raw_pack_inputs(args)
+ self._serialized_inputs = tuple(self._abi.serialize_vm_list(inputs))
results = self._abi.allocate_results(inputs, static_alloc=False)
self._context._vm_context.invoke(self._vm_function, inputs, results)
+ self._serialized_outputs = tuple(self._abi.serialize_vm_list(results))
unpacked_results = self._abi.raw_unpack_results(results)
# TODO(laurenzo): When switching from 'raw' to structured pack/unpack,
# the ABI should take care of this one-arg special case.
@@ -158,6 +159,12 @@
self._vm_function,
)
+ def get_serialized_values(self):
+ if self._serialized_inputs is None:
+ raise RuntimeError("Attempted to call get_serialized_values() before "
+ "any values were passed.")
+ return self._serialized_inputs, self._serialized_outputs
+
class BoundModule:
"""Wraps a VmModule with its context and provides nice python accessors.
@@ -219,8 +226,8 @@
else:
init_modules = None
- self._vm_context = _binding.VmContext(
- instance=self._config.vm_instance, modules=init_modules)
+ self._vm_context = _binding.VmContext(instance=self._config.vm_instance,
+ modules=init_modules)
if self._is_dynamic:
self._vm_context.register_modules(self._config.default_modules)
diff --git a/bindings/python/pyiree/rt/system_api_test.py b/bindings/python/pyiree/rt/system_api_test.py
index ca47439..a27f3b0 100644
--- a/bindings/python/pyiree/rt/system_api_test.py
+++ b/bindings/python/pyiree/rt/system_api_test.py
@@ -93,6 +93,19 @@
results = f(arg0, arg1)
np.testing.assert_allclose(results, [4., 10., 18., 28.])
+ def test_serialize_values(self):
+ ctx = rt.SystemContext()
+ self.assertTrue(ctx.is_dynamic)
+ ctx.add_module(create_simple_mul_module())
+ self.assertEqual(ctx.modules.arithmetic.name, "arithmetic")
+ f = ctx.modules.arithmetic["simple_mul"]
+ arg0 = np.array([1., 2., 3., 4.], dtype=np.float32)
+ arg1 = np.array([4., 5., 6., 7.], dtype=np.float32)
+ results = f(arg0, arg1)
+ inputs, outputs = f.get_serialized_values()
+ self.assertEqual(inputs, ("4xf32=1 2 3 4", "4xf32=4 5 6 7"))
+ self.assertEqual(outputs, ("4xf32=4 10 18 28",))
+
def test_load_module(self):
arithmetic = rt.load_module(create_simple_mul_module())
arg0 = np.array([1., 2., 3., 4.], dtype=np.float32)
diff --git a/bindings/python/pyiree/rt/vm_test.py b/bindings/python/pyiree/rt/vm_test.py
index da05a6c..e88b0f3 100644
--- a/bindings/python/pyiree/rt/vm_test.py
+++ b/bindings/python/pyiree/rt/vm_test.py
@@ -126,6 +126,7 @@
logging.info("abi: %s", abi)
inputs = abi.raw_pack_inputs((5, 6))
+ logging.info("serialize_inputs: %s", abi.serialize_vm_list(inputs))
logging.info("inputs: %s", inputs)
allocated_results = abi.allocate_results(inputs, static_alloc=False)
@@ -148,6 +149,7 @@
arg0 = np.array([[-1., 2.], [3., -4.]], dtype=np.float32)
inputs = abi.raw_pack_inputs((arg0,))
+ logging.info("Serialized inputs: %s", abi.serialize_vm_list(inputs))
logging.info("inputs: %s", inputs)
allocated_results = abi.allocate_results(inputs, static_alloc=False)
@@ -171,6 +173,7 @@
arg0 = np.array([1., 2., 3., 4.], dtype=np.float32)
arg1 = np.array([4., 5., 6., 7.], dtype=np.float32)
inputs = abi.raw_pack_inputs((arg0, arg1))
+ logging.info("Serialized inputs: %s", abi.serialize_vm_list(inputs))
logging.info("inputs: %s", inputs)
allocated_results = abi.allocate_results(inputs, static_alloc=False)
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
index 550324c..f59ff16 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
@@ -124,6 +124,8 @@
method: str,
inputs: Tuple[Any],
outputs: Tuple[Any],
+ serialized_inputs: Tuple[str],
+ serialized_outputs: Tuple[str],
rtol: float = 1e-6,
atol: float = 1e-6):
"""Records the details of a call to a CompiledModule."""
@@ -142,6 +144,9 @@
outputs = tuple()
self.outputs = outputs if isinstance(outputs, tuple) else (outputs,)
+ self.serialized_inputs = serialized_inputs
+ self.serialized_outputs = serialized_outputs
+
self.rtol = rtol
self.atol = atol
@@ -187,7 +192,13 @@
"""
os.makedirs(call_dir, exist_ok=True)
- metadata = {"method": self.method, "rtol": self.rtol, "atol": self.atol}
+ metadata = {
+ "method": self.method,
+ "serialized_inputs": self.serialized_inputs,
+ "serialized_outputs": self.serialized_outputs,
+ "rtol": self.rtol,
+ "atol": self.atol
+ }
with open(os.path.join(call_dir, "metadata.pkl"), "wb") as f:
pickle.dump(metadata, f)
@@ -248,7 +259,10 @@
if _load_dict is None:
# Extract metadata from module and function.
self.module_name = module.module_name
- self.backend = module.backend
+ self.compiled_path = module.compiled_path
+ self.backend_name = module.backend
+ self.supports_cxx_serialization = module.supports_cxx_serialization()
+ self.backend_driver = module.backend_driver
self.function_name = function.__name__
self.function_sourcefile = inspect.getsourcefile(function)
source, start_line = inspect.getsourcelines(function)
@@ -258,7 +272,10 @@
self.calls = []
else:
self.module_name = _load_dict["module_name"]
- self.backend = _load_dict["backend"]
+ self.compiled_path = _load_dict["compiled_path"]
+ self.backend_name = _load_dict["backend_name"]
+ self.supports_cxx_serialization = _load_dict["supports_cxx_serialization"]
+ self.backend_driver = _load_dict["backend_driver"]
self.function_name = _load_dict["function_name"]
self.function_sourcefile = _load_dict["function_sourcefile"]
self.function_line_numbers = _load_dict["function_line_numbers"]
@@ -266,7 +283,7 @@
self.calls = _load_dict["calls"]
def __str__(self):
- header = (f"Trace of {self.module_name} compiled to '{self.backend}' "
+ header = (f"Trace of {self.module_name} compiled to '{self.backend_name}' "
f"on function '{self.function_name}':")
# Give each call a number so it's easier to compare between multiple traces.
calls = [f"{i + 1}. {str(call)}" for i, call in enumerate(self.calls)]
@@ -307,9 +324,11 @@
if not calls_match:
logging.error("Comparision between '%s' and '%s' failed on method '%s'",
- ref_trace.backend, tar_trace.backend, ref_call.method)
- logging.error("Reference call '%s':\n%s", ref_trace.backend, ref_call)
- logging.error("Target call '%s':\n%s", tar_trace.backend, tar_call)
+ ref_trace.backend_name, tar_trace.backend_name,
+ ref_call.method)
+ logging.error("Reference call '%s':\n%s", ref_trace.backend_name,
+ ref_call)
+ logging.error("Target call '%s':\n%s", tar_trace.backend_name, tar_call)
traces_match = traces_match and calls_match
return traces_match
@@ -411,9 +430,14 @@
Args:
trace_dir: str, path to the directory to serialize the trace to.
"""
+
+ # Python serialization.
metadata = {
"module_name": self.module_name,
- "backend": self.backend,
+ "compiled_path": self.compiled_path,
+ "backend_name": self.backend_name,
+ "supports_cxx_serialization": self.supports_cxx_serialization,
+ "backend_driver": self.backend_driver,
"function_name": self.function_name,
"function_sourcefile": self.function_sourcefile,
"function_line_numbers": self.function_line_numbers,
@@ -427,6 +451,19 @@
call_dir = os.path.join(trace_dir, f"call_{str(i).zfill(width)}")
call.serialize(call_dir)
+ # C++ Serialization.
+ if not self.supports_cxx_serialization:
+ flaglines = []
+ if self.compiled_path is not None:
+ flaglines.append(f"--input_file={self.compiled_path}")
+ flaglines.append(f"--driver={self.backend_driver}")
+ inputs_str = ", ".join(self.calls[0].serialized_inputs)
+ flaglines.append(f"--inputs={inputs_str}")
+ flaglines.append(f"--entry_function={self.calls[0].method}")
+
+ with open(os.path.join(trace_dir, "flagfile"), "w") as f:
+ f.writelines(line + '\n' for line in flaglines)
+
@staticmethod
def load(trace_dir: str) -> "Trace":
"""Loads and returns a trace serialized with Trace.serialize.
@@ -446,7 +483,7 @@
def _get_trace_dir(artifacts_dir: str, trace: Trace) -> str:
- trace_dir = os.path.join(artifacts_dir, trace.backend, "traces",
+ trace_dir = os.path.join(artifacts_dir, trace.backend_name, "traces",
trace.function_name)
os.makedirs(trace_dir, exist_ok=True)
return trace_dir
@@ -484,8 +521,10 @@
# Run the method and record the details of the call.
outputs = method(*args, **kwargs)
+ serialized_inputs, serialized_outputs = method.get_serialized_values()
self._trace.calls.append(
- ModuleCall(method_name, args, outputs, **tolerances))
+ ModuleCall(method_name, args, outputs, serialized_inputs,
+ serialized_outputs, **tolerances))
return outputs
return call
@@ -621,7 +660,7 @@
failed_backend_indices = []
for i, tar_trace in enumerate(tar_traces):
logging.info("Comparing the reference backend '%s' with '%s'",
- ref_trace.backend, tar_trace.backend)
+ ref_trace.backend_name, tar_trace.backend_name)
traces_match = Trace.compare_traces(ref_trace, tar_trace)
if not traces_match:
failed_backend_indices.append(i)
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils_test.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils_test.py
index 5bffd59..dd054a5 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils_test.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils_test.py
@@ -182,8 +182,8 @@
np.array([81], dtype=np.float32), np.array([92], dtype=np.float32))
module.get_count()
- module = tf_utils.TfCompiledModule(StatefulCountingModule,
- tf_utils.BackendInfo('tf'))
+ module = tf_utils.IreeCompiledModule(StatefulCountingModule,
+ tf_utils.BackendInfo('iree_vmla'))
trace = tf_test_utils.Trace(module, trace_function)
trace_function(tf_test_utils.TracedModule(module, trace))
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py
index 98ce726..faf751e 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py
@@ -20,7 +20,7 @@
import random
import re
import tempfile
-from typing import Any, Callable, Sequence, Type
+from typing import Any, Callable, Dict, Sequence, Tuple, Type, Union
from absl import flags
from absl import logging
@@ -114,10 +114,12 @@
return os.path.join(artifacts_dir, f"{artifact_name}__{backends_string}")
-def compile_tf_module(tf_module: Type[tf.Module],
- backend_infos: Sequence["BackendInfo"] = (),
- exported_names: Sequence[str] = (),
- artifacts_dir: str = None) -> compiler.binding.OpaqueBlob:
+def compile_tf_module(
+ tf_module: Type[tf.Module],
+ backend_infos: Sequence["BackendInfo"] = (),
+ exported_names: Sequence[str] = (),
+ artifacts_dir: str = None
+) -> Tuple[compiler.binding.OpaqueBlob, Union[str, None]]:
"""Compiles a TensorFlow tf.Module and optionally saves compilation artifacts.
The artifact this creates is not callable. See IreeCompiledModule for an API
@@ -149,7 +151,8 @@
should be saved.
Returns:
- A compiled IREE module blob.
+ A compiled IREE module blob and the path to the compiled VM FlatBuffer if
+ artifacts_dir is provided.
"""
def _compile_from_path(sm_path: str) -> compiler.binding.OpaqueBlob:
@@ -189,6 +192,7 @@
target_backends.extend(backend_info.compiler_targets)
compiled_module = compiler_module.compile(target_backends=target_backends)
+ compiled_path = None
if artifacts_dir is not None:
compiled_path = _get_backends_path("compiled", backend_infos,
artifacts_dir)
@@ -197,7 +201,7 @@
with open(compiled_path, "wb") as f:
f.write(compiled_module)
- return compiled_module
+ return compiled_module, compiled_path
except Exception: # pylint: disable=broad-except
if artifacts_dir is not None:
# Disable the crash reproducer (to avoid inadvertently overwriting it).
@@ -233,12 +237,18 @@
# Public attributes:
self.backend = self._backend_info.name
+ self.backend_driver = self._backend_info.driver
self.module_name = self._module_class.__name__
+ self.compiled_path = None
def create_reinitialized(self):
"""Duplicates this module with its initial state without recompiling."""
raise NotImplementedError()
+ @staticmethod
+ def supports_cxx_serialization():
+ raise NotImplementedError()
+
class _IreeFunctionWrapper(object):
"""Wraps an IREE function, making it callable."""
@@ -250,6 +260,9 @@
def __call__(self, *args):
return self._f(*args)
+ def get_serialized_values(self) -> Tuple[Tuple[str], Tuple[str]]:
+ return self._f.get_serialized_values()
+
class IreeCompiledModule(CompiledModule):
"""Iree compiled module."""
@@ -259,7 +272,7 @@
backend_info: "BackendInfo",
exported_names: Sequence[str] = (),
artifacts_dir: str = None,
- _create_reinitialized_args: Sequence[Any] = None):
+ _create_reinitialized_dict: Dict[str, Any] = None):
"""Compile a tf.Module to the target backend in backend_info.
Args:
@@ -270,13 +283,13 @@
module_class's functions to compile. If exported_names is empty all
functions will be compiled.
artifacts_dir: an optional path to save compilation artifacts to.
- _create_reinitialized_args: used internally.
+ _create_reinitialized_dict: used internally.
"""
super().__init__(module_class, backend_info, exported_names, artifacts_dir)
- if _create_reinitialized_args is None:
+ if _create_reinitialized_dict is None:
set_random_seed()
- self._module_blob = compile_tf_module(
+ self._module_blob, self.compiled_path = compile_tf_module(
tf_module=module_class(),
backend_infos=[backend_info],
exported_names=exported_names,
@@ -285,11 +298,14 @@
self._config = rt.Config(driver_name=backend_info.driver)
else:
# Called from self.create_reinitialized()
- self._module_blob, self._module, self._config = _create_reinitialized_args
+ self._module_blob = _create_reinitialized_dict["_module_blob"]
+ self._module = _create_reinitialized_dict["_module"]
+ self._config = _create_reinitialized_dict["_config"]
+ self.compiled_path = _create_reinitialized_dict["compiled_path"]
# Holds all of the module's mutable state.
- self._context = rt.SystemContext(
- modules=[self._module], config=self._config)
+ self._context = rt.SystemContext(modules=[self._module],
+ config=self._config)
def create_reinitialized(self) -> "IreeCompiledModule":
"""Duplicates this module with its initial state without recompiling."""
@@ -297,8 +313,13 @@
self._module_class, self._backend_info, self._exported_names,
self._artifacts_dir
]
- create_reinitialized_args = [self._module_blob, self._module, self._config]
- return IreeCompiledModule(*default_args, create_reinitialized_args)
+ create_reinitialized_dict = {
+ "_module_blob": self._module_blob,
+ "_module": self._module,
+ "_config": self._config,
+ "compiled_path": self.compiled_path
+ }
+ return IreeCompiledModule(*default_args, create_reinitialized_dict)
def __getattr__(self, attr: str) -> _IreeFunctionWrapper:
# Try to resolve it as a function.
@@ -306,6 +327,10 @@
f = m[attr]
return _IreeFunctionWrapper(self._context, f)
+ @staticmethod
+ def supports_cxx_serialization() -> bool:
+ return True
+
class _TfFunctionWrapper(object):
"""Wraps a TF function, normalizing it to numpy."""
@@ -330,8 +355,13 @@
# which is sad).
if not isinstance(results, tuple):
results = (results,)
- return tf.nest.map_structure(
- self._convert_to_numpy, *results, check_types=False)
+ return tf.nest.map_structure(self._convert_to_numpy,
+ *results,
+ check_types=False)
+
+ def get_serialized_values(self) -> Tuple[Tuple[str], Tuple[str]]:
+ """Dummy function to match _IreeFunctionWrapper's API."""
+ return (), ()
class TfCompiledModule(CompiledModule):
@@ -377,6 +407,10 @@
f"The TensorFlow module does not have a callable attr '{attr}'")
return _TfFunctionWrapper(f)
+ @staticmethod
+ def supports_cxx_serialization() -> bool:
+ return False
+
class BackendInfo:
"""Contains information for compiling the specified backend."""
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils_test.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils_test.py
index 4f4084e..28ac195 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils_test.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils_test.py
@@ -64,12 +64,9 @@
def test_artifact_saving(self, backend_infos):
with tempfile.TemporaryDirectory() as artifacts_dir:
tf_module = ConstantModule()
- iree_compiled_module = tf_utils.compile_tf_module(
+ iree_compiled_module, compiled_path = tf_utils.compile_tf_module(
tf_module, backend_infos=backend_infos, artifacts_dir=artifacts_dir)
- compiled_path = tf_utils._get_backends_path('compiled', backend_infos,
- artifacts_dir)
- compiled_path = f'{compiled_path}.vmfb'
artifacts_to_check = [
'tf_input.mlir',
'iree_input.mlir',
diff --git a/integrations/tensorflow/e2e/README.md b/integrations/tensorflow/e2e/README.md
index 343b155..2a4c9eb 100644
--- a/integrations/tensorflow/e2e/README.md
+++ b/integrations/tensorflow/e2e/README.md
@@ -187,6 +187,39 @@
`SimpleArithmeticModule` example above, the `trace_dir` would be
`/tmp/iree/modules/SimpleArithmeticModule/iree_vmla/traces/simple_mul/`.
+## Benchmarking E2E Modules
+
+Abseil flagfiles containing all of the data that `iree-benchmark-module` needs
+to run are generated for each `Trace` in our E2E tests. This allows for any
+module we test to be easily benchmarked on valid inputs. The process for
+benchmarking a vision model can thus be reduced to the following:
+
+```shell
+# Generate benchmarking artifacts for all vision models:
+bazel test integrations/tensorflow/e2e/keras:vision_external_tests
+
+# Benchmark ResNet50 with cifar10 weights on vmla:
+bazel run iree/tools:iree-benchmark-module -- \
+ --flagfile=/tmp/iree/modules/ResNet50/cifar10/iree_vmla/traces/predict/flagfile
+
+# Benchmark ResNet50 with cifar10 weights on llvmjit:
+bazel run iree/tools:iree-benchmark-module -- \
+ --flagfile=/tmp/iree/modules/ResNet50/cifar10/iree_llvmjit/traces/predict/flagfile
+```
+
+Duplicate flags provided after the flagfile will take precedence. For example:
+
+```shell
+bazel run iree/tools:iree-benchmark-module -- \
+ --flagfile=/tmp/iree/modules/ResNet50/cifar10/iree_llvmjit/traces/predict/flagfile \
+ --input_file=/path/to/custom/compiled.vmfb
+```
+
+Currently, this only supports benchmarking the first module call in a trace.
+We plan to extend this to support benchmarking all of the calls in the trace,
+and also plan to support verifying outputs during the warm-up phase of the
+benchmark.
+
## Debugging Tests
If the compiler fails to compile the program, then it will create a crash