blob: d4b8a12571757514820851ffe91f9c174f2c22eb [file] [log] [blame]
// Copyright 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <vector>
#include "third_party/absl/flags/flag.h"
#include "third_party/absl/strings/numbers.h"
#include "third_party/absl/strings/str_replace.h"
#include "third_party/absl/strings/str_split.h"
#include "third_party/absl/strings/string_view.h"
#include "third_party/absl/types/source_location.h"
#include "third_party/mlir_edge/iree/base/file_io.h"
#include "third_party/mlir_edge/iree/base/init.h"
#include "third_party/mlir_edge/iree/base/status.h"
#include "third_party/mlir_edge/iree/hal/buffer_view_string_util.h"
#include "third_party/mlir_edge/iree/hal/driver_registry.h"
#include "third_party/mlir_edge/iree/schemas/module_def_generated.h"
#include "third_party/mlir_edge/iree/vm/bytecode_printer.h"
#include "third_party/mlir_edge/iree/vm/bytecode_tables_sequencer.h"
#include "third_party/mlir_edge/iree/vm/debug/debug_server_flags.h"
#include "third_party/mlir_edge/iree/vm/fiber_state.h"
#include "third_party/mlir_edge/iree/vm/function.h"
#include "third_party/mlir_edge/iree/vm/instance.h"
#include "third_party/mlir_edge/iree/vm/module.h"
#include "third_party/mlir_edge/iree/vm/module_printer.h"
#include "third_party/mlir_edge/iree/vm/sequencer_context.h"
ABSL_FLAG(std::string, main_module, "", "Main module with entry point.");
ABSL_FLAG(std::string, main_function, "",
"Function within the main module to execute.");
ABSL_FLAG(bool, print_source_info, false,
"Prints source map information in bytecode output.");
ABSL_FLAG(std::string, input_values, "", "Input shapes and optional values.");
ABSL_FLAG(std::string, input_file, "",
"Input shapes and optional values serialized in a file.");
ABSL_FLAG(std::string, output_types, "",
"Output data types (comma delimited list of b/i/u/f for "
"binary/signed int/unsigned int/float).");
namespace iree {
namespace vm {
namespace {
using ::iree::hal::BufferView;
// Parses a list of input shapes and values from a string of newline-separated
// inputs. Expects the contents to have one value per line with each value
// listed as
// [shape]xtype=[value]
// Example:
// 4x4xi8=0,1,2,3
StatusOr<std::vector<BufferView>> ParseInputsFromFlags(
hal::Allocator* allocator) {
std::string file_contents;
if (!absl::GetFlag(FLAGS_input_values).empty()) {
file_contents =
absl::StrReplaceAll(absl::GetFlag(FLAGS_input_values), {{"\\n", "\n"}});
} else if (!absl::GetFlag(FLAGS_input_file).empty()) {
ASSIGN_OR_RETURN(file_contents,
file_io::GetFileContents(absl::GetFlag(FLAGS_input_file)));
}
std::vector<BufferView> inputs;
for (const auto& line :
absl::StrSplit(file_contents, '\n', absl::SkipWhitespace())) {
ASSIGN_OR_RETURN(auto input,
hal::ParseBufferViewFromString(line, allocator));
inputs.push_back(input);
}
return inputs;
}
} // namespace
Status Run() {
ASSIGN_OR_RETURN(auto debug_server, debug::CreateDebugServerFromFlags());
auto instance = std::make_shared<Instance>(std::move(debug_server));
ASSIGN_OR_RETURN(auto driver, hal::DriverRegistry::shared_registry()->Create(
"interpreter"));
ASSIGN_OR_RETURN(auto device, driver->CreateDefaultDevice());
RETURN_IF_ERROR(instance->device_manager()->RegisterDevice(device));
SequencerContext context(instance);
// Load main module.
ASSIGN_OR_RETURN(
auto main_module_file,
ModuleFile::LoadFile(ModuleDefIdentifier(),
absl::GetFlag(FLAGS_main_module)),
_ << "while loading module file " << absl::GetFlag(FLAGS_main_module));
ASSIGN_OR_RETURN(auto main_module,
Module::FromFile(std::move(main_module_file)));
// Add native functions for use by the module.
RETURN_IF_ERROR(context.RegisterNativeFunction(
"fabsf",
[](Stack* stack, absl::Span<const BufferView> args,
absl::Span<BufferView> results) -> Status {
// TODO(benvanik): example native functions.
LOG(INFO) << "fabsf";
return OkStatus();
}));
// Register the main module with the context.
// We could add additional modules (specializations, shared libraries, etc).
// ModuleFioles are stateless so we could have the same module_file used by
// multiple contexts simultaneously.
auto* main_module_ptr = main_module.get();
RETURN_IF_ERROR(context.RegisterModule(std::move(main_module)));
// Dump the registered modules.
PrintModuleFlagBitfield print_flags = PrintModuleFlag::kNone;
if (absl::GetFlag(FLAGS_print_source_info)) {
print_flags |= PrintModuleFlag::kIncludeSourceMapping;
}
for (const auto& module : context.modules()) {
RETURN_IF_ERROR(PrintModuleToStream(sequencer_opcode_table(), *module,
print_flags, &std::cout));
}
// Setup a new fiber.
FiberState fiber_state(instance);
// Setup arguments and storage for results.
Function main_function;
if (!absl::GetFlag(FLAGS_main_function).empty()) {
// User-specified main function.
ASSIGN_OR_RETURN(main_function,
context.LookupExport(absl::GetFlag(FLAGS_main_function)));
} else {
// No main function specified; to prevent non-deterministic behavior we
// require one unless there's exactly one exported function in the module.
auto* exports = main_module_ptr->function_table().def().exports();
if (exports && exports->size() == 1) {
ASSIGN_OR_RETURN(
main_function,
main_module_ptr->function_table().LookupFunction(exports->Get(0)));
} else {
return InvalidArgumentErrorBuilder(ABSL_LOC)
<< "--main_function= must be specified to disambiguate the "
"function to run";
}
}
ASSIGN_OR_RETURN(std::vector<BufferView> args,
ParseInputsFromFlags(device->allocator()));
std::vector<BufferView> results;
results.resize(main_function.result_count());
// Call into the main function.
RETURN_IF_ERROR(context.Invoke(&fiber_state, main_function,
absl::MakeSpan(args),
absl::MakeSpan(results)));
// Dump all results to stdout.
std::vector<std::string> output_types =
absl::StrSplit(absl::GetFlag(FLAGS_output_types),
absl::delimiter::AnyOf(", "), absl::SkipWhitespace());
if (!output_types.empty() && output_types.size() != results.size()) {
return InvalidArgumentErrorBuilder(ABSL_LOC)
<< "--output_types= specified but has " << output_types.size()
<< " types when the function returns " << results.size();
}
for (int i = 0; i < results.size(); ++i) {
const auto& result = results[i];
auto print_mode = hal::BufferViewPrintMode::kFloatingPoint;
if (!output_types.empty()) {
ASSIGN_OR_RETURN(print_mode,
hal::ParseBufferViewPrintMode(output_types[i]));
}
ASSIGN_OR_RETURN(auto result_str,
PrintBufferViewToString(result, print_mode, 1024));
const auto& buffer = result.buffer;
if (!buffer) {
return InternalErrorBuilder(ABSL_LOC)
<< "result[" << i << "] unexpectedly has no buffer";
}
LOG(INFO) << "result[" << i << "]: " << buffer->DebugString();
std::cout << result_str << "\n";
}
return OkStatus();
}
extern "C" int main(int argc, char** argv) {
InitializeEnvironment(&argc, &argv);
CHECK_OK(Run());
return 0;
}
} // namespace vm
} // namespace iree