blob: c6474d7adeeb7c44c7302f37c17b87d757abb731 [file] [log] [blame]
// Copyright 2020 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree/tooling/vm_util.h"
#include <cerrno>
#include <cstdint>
#include <cstdio>
#include <ostream>
#include <type_traits>
#include <vector>
#include "iree/base/api.h"
#include "iree/base/status_cc.h"
#include "iree/base/tracing.h"
#include "iree/hal/api.h"
#include "iree/modules/hal/module.h"
#include "iree/tooling/numpy_io.h"
#include "iree/vm/ref_cc.h"
// TODO(benvanik): drop use of stdio and make an iree_io_stream_t.
#if defined(IREE_PLATFORM_WINDOWS)
static uint64_t GetFileLength(FILE* file) {
_fseeki64(file, 0, SEEK_END);
uint64_t file_length = _ftelli64(file);
_fseeki64(file, 0, SEEK_SET);
return file_length;
}
static bool IsEOF(FILE* file, uint64_t file_length) {
return _ftelli64(file) == file_length;
}
#else
static uint64_t GetFileLength(FILE* file) {
fseeko(file, 0, SEEK_END);
uint64_t file_length = ftello(file);
fseeko(file, 0, SEEK_SET);
return file_length;
}
static bool IsEOF(FILE* file, uint64_t file_length) {
return ftello(file) == file_length;
}
#endif // IREE_PLATFORM_*
namespace iree {
static iree_status_t LoadNdarraysFromFile(
iree_string_view_t file_path, iree_hal_allocator_t* device_allocator,
iree_vm_list_t* variant_list) {
// Open the file for reading.
std::string file_path_str(file_path.data, file_path.size);
FILE* file = std::fopen(file_path_str.c_str(), "rb");
if (!file) {
return iree_make_status(iree_status_code_from_errno(errno),
"failed to open file '%.*s'", (int)file_path.size,
file_path.data);
}
uint64_t file_length = GetFileLength(file);
iree_hal_buffer_params_t buffer_params = {};
buffer_params.usage = IREE_HAL_BUFFER_USAGE_DEFAULT;
buffer_params.access = IREE_HAL_MEMORY_ACCESS_READ;
buffer_params.type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL;
iree_status_t status = iree_ok_status();
while (iree_status_is_ok(status) && !IsEOF(file, file_length)) {
iree_hal_buffer_view_t* buffer_view = NULL;
status = iree_numpy_npy_load_ndarray(
file, IREE_NUMPY_NPY_LOAD_OPTION_DEFAULT, buffer_params,
device_allocator, &buffer_view);
if (iree_status_is_ok(status)) {
auto buffer_view_ref = iree_hal_buffer_view_retain_ref(buffer_view);
status = iree_vm_list_push_ref_move(variant_list, &buffer_view_ref);
}
iree_hal_buffer_view_release(buffer_view);
}
std::fclose(file);
return status;
}
// Creates a HAL buffer view with the given |metadata| and reads the contents
// from the file at |file_path|.
//
// The file contents are directly read in to memory with no processing.
static iree_status_t CreateBufferViewFromFile(
iree_string_view_t metadata, iree_string_view_t file_path,
iree_hal_allocator_t* device_allocator,
iree_hal_buffer_view_t** out_buffer_view) {
*out_buffer_view = NULL;
// Parse shape and element type used to allocate the buffer view.
iree_hal_element_type_t element_type = IREE_HAL_ELEMENT_TYPE_NONE;
iree_host_size_t shape_rank = 0;
iree_status_t shape_result = iree_hal_parse_shape_and_element_type(
metadata, 0, &shape_rank, NULL, &element_type);
if (!iree_status_is_ok(shape_result) &&
!iree_status_is_out_of_range(shape_result)) {
return shape_result;
} else if (shape_rank > 128) {
return iree_make_status(
IREE_STATUS_RESOURCE_EXHAUSTED,
"a shape rank of %zu is just a little bit excessive, eh?", shape_rank);
}
iree_status_ignore(shape_result);
iree_hal_dim_t* shape =
(iree_hal_dim_t*)iree_alloca(shape_rank * sizeof(iree_hal_dim_t));
IREE_RETURN_IF_ERROR(iree_hal_parse_shape_and_element_type(
metadata, shape_rank, &shape_rank, shape, &element_type));
// TODO(benvanik): allow specifying the encoding.
iree_hal_encoding_type_t encoding_type =
IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR;
// Open the file for reading.
std::string file_path_str(file_path.data, file_path.size);
FILE* file = std::fopen(file_path_str.c_str(), "rb");
if (!file) {
return iree_make_status(iree_status_code_from_errno(errno),
"failed to open file '%.*s'", (int)file_path.size,
file_path.data);
}
iree_hal_buffer_params_t buffer_params = {0};
buffer_params.type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL;
buffer_params.usage = IREE_HAL_BUFFER_USAGE_DEFAULT;
struct read_params_t {
FILE* file;
} read_params = {
file,
};
iree_status_t status = iree_hal_buffer_view_generate_buffer(
device_allocator, shape_rank, shape, element_type, encoding_type,
buffer_params,
+[](iree_hal_buffer_mapping_t* mapping, void* user_data) {
auto* read_params = reinterpret_cast<read_params_t*>(user_data);
size_t bytes_read =
std::fread(mapping->contents.data, 1, mapping->contents.data_length,
read_params->file);
if (bytes_read != mapping->contents.data_length) {
return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
"file contents truncated; expected %zu bytes "
"based on buffer view size",
mapping->contents.data_length);
}
return iree_ok_status();
},
&read_params, out_buffer_view);
std::fclose(file);
return status;
}
Status ParseToVariantList(iree_hal_allocator_t* device_allocator,
iree::span<const std::string> input_strings,
iree_allocator_t host_allocator,
iree_vm_list_t** out_list) {
*out_list = NULL;
vm::ref<iree_vm_list_t> variant_list;
IREE_RETURN_IF_ERROR(iree_vm_list_create(
/*element_type=*/nullptr, input_strings.size(), host_allocator,
&variant_list));
for (size_t i = 0; i < input_strings.size(); ++i) {
iree_string_view_t input_view = iree_string_view_trim(iree_make_string_view(
input_strings[i].data(), input_strings[i].size()));
if (iree_string_view_consume_prefix(&input_view, IREE_SV("@"))) {
IREE_RETURN_IF_ERROR(LoadNdarraysFromFile(input_view, device_allocator,
variant_list.get()));
continue;
} else if (iree_string_view_equal(input_view, IREE_SV("(null)"))) {
iree_vm_ref_t null_ref = iree_vm_ref_null();
IREE_RETURN_IF_ERROR(
iree_vm_list_push_ref_retain(variant_list.get(), &null_ref));
continue;
}
bool has_equal =
iree_string_view_find_char(input_view, '=', 0) != IREE_STRING_VIEW_NPOS;
bool has_x =
iree_string_view_find_char(input_view, 'x', 0) != IREE_STRING_VIEW_NPOS;
if (has_equal || has_x) {
// Buffer view (either just a shape or a shape=value) or buffer.
bool is_storage_reference = iree_string_view_consume_prefix(
&input_view, iree_make_cstring_view("&"));
iree_hal_buffer_view_t* buffer_view = nullptr;
bool has_at = iree_string_view_find_char(input_view, '@', 0) !=
IREE_STRING_VIEW_NPOS;
if (has_at) {
// Referencing an external file; split into the portion used to
// initialize the buffer view and the file contents.
iree_string_view_t metadata, file_path;
iree_string_view_split(input_view, '@', &metadata, &file_path);
iree_string_view_consume_suffix(&metadata, iree_make_cstring_view("="));
IREE_RETURN_IF_ERROR(CreateBufferViewFromFile(
metadata, file_path, device_allocator, &buffer_view));
} else {
IREE_RETURN_IF_ERROR(iree_hal_buffer_view_parse(
input_view, device_allocator, &buffer_view),
"parsing value '%.*s'", (int)input_view.size,
input_view.data);
}
if (is_storage_reference) {
// Storage buffer reference; just take the storage for the buffer view -
// it'll still have whatever contents were specified (or 0) but we'll
// discard the metadata.
auto buffer_ref = iree_hal_buffer_retain_ref(
iree_hal_buffer_view_buffer(buffer_view));
iree_hal_buffer_view_release(buffer_view);
IREE_RETURN_IF_ERROR(
iree_vm_list_push_ref_move(variant_list.get(), &buffer_ref));
} else {
auto buffer_view_ref = iree_hal_buffer_view_move_ref(buffer_view);
IREE_RETURN_IF_ERROR(
iree_vm_list_push_ref_move(variant_list.get(), &buffer_view_ref));
}
} else {
// Scalar.
bool has_dot = iree_string_view_find_char(input_view, '.', 0) !=
IREE_STRING_VIEW_NPOS;
iree_vm_value_t val;
if (has_dot) {
// Float.
val = iree_vm_value_make_f32(0.0f);
if (!iree_string_view_atof(input_view, &val.f32)) {
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"parsing value '%.*s' as f32",
(int)input_view.size, input_view.data);
}
} else {
// Integer.
val = iree_vm_value_make_i32(0);
if (!iree_string_view_atoi_int32(input_view, &val.i32)) {
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"parsing value '%.*s' as i32",
(int)input_view.size, input_view.data);
}
}
IREE_RETURN_IF_ERROR(iree_vm_list_push_value(variant_list.get(), &val));
}
}
*out_list = variant_list.release();
return OkStatus();
}
Status PrintVariantList(iree_vm_list_t* variant_list, size_t max_element_count,
std::ostream* os) {
for (iree_host_size_t i = 0; i < iree_vm_list_size(variant_list); ++i) {
iree_vm_variant_t variant = iree_vm_variant_empty();
IREE_RETURN_IF_ERROR(iree_vm_list_get_variant(variant_list, i, &variant),
"variant %zu not present", i);
*os << "result[" << i << "]: ";
if (iree_vm_variant_is_empty(variant)) {
*os << "(null)\n";
} else if (iree_vm_variant_is_value(variant)) {
switch (variant.type.value_type) {
case IREE_VM_VALUE_TYPE_I8:
*os << "i8=" << variant.i8 << "\n";
break;
case IREE_VM_VALUE_TYPE_I16:
*os << "i16=" << variant.i16 << "\n";
break;
case IREE_VM_VALUE_TYPE_I32:
*os << "i32=" << variant.i32 << "\n";
break;
case IREE_VM_VALUE_TYPE_I64:
*os << "i64=" << variant.i64 << "\n";
break;
case IREE_VM_VALUE_TYPE_F32:
*os << "f32=" << variant.f32 << "\n";
break;
case IREE_VM_VALUE_TYPE_F64:
*os << "f64=" << variant.f64 << "\n";
break;
default:
*os << "?\n";
break;
}
} else if (iree_vm_variant_is_ref(variant)) {
iree_string_view_t type_name =
iree_vm_ref_type_name(variant.type.ref_type);
*os << std::string(type_name.data, type_name.size) << "\n";
if (iree_hal_buffer_view_isa(variant.ref)) {
auto* buffer_view = iree_hal_buffer_view_deref(variant.ref);
std::string result_str(4096, '\0');
iree_status_t status;
do {
iree_host_size_t actual_length = 0;
status = iree_hal_buffer_view_format(buffer_view, max_element_count,
result_str.size() + 1,
&result_str[0], &actual_length);
result_str.resize(actual_length);
} while (iree_status_is_out_of_range(status));
IREE_RETURN_IF_ERROR(status);
*os << result_str << "\n";
} else {
// TODO(benvanik): a way for ref types to describe themselves.
*os << "(no printer)\n";
}
} else {
*os << "(null)\n";
}
}
return OkStatus();
}
} // namespace iree