blob: ab1ac43a36fbe590a0be5afe704a000867c0f593 [file] [edit]
// Copyright 2023 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree/tooling/run_module.h"
#include "iree/base/api.h"
#include "iree/base/tooling/flags.h"
#include "iree/hal/api.h"
#include "iree/hal/replay/recorder.h"
#include "iree/io/stdio_stream.h"
#include "iree/modules/hal/types.h"
#include "iree/tooling/comparison.h"
#include "iree/tooling/context_util.h"
#include "iree/tooling/device_util.h"
#include "iree/tooling/function_io.h"
#include "iree/tooling/function_util.h"
#include "iree/tooling/instrument_util.h"
#include "iree/vm/api.h"
#include "iree/vm/bytecode/module.h"
IREE_FLAG(string, function, "",
"Name of a function contained in the module specified by --module= "
"to run.");
IREE_FLAG_LIST(
string, input,
"An input (a) value or (b) buffer of the format:\n"
" (a) scalar value\n"
" value\n"
" e.g.: --input=\"3.14\"\n"
" (b) buffer:\n"
" [shape]xtype=[value]\n"
" e.g.: --input=\"2x2xi32=1 2 3 4\"\n"
"Optionally, brackets may be used to separate the element values:\n"
" e.g.: --input=\"2x2xi32=[[1 2][3 4]]\"\n"
"Raw binary files can be read to provide buffer contents:\n"
" e.g.: --input=2x2xi32=@some/file.bin\n"
"\n"
"Numpy npy files from numpy.save can be read to provide 1+ values:\n"
" e.g.: --input=@some.npy\n"
"\n"
"Each occurrence of the flag indicates an input in the order they were\n"
"specified on the command line.");
IREE_FLAG_LIST(
string, output,
"Specifies how to handle an output from the invocation:\n"
" `` (empty): ignore output\n"
" e.g.: --output=\n"
" `-`: print textual form to stdout\n"
" e.g.: --output=-\n"
" `@file.npy`: create/overwrite a numpy npy file and write an ndarray\n"
" e.g.: --output=@file.npy\n"
" `+file.npy`: create/append a numpy npy file and write an ndarray\n"
" e.g.: --output=+file.npy\n"
" `@file.bin`: create/overwrite a binary file and write value contents\n"
" e.g.: --output=@file.bin\n"
" `+file.bin`: create/append a binary file and write value contents\n"
" e.g.: --output=+file.bin\n"
"\n"
"Numpy npy files can be read in Python using numpy.load, for example an\n"
"invocation producing two outputs can be concatenated as:\n"
" --output=@file.npy --output=+file.npy\n"
"And then loaded in Python by reading from the same file:\n"
" with open('file.npy', 'rb') as f:\n"
" print(numpy.load(f))\n"
" print(numpy.load(f))\n"
"Primitive values are written as shape=() ndarrays and buffers are\n"
"written as i8 arrays with the length of the buffer.\n"
"\n"
"Binary files contain only the contents of the values/buffers provided\n"
"without metadata; users must know the shape/type of the output.\n"
"\n"
"Each occurrence of the flag indicates an output in the order they were\n"
"specified on the command line.");
IREE_FLAG_LIST(
string, expected_output,
"An expected function output following the same format as `--input=`.\n"
"When present the results of the invocation will be compared against\n"
"these values and the tool will return non-zero if any differ. If the\n"
"value of a particular output is not of interest provide `(ignored)`.");
IREE_FLAG(
int32_t, output_max_element_count, 1024,
"Prints up to the maximum number of elements of output tensors and elides\n"
"the remainder.");
IREE_FLAG(bool, print_statistics, false,
"Prints runtime statistics to stderr on exit.");
static iree_status_t iree_tooling_process_results(
iree_hal_device_t* device, iree_string_view_t results_cconv,
iree_vm_list_t* results, iree_io_stream_t* stream,
iree_allocator_t host_allocator, int* out_exit_code);
static iree_status_t iree_tooling_create_run_context(
iree_vm_instance_t* instance, iree_string_view_t default_device_uri,
iree_const_byte_span_t module_contents, iree_allocator_t host_allocator,
iree_vm_context_t** out_context, iree_vm_function_t* out_function,
iree_hal_device_t** out_device, iree_hal_allocator_t** out_device_allocator,
iree_hal_replay_recorder_t** out_replay_recorder) {
// Load all modules specified by --module= flags.
iree_tooling_module_list_t module_list;
iree_tooling_module_list_initialize(&module_list);
IREE_RETURN_IF_ERROR(iree_tooling_load_modules_from_flags(
instance, host_allocator, &module_list),
"loading modules and dependencies");
// Load the optional bytecode module from the provided flatbuffer data.
// Note that we do this after all other --module= flags are processed so that
// we ensure any dependent types are registered with the instance.
iree_status_t status = iree_ok_status();
if (!iree_const_byte_span_is_empty(module_contents)) {
iree_vm_module_t* module = NULL;
status = iree_status_annotate_f(
iree_vm_bytecode_module_create(
instance, IREE_VM_BYTECODE_MODULE_FLAG_NONE, module_contents,
iree_allocator_null(), host_allocator, &module),
"loading custom bytecode module from memory");
if (iree_status_is_ok(status)) {
status = iree_tooling_module_list_push_back(&module_list, module);
}
iree_vm_module_release(module); // Now tracked in module_list.
}
if (!iree_status_is_ok(status)) {
iree_tooling_module_list_reset(&module_list);
return status;
}
// There's concept of a "main module" in the VM but here we use it to allow
// for a shorthand --function= flag that doesn't need the module name.
iree_vm_module_t* main_module = iree_tooling_module_list_back(&module_list);
if (!main_module) {
status = iree_make_status(
IREE_STATUS_INVALID_ARGUMENT,
"no user module specified; use --module=file.vmfb to load from a "
"file or --module=- to load from stdin");
}
// Create the VM context with all of the modules. Dependent modules will be
// loaded (like the HAL) and special things like the HAL device and allocator
// are returned for convenience. Note that not all programs need the HAL.
iree_vm_context_t* context = NULL;
iree_hal_device_t* device = NULL;
iree_hal_allocator_t* device_allocator = NULL;
iree_hal_replay_recorder_t* replay_recorder = NULL;
if (iree_status_is_ok(status)) {
status = iree_status_annotate_f(
iree_tooling_create_context_from_flags(
instance, module_list.count, module_list.values, default_device_uri,
host_allocator, &context, &device, &device_allocator,
&replay_recorder),
"creating VM context");
}
iree_tooling_module_list_reset(&module_list);
if (!iree_status_is_ok(status)) {
return status;
}
// Choose which function to run - either the one specified in the flag or the
// only exported non-internal function.
iree_vm_function_t function = {0};
if (strlen(FLAG_function) == 0) {
status = iree_tooling_find_single_exported_function(main_module, &function);
} else {
status = iree_status_annotate_f(
iree_vm_module_lookup_function_by_name(
main_module, IREE_VM_FUNCTION_LINKAGE_EXPORT,
iree_make_cstring_view(FLAG_function), &function),
"looking up function '%s'", FLAG_function);
}
if (iree_status_is_ok(status)) {
*out_context = context;
*out_function = function;
*out_device = device;
*out_device_allocator = device_allocator;
*out_replay_recorder = replay_recorder;
} else {
iree_vm_context_release(context);
if (replay_recorder) {
status = iree_status_join(
status, iree_status_annotate_f(
iree_hal_replay_recorder_close(replay_recorder),
"closing HAL replay capture"));
iree_hal_replay_recorder_release(replay_recorder);
}
iree_hal_allocator_release(device_allocator);
iree_hal_device_release(device);
}
return status;
}
static iree_status_t iree_tooling_annotate_status_with_function_decl(
iree_status_t base_status, iree_vm_function_t function) {
iree_string_view_t decl = iree_vm_function_lookup_attr_by_name(
&function, IREE_SV("iree.abi.declaration"));
if (!iree_string_view_is_empty(decl)) {
return iree_status_annotate_f(base_status, "`%.*s`", (int)decl.size,
decl.data);
}
return base_status;
}
static iree_status_t iree_tooling_run_function(
iree_vm_context_t* context, iree_vm_function_t function,
iree_hal_device_t* device, iree_hal_allocator_t* device_allocator,
iree_hal_replay_recorder_t* replay_recorder,
iree_allocator_t host_allocator, int* out_exit_code) {
iree_string_view_t function_name = iree_vm_function_name(&function);
(void)function_name;
iree_vm_function_signature_t signature =
iree_vm_function_signature(&function);
iree_string_view_t arguments_cconv, results_cconv;
iree_status_t status = iree_vm_function_call_get_cconv_fragments(
&signature, &arguments_cconv, &results_cconv);
bool replay_execute_scope_open = false;
if (iree_status_is_ok(status) && replay_recorder) {
status = iree_hal_replay_recorder_scope_begin(replay_recorder,
IREE_SV("execute"));
replay_execute_scope_open = iree_status_is_ok(status);
}
// Parse --input= values into device buffers.
iree_vm_list_t* inputs = NULL;
if (iree_status_is_ok(status)) {
status = iree_status_annotate_f(
iree_tooling_parse_variants(arguments_cconv, FLAG_input_list(), device,
device_allocator, host_allocator, &inputs),
"parsing function inputs");
}
// If the function is async add fences so we can invoke it synchronously.
iree_hal_fence_t* finish_fence = NULL;
if (iree_status_is_ok(status)) {
status = iree_status_annotate_f(
iree_tooling_append_async_fences(inputs, function, device,
/*wait_fence=*/NULL, &finish_fence),
"setting up async-external fence inputs");
}
// Empty output list to be populated by the invocation.
iree_vm_list_t* outputs = NULL;
if (iree_status_is_ok(status)) {
status = iree_vm_list_create(iree_vm_make_undefined_type_def(), 16,
host_allocator, &outputs);
}
// TODO(benvanik): move behind a --verbose flag and add more logging.
if (iree_status_is_ok(status)) {
fprintf(stdout, "EXEC @%.*s\n", (int)function_name.size,
function_name.data);
fflush(stdout);
}
// Begin profiling immediately prior to invocation.
iree_hal_profiling_from_flags_t* profiling = NULL;
if (iree_status_is_ok(status)) {
status = iree_status_annotate_f(
iree_hal_begin_profiling_from_flags(device, host_allocator, &profiling),
"beginning device profiling");
}
// Invoke the function with the provided inputs.
if (iree_status_is_ok(status)) {
status = iree_status_annotate_f(
iree_vm_invoke(context, function, IREE_VM_INVOCATION_FLAG_NONE,
/*policy=*/NULL, inputs, outputs, host_allocator),
"invoking function '%.*s'", (int)function_name.size,
function_name.data);
}
iree_vm_list_release(inputs);
// If the function is async we need to wait for it to complete.
if (iree_status_is_ok(status) && finish_fence) {
status = iree_status_annotate_f(
iree_hal_fence_wait(finish_fence, iree_infinite_timeout(),
IREE_ASYNC_WAIT_FLAG_NONE),
"waiting on finish fence");
}
iree_hal_fence_release(finish_fence);
// End profiling after waiting for the invocation to finish.
if (profiling) {
status = iree_status_join(
status,
iree_status_annotate_f(iree_hal_end_profiling_from_flags(profiling),
"ending device profiling"));
}
// Grab any instrumentation data present in the context and write it to disk.
if (iree_status_is_ok(status)) {
status = iree_status_annotate_f(
iree_tooling_process_instrument_data(context, host_allocator),
"processing instrument data");
}
// Transfer outputs to the host so they can be processed. Only required when
// using full HAL device-based execution.
if (iree_status_is_ok(status) && device != NULL) {
iree_hal_buffer_params_t target_params = {
.usage = IREE_HAL_BUFFER_USAGE_TRANSFER | IREE_HAL_BUFFER_USAGE_MAPPING,
.access = IREE_HAL_MEMORY_ACCESS_ALL,
.type = IREE_HAL_MEMORY_TYPE_HOST_LOCAL |
IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE,
.queue_affinity = IREE_HAL_QUEUE_AFFINITY_ANY,
.min_alignment = 0,
};
status = iree_tooling_transfer_variants(
outputs, device, device_allocator, target_params,
/*wait_fence=*/NULL, /*signal_fence=*/NULL);
}
// Wrap stdout for printing results.
iree_io_stream_t* stdout_stream = NULL;
if (iree_status_is_ok(status)) {
status = iree_status_annotate_f(
iree_io_stdio_stream_wrap(IREE_IO_STREAM_MODE_WRITABLE, stdout,
/*owns_handle=*/false, host_allocator,
&stdout_stream),
"opening stdout stream");
}
// Handle either printing/writing the outputs or checking them against
// expected values (basic pass/fail testing).
if (iree_status_is_ok(status)) {
status = iree_status_annotate_f(
iree_tooling_process_results(device, results_cconv, outputs,
stdout_stream, host_allocator,
out_exit_code),
"processing function outputs");
}
iree_vm_list_release(outputs);
iree_io_stream_release(stdout_stream);
fflush(stdout);
if (replay_execute_scope_open) {
status = iree_status_join(status, iree_hal_replay_recorder_scope_end(
replay_recorder, IREE_SV("execute")));
}
return status;
}
static iree_status_t iree_tooling_process_results(
iree_hal_device_t* device, iree_string_view_t results_cconv,
iree_vm_list_t* results, iree_io_stream_t* stream,
iree_allocator_t host_allocator, int* out_exit_code) {
*out_exit_code = EXIT_SUCCESS;
// Basic output handling to route to the console or files.
if (FLAG_expected_output_list().count == 0) {
if (FLAG_output_list().count == 0) {
// Print all outputs.
return iree_status_annotate_f(
iree_tooling_print_variants(
IREE_SV("result"), results,
(iree_host_size_t)FLAG_output_max_element_count, stream,
host_allocator),
"printing results");
} else {
// Write (or ignore) all outputs.
return iree_status_annotate_f(
iree_tooling_write_variants(
results, FLAG_output_list(),
(iree_host_size_t)FLAG_output_max_element_count, stream,
host_allocator),
"outputting results");
}
}
// Compare against contents in host-local memory. This avoids polluting
// device memory statistics.
iree_hal_allocator_t* heap_allocator = NULL;
IREE_RETURN_IF_ERROR(iree_hal_allocator_create_heap(
IREE_SV("heap"), host_allocator, host_allocator, &heap_allocator));
// Parse expected list into host-local memory that we can easily access.
iree_vm_list_t* expected_list = NULL;
iree_status_t status = iree_status_annotate_f(
iree_tooling_parse_variants(results_cconv, FLAG_expected_output_list(),
device, heap_allocator, host_allocator,
&expected_list),
"parsing expected function outputs");
// Compare expected vs actual lists and output diffs.
if (iree_status_is_ok(status)) {
bool did_match = iree_tooling_compare_variant_lists(expected_list, results,
host_allocator, stdout);
if (did_match) {
fprintf(
stdout,
"[SUCCESS] all function outputs matched their expected values.\n");
}
// Exit code 0 if all results matched the expected values.
*out_exit_code = did_match ? EXIT_SUCCESS : EXIT_FAILURE;
}
iree_vm_list_release(expected_list);
iree_hal_allocator_release(heap_allocator);
return status;
}
iree_status_t iree_tooling_run_module_from_flags(
iree_vm_instance_t* instance, iree_allocator_t host_allocator,
int* out_exit_code) {
return iree_tooling_run_module_with_data(instance, iree_string_view_empty(),
iree_const_byte_span_empty(),
host_allocator, out_exit_code);
}
iree_status_t iree_tooling_run_module_with_data(
iree_vm_instance_t* instance, iree_string_view_t default_device_uri,
iree_const_byte_span_t module_contents, iree_allocator_t host_allocator,
int* out_exit_code) {
IREE_TRACE_ZONE_BEGIN(z0);
// Setup the VM context with all required modules and get the function to run.
// This also returns the HAL device and allocator (if any) for I/O handling.
iree_vm_context_t* context = NULL;
iree_vm_function_t function = {0};
iree_hal_device_t* device = NULL;
iree_hal_allocator_t* device_allocator = NULL;
iree_hal_replay_recorder_t* replay_recorder = NULL;
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0,
iree_tooling_create_run_context(
instance, default_device_uri, module_contents, host_allocator,
&context, &function, &device, &device_allocator, &replay_recorder),
"creating run context");
// Parse inputs, run the function, and process outputs.
iree_status_t status =
iree_tooling_run_function(context, function, device, device_allocator,
replay_recorder, host_allocator, out_exit_code);
// Annotate errors with the function description.
if (!iree_status_is_ok(status)) {
status = iree_tooling_annotate_status_with_function_decl(status, function);
}
bool replay_deinit_scope_open = false;
if (replay_recorder) {
iree_status_t begin_status = iree_hal_replay_recorder_scope_begin(
replay_recorder, IREE_SV("deinit"));
replay_deinit_scope_open = iree_status_is_ok(begin_status);
status = iree_status_join(status, begin_status);
}
// Release the context and all retained resources (variables, constants, etc).
iree_vm_context_release(context);
if (replay_deinit_scope_open) {
status = iree_status_join(status, iree_hal_replay_recorder_scope_end(
replay_recorder, IREE_SV("deinit")));
}
if (replay_recorder) {
status = iree_status_join(
status,
iree_status_annotate_f(iree_hal_replay_recorder_close(replay_recorder),
"closing HAL replay capture"));
iree_hal_replay_recorder_release(replay_recorder);
}
// Print statistics after we've released the inputs/outputs and the context
// which may be holding on to resources like constants/variables.
if (device_allocator && FLAG_print_statistics) {
status = iree_status_join(
status, iree_hal_allocator_statistics_fprint(stderr, device_allocator));
}
iree_hal_allocator_release(device_allocator);
iree_hal_device_release(device);
IREE_TRACE_ZONE_END(z0);
return status;
}