blob: 03c20d14a58e11711f078d3181a6cb29c2600b47 [file] [log] [blame]
// Copyright 2020 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include <array>
#include <cstdio>
#include <iostream>
#include <iterator>
#include <string>
#include <utility>
#include <vector>
#include "iree/base/api.h"
#include "iree/base/internal/file_io.h"
#include "iree/base/internal/flags.h"
#include "iree/base/status_cc.h"
#include "iree/base/tracing.h"
#include "iree/hal/api.h"
#include "iree/modules/hal/module.h"
#include "iree/tooling/device_util.h"
#include "iree/tooling/vm_util.h"
#include "iree/vm/api.h"
#include "iree/vm/bytecode_module.h"
#include "iree/vm/ref_cc.h"
IREE_FLAG(string, module_file, "-",
"File containing the module to load that contains the entry "
"function. Defaults to stdin.");
IREE_FLAG(string, entry_function, "",
"Name of a function contained in the module specified by module_file "
"to run.");
IREE_FLAG(bool, trace_execution, false, "Traces VM execution to stderr.");
IREE_FLAG(int32_t, print_max_element_count, 1024,
"Prints up to the maximum number of elements of output tensors, "
"eliding the remainder.");
IREE_FLAG(bool, print_statistics, false,
"Prints runtime statistics to stderr on exit.");
static iree_status_t parse_function_input(iree_string_view_t flag_name,
void* storage,
iree_string_view_t value) {
auto* list = (std::vector<std::string>*)storage;
list->push_back(std::string(value.data, value.size));
return iree_ok_status();
}
static void print_function_input(iree_string_view_t flag_name, void* storage,
FILE* file) {
auto* list = (std::vector<std::string>*)storage;
if (list->empty()) {
fprintf(file, "# --%.*s=\n", (int)flag_name.size, flag_name.data);
} else {
for (size_t i = 0; i < list->size(); ++i) {
fprintf(file, "--%.*s=\"%s\"\n", (int)flag_name.size, flag_name.data,
list->at(i).c_str());
}
}
}
static std::vector<std::string> FLAG_function_inputs;
IREE_FLAG_CALLBACK(
parse_function_input, print_function_input, &FLAG_function_inputs,
function_input,
"An input value or buffer of the format:\n"
" [shape]xtype=[value]\n"
" 2x2xi32=1 2 3 4\n"
"Optionally, brackets may be used to separate the element values:\n"
" 2x2xi32=[[1 2][3 4]]\n"
"Raw binary files can be read to provide buffer contents:\n"
" 2x2xi32=@some/file.bin\n"
"Each occurrence of the flag indicates an input in the order they were\n"
"specified on the command line.");
namespace iree {
namespace {
iree_status_t GetModuleContentsFromFlags(iree_file_contents_t** out_contents) {
IREE_TRACE_SCOPE0("GetModuleContentsFromFlags");
auto module_file = std::string(FLAG_module_file);
if (module_file == "-") {
std::cout << "Reading module contents from stdin...\n";
return iree_stdin_read_contents(iree_allocator_system(), out_contents);
} else {
return iree_file_read_contents(module_file.c_str(), iree_allocator_system(),
out_contents);
}
}
iree_status_t Run() {
IREE_TRACE_SCOPE0("iree-run-module");
IREE_RETURN_IF_ERROR(iree_hal_module_register_types(),
"registering HAL types");
iree_vm_instance_t* instance = nullptr;
IREE_RETURN_IF_ERROR(
iree_vm_instance_create(iree_allocator_system(), &instance),
"creating instance");
iree_file_contents_t* flatbuffer_contents = NULL;
IREE_RETURN_IF_ERROR(GetModuleContentsFromFlags(&flatbuffer_contents));
iree_vm_module_t* input_module = nullptr;
IREE_RETURN_IF_ERROR(iree_vm_bytecode_module_create(
flatbuffer_contents->const_buffer,
iree_file_contents_deallocator(flatbuffer_contents),
iree_allocator_system(), &input_module));
iree_hal_device_t* device = nullptr;
IREE_RETURN_IF_ERROR(iree_hal_create_device_from_flags(
iree_hal_default_device_uri(), iree_allocator_system(), &device));
iree_vm_module_t* hal_module = nullptr;
IREE_RETURN_IF_ERROR(
iree_hal_module_create(device, iree_allocator_system(), &hal_module));
iree_vm_context_t* context = nullptr;
// Order matters. The input module will likely be dependent on the hal module.
std::array<iree_vm_module_t*, 2> modules = {hal_module, input_module};
IREE_RETURN_IF_ERROR(
iree_vm_context_create_with_modules(
instance,
FLAG_trace_execution ? IREE_VM_CONTEXT_FLAG_TRACE_EXECUTION
: IREE_VM_CONTEXT_FLAG_NONE,
modules.size(), modules.data(), iree_allocator_system(), &context),
"creating context");
std::string function_name = std::string(FLAG_entry_function);
iree_vm_function_t function;
if (function_name.empty()) {
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"no --entry_function= specified");
} else {
IREE_RETURN_IF_ERROR(
iree_vm_module_lookup_function_by_name(
input_module, IREE_VM_FUNCTION_LINKAGE_EXPORT,
iree_string_view_t{function_name.data(), function_name.size()},
&function),
"looking up function '%s'", function_name.c_str());
}
vm::ref<iree_vm_list_t> inputs;
IREE_RETURN_IF_ERROR(ParseToVariantList(
iree_hal_device_allocator(device),
iree::span<const std::string>{FLAG_function_inputs.data(),
FLAG_function_inputs.size()},
&inputs));
vm::ref<iree_vm_list_t> outputs;
IREE_RETURN_IF_ERROR(iree_vm_list_create(/*element_type=*/nullptr, 16,
iree_allocator_system(), &outputs));
std::cout << "EXEC @" << function_name << "\n";
IREE_RETURN_IF_ERROR(
iree_vm_invoke(context, function, IREE_VM_INVOCATION_FLAG_NONE,
/*policy=*/nullptr, inputs.get(), outputs.get(),
iree_allocator_system()),
"invoking function '%s'", function_name.c_str());
IREE_RETURN_IF_ERROR(
PrintVariantList(outputs.get(), (size_t)FLAG_print_max_element_count),
"printing results");
inputs.reset();
outputs.reset();
iree_vm_module_release(hal_module);
iree_vm_module_release(input_module);
iree_vm_context_release(context);
if (FLAG_print_statistics) {
IREE_IGNORE_ERROR(iree_hal_allocator_statistics_fprint(
stderr, iree_hal_device_allocator(device)));
}
iree_hal_device_release(device);
iree_vm_instance_release(instance);
return iree_ok_status();
}
} // namespace
extern "C" int main(int argc, char** argv) {
iree_flags_parse_checked(IREE_FLAGS_PARSE_MODE_DEFAULT, &argc, &argv);
if (argc > 1) {
// Avoid iree-run-module spinning endlessly on stdin if the user uses single
// dashes for flags.
std::cout << "Error: unexpected positional argument (expected none)."
" Did you use pass a flag with a single dash ('-')?"
" Use '--' instead.\n";
return 1;
}
IREE_CHECK_OK(Run());
return 0;
}
} // namespace iree