Add support for benchmarking all exported functions. (#3403)
If the entry_function is set, the behavior is the same, ie, will
benchmark on the entry_function. Otherwise, will benchmark on all
exported functions which are expected to have no function inputs.
Fixes https://github.com/google/iree/issues/3388
diff --git a/iree/tools/iree-benchmark-module-main.cc b/iree/tools/iree-benchmark-module-main.cc
index 198b876..96dc0c9 100644
--- a/iree/tools/iree-benchmark-module-main.cc
+++ b/iree/tools/iree-benchmark-module-main.cc
@@ -35,7 +35,8 @@
ABSL_FLAG(std::string, entry_function, "",
"Name of a function contained in the module specified by module_file "
- "to run.");
+ "to run. If this is not set, all the exported functions will be "
+ "benchmarked and they are expected to not have input arguments.");
ABSL_FLAG(std::string, driver, "vmla", "Backend driver to use.");
@@ -58,76 +59,15 @@
namespace iree {
namespace {
-Status GetModuleContentsFromFlags(std::string& module_data) {
- auto module_file = absl::GetFlag(FLAGS_module_file);
- IREE_ASSIGN_OR_RETURN(module_data, file_io::GetFileContents(module_file));
- return iree::OkStatus();
-}
-
-// Creates VM function instance with its inputs.
-Status PrepareIREEVMFunction(
- iree_vm_instance_t** instance, iree_hal_device_t** device,
- iree_vm_module_t** hal_module, iree_vm_context_t** context,
- iree_vm_module_t** input_module, iree_vm_function_t& function,
- iree::vm::ref<iree_vm_list_t>& inputs,
- std::vector<iree::RawSignatureParser::Description>& output_descs,
- const std::string& function_name, const std::string& module_data) {
- IREE_RETURN_IF_ERROR(iree_hal_module_register_types());
- IREE_RETURN_IF_ERROR(
- iree_vm_instance_create(iree_allocator_system(), instance));
-
- // Create IREE's device and module.
- IREE_RETURN_IF_ERROR(iree::CreateDevice(absl::GetFlag(FLAGS_driver), device));
- IREE_RETURN_IF_ERROR(CreateHalModule(*device, hal_module));
- IREE_RETURN_IF_ERROR(LoadBytecodeModule(module_data, input_module));
-
- // Order matters. The input module will likely be dependent on the hal
- // module.
- std::array<iree_vm_module_t*, 2> modules = {*hal_module, *input_module};
- IREE_RETURN_IF_ERROR(iree_vm_context_create_with_modules(
- *instance, modules.data(), modules.size(), iree_allocator_system(),
- context));
-
- IREE_RETURN_IF_ERROR(
- (*input_module)
- ->lookup_function(
- (*input_module)->self, IREE_VM_FUNCTION_LINKAGE_EXPORT,
- iree_string_view_t{function_name.data(), function_name.size()},
- &function));
- IREE_RETURN_IF_ERROR(ValidateFunctionAbi(function));
-
- // Construct inputs.
- IREE_ASSIGN_OR_RETURN(auto input_descs, ParseInputSignature(function));
-
- if (!absl::GetFlag(FLAGS_function_inputs_file).empty()) {
- IREE_ASSIGN_OR_RETURN(inputs,
- ParseToVariantListFromFile(
- input_descs, iree_hal_device_allocator(*device),
- absl::GetFlag(FLAGS_function_inputs_file)));
- } else {
- IREE_ASSIGN_OR_RETURN(
- inputs,
- ParseToVariantList(input_descs, iree_hal_device_allocator(*device),
- absl::GetFlag(FLAGS_function_inputs)));
- }
-
- // Creates output singnature.
- IREE_ASSIGN_OR_RETURN(output_descs, ParseOutputSignature(function));
- return iree::OkStatus();
-}
-
void RegisterModuleBenchmarks(
- Status module_data_status, Status prepare_vm_func_status,
- std::string& function_name, iree_vm_context_t* context,
+ const std::string& function_name, iree_vm_context_t* context,
iree_vm_function_t function, iree_vm_list_t* inputs,
const std::vector<RawSignatureParser::Description>& output_descs) {
auto benchmark_name = "BM_" + function_name;
benchmark::RegisterBenchmark(
benchmark_name.c_str(),
- [module_data_status, prepare_vm_func_status, context, function, inputs,
+ [context, function, inputs,
output_descs](benchmark::State& state) -> void {
- IREE_CHECK_OK(module_data_status);
- IREE_CHECK_OK(prepare_vm_func_status);
// Warmup run step.
{
vm::ref<iree_vm_list_t> outputs;
@@ -165,6 +105,136 @@
// we can make this setting configurable with a custom command line flag.
->Unit(benchmark::kMillisecond);
}
+
+Status GetModuleContentsFromFlags(std::string& module_data) {
+ auto module_file = absl::GetFlag(FLAGS_module_file);
+ IREE_ASSIGN_OR_RETURN(module_data, file_io::GetFileContents(module_file));
+ return iree::OkStatus();
+}
+
+// TODO(hanchung): Consider to refactor this out and reuse in iree-run-module.
+// This class helps organize required resources for IREE. The order of
+// construction and destruction for resources matters. And the lifetime of
+// resources also matters. The lifetime of IREEBenchmark should be as long as
+// ::benchmark::RunSpecifiedBenchmarks() where the resources are used during
+// benchmarking.
+class IREEBenchmark {
+ public:
+ IREEBenchmark()
+ : instance_(nullptr),
+ device_(nullptr),
+ hal_module_(nullptr),
+ context_(nullptr),
+ input_module_(nullptr){};
+ ~IREEBenchmark() {
+ // Order matters.
+ inputs_.reset();
+ iree_vm_module_release(hal_module_);
+ iree_vm_module_release(input_module_);
+ iree_hal_device_release(device_);
+ iree_vm_context_release(context_);
+ iree_vm_instance_release(instance_);
+ };
+
+ Status Register() {
+ if (!instance_ || !device_ || !hal_module_ || !context_ || !input_module_) {
+ IREE_RETURN_IF_ERROR(Init());
+ }
+
+ auto function_name = absl::GetFlag(FLAGS_entry_function);
+ if (!function_name.empty()) {
+ IREE_RETURN_IF_ERROR(RegisterSpecificFunction(function_name));
+ } else {
+ IREE_RETURN_IF_ERROR(RegisterAllExportedFunctions());
+ }
+ return iree::OkStatus();
+ }
+
+ private:
+ Status Init() {
+ IREE_RETURN_IF_ERROR(GetModuleContentsFromFlags(module_data_));
+
+ IREE_RETURN_IF_ERROR(iree_hal_module_register_types());
+ IREE_RETURN_IF_ERROR(
+ iree_vm_instance_create(iree_allocator_system(), &instance_));
+
+ // Create IREE's device and module.
+ IREE_RETURN_IF_ERROR(
+ iree::CreateDevice(absl::GetFlag(FLAGS_driver), &device_));
+ IREE_RETURN_IF_ERROR(CreateHalModule(device_, &hal_module_));
+ IREE_RETURN_IF_ERROR(LoadBytecodeModule(module_data_, &input_module_));
+
+ // Order matters. The input module will likely be dependent on the hal
+ // module.
+ std::array<iree_vm_module_t*, 2> modules = {hal_module_, input_module_};
+ IREE_RETURN_IF_ERROR(iree_vm_context_create_with_modules(
+ instance_, modules.data(), modules.size(), iree_allocator_system(),
+ &context_));
+ return iree::OkStatus();
+ }
+
+ Status RegisterSpecificFunction(const std::string& function_name) {
+ iree_vm_function_t function;
+ IREE_RETURN_IF_ERROR(input_module_->lookup_function(
+ input_module_->self, IREE_VM_FUNCTION_LINKAGE_EXPORT,
+ iree_string_view_t{function_name.data(), function_name.size()},
+ &function));
+ IREE_RETURN_IF_ERROR(ValidateFunctionAbi(function));
+
+ // Construct inputs.
+ IREE_ASSIGN_OR_RETURN(auto input_descs, ParseInputSignature(function));
+ if (!absl::GetFlag(FLAGS_function_inputs_file).empty()) {
+ IREE_ASSIGN_OR_RETURN(inputs_,
+ ParseToVariantListFromFile(
+ input_descs, iree_hal_device_allocator(device_),
+ absl::GetFlag(FLAGS_function_inputs_file)));
+ } else {
+ IREE_ASSIGN_OR_RETURN(
+ inputs_,
+ ParseToVariantList(input_descs, iree_hal_device_allocator(device_),
+ absl::GetFlag(FLAGS_function_inputs)));
+ }
+
+ // Creates output singnature.
+ IREE_ASSIGN_OR_RETURN(auto output_descs, ParseOutputSignature(function));
+ RegisterModuleBenchmarks(function_name, context_, function, inputs_.get(),
+ output_descs);
+ return iree::OkStatus();
+ }
+
+ Status RegisterAllExportedFunctions() {
+ iree_vm_function_t function;
+ iree_vm_module_signature_t signature =
+ input_module_->signature(input_module_->self);
+ for (int i = 0; i < signature.export_function_count; ++i) {
+ iree_string_view_t name;
+ IREE_CHECK_OK(input_module_->get_function(input_module_->self,
+ IREE_VM_FUNCTION_LINKAGE_EXPORT,
+ i, &function, &name, nullptr));
+ if (!ValidateFunctionAbi(function).ok()) continue;
+
+ std::string function_name(name.data, name.size);
+ IREE_ASSIGN_OR_RETURN(auto input_descs, ParseInputSignature(function));
+ if (!input_descs.empty()) {
+ return InvalidArgumentErrorBuilder(IREE_LOC)
+ << "Expect not to have input arguments for '" << function_name
+ << "'";
+ }
+ IREE_ASSIGN_OR_RETURN(auto output_descs, ParseOutputSignature(function));
+ iree::RegisterModuleBenchmarks(function_name, context_, function,
+ /*inputs=*/nullptr, output_descs);
+ }
+ return iree::OkStatus();
+ }
+
+ std::string module_data_;
+ iree_vm_instance_t* instance_;
+ iree_hal_device_t* device_;
+ iree_vm_module_t* hal_module_;
+ iree_vm_context_t* context_;
+ iree_vm_module_t* input_module_;
+ iree::vm::ref<iree_vm_list_t> inputs_;
+};
} // namespace
} // namespace iree
@@ -182,7 +252,9 @@
"iree-benchmark-module \n"
" --module_file=module.vmfb\n"
" --entry_function=exported_function_to_benchmark\n"
- " [--function_inputs=2xi32=1 2,1x2xf32=2 1 | "
+ " If this is not set, all the exported functions will be \n"
+ " benchmarked and they are expected to not have input arguments\n"
+ " [--function_inputs=2xi32=1 2,1x2xf32=2 1 | \n"
" --function_inputs_file=file_with_function_inputs]\n"
" [--driver=vmla]\n"
"\n\n"
@@ -206,45 +278,12 @@
::benchmark::Initialize(&argc, argv);
iree::InitializeEnvironment(&argc, &argv);
- iree_vm_instance_t* instance = nullptr;
- iree_hal_device_t* device = nullptr;
- iree_vm_module_t* hal_module = nullptr;
- iree_vm_context_t* context = nullptr;
- iree_vm_module_t* input_module = nullptr;
- iree_vm_function_t function;
- iree::vm::ref<iree_vm_list_t> inputs;
- std::vector<iree::RawSignatureParser::Description> output_descs;
- std::string module_data;
- auto function_name = absl::GetFlag(FLAGS_entry_function);
-
- // Capture status checking will be delayed at benchmarking runtime.
- auto module_data_status = iree::GetModuleContentsFromFlags(module_data);
- auto prepare_vm_func_status = iree::PrepareIREEVMFunction(
- &instance, &device, &hal_module, &context, &input_module, function,
- inputs, output_descs, function_name, module_data);
- if (prepare_vm_func_status.ok()) {
- IREE_CHECK(instance);
- IREE_CHECK(device);
- IREE_CHECK(hal_module);
- IREE_CHECK(context);
- IREE_CHECK(input_module);
- IREE_CHECK(inputs.get());
+ iree::IREEBenchmark iree_benchmark;
+ auto status = iree_benchmark.Register();
+ if (!status.ok()) {
+ std::cout << status << std::endl;
+ return static_cast<int>(status.code());
}
-
- // Register function benchmarks...
- iree::RegisterModuleBenchmarks(module_data_status, prepare_vm_func_status,
- function_name, context, function, inputs.get(),
- output_descs);
-
- // Run benchmarks...
::benchmark::RunSpecifiedBenchmarks();
-
- // Cleanup...
- inputs.reset();
- iree_vm_module_release(hal_module);
- iree_vm_module_release(input_module);
- iree_hal_device_release(device);
- iree_vm_context_release(context);
- iree_vm_instance_release(instance);
return 0;
}
diff --git a/iree/tools/test/BUILD b/iree/tools/test/BUILD
index 5e4c337..0978d1c 100644
--- a/iree/tools/test/BUILD
+++ b/iree/tools/test/BUILD
@@ -41,6 +41,7 @@
data = [
"//iree/tools:IreeFileCheck",
"//iree/tools:iree-benchmark-module",
+ "//iree/tools:iree-translate",
],
tags = ["hostonly"],
)
diff --git a/iree/tools/test/CMakeLists.txt b/iree/tools/test/CMakeLists.txt
index 26bad30..331fd76 100644
--- a/iree/tools/test/CMakeLists.txt
+++ b/iree/tools/test/CMakeLists.txt
@@ -38,6 +38,7 @@
DATA
iree::tools::IreeFileCheck
iree::tools::iree-benchmark-module
+ iree::tools::iree-translate
LABELS
"hostonly"
)
diff --git a/iree/tools/test/benchmark_flags.txt b/iree/tools/test/benchmark_flags.txt
index 9a13871..682a699 100644
--- a/iree/tools/test/benchmark_flags.txt
+++ b/iree/tools/test/benchmark_flags.txt
@@ -8,5 +8,18 @@
// RUN: ( iree-benchmark-module --unknown-flag 2>&1 || [[ $? == 1 ]] ) | IreeFileCheck --check-prefix=UNKNOWN %s
// RUN: ( iree-benchmark-module --driver=vmla --unknown-flag --benchmark_list_tests 2>&1 || [[ $? == 1 ]] ) | IreeFileCheck --check-prefix=UNKNOWN %s
-// LIST-BENCHMARKS: BM_some_function
-// RUN: iree-benchmark-module --benchmark_list_tests --entry_function=some_function | IreeFileCheck --check-prefix=LIST-BENCHMARKS %s
+// LIST-BENCHMARKS: BM_foo1
+// LIST-BENCHMARKS: BM_foo2
+// RUN: ( iree-translate --iree-hal-target-backends=vmla -iree-mlir-to-vm-bytecode-module %s -o ${TEST_TMPDIR?}/bc.module && iree-benchmark-module --benchmark_list_tests --module_file=${TEST_TMPDIR?}/bc.module --driver=vmla --benchmark_list_tests ) | IreeFileCheck --check-prefix=LIST-BENCHMARKS %s
+module {
+ func @foo1() -> tensor<4xf32> attributes { iree.module.export } {
+ %input = iree.unfoldable_constant dense<[0.0, 1.0, 2.0, 4.0]> : tensor<4xf32>
+ %result = "mhlo.exponential"(%input) : (tensor<4xf32>) -> tensor<4xf32>
+ return %result : tensor<4xf32>
+ }
+ func @foo2() -> tensor<4xf32> attributes { iree.module.export } {
+ %input = iree.unfoldable_constant dense<[0.0, 1.0, 2.0, 4.0]> : tensor<4xf32>
+ %result = "mhlo.abs"(%input) : (tensor<4xf32>) -> tensor<4xf32>
+ return %result : tensor<4xf32>
+ }
+}
diff --git a/iree/tools/test/multiple_exported_functions.mlir b/iree/tools/test/multiple_exported_functions.mlir
new file mode 100644
index 0000000..1179558
--- /dev/null
+++ b/iree/tools/test/multiple_exported_functions.mlir
@@ -0,0 +1,17 @@
+// RUN: iree-translate --iree-hal-target-backends=vmla -iree-mlir-to-vm-bytecode-module %s -o ${TEST_TMPDIR?}/bc.module && iree-benchmark-module --driver=vmla --module_file=${TEST_TMPDIR?}/bc.module | IreeFileCheck %s
+// RUN: [[ $IREE_VULKAN_DISABLE == 1 ]] || (iree-translate --iree-hal-target-backends=vulkan-spirv -iree-mlir-to-vm-bytecode-module %s -o ${TEST_TMPDIR?}/bc.module && iree-benchmark-module --driver=vulkan --module_file=${TEST_TMPDIR?}/bc.module | IreeFileCheck %s)
+
+module {
+ func @foo1() -> tensor<4xf32> attributes { iree.module.export } {
+ %input = iree.unfoldable_constant dense<[0.0, 1.0, 2.0, 4.0]> : tensor<4xf32>
+ %result = "mhlo.exponential"(%input) : (tensor<4xf32>) -> tensor<4xf32>
+ return %result : tensor<4xf32>
+ }
+ func @foo2() -> tensor<4xf32> attributes { iree.module.export } {
+ %input = iree.unfoldable_constant dense<[0.0, 1.0, 2.0, 4.0]> : tensor<4xf32>
+ %result = "mhlo.abs"(%input) : (tensor<4xf32>) -> tensor<4xf32>
+ return %result : tensor<4xf32>
+ }
+}
+// CHECK: BM_foo1
+// CHECK: BM_foo2