blob: b567c1d692323524984f619920e1128f275270d2 [file] [log] [blame]
// Copyright 2020 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "absl/flags/flag.h"
#include "absl/flags/internal/parse.h"
#include "absl/flags/usage.h"
#include "absl/strings/string_view.h"
#include "benchmark/benchmark.h"
#include "iree/base/internal/file_io.h"
#include "iree/base/internal/flags.h"
#include "iree/base/status.h"
#include "iree/base/tracing.h"
#include "iree/hal/drivers/init.h"
#include "iree/modules/hal/hal_module.h"
#include "iree/tools/utils/vm_util.h"
#include "iree/vm/api.h"
#include "iree/vm/bytecode_module.h"
ABSL_FLAG(std::string, module_file, "-",
"File containing the module to load that contains the entry "
"function. Defaults to stdin.");
ABSL_FLAG(std::string, entry_function, "",
"Name of a function contained in the module specified by module_file "
"to run. If this is not set, all the exported functions will be "
"benchmarked and they are expected to not have input arguments.");
ABSL_FLAG(std::string, driver, "vmla", "Backend driver to use.");
ABSL_FLAG(std::vector<std::string>, function_inputs, {},
"A comma-separated list of of input buffers of the format:"
"[shape]xtype=[value]\n"
"2x2xi32=1 2 3 4\n"
"Optionally, brackets may be used to separate the element values. "
"They are ignored by the parser.\n"
"2x2xi32=[[1 2][3 4]]\n"
"Due to the absence of repeated flags in absl, commas should not be "
"used to separate elements. They are reserved for separating input "
"values:\n"
"2x2xi32=[[1 2][3 4]], 1x2xf32=[[1 2]]");
ABSL_FLAG(std::string, function_inputs_file, "",
"Provides a file for input shapes and optional values (see "
"ParseToVariantListFromFile in vm_util.h for details)");
namespace iree {
namespace {
static void BenchmarkFunction(
const std::string& benchmark_name, iree_vm_context_t* context,
iree_vm_function_t function, iree_vm_list_t* inputs,
const std::vector<RawSignatureParser::Description>& output_descs,
benchmark::State& state) {
IREE_TRACE_SCOPE_DYNAMIC(benchmark_name.c_str());
IREE_TRACE_FRAME_MARK();
// Benchmarking loop.
for (auto _ : state) {
IREE_TRACE_SCOPE0("BenchmarkIteration");
IREE_TRACE_FRAME_MARK_NAMED("Iteration");
vm::ref<iree_vm_list_t> outputs;
IREE_CHECK_OK(iree_vm_list_create(/*element_type=*/nullptr,
output_descs.size(),
iree_allocator_system(), &outputs));
IREE_CHECK_OK(iree_vm_invoke(context, function, /*policy=*/nullptr, inputs,
outputs.get(), iree_allocator_system()));
}
}
void RegisterModuleBenchmarks(
const std::string& function_name, iree_vm_context_t* context,
iree_vm_function_t function, iree_vm_list_t* inputs,
const std::vector<RawSignatureParser::Description>& output_descs) {
auto benchmark_name = "BM_" + function_name;
benchmark::RegisterBenchmark(benchmark_name.c_str(),
[benchmark_name, context, function, inputs,
output_descs](benchmark::State& state) -> void {
BenchmarkFunction(benchmark_name, context,
function, inputs,
output_descs, state);
})
// By default only the main thread is included in CPU time. Include all
// the threads instead.
->MeasureProcessCPUTime()
// To make single and multi-threaded benchmarks more comparable, use the
// wall time to determine how many iterations to run. See
// https://github.com/google/benchmark#cpu-timers,
->UseRealTime()
// Report timing in milliseconds, which is the general order of magnitude
// of model runs. The benchmark framework will print with precision
// between 0 and 3 places after the decimal while aiming for three
// significant digits. If we end up wanting precision beyond microseconds,
// we can make this setting configurable with a custom command line flag.
->Unit(benchmark::kMillisecond);
}
Status GetModuleContentsFromFlags(std::string* out_contents) {
IREE_TRACE_SCOPE0("GetModuleContentsFromFlags");
auto module_file = absl::GetFlag(FLAGS_module_file);
if (module_file == "-") {
*out_contents = std::string{std::istreambuf_iterator<char>(std::cin),
std::istreambuf_iterator<char>()};
} else {
IREE_RETURN_IF_ERROR(file_io::GetFileContents(module_file, out_contents));
}
return OkStatus();
}
// TODO(hanchung): Consider to refactor this out and reuse in iree-run-module.
// This class helps organize required resources for IREE. The order of
// construction and destruction for resources matters. And the lifetime of
// resources also matters. The lifetime of IREEBenchmark should be as long as
// ::benchmark::RunSpecifiedBenchmarks() where the resources are used during
// benchmarking.
class IREEBenchmark {
public:
IREEBenchmark() = default;
~IREEBenchmark() {
IREE_TRACE_SCOPE0("IREEBenchmark::dtor");
// Order matters.
inputs_.reset();
iree_vm_context_release(context_);
iree_vm_module_release(hal_module_);
iree_vm_module_release(input_module_);
iree_hal_device_release(device_);
iree_vm_instance_release(instance_);
};
Status Register() {
IREE_TRACE_SCOPE0("IREEBenchmark::Register");
if (!instance_ || !device_ || !hal_module_ || !context_ || !input_module_) {
IREE_RETURN_IF_ERROR(Init());
}
auto function_name = absl::GetFlag(FLAGS_entry_function);
if (!function_name.empty()) {
IREE_RETURN_IF_ERROR(RegisterSpecificFunction(function_name));
} else {
IREE_RETURN_IF_ERROR(RegisterAllExportedFunctions());
}
return iree::OkStatus();
}
private:
Status Init() {
IREE_TRACE_SCOPE0("IREEBenchmark::Init");
IREE_TRACE_FRAME_MARK_BEGIN_NAMED("init");
IREE_RETURN_IF_ERROR(GetModuleContentsFromFlags(&module_data_));
IREE_RETURN_IF_ERROR(iree_hal_module_register_types());
IREE_RETURN_IF_ERROR(
iree_vm_instance_create(iree_allocator_system(), &instance_));
// Create IREE's device and module.
IREE_RETURN_IF_ERROR(
iree::CreateDevice(absl::GetFlag(FLAGS_driver), &device_));
IREE_RETURN_IF_ERROR(CreateHalModule(device_, &hal_module_));
IREE_RETURN_IF_ERROR(LoadBytecodeModule(module_data_, &input_module_));
// Order matters. The input module will likely be dependent on the hal
// module.
std::array<iree_vm_module_t*, 2> modules = {hal_module_, input_module_};
IREE_RETURN_IF_ERROR(iree_vm_context_create_with_modules(
instance_, modules.data(), modules.size(), iree_allocator_system(),
&context_));
IREE_TRACE_FRAME_MARK_END_NAMED("init");
return iree::OkStatus();
}
Status RegisterSpecificFunction(const std::string& function_name) {
IREE_TRACE_SCOPE0("IREEBenchmark::RegisterSpecificFunction");
iree_vm_function_t function;
IREE_RETURN_IF_ERROR(input_module_->lookup_function(
input_module_->self, IREE_VM_FUNCTION_LINKAGE_EXPORT,
iree_string_view_t{function_name.data(), function_name.size()},
&function));
IREE_RETURN_IF_ERROR(ValidateFunctionAbi(function));
// Construct inputs.
std::vector<RawSignatureParser::Description> input_descs;
IREE_RETURN_IF_ERROR(ParseInputSignature(function, &input_descs));
if (!absl::GetFlag(FLAGS_function_inputs_file).empty()) {
IREE_RETURN_IF_ERROR(ParseToVariantListFromFile(
input_descs, iree_hal_device_allocator(device_),
absl::GetFlag(FLAGS_function_inputs_file), &inputs_));
} else {
IREE_RETURN_IF_ERROR(
ParseToVariantList(input_descs, iree_hal_device_allocator(device_),
absl::GetFlag(FLAGS_function_inputs), &inputs_));
}
// Creates output signature.
std::vector<RawSignatureParser::Description> output_descs;
IREE_RETURN_IF_ERROR(ParseOutputSignature(function, &output_descs));
RegisterModuleBenchmarks(function_name, context_, function, inputs_.get(),
output_descs);
return iree::OkStatus();
}
Status RegisterAllExportedFunctions() {
IREE_TRACE_SCOPE0("IREEBenchmark::RegisterAllExportedFunctions");
iree_vm_function_t function;
iree_vm_module_signature_t signature =
input_module_->signature(input_module_->self);
for (iree_host_size_t i = 0; i < signature.export_function_count; ++i) {
iree_string_view_t name;
IREE_CHECK_OK(input_module_->get_function(input_module_->self,
IREE_VM_FUNCTION_LINKAGE_EXPORT,
i, &function, &name, nullptr));
if (!ValidateFunctionAbi(function).ok()) continue;
std::string function_name(name.data, name.size);
std::vector<RawSignatureParser::Description> input_descs;
IREE_RETURN_IF_ERROR(ParseInputSignature(function, &input_descs));
if (!input_descs.empty()) {
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"expect not to have input arguments for '%.*s'",
(int)name.size, name.data);
}
std::vector<RawSignatureParser::Description> output_descs;
IREE_RETURN_IF_ERROR(ParseOutputSignature(function, &output_descs));
iree::RegisterModuleBenchmarks(function_name, context_, function,
/*inputs=*/nullptr, output_descs);
}
return iree::OkStatus();
}
std::string module_data_;
iree_vm_instance_t* instance_ = nullptr;
iree_hal_device_t* device_ = nullptr;
iree_vm_module_t* hal_module_ = nullptr;
iree_vm_context_t* context_ = nullptr;
iree_vm_module_t* input_module_ = nullptr;
iree::vm::ref<iree_vm_list_t> inputs_;
};
} // namespace
} // namespace iree
int main(int argc, char** argv) {
IREE_TRACE_SCOPE0("main");
// We have to contend with two flag parsing libraries here: absl's and
// benchmark's. To make matters worse, both define the `--help` flag. To
// ensure that each is able to parse its own flags, we use an absl "internal"
// function (still with public visibility) to parse while ignoring undefined
// flags. If it sees `--help` it will exit here, so we include the benchmark
// library usage information in the manually-set help output. Then we let
// benchmark parse its flags. Finally we call the normal initialization
// function to do other IREE initialization including flag parsing with
// normal options. Any remaining flags will be unknown and result in an error.
absl::SetProgramUsageMessage(
"iree-benchmark-module \n"
" --module_file=module.vmfb\n"
" --entry_function=exported_function_to_benchmark\n"
" If this is not set, all the exported functions will be \n"
" benchmarked and they are expected to not have input arguments\n"
" [--function_inputs=2xi32=1 2,1x2xf32=2 1 | \n"
" --function_inputs_file=file_with_function_inputs]\n"
" [--driver=vmla]\n"
"\n\n"
" Optional flags from third_party/benchmark/src/benchmark.cc:\n"
" [--benchmark_list_tests={true|false}]\n"
" [--benchmark_filter=<regex>]\n"
" [--benchmark_min_time=<min_time>]\n"
" [--benchmark_repetitions=<num_repetitions>]\n"
" [--benchmark_report_aggregates_only={true|false}]\n"
" [--benchmark_display_aggregates_only={true|false}]\n"
" [--benchmark_format=<console|json|csv>]\n"
" [--benchmark_out=<filename>]\n"
" [--benchmark_out_format=<json|console|csv>]\n"
" [--benchmark_color={auto|true|false}]\n"
" [--benchmark_counters_tabular={true|false}]\n"
" [--v=<verbosity>]\n");
absl::flags_internal::ParseCommandLineImpl(
argc, argv, absl::flags_internal::ArgvListAction::kRemoveParsedArgs,
absl::flags_internal::UsageFlagsAction::kHandleUsage,
absl::flags_internal::OnUndefinedFlag::kIgnoreUndefined);
::benchmark::Initialize(&argc, argv);
iree_flags_parse_checked(&argc, &argv);
IREE_CHECK_OK(iree_hal_register_all_available_drivers(
iree_hal_driver_registry_default()));
iree::IREEBenchmark iree_benchmark;
auto status = iree_benchmark.Register();
if (!status.ok()) {
std::cout << status << std::endl;
return static_cast<int>(status.code());
}
::benchmark::RunSpecifiedBenchmarks();
return 0;
}