Merge pull request #3278 from hanhanW:main-to-google PiperOrigin-RevId: 334170175
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py index 4042e36..87a4c8a 100644 --- a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py +++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
@@ -259,9 +259,10 @@ if _load_dict is None: # Extract metadata from module and function. self.module_name = module.module_name - self.compiled_path = module.compiled_path + self.compiled_paths = module.compiled_paths self.backend_name = module.backend - self.supports_cxx_serialization = module.supports_cxx_serialization() + self.iree_serializable = module.iree_serializable() + self.tflite_serializable = module.tflite_serializable() self.backend_driver = module.backend_driver self.function_name = function.__name__ self.function_sourcefile = inspect.getsourcefile(function) @@ -272,9 +273,10 @@ self.calls = [] else: self.module_name = _load_dict["module_name"] - self.compiled_path = _load_dict["compiled_path"] + self.compiled_paths = _load_dict["compiled_paths"] self.backend_name = _load_dict["backend_name"] - self.supports_cxx_serialization = _load_dict["supports_cxx_serialization"] + self.iree_serializable = _load_dict["iree_serializable"] + self.tflite_serializable = _load_dict["tflite_serializable"] self.backend_driver = _load_dict["backend_driver"] self.function_name = _load_dict["function_name"] self.function_sourcefile = _load_dict["function_sourcefile"] @@ -434,9 +436,10 @@ # Python serialization. metadata = { "module_name": self.module_name, - "compiled_path": self.compiled_path, + "compiled_paths": self.compiled_paths, "backend_name": self.backend_name, - "supports_cxx_serialization": self.supports_cxx_serialization, + "iree_serializable": self.iree_serializable, + "tflite_serializable": self.tflite_serializable, "backend_driver": self.backend_driver, "function_name": self.function_name, "function_sourcefile": self.function_sourcefile, @@ -451,18 +454,24 @@ call_dir = os.path.join(trace_dir, f"call_{str(i).zfill(width)}") call.serialize(call_dir) - # C++ Serialization. - if self.supports_cxx_serialization: - flaglines = [] - if self.compiled_path is not None: - flaglines.append(f"--input_file={self.compiled_path}") - flaglines.append(f"--driver={self.backend_driver}") - inputs_str = ", ".join(self.calls[0].serialized_inputs) - flaglines.append(f"--inputs={inputs_str}") - flaglines.append(f"--entry_function={self.calls[0].method}") + # C++ benchmark serialization. + if self.iree_serializable or self.tflite_serializable: + entry_function = self.calls[0].method + compiled_path = self.compiled_paths[entry_function] - with open(os.path.join(trace_dir, "flagfile"), "w") as f: - f.writelines(line + "\n" for line in flaglines) + if self.iree_serializable: + serialized_inputs = ", ".join(self.calls[0].serialized_inputs) + flagfile = [ + f"--input_file={compiled_path}", + f"--driver={self.backend_driver}", + f"--inputs={serialized_inputs}", + f"--entry_function={entry_function}" + ] + with open(os.path.join(trace_dir, "flagfile"), "w") as f: + f.writelines(line + "\n" for line in flagfile) + else: + with open(os.path.join(trace_dir, "graph_path"), "w") as f: + f.writelines(compiled_path + "\n") @staticmethod def load(trace_dir: str) -> "Trace":
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils_test.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils_test.py index 3093028..6157c68 100644 --- a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils_test.py +++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils_test.py
@@ -191,7 +191,7 @@ trace_function_dir = tf_test_utils._get_trace_dir(artifacts_dir, trace) trace.serialize(trace_function_dir) self.assertTrue( - os.path.exists(os.path.join(trace_function_dir, 'flagfile'))) + os.path.exists(os.path.join(trace_function_dir, 'metadata.pkl'))) loaded_trace = tf_test_utils.Trace.load(trace_function_dir) # Check all calls match.
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py index d8e1b8c..b017506 100644 --- a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py +++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py
@@ -157,8 +157,9 @@ try: # Convert the tf_module into raw TF input MLIR. - compiler_module = compiler.tf_module_to_compiler_module( - tf_module, exported_names, pass_pipeline=()) + compiler_module = compiler.tf_module_to_compiler_module(tf_module, + exported_names, + pass_pipeline=()) if artifacts_dir is not None: tf_mlir_path = os.path.join(artifacts_dir, "tf_input.mlir") @@ -214,15 +215,25 @@ self.backend = self._backend_info.name self.backend_driver = self._backend_info.driver self.module_name = self._module_class.__name__ - self.compiled_path = None + self.compiled_paths = None def reinitialize(self): """Reinitializes to the initial state of the passed module_class.""" raise NotImplementedError() - @staticmethod - def supports_cxx_serialization(): - raise NotImplementedError() + def iree_serializable(self): + return False + + def tflite_serializable(self): + return False + + +def _get_non_inhereted_function_names(cls): + """Gets all methods that cls has that its parents don't have.""" + names = set(dir(cls)) + for parent in cls.__bases__: + names -= set(dir(parent)) + return list(names) class _FunctionWrapper(object): @@ -269,7 +280,7 @@ super().__init__(module_class, backend_info, exported_names, artifacts_dir) set_random_seed() - self._module_blob, self.compiled_path = compile_tf_module( + self._module_blob, compiled_path = compile_tf_module( tf_module=module_class(), backend_infos=[backend_info], exported_names=exported_names, @@ -277,14 +288,26 @@ self._module = rt.VmModule.from_flatbuffer(self._module_blob) self._config = rt.Config(driver_name=backend_info.driver) + self.compiled_paths = None + if compiled_path is not None: + if not len(exported_names): + # Get all method names on 'module_class' that aren't on 'tf.Module'. + # This doesn't address all possbile scenarios. + # TODO(meadowlark): Figure out how to get a list of all of the functions + # that this module has access to via `pyiree.rt.system_api.BoundModule`. + exported_names = _get_non_inhereted_function_names(module_class) + self.compiled_paths = dict([ + (method, compiled_path) for method in exported_names + ]) + self.reinitialize() def reinitialize(self): """Reinitializes to the initial state of the passed module_class.""" # set_random_seed is not needed here because the model_class.__init__ is not # called. - self._context = rt.SystemContext( - modules=[self._module], config=self._config) + self._context = rt.SystemContext(modules=[self._module], + config=self._config) def __getattr__(self, attr: str) -> _IreeFunctionWrapper: # Try to resolve it as a function. @@ -292,9 +315,8 @@ f = m[attr] return _IreeFunctionWrapper(self._context, f) - @staticmethod - def supports_cxx_serialization() -> bool: - return True + def iree_serializable(self) -> bool: + return self.compiled_paths is not None def _normalize_numpy(result: np.ndarray): @@ -326,8 +348,9 @@ # which is sad). if not isinstance(results, tuple): results = (results,) - return tf.nest.map_structure( - self._convert_to_numpy, *results, check_types=False) + return tf.nest.map_structure(self._convert_to_numpy, + *results, + check_types=False) class TfCompiledModule(CompiledModule): @@ -372,25 +395,13 @@ f"The TensorFlow module does not have a callable attr '{attr}'") return _TfFunctionWrapper(f) - @staticmethod - def supports_cxx_serialization() -> bool: - return False - -def get_non_inhereted_function_names(cls): - """Gets all methods that cls has that its parents don't have.""" - names = set(dir(cls)) - for parent in cls.__bases__: - names -= set(dir(parent)) - return list(names) - - -def get_concrete_functions(module_class: Type[tf.Module], - exported_names: Sequence[str] = ()): +def _get_concrete_functions(module_class: Type[tf.Module], + exported_names: Sequence[str] = ()): """Get concrete functions from non-inherited methods or exported_names.""" if not len(exported_names): # Get all method names on 'module_class' that aren't on 'tf.Module'. - exported_names = get_non_inhereted_function_names(module_class) + exported_names = _get_non_inhereted_function_names(module_class) instance = module_class() functions = [] for name in exported_names: @@ -402,8 +413,11 @@ exported_names: Sequence[str] = (), artifacts_dir: str = None): """Compile a dict of TFLite interpreters for the methods on module_class.""" - functions, names = get_concrete_functions(module_class, exported_names) + functions, names = _get_concrete_functions(module_class, exported_names) interpreters = dict() + compiled_paths = None + if artifacts_dir is not None: + compiled_paths = dict() def _interpret_bytes(tflite_module: bytes, base_dir: str): """Save compiled TFLite module bytes and convert into an interpreter.""" @@ -412,7 +426,10 @@ tflite_path = os.path.join(tflite_dir, f"{name}.tflite") with open(tflite_path, "wb") as f: f.write(tflite_module) + interpreters[name] = tf.lite.Interpreter(tflite_path) + if artifacts_dir is not None: + compiled_paths[name] = tflite_path for name, function in zip(names, functions): converter = tf.lite.TFLiteConverter.from_concrete_functions([function]) @@ -424,7 +441,7 @@ else: _interpret_bytes(tflite_module, artifacts_dir) - return interpreters + return interpreters, compiled_paths class _TfLiteFunctionWrapper(_FunctionWrapper): @@ -465,8 +482,8 @@ artifacts_dir: str = None): super().__init__(module_class, backend_info, exported_names, artifacts_dir) set_random_seed() - self._interpreters = compile_to_tflite(module_class, exported_names, - artifacts_dir) + self._interpreters, self.compiled_paths = compile_to_tflite( + module_class, exported_names, artifacts_dir) def reinitialize(self): """Reinitializes to the initial state of the passed module_class.""" @@ -480,9 +497,8 @@ f"The TFLite module does not have an interpreter for '{attr}'") return _TfLiteFunctionWrapper(self._interpreters[attr]) - @staticmethod - def supports_cxx_serialization() -> bool: - return False + def tflite_serializable(self) -> bool: + return self.compiled_paths is not None class BackendInfo:
diff --git a/iree/tools/BUILD b/iree/tools/BUILD index 530e835..9ccc7bf 100644 --- a/iree/tools/BUILD +++ b/iree/tools/BUILD
@@ -40,6 +40,8 @@ deps = [ ":vm_util", "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/flags:parse", + "@com_google_absl//absl/flags:usage", "@com_google_absl//absl/strings", "@com_google_benchmark//:benchmark", "//iree/base:init",
diff --git a/iree/tools/CMakeLists.txt b/iree/tools/CMakeLists.txt index 9a073c6..9d73449 100644 --- a/iree/tools/CMakeLists.txt +++ b/iree/tools/CMakeLists.txt
@@ -59,6 +59,8 @@ DEPS ::vm_util absl::flags + absl::flags_parse + absl::flags_usage absl::strings benchmark iree::base::init
diff --git a/iree/tools/iree-benchmark-module-main.cc b/iree/tools/iree-benchmark-module-main.cc index 188a170..0c260be 100644 --- a/iree/tools/iree-benchmark-module-main.cc +++ b/iree/tools/iree-benchmark-module-main.cc
@@ -13,6 +13,8 @@ // limitations under the License. #include "absl/flags/flag.h" +#include "absl/flags/internal/parse.h" +#include "absl/flags/usage.h" #include "absl/strings/string_view.h" #include "benchmark/benchmark.h" #include "iree/base/file_io.h" @@ -164,8 +166,12 @@ } // namespace -void RegisterModuleBenchmarks() { +Status RegisterModuleBenchmarks() { auto function_name = absl::GetFlag(FLAGS_entry_function); + if (function_name.empty()) { + return InvalidArgumentErrorBuilder(IREE_LOC) + << "Must specify an entry_function"; + } auto benchmark_name = "BM_" + function_name; benchmark::RegisterBenchmark(benchmark_name.c_str(), [function_name](benchmark::State& state) { @@ -184,13 +190,47 @@ // significant digits. If we end up wanting precision beyond microseconds, // we can make this setting configurable with a custom command line flag. ->Unit(benchmark::kMillisecond); + return OkStatus(); } } // namespace iree int main(int argc, char** argv) { + // We have to contend with two flag parsing libraries here: absl's and + // benchmark's. To make matters worse, both define the `--help` flag. To + // ensure that each is able to parse its own flags, we use an absl "internal" + // function (still with public visibility) to parse while ignoring undefined + // flags. If it sees `--help` it will exit here, so we include the benchmark + // library usage information in the manually-set help output. Then we let + // benchmark parse its flags. Finally we call the normal initialization + // function to do other IREE initialization including flag parsing with + // normal options. Any remaining flags will be unknown and result in an error. + absl::SetProgramUsageMessage( + "iree-benchmark-module \n" + " --input_file=module.vmfb\n" + " --entry_function=exported_function_to_benchmark\n" + " [--inputs=2xi32=1 2,1x2xf32=2 1 | --inputs_file=file_with_inputs]\n" + " [--driver=vmla]\n" + "\n\n" + " Optional flags from third_party/benchmark/src/benchmark.cc:\n" + " [--benchmark_list_tests={true|false}]\n" + " [--benchmark_filter=<regex>]\n" + " [--benchmark_min_time=<min_time>]\n" + " [--benchmark_repetitions=<num_repetitions>]\n" + " [--benchmark_report_aggregates_only={true|false}]\n" + " [--benchmark_display_aggregates_only={true|false}]\n" + " [--benchmark_format=<console|json|csv>]\n" + " [--benchmark_out=<filename>]\n" + " [--benchmark_out_format=<json|console|csv>]\n" + " [--benchmark_color={auto|true|false}]\n" + " [--benchmark_counters_tabular={true|false}]\n" + " [--v=<verbosity>]\n"); + absl::flags_internal::ParseCommandLineImpl( + argc, argv, absl::flags_internal::ArgvListAction::kRemoveParsedArgs, + absl::flags_internal::UsageFlagsAction::kHandleUsage, + absl::flags_internal::OnUndefinedFlag::kIgnoreUndefined); ::benchmark::Initialize(&argc, argv); iree::InitializeEnvironment(&argc, &argv); - iree::RegisterModuleBenchmarks(); + IREE_CHECK_OK(iree::RegisterModuleBenchmarks()); ::benchmark::RunSpecifiedBenchmarks(); return 0; }
diff --git a/iree/tools/test/BUILD b/iree/tools/test/BUILD index 2446837..5e4c337 100644 --- a/iree/tools/test/BUILD +++ b/iree/tools/test/BUILD
@@ -34,3 +34,13 @@ ], tags = ["hostonly"], ) + +iree_lit_test_suite( + name = "benchmark_flags", + srcs = ["benchmark_flags.txt"], + data = [ + "//iree/tools:IreeFileCheck", + "//iree/tools:iree-benchmark-module", + ], + tags = ["hostonly"], +)
diff --git a/iree/tools/test/CMakeLists.txt b/iree/tools/test/CMakeLists.txt index 658aa8e..26bad30 100644 --- a/iree/tools/test/CMakeLists.txt +++ b/iree/tools/test/CMakeLists.txt
@@ -29,3 +29,15 @@ LABELS "hostonly" ) + +iree_lit_test_suite( + NAME + benchmark_flags + SRCS + "benchmark_flags.txt" + DATA + iree::tools::IreeFileCheck + iree::tools::iree-benchmark-module + LABELS + "hostonly" +)
diff --git a/iree/tools/test/benchmark_flags.txt b/iree/tools/test/benchmark_flags.txt new file mode 100644 index 0000000..70f279c --- /dev/null +++ b/iree/tools/test/benchmark_flags.txt
@@ -0,0 +1,12 @@ +// HELP: iree-benchmark-module +// HELP: --input_file +// HELP: --benchmark_list_tests +// RUN: ( iree-benchmark-module --help || [[ $? == 1 ]] ) | IreeFileCheck --check-prefix=HELP %s +// RUN: ( iree-benchmark-module --helpshort || [[ $? == 1 ]] ) | IreeFileCheck --check-prefix=HELP %s + +// UNKNOWN: unknown-flag +// RUN: ( iree-benchmark-module --unknown-flag 2>&1 || [[ $? == 1 ]] ) | IreeFileCheck --check-prefix=UNKNOWN %s +// RUN: ( iree-benchmark-module --driver=vmla --unknown-flag --benchmark_list_tests 2>&1 || [[ $? == 1 ]] ) | IreeFileCheck --check-prefix=UNKNOWN %s + +// LIST-BENCHMARKS: BM_some_function +// RUN: iree-benchmark-module --benchmark_list_tests --entry_function=some_function | IreeFileCheck --check-prefix=LIST-BENCHMARKS %s