Adds for logging all traces to stdout (#2840)
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
index c8114a1..1249254 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
@@ -46,6 +46,9 @@
flags.DEFINE_bool(
"summarize", True,
"Summarize the inputs and outputs of each module trace logged to disk.")
+flags.DEFINE_bool(
+ "log_all_traces", False,
+ "Log all traces to logging.info, even if comparison passes.")
FLAGS = flags.FLAGS
NUMPY_LINEWIDTH = 120
@@ -483,9 +486,13 @@
# Run the traces through trace_function with their associated modules.
tf_utils.set_random_seed()
trace_function(TracedModule(self._ref_module, ref_trace))
+ if FLAGS.log_all_traces:
+ logging.info(ref_trace)
for module, trace in zip(self._tar_modules, tar_traces):
tf_utils.set_random_seed()
trace_function(TracedModule(module, trace))
+ if FLAGS.log_all_traces:
+ logging.info(trace)
# Compare each target trace of trace_function with the reference trace.
failed_backend_indices = []