Benchmarking only iree_vm_invoke moving everything else outside (#3324)
diff --git a/iree/tools/iree-benchmark-module-main.cc b/iree/tools/iree-benchmark-module-main.cc
index d5a4293..da15e5b 100644
--- a/iree/tools/iree-benchmark-module-main.cc
+++ b/iree/tools/iree-benchmark-module-main.cc
@@ -58,61 +58,47 @@
namespace iree {
namespace {
-StatusOr<std::string> GetModuleContentsFromFlags() {
- IREE_TRACE_SCOPE0("GetModuleContentsFromFlags");
+Status GetModuleContentsFromFlags(std::string& module_data) {
auto module_file = absl::GetFlag(FLAGS_module_file);
- if (module_file.empty()) {
- return InvalidArgumentErrorBuilder(IREE_LOC)
- << "module_file must be specified";
- }
- return file_io::GetFileContents(module_file);
+ IREE_ASSIGN_OR_RETURN(module_data, file_io::GetFileContents(module_file));
+ return iree::OkStatus();
}
-Status RunFunction(::benchmark::State& state,
- const std::string& function_name) {
- IREE_TRACE_SCOPE0("iree-benchmark-module");
-
- IREE_RETURN_IF_ERROR(iree_hal_module_register_types())
- << "registering HAL types";
- iree_vm_instance_t* instance = nullptr;
+// Creates VM function instance with its inputs.
+Status PrepareIREEVMFunction(
+ iree_vm_instance_t* instance, iree_hal_device_t* device,
+ iree_vm_module_t* hal_module, iree_vm_context_t** context,
+ iree_vm_module_t* input_module, iree_vm_function_t& function,
+ iree::vm::ref<iree_vm_list_t>& inputs,
+ std::vector<iree::RawSignatureParser::Description>& output_descs,
+ const std::string& function_name, const std::string& module_data) {
+ IREE_RETURN_IF_ERROR(iree_hal_module_register_types());
IREE_RETURN_IF_ERROR(
- iree_vm_instance_create(iree_allocator_system(), &instance))
- << "creating instance";
+ iree_vm_instance_create(iree_allocator_system(), &instance));
- IREE_ASSIGN_OR_RETURN(auto module_data, GetModuleContentsFromFlags());
- iree_vm_module_t* input_module = nullptr;
+ // Create IREE's device and module.
+ IREE_RETURN_IF_ERROR(
+ iree::CreateDevice(absl::GetFlag(FLAGS_driver), &device));
+ IREE_RETURN_IF_ERROR(CreateHalModule(device, &hal_module));
IREE_RETURN_IF_ERROR(LoadBytecodeModule(module_data, &input_module));
- iree_hal_device_t* device = nullptr;
- IREE_RETURN_IF_ERROR(CreateDevice(absl::GetFlag(FLAGS_driver), &device));
- iree_vm_module_t* hal_module = nullptr;
- IREE_RETURN_IF_ERROR(CreateHalModule(device, &hal_module));
-
- iree_vm_context_t* context = nullptr;
- // Order matters. The input module will likely be dependent on the hal module.
+ // Order matters. The input module will likely be dependent on the hal
+ // module.
std::array<iree_vm_module_t*, 2> modules = {hal_module, input_module};
IREE_RETURN_IF_ERROR(iree_vm_context_create_with_modules(
instance, modules.data(), modules.size(), iree_allocator_system(),
- &context))
- << "creating context";
+ context));
- iree_vm_function_t function;
IREE_RETURN_IF_ERROR(input_module->lookup_function(
input_module->self, IREE_VM_FUNCTION_LINKAGE_EXPORT,
iree_string_view_t{function_name.data(), function_name.size()},
- &function))
- << "looking up function '" << function_name << "'";
-
+ &function));
IREE_RETURN_IF_ERROR(ValidateFunctionAbi(function));
+
+ // Construct inputs.
IREE_ASSIGN_OR_RETURN(auto input_descs, ParseInputSignature(function));
- vm::ref<iree_vm_list_t> inputs;
if (!absl::GetFlag(FLAGS_function_inputs_file).empty()) {
- if (!absl::GetFlag(FLAGS_function_inputs).empty()) {
- return InvalidArgumentErrorBuilder(IREE_LOC)
- << "Expected only one of function_inputs and function_inputs_file "
- "to be set";
- }
IREE_ASSIGN_OR_RETURN(inputs,
ParseToVariantListFromFile(
input_descs, iree_hal_device_allocator(device),
@@ -124,60 +110,46 @@
absl::GetFlag(FLAGS_function_inputs)));
}
- IREE_ASSIGN_OR_RETURN(auto output_descs, ParseOutputSignature(function));
-
- // Execute once to make sure any first-iteration outliers are eliminated (e.g.
- // JITing the SPIR-V) and clearly separate out benchmark-related problems in
- // future debugging.
- {
- vm::ref<iree_vm_list_t> outputs;
- IREE_RETURN_IF_ERROR(
- iree_vm_list_create(/*element_type=*/nullptr, output_descs.size(),
- iree_allocator_system(), &outputs));
- IREE_RETURN_IF_ERROR(iree_vm_invoke(context, function, /*policy=*/nullptr,
- inputs.get(), outputs.get(),
- iree_allocator_system()));
- }
-
- for (auto _ : state) {
- // No status conversions and conditional returns in the benchmarked inner
- // loop.
- vm::ref<iree_vm_list_t> outputs;
- IREE_CHECK_OK(iree_vm_list_create(/*element_type=*/nullptr,
- output_descs.size(),
- iree_allocator_system(), &outputs));
- IREE_CHECK_OK(iree_vm_invoke(context, function, /*policy=*/nullptr,
- inputs.get(), outputs.get(),
- iree_allocator_system()));
- }
-
- inputs.reset();
- iree_vm_module_release(hal_module);
- iree_vm_module_release(input_module);
- iree_hal_device_release(device);
- iree_vm_context_release(context);
- iree_vm_instance_release(instance);
- return OkStatus();
+ // Creates output singnature.
+ IREE_ASSIGN_OR_RETURN(output_descs, ParseOutputSignature(function));
+ return iree::OkStatus();
}
-void BM_RunModule(benchmark::State& state, const std::string& function_name) {
- // Delegate to a status-returning function so we can use the status macros.
- IREE_CHECK_OK(RunFunction(state, function_name));
-}
-
-} // namespace
-
-Status RegisterModuleBenchmarks() {
- auto function_name = absl::GetFlag(FLAGS_entry_function);
- if (function_name.empty()) {
- return InvalidArgumentErrorBuilder(IREE_LOC)
- << "Must specify an entry_function";
- }
+void RegisterModuleBenchmarks(
+ Status module_data_status, Status prepare_vm_func_status,
+ std::string& function_name, iree_vm_context_t* context,
+ iree_vm_function_t function, iree_vm_list_t* inputs,
+ const std::vector<RawSignatureParser::Description>& output_descs) {
auto benchmark_name = "BM_" + function_name;
- benchmark::RegisterBenchmark(benchmark_name.c_str(),
- [function_name](benchmark::State& state) {
- BM_RunModule(state, function_name);
- })
+ benchmark::RegisterBenchmark(
+ benchmark_name.c_str(),
+ [module_data_status, prepare_vm_func_status, context, function, inputs,
+ output_descs](benchmark::State& state) -> void {
+ IREE_CHECK_OK(module_data_status);
+ IREE_CHECK_OK(prepare_vm_func_status);
+ // Warmup run step.
+ {
+ vm::ref<iree_vm_list_t> outputs;
+ IREE_CHECK_OK(iree_vm_list_create(/*element_type=*/nullptr,
+ output_descs.size(),
+ iree_allocator_system(), &outputs));
+ IREE_CHECK_OK(iree_vm_invoke(context, function, /*policy=*/nullptr,
+ inputs, outputs.get(),
+ iree_allocator_system()));
+ }
+ // Benchmarking loop.
+ for (auto _ : state) {
+ // No status conversions and conditional returns in the benchmarked
+ // inner loop.
+ vm::ref<iree_vm_list_t> outputs;
+ IREE_CHECK_OK(iree_vm_list_create(/*element_type=*/nullptr,
+ output_descs.size(),
+ iree_allocator_system(), &outputs));
+ IREE_CHECK_OK(iree_vm_invoke(context, function, /*policy=*/nullptr,
+ inputs, outputs.get(),
+ iree_allocator_system()));
+ }
+ })
// By default only the main thread is included in CPU time. Include all
// the threads instead.
->MeasureProcessCPUTime()
@@ -191,8 +163,8 @@
// significant digits. If we end up wanting precision beyond microseconds,
// we can make this setting configurable with a custom command line flag.
->Unit(benchmark::kMillisecond);
- return OkStatus();
}
+} // namespace
} // namespace iree
int main(int argc, char** argv) {
@@ -232,7 +204,37 @@
absl::flags_internal::OnUndefinedFlag::kIgnoreUndefined);
::benchmark::Initialize(&argc, argv);
iree::InitializeEnvironment(&argc, &argv);
- IREE_CHECK_OK(iree::RegisterModuleBenchmarks());
+
+ iree_vm_instance_t* instance = nullptr;
+ iree_hal_device_t* device = nullptr;
+ iree_vm_module_t* hal_module = nullptr;
+ iree_vm_context_t* context = nullptr;
+ iree_vm_module_t* input_module = nullptr;
+ iree_vm_function_t function;
+ iree::vm::ref<iree_vm_list_t> inputs;
+ std::vector<iree::RawSignatureParser::Description> output_descs;
+ std::string module_data;
+ auto function_name = absl::GetFlag(FLAGS_entry_function);
+
+ // Capture status checking will be delayed at benchmarking runtime.
+ auto module_data_status = iree::GetModuleContentsFromFlags(module_data);
+ auto prepare_vm_func_status = iree::PrepareIREEVMFunction(
+ instance, device, hal_module, &context, input_module, function, inputs,
+ output_descs, function_name, module_data);
+
+ // Register function benchmarks...
+ iree::RegisterModuleBenchmarks(module_data_status, prepare_vm_func_status,
+ function_name, context, function, inputs.get(),
+ output_descs);
+
+ // Run benchmarks...
::benchmark::RunSpecifiedBenchmarks();
+
+ // Cleanup...
+ iree_vm_module_release(hal_module);
+ iree_vm_module_release(input_module);
+ iree_hal_device_release(device);
+ iree_vm_context_release(context);
+ iree_vm_instance_release(instance);
return 0;
}