Add support for benchmarking all exported functions. (#3403)

If the entry_function is set, the behavior is the same, ie, will
benchmark on the entry_function. Otherwise, will benchmark on all
exported functions which are expected to have no function inputs.

Fixes https://github.com/google/iree/issues/3388
diff --git a/iree/tools/iree-benchmark-module-main.cc b/iree/tools/iree-benchmark-module-main.cc
index 198b876..96dc0c9 100644
--- a/iree/tools/iree-benchmark-module-main.cc
+++ b/iree/tools/iree-benchmark-module-main.cc
@@ -35,7 +35,8 @@
 
 ABSL_FLAG(std::string, entry_function, "",
           "Name of a function contained in the module specified by module_file "
-          "to run.");
+          "to run. If this is not set, all the exported functions will be "
+          "benchmarked and they are expected to not have input arguments.");
 
 ABSL_FLAG(std::string, driver, "vmla", "Backend driver to use.");
 
@@ -58,76 +59,15 @@
 namespace iree {
 namespace {
 
-Status GetModuleContentsFromFlags(std::string& module_data) {
-  auto module_file = absl::GetFlag(FLAGS_module_file);
-  IREE_ASSIGN_OR_RETURN(module_data, file_io::GetFileContents(module_file));
-  return iree::OkStatus();
-}
-
-// Creates VM function instance with its inputs.
-Status PrepareIREEVMFunction(
-    iree_vm_instance_t** instance, iree_hal_device_t** device,
-    iree_vm_module_t** hal_module, iree_vm_context_t** context,
-    iree_vm_module_t** input_module, iree_vm_function_t& function,
-    iree::vm::ref<iree_vm_list_t>& inputs,
-    std::vector<iree::RawSignatureParser::Description>& output_descs,
-    const std::string& function_name, const std::string& module_data) {
-  IREE_RETURN_IF_ERROR(iree_hal_module_register_types());
-  IREE_RETURN_IF_ERROR(
-      iree_vm_instance_create(iree_allocator_system(), instance));
-
-  // Create IREE's device and module.
-  IREE_RETURN_IF_ERROR(iree::CreateDevice(absl::GetFlag(FLAGS_driver), device));
-  IREE_RETURN_IF_ERROR(CreateHalModule(*device, hal_module));
-  IREE_RETURN_IF_ERROR(LoadBytecodeModule(module_data, input_module));
-
-  // Order matters. The input module will likely be dependent on the hal
-  // module.
-  std::array<iree_vm_module_t*, 2> modules = {*hal_module, *input_module};
-  IREE_RETURN_IF_ERROR(iree_vm_context_create_with_modules(
-      *instance, modules.data(), modules.size(), iree_allocator_system(),
-      context));
-
-  IREE_RETURN_IF_ERROR(
-      (*input_module)
-          ->lookup_function(
-              (*input_module)->self, IREE_VM_FUNCTION_LINKAGE_EXPORT,
-              iree_string_view_t{function_name.data(), function_name.size()},
-              &function));
-  IREE_RETURN_IF_ERROR(ValidateFunctionAbi(function));
-
-  // Construct inputs.
-  IREE_ASSIGN_OR_RETURN(auto input_descs, ParseInputSignature(function));
-
-  if (!absl::GetFlag(FLAGS_function_inputs_file).empty()) {
-    IREE_ASSIGN_OR_RETURN(inputs,
-                          ParseToVariantListFromFile(
-                              input_descs, iree_hal_device_allocator(*device),
-                              absl::GetFlag(FLAGS_function_inputs_file)));
-  } else {
-    IREE_ASSIGN_OR_RETURN(
-        inputs,
-        ParseToVariantList(input_descs, iree_hal_device_allocator(*device),
-                           absl::GetFlag(FLAGS_function_inputs)));
-  }
-
-  // Creates output singnature.
-  IREE_ASSIGN_OR_RETURN(output_descs, ParseOutputSignature(function));
-  return iree::OkStatus();
-}
-
 void RegisterModuleBenchmarks(
-    Status module_data_status, Status prepare_vm_func_status,
-    std::string& function_name, iree_vm_context_t* context,
+    const std::string& function_name, iree_vm_context_t* context,
     iree_vm_function_t function, iree_vm_list_t* inputs,
     const std::vector<RawSignatureParser::Description>& output_descs) {
   auto benchmark_name = "BM_" + function_name;
   benchmark::RegisterBenchmark(
       benchmark_name.c_str(),
-      [module_data_status, prepare_vm_func_status, context, function, inputs,
+      [context, function, inputs,
        output_descs](benchmark::State& state) -> void {
-        IREE_CHECK_OK(module_data_status);
-        IREE_CHECK_OK(prepare_vm_func_status);
         // Warmup run step.
         {
           vm::ref<iree_vm_list_t> outputs;
@@ -165,6 +105,136 @@
       // we can make this setting configurable with a custom command line flag.
       ->Unit(benchmark::kMillisecond);
 }
+
+Status GetModuleContentsFromFlags(std::string& module_data) {
+  auto module_file = absl::GetFlag(FLAGS_module_file);
+  IREE_ASSIGN_OR_RETURN(module_data, file_io::GetFileContents(module_file));
+  return iree::OkStatus();
+}
+
+// TODO(hanchung): Consider to refactor this out and reuse in iree-run-module.
+// This class helps organize required resources for IREE. The order of
+// construction and destruction for resources matters. And the lifetime of
+// resources also matters. The lifetime of IREEBenchmark should be as long as
+// ::benchmark::RunSpecifiedBenchmarks() where the resources are used during
+// benchmarking.
+class IREEBenchmark {
+ public:
+  IREEBenchmark()
+      : instance_(nullptr),
+        device_(nullptr),
+        hal_module_(nullptr),
+        context_(nullptr),
+        input_module_(nullptr){};
+  ~IREEBenchmark() {
+    // Order matters.
+    inputs_.reset();
+    iree_vm_module_release(hal_module_);
+    iree_vm_module_release(input_module_);
+    iree_hal_device_release(device_);
+    iree_vm_context_release(context_);
+    iree_vm_instance_release(instance_);
+  };
+
+  Status Register() {
+    if (!instance_ || !device_ || !hal_module_ || !context_ || !input_module_) {
+      IREE_RETURN_IF_ERROR(Init());
+    }
+
+    auto function_name = absl::GetFlag(FLAGS_entry_function);
+    if (!function_name.empty()) {
+      IREE_RETURN_IF_ERROR(RegisterSpecificFunction(function_name));
+    } else {
+      IREE_RETURN_IF_ERROR(RegisterAllExportedFunctions());
+    }
+    return iree::OkStatus();
+  }
+
+ private:
+  Status Init() {
+    IREE_RETURN_IF_ERROR(GetModuleContentsFromFlags(module_data_));
+
+    IREE_RETURN_IF_ERROR(iree_hal_module_register_types());
+    IREE_RETURN_IF_ERROR(
+        iree_vm_instance_create(iree_allocator_system(), &instance_));
+
+    // Create IREE's device and module.
+    IREE_RETURN_IF_ERROR(
+        iree::CreateDevice(absl::GetFlag(FLAGS_driver), &device_));
+    IREE_RETURN_IF_ERROR(CreateHalModule(device_, &hal_module_));
+    IREE_RETURN_IF_ERROR(LoadBytecodeModule(module_data_, &input_module_));
+
+    // Order matters. The input module will likely be dependent on the hal
+    // module.
+    std::array<iree_vm_module_t*, 2> modules = {hal_module_, input_module_};
+    IREE_RETURN_IF_ERROR(iree_vm_context_create_with_modules(
+        instance_, modules.data(), modules.size(), iree_allocator_system(),
+        &context_));
+    return iree::OkStatus();
+  }
+
+  Status RegisterSpecificFunction(const std::string& function_name) {
+    iree_vm_function_t function;
+    IREE_RETURN_IF_ERROR(input_module_->lookup_function(
+        input_module_->self, IREE_VM_FUNCTION_LINKAGE_EXPORT,
+        iree_string_view_t{function_name.data(), function_name.size()},
+        &function));
+    IREE_RETURN_IF_ERROR(ValidateFunctionAbi(function));
+
+    // Construct inputs.
+    IREE_ASSIGN_OR_RETURN(auto input_descs, ParseInputSignature(function));
+    if (!absl::GetFlag(FLAGS_function_inputs_file).empty()) {
+      IREE_ASSIGN_OR_RETURN(inputs_,
+                            ParseToVariantListFromFile(
+                                input_descs, iree_hal_device_allocator(device_),
+                                absl::GetFlag(FLAGS_function_inputs_file)));
+    } else {
+      IREE_ASSIGN_OR_RETURN(
+          inputs_,
+          ParseToVariantList(input_descs, iree_hal_device_allocator(device_),
+                             absl::GetFlag(FLAGS_function_inputs)));
+    }
+
+    // Creates output singnature.
+    IREE_ASSIGN_OR_RETURN(auto output_descs, ParseOutputSignature(function));
+    RegisterModuleBenchmarks(function_name, context_, function, inputs_.get(),
+                             output_descs);
+    return iree::OkStatus();
+  }
+
+  Status RegisterAllExportedFunctions() {
+    iree_vm_function_t function;
+    iree_vm_module_signature_t signature =
+        input_module_->signature(input_module_->self);
+    for (int i = 0; i < signature.export_function_count; ++i) {
+      iree_string_view_t name;
+      IREE_CHECK_OK(input_module_->get_function(input_module_->self,
+                                                IREE_VM_FUNCTION_LINKAGE_EXPORT,
+                                                i, &function, &name, nullptr));
+      if (!ValidateFunctionAbi(function).ok()) continue;
+
+      std::string function_name(name.data, name.size);
+      IREE_ASSIGN_OR_RETURN(auto input_descs, ParseInputSignature(function));
+      if (!input_descs.empty()) {
+        return InvalidArgumentErrorBuilder(IREE_LOC)
+               << "Expect not to have input arguments for '" << function_name
+               << "'";
+      }
+      IREE_ASSIGN_OR_RETURN(auto output_descs, ParseOutputSignature(function));
+      iree::RegisterModuleBenchmarks(function_name, context_, function,
+                                     /*inputs=*/nullptr, output_descs);
+    }
+    return iree::OkStatus();
+  }
+
+  std::string module_data_;
+  iree_vm_instance_t* instance_;
+  iree_hal_device_t* device_;
+  iree_vm_module_t* hal_module_;
+  iree_vm_context_t* context_;
+  iree_vm_module_t* input_module_;
+  iree::vm::ref<iree_vm_list_t> inputs_;
+};
 }  // namespace
 }  // namespace iree
 
@@ -182,7 +252,9 @@
       "iree-benchmark-module \n"
       "    --module_file=module.vmfb\n"
       "    --entry_function=exported_function_to_benchmark\n"
-      "    [--function_inputs=2xi32=1 2,1x2xf32=2 1 | "
+      "      If this is not set, all the exported functions will be \n"
+      "      benchmarked and they are expected to not have input arguments\n"
+      "    [--function_inputs=2xi32=1 2,1x2xf32=2 1 | \n"
       "     --function_inputs_file=file_with_function_inputs]\n"
       "    [--driver=vmla]\n"
       "\n\n"
@@ -206,45 +278,12 @@
   ::benchmark::Initialize(&argc, argv);
   iree::InitializeEnvironment(&argc, &argv);
 
-  iree_vm_instance_t* instance = nullptr;
-  iree_hal_device_t* device = nullptr;
-  iree_vm_module_t* hal_module = nullptr;
-  iree_vm_context_t* context = nullptr;
-  iree_vm_module_t* input_module = nullptr;
-  iree_vm_function_t function;
-  iree::vm::ref<iree_vm_list_t> inputs;
-  std::vector<iree::RawSignatureParser::Description> output_descs;
-  std::string module_data;
-  auto function_name = absl::GetFlag(FLAGS_entry_function);
-
-  // Capture status checking will be delayed at benchmarking runtime.
-  auto module_data_status = iree::GetModuleContentsFromFlags(module_data);
-  auto prepare_vm_func_status = iree::PrepareIREEVMFunction(
-      &instance, &device, &hal_module, &context, &input_module, function,
-      inputs, output_descs, function_name, module_data);
-  if (prepare_vm_func_status.ok()) {
-    IREE_CHECK(instance);
-    IREE_CHECK(device);
-    IREE_CHECK(hal_module);
-    IREE_CHECK(context);
-    IREE_CHECK(input_module);
-    IREE_CHECK(inputs.get());
+  iree::IREEBenchmark iree_benchmark;
+  auto status = iree_benchmark.Register();
+  if (!status.ok()) {
+    std::cout << status << std::endl;
+    return static_cast<int>(status.code());
   }
-
-  // Register function benchmarks...
-  iree::RegisterModuleBenchmarks(module_data_status, prepare_vm_func_status,
-                                 function_name, context, function, inputs.get(),
-                                 output_descs);
-
-  // Run benchmarks...
   ::benchmark::RunSpecifiedBenchmarks();
-
-  // Cleanup...
-  inputs.reset();
-  iree_vm_module_release(hal_module);
-  iree_vm_module_release(input_module);
-  iree_hal_device_release(device);
-  iree_vm_context_release(context);
-  iree_vm_instance_release(instance);
   return 0;
 }
diff --git a/iree/tools/test/BUILD b/iree/tools/test/BUILD
index 5e4c337..0978d1c 100644
--- a/iree/tools/test/BUILD
+++ b/iree/tools/test/BUILD
@@ -41,6 +41,7 @@
     data = [
         "//iree/tools:IreeFileCheck",
         "//iree/tools:iree-benchmark-module",
+        "//iree/tools:iree-translate",
     ],
     tags = ["hostonly"],
 )
diff --git a/iree/tools/test/CMakeLists.txt b/iree/tools/test/CMakeLists.txt
index 26bad30..331fd76 100644
--- a/iree/tools/test/CMakeLists.txt
+++ b/iree/tools/test/CMakeLists.txt
@@ -38,6 +38,7 @@
   DATA
     iree::tools::IreeFileCheck
     iree::tools::iree-benchmark-module
+    iree::tools::iree-translate
   LABELS
     "hostonly"
 )
diff --git a/iree/tools/test/benchmark_flags.txt b/iree/tools/test/benchmark_flags.txt
index 9a13871..682a699 100644
--- a/iree/tools/test/benchmark_flags.txt
+++ b/iree/tools/test/benchmark_flags.txt
@@ -8,5 +8,18 @@
 // RUN: ( iree-benchmark-module --unknown-flag 2>&1 || [[ $? == 1 ]] ) | IreeFileCheck --check-prefix=UNKNOWN %s
 // RUN: ( iree-benchmark-module --driver=vmla --unknown-flag --benchmark_list_tests 2>&1 || [[ $? == 1 ]] ) | IreeFileCheck --check-prefix=UNKNOWN %s
 
-// LIST-BENCHMARKS: BM_some_function
-// RUN: iree-benchmark-module --benchmark_list_tests --entry_function=some_function | IreeFileCheck --check-prefix=LIST-BENCHMARKS %s
+// LIST-BENCHMARKS: BM_foo1
+// LIST-BENCHMARKS: BM_foo2
+// RUN: ( iree-translate --iree-hal-target-backends=vmla -iree-mlir-to-vm-bytecode-module %s -o ${TEST_TMPDIR?}/bc.module && iree-benchmark-module --benchmark_list_tests --module_file=${TEST_TMPDIR?}/bc.module --driver=vmla --benchmark_list_tests ) | IreeFileCheck --check-prefix=LIST-BENCHMARKS %s
+module {
+  func @foo1() -> tensor<4xf32> attributes { iree.module.export } {
+    %input = iree.unfoldable_constant dense<[0.0, 1.0, 2.0, 4.0]> : tensor<4xf32>
+    %result = "mhlo.exponential"(%input) : (tensor<4xf32>) -> tensor<4xf32>
+    return %result : tensor<4xf32>
+  }
+  func @foo2() -> tensor<4xf32> attributes { iree.module.export } {
+    %input = iree.unfoldable_constant dense<[0.0, 1.0, 2.0, 4.0]> : tensor<4xf32>
+    %result = "mhlo.abs"(%input) : (tensor<4xf32>) -> tensor<4xf32>
+    return %result : tensor<4xf32>
+  }
+}
diff --git a/iree/tools/test/multiple_exported_functions.mlir b/iree/tools/test/multiple_exported_functions.mlir
new file mode 100644
index 0000000..1179558
--- /dev/null
+++ b/iree/tools/test/multiple_exported_functions.mlir
@@ -0,0 +1,17 @@
+// RUN: iree-translate --iree-hal-target-backends=vmla -iree-mlir-to-vm-bytecode-module %s -o ${TEST_TMPDIR?}/bc.module && iree-benchmark-module --driver=vmla --module_file=${TEST_TMPDIR?}/bc.module | IreeFileCheck %s
+// RUN: [[ $IREE_VULKAN_DISABLE == 1 ]] || (iree-translate --iree-hal-target-backends=vulkan-spirv -iree-mlir-to-vm-bytecode-module %s -o ${TEST_TMPDIR?}/bc.module && iree-benchmark-module --driver=vulkan --module_file=${TEST_TMPDIR?}/bc.module | IreeFileCheck %s)
+
+module {
+  func @foo1() -> tensor<4xf32> attributes { iree.module.export } {
+    %input = iree.unfoldable_constant dense<[0.0, 1.0, 2.0, 4.0]> : tensor<4xf32>
+    %result = "mhlo.exponential"(%input) : (tensor<4xf32>) -> tensor<4xf32>
+    return %result : tensor<4xf32>
+  }
+  func @foo2() -> tensor<4xf32> attributes { iree.module.export } {
+    %input = iree.unfoldable_constant dense<[0.0, 1.0, 2.0, 4.0]> : tensor<4xf32>
+    %result = "mhlo.abs"(%input) : (tensor<4xf32>) -> tensor<4xf32>
+    return %result : tensor<4xf32>
+  }
+}
+// CHECK: BM_foo1
+// CHECK: BM_foo2