blob: b5f45b98b616eca0db57e91846797d44ecc109d9 [file] [log] [blame]
// Copyright 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <iostream>
#include <vector>
#include "absl/flags/flag.h"
#include "absl/strings/numbers.h"
#include "absl/strings/str_replace.h"
#include "absl/strings/str_split.h"
#include "absl/strings/string_view.h"
#include "base/file_io.h"
#include "base/file_path.h"
#include "base/init.h"
#include "base/source_location.h"
#include "base/status.h"
#include "benchmark/benchmark.h"
#include "hal/buffer_view_string_util.h"
#include "hal/driver_registry.h"
#include "rt/context.h"
#include "rt/debug/debug_server_flags.h"
#include "rt/instance.h"
#include "rt/module_printer.h"
#include "schemas/module_def_generated.h"
#include "vm/sequencer_module.h"
ABSL_FLAG(std::string, main_module, "", "Main module with entry point.");
ABSL_FLAG(std::string, main_function, "",
"Function within the main module to execute.");
ABSL_FLAG(std::string, input_values, "", "Input shapes and optional values.");
ABSL_FLAG(std::string, input_file, "",
"Input shapes and optional values serialized in a file.");
namespace iree {
namespace {
// Parses a list of input shapes and values from a string of newline-separated
// inputs. Expects the contents to have one value per line with each value
// listed as
// [shape]xtype=[value]
// Example:
// 4x4xi8=0,1,2,3
StatusOr<std::vector<hal::BufferView>> ParseInputsFromFlags(
hal::Allocator* allocator) {
std::string file_contents;
if (!absl::GetFlag(FLAGS_input_values).empty()) {
file_contents =
absl::StrReplaceAll(absl::GetFlag(FLAGS_input_values), {{"\\n", "\n"}});
} else if (!absl::GetFlag(FLAGS_input_file).empty()) {
ASSIGN_OR_RETURN(file_contents,
file_io::GetFileContents(absl::GetFlag(FLAGS_input_file)));
}
std::vector<hal::BufferView> inputs;
for (const auto& line :
absl::StrSplit(file_contents, '\n', absl::SkipWhitespace())) {
ASSIGN_OR_RETURN(auto input,
hal::ParseBufferViewFromString(line, allocator));
inputs.push_back(input);
}
return inputs;
}
Status Run(benchmark::State& state) {
ASSIGN_OR_RETURN(auto debug_server, rt::debug::CreateDebugServerFromFlags());
auto instance = make_ref<rt::Instance>(std::move(debug_server));
ASSIGN_OR_RETURN(auto driver, hal::DriverRegistry::shared_registry()->Create(
"interpreter"));
ASSIGN_OR_RETURN(auto device, driver->CreateDefaultDevice());
RETURN_IF_ERROR(instance->device_manager()->RegisterDevice(device));
auto policy = make_ref<rt::Policy>();
auto context = make_ref<rt::Context>(add_ref(instance), std::move(policy));
// Load main module.
ASSIGN_OR_RETURN(
auto main_module_file,
vm::ModuleFile::LoadFile(ModuleDefIdentifier(),
absl::GetFlag(FLAGS_main_module)),
_ << "while loading module file " << absl::GetFlag(FLAGS_main_module));
ASSIGN_OR_RETURN(auto main_module,
vm::SequencerModule::FromFile(std::move(main_module_file)));
// Register the main module with the context.
// We could add additional modules (specializations, shared libraries, etc).
// ModuleFiles are stateless so we could have the same module_file used by
// multiple contexts simultaneously.
RETURN_IF_ERROR(context->RegisterModule(add_ref(main_module)));
rt::Function main_function;
if (!absl::GetFlag(FLAGS_main_function).empty()) {
// User-specified main function.
ASSIGN_OR_RETURN(main_function, main_module->LookupFunctionByName(
rt::Function::Linkage::kExport,
absl::GetFlag(FLAGS_main_function)));
} else {
// No main function specified; to prevent non-deterministic behavior we
// require one unless there's exactly one exported function in the module.
if (main_module->signature().export_function_count() == 1) {
ASSIGN_OR_RETURN(main_function, main_module->LookupFunctionByOrdinal(
rt::Function::Linkage::kExport, 0));
} else {
return InvalidArgumentErrorBuilder(IREE_LOC)
<< "--main_function= must be specified to disambiguate the "
"function to run";
}
}
// Call into the main function.
ASSIGN_OR_RETURN(auto arguments, ParseInputsFromFlags(device->allocator()));
for (auto _ : state) {
ASSIGN_OR_RETURN(auto invocation,
rt::Invocation::Create(add_ref(context), main_function,
make_ref<rt::Policy>(), {},
absl::MakeConstSpan(arguments)));
RETURN_IF_ERROR(invocation->Await(absl::InfiniteFuture()));
}
return OkStatus();
}
void BM_RunModule(benchmark::State& state) {
// Delegate to a status-returning function so we can use the status macros.
CHECK_OK(Run(state));
}
// By default only the main thread is included in CPU time. Include all the
// threads instead. To make single and multi-threaded benchmarks more
// comparable, use the wall time to determine how many iterations to run.
// See https://github.com/google/benchmark#cpu-timers,
BENCHMARK(BM_RunModule)->MeasureProcessCPUTime()->UseRealTime();
} // namespace
extern "C" int main(int argc, char** argv) {
// The benchmark library uses a different mechanism for its flags. This
// consumes any arguments it understands from argv. It must come before
// InitializeEnvironment to avoid failures on unknown flags.
::benchmark::Initialize(&argc, argv);
InitializeEnvironment(&argc, &argv);
size_t run_benchmark_count = ::benchmark::RunSpecifiedBenchmarks();
CHECK_GT(run_benchmark_count, 0) << "No benchmarks were run";
return 0;
}
} // namespace iree