blob: 9ba5e84a220590e2b365fdc147bd059084db7eb5 [file] [log] [blame]
// Copyright 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "vm/bytecode_module.h"
#include "absl/memory/memory.h"
#include "base/status.h"
#include "base/tracing.h"
#include "hal/buffer_view.h"
#include "vm/bytecode_disassembler.h"
namespace iree {
namespace vm {
namespace {
using ::iree::hal::BufferView;
using ::iree::rt::Function;
using ::iree::rt::FunctionSignature;
using ::iree::rt::Module;
using ::iree::rt::ModuleSignature;
Status ValidateElementSize(int element_bit_width,
const ElementTypeDef& expected_element_type) {
switch (expected_element_type.type_union_type()) {
case ElementTypeDefUnion::FloatTypeDef: {
auto expected_bit_width =
expected_element_type.type_union_as_FloatTypeDef()->width();
if (element_bit_width != expected_bit_width) {
return InvalidArgumentErrorBuilder(IREE_LOC)
<< "Has element bit width " << element_bit_width
<< " but expected " << expected_bit_width;
}
return OkStatus();
}
case ElementTypeDefUnion::IntegerTypeDef: {
auto expected_bit_width =
expected_element_type.type_union_as_IntegerTypeDef()->width();
if (element_bit_width != expected_bit_width) {
return InvalidArgumentErrorBuilder(IREE_LOC)
<< "Has element bit width " << element_bit_width
<< " but expected " << expected_bit_width;
}
return OkStatus();
}
case ElementTypeDefUnion::UnknownTypeDef:
case ElementTypeDefUnion::NONE: {
}
}
return InvalidArgumentErrorBuilder(IREE_LOC)
<< "Defined type has unsupported element type "
<< EnumNameElementTypeDefUnion(
expected_element_type.type_union_type());
}
Status ValidateTypeStructure(const FunctionTypeDef& type_def) {
// Ensure all fields are populated.
return OkStatus();
}
Status ValidateFunctionTableStructure(
const FunctionTableDef& function_table_def) {
if (!function_table_def.functions()) {
return InvalidArgumentErrorBuilder(IREE_LOC)
<< "Function table is missing the function listing";
}
// All functions must contain a valid type.
const auto& functions = *function_table_def.functions();
for (int i = 0; i < functions.size(); ++i) {
const auto* function = functions[i];
if (!function) {
return InvalidArgumentErrorBuilder(IREE_LOC)
<< "Function ordinal " << i << " is missing its contents";
}
if (!function->type()) {
return InvalidArgumentErrorBuilder(IREE_LOC)
<< "Function ordinal " << i << " is missing its type";
}
RETURN_IF_ERROR(ValidateTypeStructure(*function->type()));
}
// Imports must also have a name (that we can use to resolve it).
if (function_table_def.imports()) {
const auto& imports = *function_table_def.imports();
for (int i = 0; i < imports.size(); ++i) {
int function_index = imports[i];
if (!functions[function_index]->name()) {
return InvalidArgumentErrorBuilder(IREE_LOC)
<< "Import ordinal " << i << " is missing its contents";
}
}
}
// Exports must also have a name (that others will use to look it up).
if (function_table_def.exports()) {
const auto& exports = *function_table_def.exports();
for (int i = 0; i < exports.size(); ++i) {
int function_index = exports[i];
if (!functions[function_index]->name()) {
return InvalidArgumentErrorBuilder(IREE_LOC)
<< "Export ordinal " << i << " is missing its contents";
}
}
}
return OkStatus();
}
Status ValidateExecutableTableStructure(
const ExecutableTableDef& executable_table_def) {
if (!executable_table_def.multi_arch_executables()) {
// May have sequencer only fns. Fine to not have dispatchable executables.
return OkStatus();
}
// All fat executables need at least one device-specific executable.
const auto& multi_arch_executables =
*executable_table_def.multi_arch_executables();
for (int i = 0; i < multi_arch_executables.size(); ++i) {
const auto* multi_arch_executable = multi_arch_executables[i];
if (!multi_arch_executable || !multi_arch_executable->executables() ||
multi_arch_executable->executables()->size() == 0) {
return InvalidArgumentErrorBuilder(IREE_LOC)
<< "Multi-arch executable ordinal " << i
<< " is missing its contents";
}
}
return OkStatus();
}
} // namespace
// static
Status BytecodeModule::ValidateStructure(const ModuleDef& module_def) {
IREE_TRACE_SCOPE0("BytecodeModule::ValidateStructure");
// Must have a function table.
if (module_def.function_table()) {
RETURN_IF_ERROR(
ValidateFunctionTableStructure(*module_def.function_table()));
} else {
return InvalidArgumentErrorBuilder(IREE_LOC)
<< "ModuleDef is missing a function table";
}
// Must have an executable table.
if (module_def.executable_table()) {
RETURN_IF_ERROR(
ValidateExecutableTableStructure(*module_def.executable_table()));
} else {
return InvalidArgumentErrorBuilder(IREE_LOC)
<< "ModuleDef is missing an executable table";
}
return OkStatus();
}
BytecodeModule::BytecodeModule(std::unique_ptr<ModuleFile> module_file,
OpcodeTable opcode_table)
: module_file_(std::move(module_file)),
module_def_(*module_file_->root()),
source_resolver_(SourceMapResolver::FromModule(module_def_)),
disassembler_(absl::make_unique<BytecodeDisassembler>(opcode_table)) {}
BytecodeModule::~BytecodeModule() = default;
const ModuleSignature BytecodeModule::signature() const {
return ModuleSignature(function_table_def().imports()->size(),
function_table_def().exports()->size(),
function_table_def().functions()->size(), 0);
}
std::string BytecodeModule::DebugStringShort() const {
return std::string(name());
}
StatusOr<int32_t> BytecodeModule::MapFunctionOrdinal(Function::Linkage linkage,
int32_t ordinal) const {
const auto& function_table = function_table_def();
switch (linkage) {
case Function::Linkage::kImport:
if (ordinal < 0 || ordinal >= function_table.imports()->size()) {
return InvalidArgumentErrorBuilder(IREE_LOC)
<< "Import ordinal " << ordinal
<< " is outside the valid range [0, "
<< function_table.imports()->size() << ")";
}
ordinal = function_table.imports()->Get(ordinal);
break;
case Function::Linkage::kExport:
if (ordinal < 0 || ordinal >= function_table.exports()->size()) {
return InvalidArgumentErrorBuilder(IREE_LOC)
<< "Export ordinal " << ordinal
<< " is outside the valid range [0, "
<< function_table.exports()->size() << ")";
}
ordinal = function_table.exports()->Get(ordinal);
break;
default:
break;
}
if (ordinal < 0 || ordinal >= function_table.functions()->size()) {
return OutOfRangeErrorBuilder(IREE_LOC)
<< "Function ordinal " << ordinal
<< " is outside the valid range [0, "
<< function_table.functions()->size() << ")";
}
return ordinal;
}
StatusOr<const Function> BytecodeModule::LookupFunctionByOrdinal(
Function::Linkage linkage, int32_t ordinal) const {
ASSIGN_OR_RETURN(ordinal, MapFunctionOrdinal(linkage, ordinal));
return Function(this, Function::Linkage::kInternal, ordinal);
}
StatusOr<const Function> BytecodeModule::LookupFunctionByName(
Function::Linkage linkage, absl::string_view name) const {
const auto& functions = *function_table_def().functions();
for (int i = 0; i < functions.size(); ++i) {
const auto* function_def = functions.Get(i);
if (WrapString(function_def->name()) == name) {
return LookupFunctionByOrdinal(Function::Linkage::kInternal, i);
}
}
return NotFoundErrorBuilder(IREE_LOC)
<< "Function '" << name
<< "' not found in function table (or names have been stripped)";
}
StatusOr<absl::string_view> BytecodeModule::GetFunctionName(
Function::Linkage linkage, int32_t ordinal) const {
ASSIGN_OR_RETURN(ordinal, MapFunctionOrdinal(linkage, ordinal));
const auto* function_def = function_table_def().functions()->Get(ordinal);
return WrapString(function_def->name());
}
StatusOr<const FunctionSignature> BytecodeModule::GetFunctionSignature(
Function::Linkage linkage, int32_t ordinal) const {
ASSIGN_OR_RETURN(ordinal, MapFunctionOrdinal(linkage, ordinal));
const auto* function_def = function_table_def().functions()->Get(ordinal);
const auto* type_def = function_def->type();
return FunctionSignature(
type_def->inputs() ? type_def->inputs()->size() : 0,
type_def->results() ? type_def->results()->size() : 0);
}
StatusOr<const FunctionDef*> BytecodeModule::GetFunctionDef(
rt::Function::Linkage linkage, int32_t ordinal) const {
ASSIGN_OR_RETURN(ordinal, MapFunctionOrdinal(linkage, ordinal));
const auto& function_defs = *function_table_def().functions();
if (ordinal >= function_defs.size()) {
return OutOfRangeErrorBuilder(IREE_LOC)
<< "Internal function ordinal " << ordinal
<< " out of range of table (" << function_defs.size() << ")";
}
return function_defs.Get(ordinal);
}
StatusOr<const MultiArchExecutableDef*>
BytecodeModule::LookupMultiArchExecutable(int executable_ordinal) const {
if (executable_ordinal < 0 ||
executable_ordinal >=
executable_table_def().multi_arch_executables()->size()) {
return InvalidArgumentErrorBuilder(IREE_LOC)
<< "Invalid multi-arch executable ordinal " << executable_ordinal;
}
return executable_table_def().multi_arch_executables()->Get(
executable_ordinal);
}
// static
Status BytecodeModule::ValidateArgType(const BufferView& arg,
const MemRefTypeDef& expected_type) {
RETURN_IF_ERROR(
ValidateElementSize(arg.element_size * 8, *expected_type.element_type()));
auto expected_shape = expected_type.shape();
if (arg.shape.size() != expected_shape->size()) {
return InvalidArgumentErrorBuilder(IREE_LOC)
<< "Argument should have rank " << expected_shape->size()
<< " but has rank " << arg.shape.size();
}
for (int i = 0; i < expected_shape->size(); ++i) {
auto dim_size = arg.shape[i];
auto expected_dim_size = expected_shape->Get(i);
if (dim_size != expected_dim_size) {
return InvalidArgumentErrorBuilder(IREE_LOC)
<< "Argument dimension " << i << " should have size "
<< expected_dim_size << " but has size " << dim_size;
}
}
return OkStatus();
}
} // namespace vm
} // namespace iree