Benchmarking only iree_vm_invoke moving everything else outside (#3324)

diff --git a/iree/tools/iree-benchmark-module-main.cc b/iree/tools/iree-benchmark-module-main.cc
index d5a4293..da15e5b 100644
--- a/iree/tools/iree-benchmark-module-main.cc
+++ b/iree/tools/iree-benchmark-module-main.cc
@@ -58,61 +58,47 @@
 namespace iree {
 namespace {
 
-StatusOr<std::string> GetModuleContentsFromFlags() {
-  IREE_TRACE_SCOPE0("GetModuleContentsFromFlags");
+Status GetModuleContentsFromFlags(std::string& module_data) {
   auto module_file = absl::GetFlag(FLAGS_module_file);
-  if (module_file.empty()) {
-    return InvalidArgumentErrorBuilder(IREE_LOC)
-           << "module_file must be specified";
-  }
-  return file_io::GetFileContents(module_file);
+  IREE_ASSIGN_OR_RETURN(module_data, file_io::GetFileContents(module_file));
+  return iree::OkStatus();
 }
 
-Status RunFunction(::benchmark::State& state,
-                   const std::string& function_name) {
-  IREE_TRACE_SCOPE0("iree-benchmark-module");
-
-  IREE_RETURN_IF_ERROR(iree_hal_module_register_types())
-      << "registering HAL types";
-  iree_vm_instance_t* instance = nullptr;
+// Creates VM function instance with its inputs.
+Status PrepareIREEVMFunction(
+    iree_vm_instance_t* instance, iree_hal_device_t* device,
+    iree_vm_module_t* hal_module, iree_vm_context_t** context,
+    iree_vm_module_t* input_module, iree_vm_function_t& function,
+    iree::vm::ref<iree_vm_list_t>& inputs,
+    std::vector<iree::RawSignatureParser::Description>& output_descs,
+    const std::string& function_name, const std::string& module_data) {
+  IREE_RETURN_IF_ERROR(iree_hal_module_register_types());
   IREE_RETURN_IF_ERROR(
-      iree_vm_instance_create(iree_allocator_system(), &instance))
-      << "creating instance";
+      iree_vm_instance_create(iree_allocator_system(), &instance));
 
-  IREE_ASSIGN_OR_RETURN(auto module_data, GetModuleContentsFromFlags());
-  iree_vm_module_t* input_module = nullptr;
+  // Create IREE's device and module.
+  IREE_RETURN_IF_ERROR(
+      iree::CreateDevice(absl::GetFlag(FLAGS_driver), &device));
+  IREE_RETURN_IF_ERROR(CreateHalModule(device, &hal_module));
   IREE_RETURN_IF_ERROR(LoadBytecodeModule(module_data, &input_module));
 
-  iree_hal_device_t* device = nullptr;
-  IREE_RETURN_IF_ERROR(CreateDevice(absl::GetFlag(FLAGS_driver), &device));
-  iree_vm_module_t* hal_module = nullptr;
-  IREE_RETURN_IF_ERROR(CreateHalModule(device, &hal_module));
-
-  iree_vm_context_t* context = nullptr;
-  // Order matters. The input module will likely be dependent on the hal module.
+  // Order matters. The input module will likely be dependent on the hal
+  // module.
   std::array<iree_vm_module_t*, 2> modules = {hal_module, input_module};
   IREE_RETURN_IF_ERROR(iree_vm_context_create_with_modules(
       instance, modules.data(), modules.size(), iree_allocator_system(),
-      &context))
-      << "creating context";
+      context));
 
-  iree_vm_function_t function;
   IREE_RETURN_IF_ERROR(input_module->lookup_function(
       input_module->self, IREE_VM_FUNCTION_LINKAGE_EXPORT,
       iree_string_view_t{function_name.data(), function_name.size()},
-      &function))
-      << "looking up function '" << function_name << "'";
-
+      &function));
   IREE_RETURN_IF_ERROR(ValidateFunctionAbi(function));
+
+  // Construct inputs.
   IREE_ASSIGN_OR_RETURN(auto input_descs, ParseInputSignature(function));
 
-  vm::ref<iree_vm_list_t> inputs;
   if (!absl::GetFlag(FLAGS_function_inputs_file).empty()) {
-    if (!absl::GetFlag(FLAGS_function_inputs).empty()) {
-      return InvalidArgumentErrorBuilder(IREE_LOC)
-             << "Expected only one of function_inputs and function_inputs_file "
-                "to be set";
-    }
     IREE_ASSIGN_OR_RETURN(inputs,
                           ParseToVariantListFromFile(
                               input_descs, iree_hal_device_allocator(device),
@@ -124,60 +110,46 @@
                            absl::GetFlag(FLAGS_function_inputs)));
   }
 
-  IREE_ASSIGN_OR_RETURN(auto output_descs, ParseOutputSignature(function));
-
-  // Execute once to make sure any first-iteration outliers are eliminated (e.g.
-  // JITing the SPIR-V) and clearly separate out benchmark-related problems in
-  // future debugging.
-  {
-    vm::ref<iree_vm_list_t> outputs;
-    IREE_RETURN_IF_ERROR(
-        iree_vm_list_create(/*element_type=*/nullptr, output_descs.size(),
-                            iree_allocator_system(), &outputs));
-    IREE_RETURN_IF_ERROR(iree_vm_invoke(context, function, /*policy=*/nullptr,
-                                        inputs.get(), outputs.get(),
-                                        iree_allocator_system()));
-  }
-
-  for (auto _ : state) {
-    // No status conversions and conditional returns in the benchmarked inner
-    // loop.
-    vm::ref<iree_vm_list_t> outputs;
-    IREE_CHECK_OK(iree_vm_list_create(/*element_type=*/nullptr,
-                                      output_descs.size(),
-                                      iree_allocator_system(), &outputs));
-    IREE_CHECK_OK(iree_vm_invoke(context, function, /*policy=*/nullptr,
-                                 inputs.get(), outputs.get(),
-                                 iree_allocator_system()));
-  }
-
-  inputs.reset();
-  iree_vm_module_release(hal_module);
-  iree_vm_module_release(input_module);
-  iree_hal_device_release(device);
-  iree_vm_context_release(context);
-  iree_vm_instance_release(instance);
-  return OkStatus();
+  // Creates output singnature.
+  IREE_ASSIGN_OR_RETURN(output_descs, ParseOutputSignature(function));
+  return iree::OkStatus();
 }
 
-void BM_RunModule(benchmark::State& state, const std::string& function_name) {
-  // Delegate to a status-returning function so we can use the status macros.
-  IREE_CHECK_OK(RunFunction(state, function_name));
-}
-
-}  // namespace
-
-Status RegisterModuleBenchmarks() {
-  auto function_name = absl::GetFlag(FLAGS_entry_function);
-  if (function_name.empty()) {
-    return InvalidArgumentErrorBuilder(IREE_LOC)
-           << "Must specify an entry_function";
-  }
+void RegisterModuleBenchmarks(
+    Status module_data_status, Status prepare_vm_func_status,
+    std::string& function_name, iree_vm_context_t* context,
+    iree_vm_function_t function, iree_vm_list_t* inputs,
+    const std::vector<RawSignatureParser::Description>& output_descs) {
   auto benchmark_name = "BM_" + function_name;
-  benchmark::RegisterBenchmark(benchmark_name.c_str(),
-                               [function_name](benchmark::State& state) {
-                                 BM_RunModule(state, function_name);
-                               })
+  benchmark::RegisterBenchmark(
+      benchmark_name.c_str(),
+      [module_data_status, prepare_vm_func_status, context, function, inputs,
+       output_descs](benchmark::State& state) -> void {
+        IREE_CHECK_OK(module_data_status);
+        IREE_CHECK_OK(prepare_vm_func_status);
+        // Warmup run step.
+        {
+          vm::ref<iree_vm_list_t> outputs;
+          IREE_CHECK_OK(iree_vm_list_create(/*element_type=*/nullptr,
+                                            output_descs.size(),
+                                            iree_allocator_system(), &outputs));
+          IREE_CHECK_OK(iree_vm_invoke(context, function, /*policy=*/nullptr,
+                                       inputs, outputs.get(),
+                                       iree_allocator_system()));
+        }
+        // Benchmarking loop.
+        for (auto _ : state) {
+          // No status conversions and conditional returns in the benchmarked
+          // inner loop.
+          vm::ref<iree_vm_list_t> outputs;
+          IREE_CHECK_OK(iree_vm_list_create(/*element_type=*/nullptr,
+                                            output_descs.size(),
+                                            iree_allocator_system(), &outputs));
+          IREE_CHECK_OK(iree_vm_invoke(context, function, /*policy=*/nullptr,
+                                       inputs, outputs.get(),
+                                       iree_allocator_system()));
+        }
+      })
       // By default only the main thread is included in CPU time. Include all
       // the threads instead.
       ->MeasureProcessCPUTime()
@@ -191,8 +163,8 @@
       // significant digits. If we end up wanting precision beyond microseconds,
       // we can make this setting configurable with a custom command line flag.
       ->Unit(benchmark::kMillisecond);
-  return OkStatus();
 }
+}  // namespace
 }  // namespace iree
 
 int main(int argc, char** argv) {
@@ -232,7 +204,37 @@
       absl::flags_internal::OnUndefinedFlag::kIgnoreUndefined);
   ::benchmark::Initialize(&argc, argv);
   iree::InitializeEnvironment(&argc, &argv);
-  IREE_CHECK_OK(iree::RegisterModuleBenchmarks());
+
+  iree_vm_instance_t* instance = nullptr;
+  iree_hal_device_t* device = nullptr;
+  iree_vm_module_t* hal_module = nullptr;
+  iree_vm_context_t* context = nullptr;
+  iree_vm_module_t* input_module = nullptr;
+  iree_vm_function_t function;
+  iree::vm::ref<iree_vm_list_t> inputs;
+  std::vector<iree::RawSignatureParser::Description> output_descs;
+  std::string module_data;
+  auto function_name = absl::GetFlag(FLAGS_entry_function);
+
+  // Capture status checking will be delayed at benchmarking runtime.
+  auto module_data_status = iree::GetModuleContentsFromFlags(module_data);
+  auto prepare_vm_func_status = iree::PrepareIREEVMFunction(
+      instance, device, hal_module, &context, input_module, function, inputs,
+      output_descs, function_name, module_data);
+
+  // Register function benchmarks...
+  iree::RegisterModuleBenchmarks(module_data_status, prepare_vm_func_status,
+                                 function_name, context, function, inputs.get(),
+                                 output_descs);
+
+  // Run benchmarks...
   ::benchmark::RunSpecifiedBenchmarks();
+
+  // Cleanup...
+  iree_vm_module_release(hal_module);
+  iree_vm_module_release(input_module);
+  iree_hal_device_release(device);
+  iree_vm_context_release(context);
+  iree_vm_instance_release(instance);
   return 0;
 }