|  | // Copyright 2020 The IREE Authors | 
|  | // | 
|  | // Licensed under the Apache License v2.0 with LLVM Exceptions. | 
|  | // See https://llvm.org/LICENSE.txt for license information. | 
|  | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | 
|  |  | 
|  | //===----------------------------------------------------------------------===// | 
|  | // iree-benchmark-module: benchmarks public functions in an IREE VM module | 
|  | //===----------------------------------------------------------------------===// | 
|  | // | 
|  | // This runs exported functions using flags specified on the command line. | 
|  | // Each function is measured independently and the numbers reported will be for | 
|  | // the full end-to-end CPU and wall times. | 
|  | // | 
|  | // From an ML perspective this is an integration benchmark for measuring total | 
|  | // user-visible latency of model entry points. It is *not* a microbenchmarking | 
|  | // tool for individual device-side dispatch functions (aka ops aka kernels). | 
|  | // If interested in the precise time of a particular dispatch then tracy, | 
|  | // executable_library_benchmark, and platform/vendor tooling (nsight, perf, etc) | 
|  | // are to be used instead and attaching them to this tool is often useful in | 
|  | // order to get a large sample set. | 
|  | // | 
|  | // By default all functions taking no inputs will be benchmarked. If a function | 
|  | // takes inputs then the user will need to specify them using --input= | 
|  | // flags. Depending on the input program the -iree-flow-export-benchmark-funcs | 
|  | // flag can be passed to the compiler to attempt to wrap each function with | 
|  | // dummy inputs however this will fail in programs with dynamically shaped | 
|  | // inputs. The workaround for avoiding the need for flags is to provide the | 
|  | // input program in a form with no inputs from the start. | 
|  | // | 
|  | // It's important to remember that IREE is not a BLAS library and is meant to | 
|  | // run entire programs. It's not generally appropriate to benchmark a model with | 
|  | // a single matmul, for example, as that's just treating IREE as a BLAS library. | 
|  | // Note also that user-level ops in a frontend environment don't map to the | 
|  | // dispatches that IREE executes: IREE is a compiler like any other and does not | 
|  | // guarantee a source line of code translates into an atomically divisible and | 
|  | // independently measurable execution command. In other words don't expect to be | 
|  | // able to benchmark the cost of a broadcasting elementwise tf.add op within a | 
|  | // model: by the time we are running the program that's fused itself into a | 
|  | // single machine instruction operating as part of some other ops. | 
|  | // | 
|  | // For coarse dispatch testing and triaging it can still be useful to remove | 
|  | // some of the overheads introduced by whole-program execution and the compiler | 
|  | // flag --iree-hal-benchmark-dispatch-repeat-count=N is provided to enable | 
|  | // batching. Whatever N is chosen must then be passed to this tool via | 
|  | // --batch_size=N so that the benchmark reporting properly reflects the | 
|  | // batching. As an example --iree-hal-benchmark-dispatch-repeat-count=32 + | 
|  | // --batch_size=32 will reduce the overheads by 32x. Think of this as a way to | 
|  | // control the p value in Amdahl's law representing the amount of time spent in | 
|  | // dispatches relative to the rest of the program. This isn't representative of | 
|  | // how the full program will run, though, and YMMV. Always verify timings with | 
|  | // an appropriate device-specific tool before trusting the more generic and | 
|  | // higher-level numbers from this tool. | 
|  |  | 
|  | #include <array> | 
|  | #include <cstdio> | 
|  | #include <iterator> | 
|  | #include <string> | 
|  | #include <type_traits> | 
|  | #include <utility> | 
|  | #include <vector> | 
|  |  | 
|  | #include "benchmark/benchmark.h" | 
|  | #include "iree/base/api.h" | 
|  | #include "iree/base/internal/flags.h" | 
|  | #include "iree/hal/api.h" | 
|  | #include "iree/modules/hal/types.h" | 
|  | #include "iree/tooling/context_util.h" | 
|  | #include "iree/tooling/device_util.h" | 
|  | #include "iree/tooling/function_io.h" | 
|  | #include "iree/vm/api.h" | 
|  |  | 
|  | constexpr char kNanosecondsUnitString[] = "ns"; | 
|  | constexpr char kMicrosecondsUnitString[] = "us"; | 
|  | constexpr char kMillisecondsUnitString[] = "ms"; | 
|  |  | 
|  | // TODO(hanchung): Extract the batch size using | 
|  | // iree_vm_function_lookup_attr_by_name. | 
|  | IREE_FLAG(int32_t, batch_size, 1, | 
|  | "Number of invocations per iteration, which for dispatch benchmarks " | 
|  | "must match the --iree-hal-benchmark-dispatch-repeat-count value " | 
|  | "used during compilation."); | 
|  | IREE_FLAG(int32_t, batch_concurrency, 1, | 
|  | "Number of invocations within a batch that should run concurrently."); | 
|  |  | 
|  | IREE_FLAG(string, function, "", | 
|  | "Name of a function contained in the module specified by --module= " | 
|  | "to run. If this is not set, all the exported functions will be " | 
|  | "benchmarked and they are expected to not have input arguments."); | 
|  |  | 
|  | IREE_FLAG(bool, print_statistics, false, | 
|  | "Prints runtime statistics to stderr on exit."); | 
|  |  | 
|  | IREE_FLAG_LIST( | 
|  | string, input, | 
|  | "An input value or buffer of the format:\n" | 
|  | "  [shape]xtype=[value]\n" | 
|  | "  2x2xi32=1 2 3 4\n" | 
|  | "Optionally, brackets may be used to separate the element values:\n" | 
|  | "  2x2xi32=[[1 2][3 4]]\n" | 
|  | "Raw binary files can be read to provide buffer contents:\n" | 
|  | "  2x2xi32=@some/file.bin\n" | 
|  | "numpy npy files (from numpy.save) can be read to provide 1+ values:\n" | 
|  | "  @some.npy\n" | 
|  | "Each occurrence of the flag indicates an input in the order they were\n" | 
|  | "specified on the command line."); | 
|  |  | 
|  | static iree_status_t parse_time_unit(iree_string_view_t flag_name, | 
|  | void* storage, iree_string_view_t value) { | 
|  | auto* unit = (std::pair<bool, benchmark::TimeUnit>*)storage; | 
|  | auto unit_string = std::string(value.data, value.size); | 
|  | if (unit_string == kMillisecondsUnitString) { | 
|  | *unit = {true, benchmark::kMillisecond}; | 
|  | return iree_ok_status(); | 
|  | } else if (unit_string == kMicrosecondsUnitString) { | 
|  | *unit = {true, benchmark::kMicrosecond}; | 
|  | return iree_ok_status(); | 
|  | } else if (unit_string == kNanosecondsUnitString) { | 
|  | *unit = {true, benchmark::kNanosecond}; | 
|  | return iree_ok_status(); | 
|  | } | 
|  | return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, | 
|  | "unsupported time unit"); | 
|  | } | 
|  | static void print_time_unit(iree_string_view_t flag_name, void* storage, | 
|  | FILE* file) { | 
|  | auto* unit = (std::pair<bool, benchmark::TimeUnit>*)storage; | 
|  | if (!unit->first) { | 
|  | return; | 
|  | } | 
|  | std::string unit_string; | 
|  | switch (unit->second) { | 
|  | case benchmark::kMillisecond: | 
|  | unit_string = kMillisecondsUnitString; | 
|  | break; | 
|  | case benchmark::kMicrosecond: | 
|  | unit_string = kMicrosecondsUnitString; | 
|  | break; | 
|  | case benchmark::kNanosecond: | 
|  | unit_string = kNanosecondsUnitString; | 
|  | break; | 
|  | default: | 
|  | assert(false && "Unexpected time unit."); | 
|  | } | 
|  | fprintf(file, "--%.*s=\"%s\"\n", (int)flag_name.size, flag_name.data, | 
|  | unit_string.c_str()); | 
|  | } | 
|  | // Time unit to be printed. If the first field is false, each place will use its | 
|  | // default time unit. | 
|  | static std::pair<bool, benchmark::TimeUnit> FLAG_time_unit = { | 
|  | false, benchmark::kNanosecond}; | 
|  | IREE_FLAG_CALLBACK( | 
|  | parse_time_unit, print_time_unit, &FLAG_time_unit, time_unit, | 
|  | "The time unit to be printed in the results. Can be 'ms', 'us', or 'ns'."); | 
|  |  | 
|  | namespace iree { | 
|  | namespace { | 
|  |  | 
|  | static void BenchmarkGenericFunction(const std::string& benchmark_name, | 
|  | int32_t batch_size, | 
|  | iree_hal_device_t* device, | 
|  | iree_vm_context_t* context, | 
|  | iree_vm_function_t function, | 
|  | iree_vm_list_t* inputs, | 
|  | benchmark::State& state) { | 
|  | IREE_TRACE_ZONE_BEGIN_NAMED_DYNAMIC(z0, benchmark_name.data(), | 
|  | benchmark_name.size()); | 
|  | IREE_TRACE_FRAME_MARK(); | 
|  |  | 
|  | vm::ref<iree_vm_list_t> outputs; | 
|  | IREE_CHECK_OK(iree_vm_list_create(iree_vm_make_undefined_type_def(), 16, | 
|  | iree_allocator_system(), &outputs)); | 
|  |  | 
|  | // Benchmarking loop. | 
|  | while (state.KeepRunningBatch(batch_size)) { | 
|  | IREE_TRACE_ZONE_BEGIN_NAMED(z1, "BenchmarkIteration"); | 
|  | IREE_TRACE_FRAME_MARK_NAMED("Iteration"); | 
|  | IREE_CHECK_OK(iree_vm_invoke( | 
|  | context, function, IREE_VM_INVOCATION_FLAG_NONE, /*policy=*/nullptr, | 
|  | inputs, outputs.get(), iree_allocator_system())); | 
|  | IREE_CHECK_OK(iree_vm_list_resize(outputs.get(), 0)); | 
|  | IREE_TRACE_ZONE_END(z1); | 
|  | if (device) { | 
|  | state.PauseTiming(); | 
|  | IREE_CHECK_OK(iree_hal_device_profiling_flush(device)); | 
|  | state.ResumeTiming(); | 
|  | } | 
|  | } | 
|  | state.SetItemsProcessed(state.iterations()); | 
|  |  | 
|  | IREE_TRACE_ZONE_END(z0); | 
|  | } | 
|  |  | 
|  | void RegisterGenericBenchmark(const std::string& function_name, | 
|  | iree_hal_device_t* device, | 
|  | iree_vm_context_t* context, | 
|  | iree_vm_function_t function, | 
|  | iree_vm_list_t* inputs) { | 
|  | auto benchmark_name = "BM_" + function_name; | 
|  | int32_t batch_size = FLAG_batch_size; | 
|  | benchmark::RegisterBenchmark(benchmark_name.c_str(), | 
|  | [=](benchmark::State& state) -> void { | 
|  | BenchmarkGenericFunction( | 
|  | benchmark_name, batch_size, device, | 
|  | context, function, inputs, state); | 
|  | }) | 
|  | // By default only the main thread is included in CPU time. Include all | 
|  | // the threads instead. | 
|  | ->MeasureProcessCPUTime() | 
|  | // To make single and multi-threaded benchmarks more comparable, use the | 
|  | // wall time to determine how many iterations to run. See | 
|  | // https://github.com/google/benchmark#cpu-timers, | 
|  | ->UseRealTime() | 
|  | ->Unit(FLAG_time_unit.first ? FLAG_time_unit.second | 
|  | : benchmark::kMillisecond); | 
|  | } | 
|  |  | 
|  | // Runs up to |batch_size| pipelined invocations in sequence along with | 
|  | // concurrency. Example: | 
|  | //   batch_size=1, concurrency=1: | 
|  | //     [invocation 0] | 
|  | //   batch_size=2, concurrency=1: | 
|  | //     [invocation 0] -> [invocation 1] | 
|  | //   batch_size=2, concurrency=2: | 
|  | //     [invocation 0] | 
|  | //     [invocation 1] | 
|  | //   batch_size=4, concurrency=2: | 
|  | //     [invocation 0] -> [invocation 2] | 
|  | //     [invocation 1] -> [invocation 3] | 
|  | static void BenchmarkAsyncFunction( | 
|  | const std::string& benchmark_name, int32_t batch_size, | 
|  | int32_t batch_concurrency, iree_hal_device_t* device, | 
|  | iree_vm_context_t* context, iree_vm_function_t function, | 
|  | iree_vm_list_t* common_inputs, benchmark::State& state) { | 
|  | IREE_TRACE_ZONE_BEGIN_NAMED_DYNAMIC(z0, benchmark_name.data(), | 
|  | benchmark_name.size()); | 
|  | IREE_TRACE_FRAME_MARK(); | 
|  | iree_allocator_t host_allocator = iree_allocator_system(); | 
|  |  | 
|  | // Round up batch size to some multiple of concurrency. | 
|  | batch_size = (int32_t)iree_host_align(batch_size, batch_concurrency); | 
|  |  | 
|  | // Benchmarking loop. | 
|  | while (state.KeepRunningBatch(batch_size)) { | 
|  | state.PauseTiming(); | 
|  | IREE_TRACE_ZONE_BEGIN_NAMED(z1, "BenchmarkIteration"); | 
|  | IREE_TRACE_FRAME_MARK_NAMED("Iteration"); | 
|  |  | 
|  | IREE_TRACE_ZONE_BEGIN_NAMED(z_begin, "PrepareBatch"); | 
|  |  | 
|  | // Each concurrent track of execution gets its own semaphore. | 
|  | std::vector<vm::ref<iree_hal_semaphore_t>> timeline_semaphores; | 
|  | for (int32_t i = 0; i < batch_concurrency; ++i) { | 
|  | vm::ref<iree_hal_semaphore_t> timeline_semaphore; | 
|  | IREE_CHECK_OK( | 
|  | iree_hal_semaphore_create(device, 0ull, &timeline_semaphore)); | 
|  | timeline_semaphores.push_back(std::move(timeline_semaphore)); | 
|  | } | 
|  |  | 
|  | // Preallocate fences and I/O for each invocation. | 
|  | // The same inputs are used for each but we need a unique list to hold the | 
|  | // unique fences. Each fence represents when the invocation has completed. | 
|  | std::vector<vm::ref<iree_hal_fence_t>> invocation_fences; | 
|  | std::vector<vm::ref<iree_vm_list_t>> invocation_inputs; | 
|  | std::vector<vm::ref<iree_vm_list_t>> invocation_outputs; | 
|  | vm::ref<iree_hal_fence_t> completion_fence; | 
|  | IREE_CHECK_OK(iree_hal_fence_create(batch_concurrency, host_allocator, | 
|  | &completion_fence)); | 
|  | for (int32_t i = 0; i < batch_size / batch_concurrency; ++i) { | 
|  | for (int32_t j = 0; j < batch_concurrency; ++j) { | 
|  | // Chain each concurrent minibatch to the previous. Note that to start | 
|  | // we wait on nothing and begin executing immediately. | 
|  | vm::ref<iree_hal_fence_t> wait_fence; | 
|  | if (i > 0) { | 
|  | wait_fence = vm::retain_ref( | 
|  | invocation_fences[(i - 1) * batch_concurrency + j]); | 
|  | } | 
|  | uint64_t signal_value = i + 1; | 
|  | vm::ref<iree_hal_fence_t> signal_fence; | 
|  | IREE_CHECK_OK(iree_hal_fence_create_at(timeline_semaphores[j].get(), | 
|  | signal_value, host_allocator, | 
|  | &signal_fence)); | 
|  | invocation_fences.push_back(vm::retain_ref(signal_fence)); | 
|  |  | 
|  | // Join the final minibatch on the completion fence. | 
|  | if (i == batch_size / batch_concurrency - 1) { | 
|  | IREE_CHECK_OK(iree_hal_fence_insert(completion_fence.get(), | 
|  | timeline_semaphores[j].get(), | 
|  | signal_value)); | 
|  | } | 
|  |  | 
|  | // Clone common inputs and add the invocation-specific fences. | 
|  | vm::ref<iree_vm_list_t> inputs; | 
|  | IREE_CHECK_OK( | 
|  | iree_vm_list_clone(common_inputs, host_allocator, &inputs)); | 
|  | IREE_CHECK_OK(iree_vm_list_push_ref_move(inputs.get(), wait_fence)); | 
|  | IREE_CHECK_OK(iree_vm_list_push_ref_move(inputs.get(), signal_fence)); | 
|  | invocation_inputs.push_back(std::move(inputs)); | 
|  |  | 
|  | // Setup empty outputs. | 
|  | vm::ref<iree_vm_list_t> outputs; | 
|  | IREE_CHECK_OK(iree_vm_list_create(iree_vm_make_undefined_type_def(), 16, | 
|  | host_allocator, &outputs)); | 
|  | invocation_outputs.push_back(std::move(outputs)); | 
|  | } | 
|  | } | 
|  |  | 
|  | IREE_TRACE_ZONE_END(z_begin); | 
|  |  | 
|  | state.ResumeTiming(); | 
|  | { | 
|  | // TODO(benvanik): replace with async invocations. Today if the invocation | 
|  | // performs any waits this will block on the initial invoke instead of | 
|  | // actually overlapping things. | 
|  | for (int32_t i = 0; i < batch_size; ++i) { | 
|  | IREE_CHECK_OK( | 
|  | iree_vm_invoke(context, function, IREE_VM_INVOCATION_FLAG_NONE, | 
|  | /*policy=*/nullptr, invocation_inputs[i].get(), | 
|  | invocation_outputs[i].get(), host_allocator)); | 
|  | } | 
|  | IREE_CHECK_OK( | 
|  | iree_hal_fence_wait(completion_fence.get(), iree_infinite_timeout())); | 
|  | } | 
|  | state.PauseTiming(); | 
|  |  | 
|  | IREE_TRACE_ZONE_BEGIN_NAMED(z_end, "CleanupBatch"); | 
|  | for (int32_t i = 0; i < batch_size; ++i) { | 
|  | iree_vm_list_clear(invocation_outputs[i].get()); | 
|  | } | 
|  | invocation_fences.clear(); | 
|  | invocation_inputs.clear(); | 
|  | invocation_outputs.clear(); | 
|  | completion_fence.reset(); | 
|  | timeline_semaphores.clear(); | 
|  | IREE_TRACE_ZONE_END(z_end); | 
|  |  | 
|  | IREE_TRACE_ZONE_END(z1); | 
|  | if (device) { | 
|  | IREE_CHECK_OK(iree_hal_device_profiling_flush(device)); | 
|  | } | 
|  | state.ResumeTiming(); | 
|  | } | 
|  | state.SetItemsProcessed(state.iterations()); | 
|  |  | 
|  | IREE_TRACE_ZONE_END(z0); | 
|  | } | 
|  |  | 
|  | void RegisterAsyncBenchmark(const std::string& function_name, | 
|  | iree_hal_device_t* device, | 
|  | iree_vm_context_t* context, | 
|  | iree_vm_function_t function, | 
|  | iree_vm_list_t* inputs) { | 
|  | auto benchmark_name = "BM_" + function_name; | 
|  | int32_t batch_size = FLAG_batch_size; | 
|  | int32_t batch_concurrency = FLAG_batch_concurrency; | 
|  | benchmark::RegisterBenchmark( | 
|  | benchmark_name.c_str(), | 
|  | [=](benchmark::State& state) -> void { | 
|  | BenchmarkAsyncFunction(benchmark_name, batch_size, batch_concurrency, | 
|  | device, context, function, inputs, state); | 
|  | }) | 
|  | // By default only the main thread is included in CPU time. Include all | 
|  | // the threads instead. | 
|  | ->MeasureProcessCPUTime() | 
|  | // To make single and multi-threaded benchmarks more comparable, use the | 
|  | // wall time to determine how many iterations to run. See | 
|  | // https://github.com/google/benchmark#cpu-timers, | 
|  | ->UseRealTime() | 
|  | ->Unit(FLAG_time_unit.first ? FLAG_time_unit.second | 
|  | : benchmark::kMillisecond); | 
|  | } | 
|  |  | 
|  | static void BenchmarkDispatchFunction(const std::string& benchmark_name, | 
|  | iree_vm_context_t* context, | 
|  | iree_vm_function_t function, | 
|  | benchmark::State& state) { | 
|  | IREE_TRACE_ZONE_BEGIN_NAMED_DYNAMIC(z0, benchmark_name.data(), | 
|  | benchmark_name.size()); | 
|  | IREE_TRACE_FRAME_MARK(); | 
|  |  | 
|  | vm::ref<iree_vm_list_t> inputs; | 
|  | IREE_CHECK_OK(iree_vm_list_create(iree_vm_make_undefined_type_def(), 16, | 
|  | iree_allocator_system(), &inputs)); | 
|  | iree_vm_value_t batch_size = iree_vm_value_make_i32(FLAG_batch_size); | 
|  | IREE_CHECK_OK(iree_vm_list_push_value(inputs.get(), &batch_size)); | 
|  |  | 
|  | vm::ref<iree_vm_list_t> outputs; | 
|  | IREE_CHECK_OK(iree_vm_list_create(iree_vm_make_undefined_type_def(), 16, | 
|  | iree_allocator_system(), &outputs)); | 
|  |  | 
|  | // Benchmarking loop. | 
|  | while (state.KeepRunningBatch(FLAG_batch_size)) { | 
|  | IREE_TRACE_ZONE_BEGIN_NAMED(z1, "BenchmarkIteration"); | 
|  | IREE_TRACE_FRAME_MARK_NAMED("Iteration"); | 
|  | IREE_CHECK_OK(iree_vm_invoke( | 
|  | context, function, IREE_VM_INVOCATION_FLAG_NONE, /*policy=*/nullptr, | 
|  | inputs.get(), outputs.get(), iree_allocator_system())); | 
|  | IREE_CHECK_OK(iree_vm_list_resize(outputs.get(), 0)); | 
|  | IREE_TRACE_ZONE_END(z1); | 
|  | } | 
|  | state.SetItemsProcessed(state.iterations()); | 
|  |  | 
|  | IREE_TRACE_ZONE_END(z0); | 
|  | } | 
|  |  | 
|  | void RegisterDispatchBenchmark(const std::string& function_name, | 
|  | iree_vm_context_t* context, | 
|  | iree_vm_function_t function) { | 
|  | auto benchmark_name = "BM_" + function_name; | 
|  | benchmark::RegisterBenchmark( | 
|  | benchmark_name.c_str(), | 
|  | [benchmark_name, context, function](benchmark::State& state) -> void { | 
|  | BenchmarkDispatchFunction(benchmark_name, context, function, state); | 
|  | }) | 
|  | // By default only the main thread is included in CPU time. Include all | 
|  | // the threads instead. | 
|  | ->MeasureProcessCPUTime() | 
|  | // To make single and multi-threaded benchmarks more comparable, use the | 
|  | // wall time to determine how many iterations to run. See | 
|  | // https://github.com/google/benchmark#cpu-timers, | 
|  | ->UseRealTime() | 
|  | ->Unit(FLAG_time_unit.first ? FLAG_time_unit.second | 
|  | : benchmark::kMicrosecond); | 
|  | } | 
|  |  | 
|  | // The lifetime of IREEBenchmark should be as long as | 
|  | // ::benchmark::RunSpecifiedBenchmarks() where the resources are used during | 
|  | // benchmarking. | 
|  | class IREEBenchmark { | 
|  | public: | 
|  | IREEBenchmark() { iree_tooling_module_list_initialize(&module_list_); } | 
|  |  | 
|  | ~IREEBenchmark() { | 
|  | IREE_TRACE_SCOPE_NAMED("IREEBenchmark::dtor"); | 
|  |  | 
|  | // Order matters. Tear down modules first to release resources. | 
|  | inputs_.reset(); | 
|  | context_.reset(); | 
|  | iree_tooling_module_list_reset(&module_list_); | 
|  | instance_.reset(); | 
|  |  | 
|  | // Tear down device last in order to get accurate statistics. | 
|  | if (device_allocator_ && FLAG_print_statistics) { | 
|  | IREE_IGNORE_ERROR(iree_hal_allocator_statistics_fprint( | 
|  | stderr, device_allocator_.get())); | 
|  | } | 
|  | device_allocator_.reset(); | 
|  | device_.reset(); | 
|  | }; | 
|  |  | 
|  | iree_hal_device_t* device() const { return device_.get(); } | 
|  |  | 
|  | iree_status_t Register() { | 
|  | IREE_TRACE_SCOPE_NAMED("IREEBenchmark::Register"); | 
|  |  | 
|  | if (!instance_ || !device_allocator_ || !context_ || !module_list_.count) { | 
|  | IREE_RETURN_IF_ERROR(Init()); | 
|  | } | 
|  |  | 
|  | auto function_name = std::string(FLAG_function); | 
|  | if (!function_name.empty()) { | 
|  | IREE_RETURN_IF_ERROR(RegisterSpecificFunction(function_name)); | 
|  | } else { | 
|  | IREE_RETURN_IF_ERROR(RegisterAllExportedFunctions()); | 
|  | } | 
|  | return iree_ok_status(); | 
|  | } | 
|  |  | 
|  | private: | 
|  | iree_status_t Init() { | 
|  | IREE_TRACE_SCOPE_NAMED("IREEBenchmark::Init"); | 
|  | IREE_TRACE_FRAME_MARK_BEGIN_NAMED("init"); | 
|  |  | 
|  | iree_allocator_t host_allocator = iree_allocator_system(); | 
|  | IREE_RETURN_IF_ERROR( | 
|  | iree_tooling_create_instance(host_allocator, &instance_)); | 
|  |  | 
|  | IREE_RETURN_IF_ERROR(iree_tooling_load_modules_from_flags( | 
|  | instance_.get(), host_allocator, &module_list_)); | 
|  |  | 
|  | IREE_RETURN_IF_ERROR(iree_tooling_create_context_from_flags( | 
|  | instance_.get(), module_list_.count, module_list_.values, | 
|  | /*default_device_uri=*/iree_string_view_empty(), host_allocator, | 
|  | &context_, &device_, &device_allocator_)); | 
|  |  | 
|  | IREE_TRACE_FRAME_MARK_END_NAMED("init"); | 
|  | return iree_ok_status(); | 
|  | } | 
|  |  | 
|  | iree_status_t RegisterSpecificFunction(const std::string& function_name) { | 
|  | IREE_TRACE_SCOPE_NAMED("IREEBenchmark::RegisterSpecificFunction"); | 
|  |  | 
|  | iree_vm_module_t* main_module = | 
|  | iree_tooling_module_list_back(&module_list_); | 
|  | iree_vm_function_t function; | 
|  | IREE_RETURN_IF_ERROR(iree_vm_module_lookup_function_by_name( | 
|  | main_module, IREE_VM_FUNCTION_LINKAGE_EXPORT, | 
|  | iree_string_view_t{function_name.data(), | 
|  | (iree_host_size_t)function_name.size()}, | 
|  | &function)); | 
|  | iree_vm_function_signature_t signature = | 
|  | iree_vm_function_signature(&function); | 
|  | iree_string_view_t arguments_cconv, results_cconv; | 
|  | IREE_RETURN_IF_ERROR(iree_vm_function_call_get_cconv_fragments( | 
|  | &signature, &arguments_cconv, &results_cconv)); | 
|  |  | 
|  | IREE_CHECK_OK(iree_tooling_parse_variants( | 
|  | arguments_cconv, FLAG_input_list(), device_.get(), | 
|  | device_allocator_.get(), iree_vm_instance_allocator(instance_.get()), | 
|  | &inputs_)); | 
|  |  | 
|  | iree_string_view_t invocation_model = iree_vm_function_lookup_attr_by_name( | 
|  | &function, IREE_SV("iree.abi.model")); | 
|  | if (iree_string_view_equal(invocation_model, IREE_SV("coarse-fences"))) { | 
|  | // Asynchronous invocation. | 
|  | iree::RegisterAsyncBenchmark(function_name, device_.get(), context_.get(), | 
|  | function, inputs_.get()); | 
|  | } else { | 
|  | // Synchronous invocation. | 
|  | iree::RegisterGenericBenchmark(function_name, device_.get(), | 
|  | context_.get(), function, inputs_.get()); | 
|  | } | 
|  | return iree_ok_status(); | 
|  | } | 
|  |  | 
|  | iree_status_t RegisterAllExportedFunctions() { | 
|  | IREE_TRACE_SCOPE_NAMED("IREEBenchmark::RegisterAllExportedFunctions"); | 
|  | iree_vm_module_t* main_module = | 
|  | iree_tooling_module_list_back(&module_list_); | 
|  | iree_vm_module_signature_t signature = | 
|  | iree_vm_module_signature(main_module); | 
|  | for (iree_host_size_t i = 0; i < signature.export_function_count; ++i) { | 
|  | iree_vm_function_t function; | 
|  | IREE_RETURN_IF_ERROR(iree_vm_module_lookup_function_by_ordinal( | 
|  | main_module, IREE_VM_FUNCTION_LINKAGE_EXPORT, i, &function)); | 
|  | iree_string_view_t function_name = iree_vm_function_name(&function); | 
|  |  | 
|  | // We run anything with the 'benchmark' attribute. | 
|  | // If the attribute is not present we'll run anything that looks runnable. | 
|  | iree_string_view_t benchmark_type = iree_vm_function_lookup_attr_by_name( | 
|  | &function, IREE_SV("iree.benchmark")); | 
|  | if (iree_string_view_equal(benchmark_type, IREE_SV("dispatch"))) { | 
|  | iree::RegisterDispatchBenchmark( | 
|  | std::string(function_name.data, function_name.size), context_.get(), | 
|  | function); | 
|  | } else if (iree_string_view_equal(benchmark_type, IREE_SV("entry"))) { | 
|  | iree::RegisterGenericBenchmark( | 
|  | std::string(function_name.data, function_name.size), device_.get(), | 
|  | context_.get(), function, | 
|  | /*inputs=*/nullptr); | 
|  | } else { | 
|  | // Pick up generic () -> () functions. | 
|  | if (iree_string_view_starts_with(function_name, | 
|  | iree_make_cstring_view("__")) || | 
|  | iree_string_view_find_char(function_name, '$', 0) != | 
|  | IREE_STRING_VIEW_NPOS) { | 
|  | // Skip internal or special functions. | 
|  | continue; | 
|  | } | 
|  |  | 
|  | // Query function information to determine how to run it. | 
|  | iree_vm_function_signature_t signature = | 
|  | iree_vm_function_signature(&function); | 
|  | iree_host_size_t argument_count = 0; | 
|  | iree_host_size_t result_count = 0; | 
|  | IREE_RETURN_IF_ERROR(iree_vm_function_call_count_arguments_and_results( | 
|  | &signature, &argument_count, &result_count)); | 
|  | iree_string_view_t invocation_model = | 
|  | iree_vm_function_lookup_attr_by_name(&function, | 
|  | IREE_SV("iree.abi.model")); | 
|  | if (iree_string_view_equal(invocation_model, | 
|  | IREE_SV("coarse-fences"))) { | 
|  | // Asynchronous invocation with coarse fences. Expect just those. | 
|  | if (argument_count == 2) { | 
|  | // Only functions taking a (wait, signal) fence pair are run. | 
|  | iree::RegisterAsyncBenchmark( | 
|  | std::string(function_name.data, function_name.size), | 
|  | device_.get(), context_.get(), function, | 
|  | /*inputs=*/nullptr); | 
|  | } | 
|  | } else { | 
|  | // Basic synchronous invocation. | 
|  | if (argument_count == 0) { | 
|  | // Only functions with no inputs are run (because we can't pass | 
|  | // anything). | 
|  | iree::RegisterGenericBenchmark( | 
|  | std::string(function_name.data, function_name.size), | 
|  | device_.get(), context_.get(), function, | 
|  | /*inputs=*/nullptr); | 
|  | } | 
|  | } | 
|  | } | 
|  | } | 
|  | return iree_ok_status(); | 
|  | } | 
|  |  | 
|  | iree::vm::ref<iree_vm_instance_t> instance_; | 
|  | iree::vm::ref<iree_vm_context_t> context_; | 
|  | iree::vm::ref<iree_hal_device_t> device_; | 
|  | iree::vm::ref<iree_hal_allocator_t> device_allocator_; | 
|  | iree_tooling_module_list_t module_list_; | 
|  | iree::vm::ref<iree_vm_list_t> inputs_; | 
|  | }; | 
|  | }  // namespace | 
|  | }  // namespace iree | 
|  |  | 
|  | int main(int argc, char** argv) { | 
|  | IREE_TRACE_APP_ENTER(); | 
|  | IREE_TRACE_ZONE_BEGIN_NAMED(z0, "iree-benchmark-module"); | 
|  |  | 
|  | // Pass through flags to benchmark (allowing --help to fall through). | 
|  | iree_flags_set_usage( | 
|  | "iree-benchmark-module", | 
|  | "Benchmarks a function within a compiled IREE module and handles I/O\n" | 
|  | "parsing. Modules can be provided by file path (`--module=file.vmfb`)\n" | 
|  | "or read from stdin (`--module=-`) and the function to execute\n" | 
|  | "matches the original name provided to the compiler\n" | 
|  | "(`--function=foo` for `func.func @foo`).\n"); | 
|  | iree_flags_parse_checked(IREE_FLAGS_PARSE_MODE_UNDEFINED_OK | | 
|  | IREE_FLAGS_PARSE_MODE_CONTINUE_AFTER_HELP, | 
|  | &argc, &argv); | 
|  | ::benchmark::Initialize(&argc, argv); | 
|  |  | 
|  | iree::IREEBenchmark iree_benchmark; | 
|  | iree_status_t status = iree_benchmark.Register(); | 
|  | if (!iree_status_is_ok(status)) { | 
|  | int exit_code = static_cast<int>(iree_status_code(status)); | 
|  | printf("%s\n", iree::Status(std::move(status)).ToString().c_str()); | 
|  | IREE_TRACE_ZONE_END(z0); | 
|  | IREE_TRACE_APP_EXIT(exit_code); | 
|  | return exit_code; | 
|  | } | 
|  | IREE_CHECK_OK(iree_hal_begin_profiling_from_flags(iree_benchmark.device())); | 
|  | ::benchmark::RunSpecifiedBenchmarks(); | 
|  | IREE_CHECK_OK(iree_hal_end_profiling_from_flags(iree_benchmark.device())); | 
|  |  | 
|  | IREE_TRACE_ZONE_END(z0); | 
|  | IREE_TRACE_APP_EXIT(EXIT_SUCCESS); | 
|  | return EXIT_SUCCESS; | 
|  | } |