Merge pull request #3278 from hanhanW:main-to-google

PiperOrigin-RevId: 334170175
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
index 4042e36..87a4c8a 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
@@ -259,9 +259,10 @@
     if _load_dict is None:
       # Extract metadata from module and function.
       self.module_name = module.module_name
-      self.compiled_path = module.compiled_path
+      self.compiled_paths = module.compiled_paths
       self.backend_name = module.backend
-      self.supports_cxx_serialization = module.supports_cxx_serialization()
+      self.iree_serializable = module.iree_serializable()
+      self.tflite_serializable = module.tflite_serializable()
       self.backend_driver = module.backend_driver
       self.function_name = function.__name__
       self.function_sourcefile = inspect.getsourcefile(function)
@@ -272,9 +273,10 @@
       self.calls = []
     else:
       self.module_name = _load_dict["module_name"]
-      self.compiled_path = _load_dict["compiled_path"]
+      self.compiled_paths = _load_dict["compiled_paths"]
       self.backend_name = _load_dict["backend_name"]
-      self.supports_cxx_serialization = _load_dict["supports_cxx_serialization"]
+      self.iree_serializable = _load_dict["iree_serializable"]
+      self.tflite_serializable = _load_dict["tflite_serializable"]
       self.backend_driver = _load_dict["backend_driver"]
       self.function_name = _load_dict["function_name"]
       self.function_sourcefile = _load_dict["function_sourcefile"]
@@ -434,9 +436,10 @@
     # Python serialization.
     metadata = {
         "module_name": self.module_name,
-        "compiled_path": self.compiled_path,
+        "compiled_paths": self.compiled_paths,
         "backend_name": self.backend_name,
-        "supports_cxx_serialization": self.supports_cxx_serialization,
+        "iree_serializable": self.iree_serializable,
+        "tflite_serializable": self.tflite_serializable,
         "backend_driver": self.backend_driver,
         "function_name": self.function_name,
         "function_sourcefile": self.function_sourcefile,
@@ -451,18 +454,24 @@
       call_dir = os.path.join(trace_dir, f"call_{str(i).zfill(width)}")
       call.serialize(call_dir)
 
-    # C++ Serialization.
-    if self.supports_cxx_serialization:
-      flaglines = []
-      if self.compiled_path is not None:
-        flaglines.append(f"--input_file={self.compiled_path}")
-      flaglines.append(f"--driver={self.backend_driver}")
-      inputs_str = ", ".join(self.calls[0].serialized_inputs)
-      flaglines.append(f"--inputs={inputs_str}")
-      flaglines.append(f"--entry_function={self.calls[0].method}")
+    # C++ benchmark serialization.
+    if self.iree_serializable or self.tflite_serializable:
+      entry_function = self.calls[0].method
+      compiled_path = self.compiled_paths[entry_function]
 
-      with open(os.path.join(trace_dir, "flagfile"), "w") as f:
-        f.writelines(line + "\n" for line in flaglines)
+      if self.iree_serializable:
+        serialized_inputs = ", ".join(self.calls[0].serialized_inputs)
+        flagfile = [
+            f"--input_file={compiled_path}",
+            f"--driver={self.backend_driver}",
+            f"--inputs={serialized_inputs}",
+            f"--entry_function={entry_function}"
+        ]
+        with open(os.path.join(trace_dir, "flagfile"), "w") as f:
+          f.writelines(line + "\n" for line in flagfile)
+      else:
+        with open(os.path.join(trace_dir, "graph_path"), "w") as f:
+          f.writelines(compiled_path + "\n")
 
   @staticmethod
   def load(trace_dir: str) -> "Trace":
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils_test.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils_test.py
index 3093028..6157c68 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils_test.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils_test.py
@@ -191,7 +191,7 @@
       trace_function_dir = tf_test_utils._get_trace_dir(artifacts_dir, trace)
       trace.serialize(trace_function_dir)
       self.assertTrue(
-          os.path.exists(os.path.join(trace_function_dir, 'flagfile')))
+          os.path.exists(os.path.join(trace_function_dir, 'metadata.pkl')))
       loaded_trace = tf_test_utils.Trace.load(trace_function_dir)
 
       # Check all calls match.
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py
index d8e1b8c..b017506 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py
@@ -157,8 +157,9 @@
 
   try:
     # Convert the tf_module into raw TF input MLIR.
-    compiler_module = compiler.tf_module_to_compiler_module(
-        tf_module, exported_names, pass_pipeline=())
+    compiler_module = compiler.tf_module_to_compiler_module(tf_module,
+                                                            exported_names,
+                                                            pass_pipeline=())
 
     if artifacts_dir is not None:
       tf_mlir_path = os.path.join(artifacts_dir, "tf_input.mlir")
@@ -214,15 +215,25 @@
     self.backend = self._backend_info.name
     self.backend_driver = self._backend_info.driver
     self.module_name = self._module_class.__name__
-    self.compiled_path = None
+    self.compiled_paths = None
 
   def reinitialize(self):
     """Reinitializes to the initial state of the passed module_class."""
     raise NotImplementedError()
 
-  @staticmethod
-  def supports_cxx_serialization():
-    raise NotImplementedError()
+  def iree_serializable(self):
+    return False
+
+  def tflite_serializable(self):
+    return False
+
+
+def _get_non_inhereted_function_names(cls):
+  """Gets all methods that cls has that its parents don't have."""
+  names = set(dir(cls))
+  for parent in cls.__bases__:
+    names -= set(dir(parent))
+  return list(names)
 
 
 class _FunctionWrapper(object):
@@ -269,7 +280,7 @@
     super().__init__(module_class, backend_info, exported_names, artifacts_dir)
 
     set_random_seed()
-    self._module_blob, self.compiled_path = compile_tf_module(
+    self._module_blob, compiled_path = compile_tf_module(
         tf_module=module_class(),
         backend_infos=[backend_info],
         exported_names=exported_names,
@@ -277,14 +288,26 @@
     self._module = rt.VmModule.from_flatbuffer(self._module_blob)
     self._config = rt.Config(driver_name=backend_info.driver)
 
+    self.compiled_paths = None
+    if compiled_path is not None:
+      if not len(exported_names):
+        # Get all method names on 'module_class' that aren't on 'tf.Module'.
+        # This doesn't address all possbile scenarios.
+        # TODO(meadowlark): Figure out how to get a list of all of the functions
+        # that this module has access to via `pyiree.rt.system_api.BoundModule`.
+        exported_names = _get_non_inhereted_function_names(module_class)
+      self.compiled_paths = dict([
+          (method, compiled_path) for method in exported_names
+      ])
+
     self.reinitialize()
 
   def reinitialize(self):
     """Reinitializes to the initial state of the passed module_class."""
     # set_random_seed is not needed here because the model_class.__init__ is not
     # called.
-    self._context = rt.SystemContext(
-        modules=[self._module], config=self._config)
+    self._context = rt.SystemContext(modules=[self._module],
+                                     config=self._config)
 
   def __getattr__(self, attr: str) -> _IreeFunctionWrapper:
     # Try to resolve it as a function.
@@ -292,9 +315,8 @@
     f = m[attr]
     return _IreeFunctionWrapper(self._context, f)
 
-  @staticmethod
-  def supports_cxx_serialization() -> bool:
-    return True
+  def iree_serializable(self) -> bool:
+    return self.compiled_paths is not None
 
 
 def _normalize_numpy(result: np.ndarray):
@@ -326,8 +348,9 @@
     # which is sad).
     if not isinstance(results, tuple):
       results = (results,)
-    return tf.nest.map_structure(
-        self._convert_to_numpy, *results, check_types=False)
+    return tf.nest.map_structure(self._convert_to_numpy,
+                                 *results,
+                                 check_types=False)
 
 
 class TfCompiledModule(CompiledModule):
@@ -372,25 +395,13 @@
           f"The TensorFlow module does not have a callable attr '{attr}'")
     return _TfFunctionWrapper(f)
 
-  @staticmethod
-  def supports_cxx_serialization() -> bool:
-    return False
 
-
-def get_non_inhereted_function_names(cls):
-  """Gets all methods that cls has that its parents don't have."""
-  names = set(dir(cls))
-  for parent in cls.__bases__:
-    names -= set(dir(parent))
-  return list(names)
-
-
-def get_concrete_functions(module_class: Type[tf.Module],
-                           exported_names: Sequence[str] = ()):
+def _get_concrete_functions(module_class: Type[tf.Module],
+                            exported_names: Sequence[str] = ()):
   """Get concrete functions from non-inherited methods or exported_names."""
   if not len(exported_names):
     # Get all method names on 'module_class' that aren't on 'tf.Module'.
-    exported_names = get_non_inhereted_function_names(module_class)
+    exported_names = _get_non_inhereted_function_names(module_class)
   instance = module_class()
   functions = []
   for name in exported_names:
@@ -402,8 +413,11 @@
                       exported_names: Sequence[str] = (),
                       artifacts_dir: str = None):
   """Compile a dict of TFLite interpreters for the methods on module_class."""
-  functions, names = get_concrete_functions(module_class, exported_names)
+  functions, names = _get_concrete_functions(module_class, exported_names)
   interpreters = dict()
+  compiled_paths = None
+  if artifacts_dir is not None:
+    compiled_paths = dict()
 
   def _interpret_bytes(tflite_module: bytes, base_dir: str):
     """Save compiled TFLite module bytes and convert into an interpreter."""
@@ -412,7 +426,10 @@
     tflite_path = os.path.join(tflite_dir, f"{name}.tflite")
     with open(tflite_path, "wb") as f:
       f.write(tflite_module)
+
     interpreters[name] = tf.lite.Interpreter(tflite_path)
+    if artifacts_dir is not None:
+      compiled_paths[name] = tflite_path
 
   for name, function in zip(names, functions):
     converter = tf.lite.TFLiteConverter.from_concrete_functions([function])
@@ -424,7 +441,7 @@
     else:
       _interpret_bytes(tflite_module, artifacts_dir)
 
-  return interpreters
+  return interpreters, compiled_paths
 
 
 class _TfLiteFunctionWrapper(_FunctionWrapper):
@@ -465,8 +482,8 @@
                artifacts_dir: str = None):
     super().__init__(module_class, backend_info, exported_names, artifacts_dir)
     set_random_seed()
-    self._interpreters = compile_to_tflite(module_class, exported_names,
-                                           artifacts_dir)
+    self._interpreters, self.compiled_paths = compile_to_tflite(
+        module_class, exported_names, artifacts_dir)
 
   def reinitialize(self):
     """Reinitializes to the initial state of the passed module_class."""
@@ -480,9 +497,8 @@
           f"The TFLite module does not have an interpreter for '{attr}'")
     return _TfLiteFunctionWrapper(self._interpreters[attr])
 
-  @staticmethod
-  def supports_cxx_serialization() -> bool:
-    return False
+  def tflite_serializable(self) -> bool:
+    return self.compiled_paths is not None
 
 
 class BackendInfo:
diff --git a/iree/tools/BUILD b/iree/tools/BUILD
index 530e835..9ccc7bf 100644
--- a/iree/tools/BUILD
+++ b/iree/tools/BUILD
@@ -40,6 +40,8 @@
     deps = [
         ":vm_util",
         "@com_google_absl//absl/flags:flag",
+        "@com_google_absl//absl/flags:parse",
+        "@com_google_absl//absl/flags:usage",
         "@com_google_absl//absl/strings",
         "@com_google_benchmark//:benchmark",
         "//iree/base:init",
diff --git a/iree/tools/CMakeLists.txt b/iree/tools/CMakeLists.txt
index 9a073c6..9d73449 100644
--- a/iree/tools/CMakeLists.txt
+++ b/iree/tools/CMakeLists.txt
@@ -59,6 +59,8 @@
   DEPS
     ::vm_util
     absl::flags
+    absl::flags_parse
+    absl::flags_usage
     absl::strings
     benchmark
     iree::base::init
diff --git a/iree/tools/iree-benchmark-module-main.cc b/iree/tools/iree-benchmark-module-main.cc
index 188a170..0c260be 100644
--- a/iree/tools/iree-benchmark-module-main.cc
+++ b/iree/tools/iree-benchmark-module-main.cc
@@ -13,6 +13,8 @@
 // limitations under the License.
 
 #include "absl/flags/flag.h"
+#include "absl/flags/internal/parse.h"
+#include "absl/flags/usage.h"
 #include "absl/strings/string_view.h"
 #include "benchmark/benchmark.h"
 #include "iree/base/file_io.h"
@@ -164,8 +166,12 @@
 
 }  // namespace
 
-void RegisterModuleBenchmarks() {
+Status RegisterModuleBenchmarks() {
   auto function_name = absl::GetFlag(FLAGS_entry_function);
+  if (function_name.empty()) {
+    return InvalidArgumentErrorBuilder(IREE_LOC)
+           << "Must specify an entry_function";
+  }
   auto benchmark_name = "BM_" + function_name;
   benchmark::RegisterBenchmark(benchmark_name.c_str(),
                                [function_name](benchmark::State& state) {
@@ -184,13 +190,47 @@
       // significant digits. If we end up wanting precision beyond microseconds,
       // we can make this setting configurable with a custom command line flag.
       ->Unit(benchmark::kMillisecond);
+  return OkStatus();
 }
 }  // namespace iree
 
 int main(int argc, char** argv) {
+  // We have to contend with two flag parsing libraries here: absl's and
+  // benchmark's. To make matters worse, both define the `--help` flag. To
+  // ensure that each is able to parse its own flags, we use an absl "internal"
+  // function (still with public visibility) to parse while ignoring undefined
+  // flags. If it sees `--help` it will exit here, so we include the benchmark
+  // library usage information in the manually-set help output. Then we let
+  // benchmark parse its flags. Finally we call the normal initialization
+  // function to do other IREE initialization including flag parsing with
+  // normal options. Any remaining flags will be unknown and result in an error.
+  absl::SetProgramUsageMessage(
+      "iree-benchmark-module \n"
+      "    --input_file=module.vmfb\n"
+      "    --entry_function=exported_function_to_benchmark\n"
+      "    [--inputs=2xi32=1 2,1x2xf32=2 1 | --inputs_file=file_with_inputs]\n"
+      "    [--driver=vmla]\n"
+      "\n\n"
+      "  Optional flags from third_party/benchmark/src/benchmark.cc:\n"
+      "    [--benchmark_list_tests={true|false}]\n"
+      "    [--benchmark_filter=<regex>]\n"
+      "    [--benchmark_min_time=<min_time>]\n"
+      "    [--benchmark_repetitions=<num_repetitions>]\n"
+      "    [--benchmark_report_aggregates_only={true|false}]\n"
+      "    [--benchmark_display_aggregates_only={true|false}]\n"
+      "    [--benchmark_format=<console|json|csv>]\n"
+      "    [--benchmark_out=<filename>]\n"
+      "    [--benchmark_out_format=<json|console|csv>]\n"
+      "    [--benchmark_color={auto|true|false}]\n"
+      "    [--benchmark_counters_tabular={true|false}]\n"
+      "    [--v=<verbosity>]\n");
+  absl::flags_internal::ParseCommandLineImpl(
+      argc, argv, absl::flags_internal::ArgvListAction::kRemoveParsedArgs,
+      absl::flags_internal::UsageFlagsAction::kHandleUsage,
+      absl::flags_internal::OnUndefinedFlag::kIgnoreUndefined);
   ::benchmark::Initialize(&argc, argv);
   iree::InitializeEnvironment(&argc, &argv);
-  iree::RegisterModuleBenchmarks();
+  IREE_CHECK_OK(iree::RegisterModuleBenchmarks());
   ::benchmark::RunSpecifiedBenchmarks();
   return 0;
 }
diff --git a/iree/tools/test/BUILD b/iree/tools/test/BUILD
index 2446837..5e4c337 100644
--- a/iree/tools/test/BUILD
+++ b/iree/tools/test/BUILD
@@ -34,3 +34,13 @@
     ],
     tags = ["hostonly"],
 )
+
+iree_lit_test_suite(
+    name = "benchmark_flags",
+    srcs = ["benchmark_flags.txt"],
+    data = [
+        "//iree/tools:IreeFileCheck",
+        "//iree/tools:iree-benchmark-module",
+    ],
+    tags = ["hostonly"],
+)
diff --git a/iree/tools/test/CMakeLists.txt b/iree/tools/test/CMakeLists.txt
index 658aa8e..26bad30 100644
--- a/iree/tools/test/CMakeLists.txt
+++ b/iree/tools/test/CMakeLists.txt
@@ -29,3 +29,15 @@
   LABELS
     "hostonly"
 )
+
+iree_lit_test_suite(
+  NAME
+    benchmark_flags
+  SRCS
+    "benchmark_flags.txt"
+  DATA
+    iree::tools::IreeFileCheck
+    iree::tools::iree-benchmark-module
+  LABELS
+    "hostonly"
+)
diff --git a/iree/tools/test/benchmark_flags.txt b/iree/tools/test/benchmark_flags.txt
new file mode 100644
index 0000000..70f279c
--- /dev/null
+++ b/iree/tools/test/benchmark_flags.txt
@@ -0,0 +1,12 @@
+// HELP: iree-benchmark-module
+// HELP: --input_file
+// HELP: --benchmark_list_tests
+// RUN: ( iree-benchmark-module --help || [[ $? == 1 ]] )  | IreeFileCheck --check-prefix=HELP %s
+// RUN: ( iree-benchmark-module --helpshort || [[ $? == 1 ]] ) | IreeFileCheck --check-prefix=HELP %s
+
+// UNKNOWN: unknown-flag
+// RUN: ( iree-benchmark-module --unknown-flag 2>&1 || [[ $? == 1 ]] ) | IreeFileCheck --check-prefix=UNKNOWN %s
+// RUN: ( iree-benchmark-module --driver=vmla --unknown-flag --benchmark_list_tests 2>&1 || [[ $? == 1 ]] ) | IreeFileCheck --check-prefix=UNKNOWN %s
+
+// LIST-BENCHMARKS: BM_some_function
+// RUN: iree-benchmark-module --benchmark_list_tests --entry_function=some_function | IreeFileCheck --check-prefix=LIST-BENCHMARKS %s