blob: 4f01524c38e812dd8885901daf03de3851436047 [file] [log] [blame]
// Copyright 2021 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree/tools/utils/trace_replay.h"
#include <ctype.h>
#include <inttypes.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include "iree/base/internal/file_io.h"
#include "iree/base/internal/file_path.h"
#include "iree/base/tracing.h"
#include "iree/modules/hal/module.h"
#include "iree/vm/bytecode_module.h"
iree_status_t iree_trace_replay_initialize(
iree_string_view_t root_path, iree_vm_instance_t* instance,
iree_vm_context_flags_t context_flags, iree_allocator_t host_allocator,
iree_trace_replay_t* out_replay) {
memset(out_replay, 0, sizeof(*out_replay));
IREE_RETURN_IF_ERROR(iree_hal_module_register_types());
out_replay->root_path = root_path;
out_replay->instance = instance;
out_replay->context_flags = context_flags;
out_replay->host_allocator = host_allocator;
iree_vm_instance_retain(out_replay->instance);
return iree_ok_status();
}
void iree_trace_replay_deinitialize(iree_trace_replay_t* replay) {
iree_hal_device_release(replay->device);
iree_vm_context_release(replay->context);
iree_vm_instance_release(replay->instance);
memset(replay, 0, sizeof(*replay));
}
void iree_trace_replay_set_hal_driver_override(iree_trace_replay_t* replay,
iree_string_view_t driver) {
replay->driver = driver;
}
iree_status_t iree_trace_replay_event_context_load(iree_trace_replay_t* replay,
yaml_document_t* document,
yaml_node_t* event_node) {
// Cleanup previous state.
iree_hal_device_release(replay->device);
replay->device = NULL;
iree_vm_context_release(replay->context);
replay->context = NULL;
// Create new context.
// TODO(benvanik): allow setting flags from the trace files.
return iree_vm_context_create(replay->instance, replay->context_flags,
replay->host_allocator, &replay->context);
}
static iree_status_t iree_trace_replay_create_device(
iree_trace_replay_t* replay, yaml_node_t* driver_node,
iree_allocator_t host_allocator, iree_hal_device_t** out_device) {
// Use the provided driver name or override with the --driver= flag.
iree_string_view_t driver_name = iree_yaml_node_as_string(driver_node);
if (iree_string_view_is_empty(driver_name)) {
driver_name = replay->driver;
}
// Try to create a device from the driver.
iree_hal_driver_t* driver = NULL;
IREE_RETURN_IF_ERROR(iree_hal_driver_registry_try_create_by_name(
iree_hal_driver_registry_default(), driver_name, host_allocator,
&driver));
iree_status_t status =
iree_hal_driver_create_default_device(driver, host_allocator, out_device);
iree_hal_driver_release(driver);
return status;
}
static iree_status_t iree_trace_replay_load_builtin_module(
iree_trace_replay_t* replay, yaml_document_t* document,
yaml_node_t* module_node) {
iree_vm_module_t* module = NULL;
yaml_node_t* name_node = NULL;
IREE_RETURN_IF_ERROR(iree_yaml_mapping_find(
document, module_node, iree_make_cstring_view("name"), &name_node));
if (iree_yaml_string_equal(name_node, iree_make_cstring_view("hal"))) {
yaml_node_t* driver_node = NULL;
IREE_RETURN_IF_ERROR(iree_yaml_mapping_try_find(
document, module_node, iree_make_cstring_view("driver"), &driver_node));
IREE_RETURN_IF_ERROR(iree_trace_replay_create_device(
replay, driver_node, replay->host_allocator, &replay->device));
IREE_RETURN_IF_ERROR(iree_hal_module_create(
replay->device, replay->host_allocator, &module));
}
if (!module) {
return iree_make_status(
IREE_STATUS_NOT_FOUND, "builtin module '%.*s' not registered",
(int)name_node->data.scalar.length, name_node->data.scalar.value);
}
iree_status_t status =
iree_vm_context_register_modules(replay->context, &module, 1);
iree_vm_module_release(module);
return status;
}
static iree_status_t iree_trace_replay_load_bytecode_module(
iree_trace_replay_t* replay, yaml_document_t* document,
yaml_node_t* module_node) {
yaml_node_t* path_node = NULL;
IREE_RETURN_IF_ERROR(iree_yaml_mapping_find(
document, module_node, iree_make_cstring_view("path"), &path_node));
// Load bytecode file (or stdin) contents into memory.
iree_byte_span_t flatbuffer_data;
iree_status_t status = iree_ok_status();
if (iree_yaml_string_equal(path_node, iree_make_cstring_view("<stdin>"))) {
status = iree_stdin_read_contents(replay->host_allocator, &flatbuffer_data);
} else {
char* full_path = NULL;
IREE_RETURN_IF_ERROR(iree_file_path_join(
replay->root_path, iree_yaml_node_as_string(path_node),
replay->host_allocator, &full_path));
status = iree_file_read_contents(full_path, replay->host_allocator,
&flatbuffer_data);
iree_allocator_free(replay->host_allocator, full_path);
}
// Load and verify the bytecode module.
iree_vm_module_t* module = NULL;
if (iree_status_is_ok(status)) {
status = iree_vm_bytecode_module_create(
iree_make_const_byte_span(flatbuffer_data.data,
flatbuffer_data.data_length),
replay->host_allocator, replay->host_allocator, &module);
if (!iree_status_is_ok(status)) {
iree_allocator_free(replay->host_allocator, flatbuffer_data.data);
}
}
// Register the bytecode module with the context.
if (iree_status_is_ok(status)) {
status = iree_vm_context_register_modules(replay->context, &module, 1);
}
iree_vm_module_release(module);
return status;
}
iree_status_t iree_trace_replay_event_module_load(iree_trace_replay_t* replay,
yaml_document_t* document,
yaml_node_t* event_node) {
yaml_node_t* module_node = NULL;
IREE_RETURN_IF_ERROR(iree_yaml_mapping_find(
document, event_node, iree_make_cstring_view("module"), &module_node));
yaml_node_t* type_node = NULL;
IREE_RETURN_IF_ERROR(iree_yaml_mapping_find(
document, module_node, iree_make_cstring_view("type"), &type_node));
iree_string_view_t type = iree_yaml_node_as_string(type_node);
if (iree_string_view_equal(type, iree_make_cstring_view("builtin"))) {
return iree_trace_replay_load_builtin_module(replay, document, module_node);
} else if (iree_string_view_equal(type, iree_make_cstring_view("bytecode"))) {
return iree_trace_replay_load_bytecode_module(replay, document,
module_node);
}
return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
"module type '%.*s' not recognized", (int)type.size,
type.data);
}
static iree_status_t iree_trace_replay_parse_item(iree_trace_replay_t* replay,
yaml_document_t* document,
yaml_node_t* value_node,
iree_vm_list_t* target_list);
static iree_status_t iree_trace_replay_parse_item_sequence(
iree_trace_replay_t* replay, yaml_document_t* document,
yaml_node_t* sequence_node, iree_vm_list_t* target_list);
// Parses a scalar value and appends it to |target_list|.
//
// ```yaml
// i8: 7
// ```
static iree_status_t iree_trace_replay_parse_scalar(
iree_trace_replay_t* replay, yaml_document_t* document,
yaml_node_t* value_node, iree_vm_list_t* target_list) {
yaml_node_t* data_node = NULL;
IREE_RETURN_IF_ERROR(iree_yaml_mapping_try_find(
document, value_node, iree_make_cstring_view("i8"), &data_node));
if (data_node) {
int32_t value = 0;
if (!iree_string_view_atoi_int32(iree_yaml_node_as_string(data_node),
&value)) {
return iree_make_status(
IREE_STATUS_INVALID_ARGUMENT, "failed to parse i8 value: '%.*s'",
(int)data_node->data.scalar.length, data_node->data.scalar.value);
}
iree_vm_variant_t variant = iree_vm_variant_empty();
variant.type.value_type = IREE_VM_VALUE_TYPE_I8;
variant.i8 = (int8_t)value;
return iree_vm_list_push_variant(target_list, &variant);
}
IREE_RETURN_IF_ERROR(iree_yaml_mapping_try_find(
document, value_node, iree_make_cstring_view("i16"), &data_node));
if (data_node) {
int32_t value = 0;
if (!iree_string_view_atoi_int32(iree_yaml_node_as_string(data_node),
&value)) {
return iree_make_status(
IREE_STATUS_INVALID_ARGUMENT, "failed to parse i16 value: '%.*s'",
(int)data_node->data.scalar.length, data_node->data.scalar.value);
}
iree_vm_variant_t variant = iree_vm_variant_empty();
variant.type.value_type = IREE_VM_VALUE_TYPE_I16;
variant.i16 = (int16_t)value;
return iree_vm_list_push_variant(target_list, &variant);
}
IREE_RETURN_IF_ERROR(iree_yaml_mapping_try_find(
document, value_node, iree_make_cstring_view("i32"), &data_node));
if (data_node) {
iree_vm_variant_t variant = iree_vm_variant_empty();
variant.type.value_type = IREE_VM_VALUE_TYPE_I32;
if (!iree_string_view_atoi_int32(iree_yaml_node_as_string(data_node),
&variant.i32)) {
return iree_make_status(
IREE_STATUS_INVALID_ARGUMENT, "failed to parse i32 value: '%.*s'",
(int)data_node->data.scalar.length, data_node->data.scalar.value);
}
return iree_vm_list_push_variant(target_list, &variant);
}
IREE_RETURN_IF_ERROR(iree_yaml_mapping_try_find(
document, value_node, iree_make_cstring_view("i64"), &data_node));
if (data_node) {
iree_vm_variant_t variant = iree_vm_variant_empty();
variant.type.value_type = IREE_VM_VALUE_TYPE_I64;
if (!iree_string_view_atoi_int64(iree_yaml_node_as_string(data_node),
&variant.i64)) {
return iree_make_status(
IREE_STATUS_INVALID_ARGUMENT, "failed to parse i64 value: '%.*s'",
(int)data_node->data.scalar.length, data_node->data.scalar.value);
}
return iree_vm_list_push_variant(target_list, &variant);
}
IREE_RETURN_IF_ERROR(iree_yaml_mapping_try_find(
document, value_node, iree_make_cstring_view("f32"), &data_node));
if (data_node) {
iree_vm_variant_t variant = iree_vm_variant_empty();
variant.type.value_type = IREE_VM_VALUE_TYPE_F32;
if (!iree_string_view_atof(iree_yaml_node_as_string(data_node),
&variant.f32)) {
return iree_make_status(
IREE_STATUS_INVALID_ARGUMENT, "failed to parse f32 value: '%.*s'",
(int)data_node->data.scalar.length, data_node->data.scalar.value);
}
return iree_vm_list_push_variant(target_list, &variant);
}
IREE_RETURN_IF_ERROR(iree_yaml_mapping_try_find(
document, value_node, iree_make_cstring_view("f64"), &data_node));
if (data_node) {
iree_vm_variant_t variant = iree_vm_variant_empty();
variant.type.value_type = IREE_VM_VALUE_TYPE_F64;
if (!iree_string_view_atod(iree_yaml_node_as_string(data_node),
&variant.f64)) {
return iree_make_status(
IREE_STATUS_INVALID_ARGUMENT, "failed to parse f64 value: '%.*s'",
(int)data_node->data.scalar.length, data_node->data.scalar.value);
}
return iree_vm_list_push_variant(target_list, &variant);
}
return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
"(%zu): unimplemented scalar type parser",
value_node->start_mark.line);
}
// Parses a !vm.list and appends it to |target_list|.
//
// ```yaml
// items:
// - type: value
// i8: 7
// - type: vm.list
// items: ...
// ```
static iree_status_t iree_trace_replay_parse_vm_list(
iree_trace_replay_t* replay, yaml_document_t* document,
yaml_node_t* value_node, iree_vm_list_t* target_list) {
if (value_node->type != YAML_MAPPING_NODE) {
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"(%zu): expected sequence node for type",
value_node->start_mark.line);
}
yaml_node_t* items_node = NULL;
IREE_RETURN_IF_ERROR(iree_yaml_mapping_try_find(
document, value_node, iree_make_cstring_view("items"), &items_node));
iree_vm_list_t* list = NULL;
IREE_RETURN_IF_ERROR(iree_vm_list_create(/*element_type=*/NULL,
/*initial_capacity=*/8,
replay->host_allocator, &list));
iree_status_t status = iree_ok_status();
if (items_node) {
status = iree_trace_replay_parse_item_sequence(replay, document, items_node,
list);
}
if (iree_status_is_ok(status)) {
iree_vm_ref_t list_ref = iree_vm_list_move_ref(list);
status = iree_vm_list_push_ref_move(target_list, &list_ref);
}
if (!iree_status_is_ok(status)) {
iree_vm_list_release(list);
}
return status;
}
// Parses a shape sequence.
//
// ```yaml
// shape:
// - 1
// - 2
// - 3
// ```
// or
// ```yaml
// shape: 1x2x3
// ```
static iree_status_t iree_trace_replay_parse_hal_shape(
iree_trace_replay_t* replay, yaml_document_t* document,
yaml_node_t* shape_node, iree_host_size_t shape_capacity,
iree_hal_dim_t* shape, iree_host_size_t* out_shape_rank) {
iree_host_size_t shape_rank = 0;
*out_shape_rank = shape_rank;
if (!shape_node) return iree_ok_status();
if (shape_node->type == YAML_SCALAR_NODE) {
// Short-hand using the canonical shape parser (4x8).
return iree_hal_parse_shape(iree_yaml_node_as_string(shape_node),
shape_capacity, shape, out_shape_rank);
} else if (shape_node->type != YAML_SEQUENCE_NODE) {
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"(%zu): expected scalar or sequence node for shape",
shape_node->start_mark.line);
}
// Shape dimension list:
for (yaml_node_item_t* item = shape_node->data.sequence.items.start;
item != shape_node->data.sequence.items.top; ++item) {
yaml_node_t* dim_node = yaml_document_get_node(document, *item);
if (dim_node->type != YAML_SCALAR_NODE) {
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"(%zu): expected integer shape dimension",
dim_node->start_mark.line);
}
int64_t dim = 0;
if (!iree_string_view_atoi_int64(iree_yaml_node_as_string(dim_node),
&dim)) {
return iree_make_status(
IREE_STATUS_INVALID_ARGUMENT, "(%zu): invalid shape dimension '%.*s'",
dim_node->start_mark.line, (int)dim_node->data.scalar.length,
dim_node->data.scalar.value);
}
if (shape_rank >= shape_capacity) {
return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
"(%zu): shape rank overflow (>%zu)",
shape_node->start_mark.line, shape_capacity);
}
shape[shape_rank++] = (iree_hal_dim_t)dim;
}
*out_shape_rank = shape_rank;
return iree_ok_status();
}
// Parses an element type.
//
// ```yaml
// element_type: 50331680
// ```
// or
// ```yaml
// element_type: f32
// ```
static iree_status_t iree_trace_replay_parse_hal_element_type(
iree_trace_replay_t* replay, yaml_document_t* document,
yaml_node_t* element_type_node, iree_hal_element_type_t* out_element_type) {
*out_element_type = IREE_HAL_ELEMENT_TYPE_NONE;
iree_string_view_t element_type_str =
iree_yaml_node_as_string(element_type_node);
if (iree_string_view_is_empty(element_type_str)) {
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"(%zu): element type missing",
element_type_node->start_mark.line);
}
// If the first character is a digit then interpret as a %d type.
if (isdigit(element_type_str.data[0])) {
static_assert(sizeof(*out_element_type) == sizeof(uint32_t), "4 bytes");
if (!iree_string_view_atoi_uint32(element_type_str, out_element_type)) {
return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
"(%zu): invalid element type",
element_type_node->start_mark.line);
}
return iree_ok_status();
}
// Parse as a canonical element type.
return iree_hal_parse_element_type(element_type_str, out_element_type);
}
// Parses an encoding type.
//
// ```yaml
// encoding_type: 50331680
// ```
static iree_status_t iree_trace_replay_parse_hal_encoding_type(
iree_trace_replay_t* replay, yaml_document_t* document,
yaml_node_t* encoding_type_node,
iree_hal_encoding_type_t* out_encoding_type) {
*out_encoding_type = IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR;
if (!encoding_type_node) {
return iree_ok_status();
}
if (!encoding_type_node) return iree_ok_status();
iree_string_view_t encoding_type_str =
iree_yaml_node_as_string(encoding_type_node);
if (iree_string_view_is_empty(encoding_type_str)) {
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"(%zu): encoding type missing",
encoding_type_node->start_mark.line);
}
// If the first character is a digit then interpret as a %d type.
if (isdigit(encoding_type_str.data[0])) {
static_assert(sizeof(*out_encoding_type) == sizeof(uint32_t), "4 bytes");
if (!iree_string_view_atoi_uint32(encoding_type_str, out_encoding_type)) {
return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
"(%zu): invalid encoding type",
encoding_type_node->start_mark.line);
}
return iree_ok_status();
}
// Parse as a canonical encoding type.
// TODO(#6762): implement iree_hal_parse_element_type.
return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
"iree_hal_parse_encoding_type not implemented");
}
// Parses a serialized !hal.buffer into |buffer|.
//
// ```yaml
// contents: !!binary |
// AACAPwAAAEAAAEBAAACAQA==
// ```
static iree_status_t iree_trace_replay_parse_hal_buffer(
iree_trace_replay_t* replay, yaml_document_t* document,
yaml_node_t* contents_node, iree_hal_element_type_t element_type,
iree_hal_buffer_t* buffer) {
if (!contents_node) {
// Empty contents = zero fill.
return iree_ok_status();
} else if (contents_node->type != YAML_SCALAR_NODE) {
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"(%zu): expected scalar node for buffer contents",
contents_node->start_mark.line);
}
iree_string_view_t value =
iree_string_view_trim(iree_yaml_node_as_string(contents_node));
iree_hal_buffer_mapping_t mapping;
IREE_RETURN_IF_ERROR(
iree_hal_buffer_map_range(buffer, IREE_HAL_MEMORY_ACCESS_DISCARD_WRITE, 0,
IREE_WHOLE_BUFFER, &mapping));
iree_status_t status = iree_ok_status();
if (strcmp(contents_node->tag, "tag:yaml.org,2002:binary") == 0) {
status = iree_yaml_base64_decode(value, mapping.contents);
} else if (strcmp(contents_node->tag, "tag:yaml.org,2002:str") == 0) {
status =
iree_hal_parse_buffer_elements(value, element_type, mapping.contents);
} else {
status = iree_make_status(
IREE_STATUS_UNIMPLEMENTED, "(%zu): unimplemented buffer encoding '%s'",
contents_node->start_mark.line, contents_node->tag);
}
iree_hal_buffer_unmap_range(&mapping);
return status;
}
// Writes an element of the given |element_type| with the given integral |value|
// to |dst|.
static void iree_trace_replay_write_element(
iree_hal_element_type_t element_type, int value, void* dst) {
#define IREE_TRACE_REPLAY_WRITE_ELEMENT_CASE(ETYPE, CTYPE) \
case IREE_HAL_ELEMENT_TYPE_##ETYPE: \
*(CTYPE*)dst = (CTYPE)value; \
break;
switch (element_type) {
IREE_TRACE_REPLAY_WRITE_ELEMENT_CASE(SINT_8, int8_t)
IREE_TRACE_REPLAY_WRITE_ELEMENT_CASE(SINT_16, int16_t)
IREE_TRACE_REPLAY_WRITE_ELEMENT_CASE(SINT_32, int32_t)
IREE_TRACE_REPLAY_WRITE_ELEMENT_CASE(SINT_64, int64_t)
IREE_TRACE_REPLAY_WRITE_ELEMENT_CASE(UINT_8, uint8_t)
IREE_TRACE_REPLAY_WRITE_ELEMENT_CASE(UINT_16, uint16_t)
IREE_TRACE_REPLAY_WRITE_ELEMENT_CASE(UINT_32, uint32_t)
IREE_TRACE_REPLAY_WRITE_ELEMENT_CASE(UINT_64, uint64_t)
IREE_TRACE_REPLAY_WRITE_ELEMENT_CASE(FLOAT_32, float)
IREE_TRACE_REPLAY_WRITE_ELEMENT_CASE(FLOAT_64, double)
default:
IREE_ASSERT(false, "unhandled element type");
break;
}
#undef IREE_TRACE_REPLAY_WRITE_ELEMENT_CASE
}
// Writes an identity matrix, with matrix elements of the given |element_type|,
// to the destination |span|. The matrix shape is inferred from |inner_size|
// and the span's length.
//
// Here by 'identity matrix' we mean any two-dimensional array of integers
// of the form
//
// array[i, j] = ((i == j) ? 1 : 0)
//
// Technically they are only called 'identity matrix' for square shapes.
//
// These identity matrices are useful in matrix multiplication tests to
// generate testcases that are easy to debug numerically, as the identity
// matrix is the neutral element for matrix multiplication.
static void iree_trace_replay_generate_identity_matrix(
iree_hal_element_type_t element_type, iree_byte_span_t span,
iree_hal_dim_t inner_size) {
iree_host_size_t element_byte_count =
iree_hal_element_byte_count(element_type);
uint8_t* data_end = span.data + span.data_length;
iree_host_size_t inner_index = 0;
iree_host_size_t outer_index = 0;
for (uint8_t* data = span.data; data < data_end; data += element_byte_count) {
int value = inner_index == outer_index ? 1 : 0;
iree_trace_replay_write_element(element_type, value, data);
++inner_index;
if (inner_index == inner_size) {
inner_index = 0;
++outer_index;
}
}
}
// Simple deterministic pseudorandom generator.
// Typically in tests we want reproducible results both across runs and across
// machines.
static uint8_t iree_trace_replay_pseudorandom_uint8(uint32_t* state) {
// Same as C++'s std::minstd_rand.
*state = (*state * 48271) % 2147483647;
// return the second-least-signicant out of the 4 bytes of state. it avoids
// some mild issues with the least-significant and most-significant bytes.
return *state >> 8;
}
// Fills the destination span with pseudorandom values of the given
// |element_type|. The given |seed| is passed to the pseudorandom generator.
// The pseudorandom values are reproducible both across runs and across
// machines.
static void iree_trace_replay_generate_fully_specified_pseudorandom_buffer(
iree_hal_element_type_t element_type, iree_byte_span_t span,
uint32_t seed) {
const bool is_unsigned = iree_hal_element_numerical_type(element_type) ==
IREE_HAL_NUMERICAL_TYPE_INTEGER_UNSIGNED;
iree_host_size_t element_byte_count =
iree_hal_element_byte_count(element_type);
uint8_t* data_end = span.data + span.data_length;
uint32_t state = seed;
for (uint8_t* data = span.data; data < data_end; data += element_byte_count) {
int value_in_uint8_range = iree_trace_replay_pseudorandom_uint8(&state);
int value = value_in_uint8_range + (is_unsigned ? 0 : -128);
iree_trace_replay_write_element(element_type, value, data);
}
}
// Generates the destination |buffer| using the generator specified by
// |contents_generator_node|.
static iree_status_t iree_trace_replay_generate_hal_buffer(
iree_trace_replay_t* replay, yaml_document_t* document,
yaml_node_t* contents_generator_node, iree_hal_element_type_t element_type,
iree_hal_buffer_t* buffer, iree_hal_dim_t* shape,
iree_host_size_t shape_size) {
if (!contents_generator_node) {
return iree_ok_status();
} else if (contents_generator_node->type != YAML_SCALAR_NODE) {
return iree_make_status(
IREE_STATUS_INVALID_ARGUMENT,
"(%zu): expected scalar node for buffer contents_generator",
contents_generator_node->start_mark.line);
}
iree_hal_buffer_mapping_t mapping;
IREE_RETURN_IF_ERROR(
iree_hal_buffer_map_range(buffer, IREE_HAL_MEMORY_ACCESS_DISCARD_WRITE, 0,
IREE_WHOLE_BUFFER, &mapping));
iree_status_t status = iree_ok_status();
if (strcmp(contents_generator_node->tag, "!tag:iree:identity_matrix") == 0) {
if (shape_size == 2) {
iree_hal_dim_t inner_size = shape[shape_size - 1];
iree_trace_replay_generate_identity_matrix(element_type, mapping.contents,
inner_size);
} else {
status = iree_make_status(
IREE_STATUS_INVALID_ARGUMENT,
"the identity_matrix generator is only for 2D shapes (matrices)");
}
} else if (strcmp(contents_generator_node->tag,
"!tag:iree:fully_specified_pseudorandom") == 0) {
// To enable pseudorandom tests that are both reproducible and invariant
// under reordering and filtering testcases, the seed is explicitly
// passed as argument in the contents_generator tag.
iree_string_view_t seed_str = iree_string_view_trim(
iree_yaml_node_as_string(contents_generator_node));
uint32_t seed;
if (iree_string_view_atoi_uint32(seed_str, &seed)) {
iree_trace_replay_generate_fully_specified_pseudorandom_buffer(
element_type, mapping.contents, seed);
} else {
status = iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"could not parse the seed argument ('%s') of "
"the fully_specified_pseudorandom tag",
seed_str.data);
}
} else {
status = iree_make_status(
IREE_STATUS_UNIMPLEMENTED, "(%zu): unimplemented buffer generator '%s'",
contents_generator_node->start_mark.line, contents_generator_node->tag);
}
iree_hal_buffer_unmap_range(&mapping);
return status;
}
// Parses a !hal.buffer_view and appends it to |target_list|.
//
// ```yaml
// shape:
// - 4
// element_type: 50331680
// contents: !!binary |
// AACAPwAAAEAAAEBAAACAQA==
// ```
static iree_status_t iree_trace_replay_parse_hal_buffer_view(
iree_trace_replay_t* replay, yaml_document_t* document,
yaml_node_t* value_node, iree_vm_list_t* target_list) {
yaml_node_t* shape_node = NULL;
IREE_RETURN_IF_ERROR(iree_yaml_mapping_try_find(
document, value_node, iree_make_cstring_view("shape"), &shape_node));
iree_hal_dim_t shape[16];
iree_host_size_t shape_rank = 0;
IREE_RETURN_IF_ERROR(iree_trace_replay_parse_hal_shape(
replay, document, shape_node, IREE_ARRAYSIZE(shape), shape, &shape_rank));
yaml_node_t* element_type_node = NULL;
IREE_RETURN_IF_ERROR(iree_yaml_mapping_find(
document, value_node, iree_make_cstring_view("element_type"),
&element_type_node));
iree_hal_element_type_t element_type = IREE_HAL_ELEMENT_TYPE_NONE;
IREE_RETURN_IF_ERROR(iree_trace_replay_parse_hal_element_type(
replay, document, element_type_node, &element_type));
yaml_node_t* encoding_type_node = NULL;
IREE_RETURN_IF_ERROR(iree_yaml_mapping_try_find(
document, value_node, iree_make_cstring_view("encoding_type"),
&encoding_type_node));
iree_hal_encoding_type_t encoding_type =
IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR;
IREE_RETURN_IF_ERROR(iree_trace_replay_parse_hal_encoding_type(
replay, document, encoding_type_node, &encoding_type));
yaml_node_t* contents_node = NULL;
IREE_RETURN_IF_ERROR(iree_yaml_mapping_try_find(
document, value_node, iree_make_cstring_view("contents"),
&contents_node));
yaml_node_t* contents_generator_node = NULL;
IREE_RETURN_IF_ERROR(iree_yaml_mapping_try_find(
document, value_node, iree_make_cstring_view("contents_generator"),
&contents_generator_node));
if (contents_node && contents_generator_node) {
return iree_make_status(
IREE_STATUS_INVALID_ARGUMENT,
"(%zu): cannot have both contents and contents_generator",
contents_generator_node->start_mark.line);
}
iree_device_size_t allocation_size = 0;
IREE_RETURN_IF_ERROR(iree_hal_buffer_compute_view_size(
shape, shape_rank, element_type, IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR,
&allocation_size));
iree_hal_buffer_t* buffer = NULL;
IREE_RETURN_IF_ERROR(iree_hal_allocator_allocate_buffer(
iree_hal_device_allocator(replay->device),
IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL | IREE_HAL_MEMORY_TYPE_HOST_VISIBLE,
IREE_HAL_BUFFER_USAGE_ALL, allocation_size, &buffer));
iree_status_t status = iree_trace_replay_generate_hal_buffer(
replay, document, contents_generator_node, element_type, buffer, shape,
shape_rank);
if (!iree_status_is_ok(status)) {
iree_hal_buffer_release(buffer);
return status;
}
status = iree_trace_replay_parse_hal_buffer(replay, document, contents_node,
element_type, buffer);
if (!iree_status_is_ok(status)) {
iree_hal_buffer_release(buffer);
return status;
}
iree_hal_buffer_view_t* buffer_view = NULL;
status = iree_hal_buffer_view_create(buffer, shape, shape_rank, element_type,
encoding_type, &buffer_view);
iree_hal_buffer_release(buffer);
IREE_RETURN_IF_ERROR(status);
iree_vm_ref_t buffer_view_ref = iree_hal_buffer_view_move_ref(buffer_view);
status = iree_vm_list_push_ref_move(target_list, &buffer_view_ref);
iree_vm_ref_release(&buffer_view_ref);
return status;
}
// Parses a !hal.buffer_view in tensor form and appends it to |target_list|.
//
// ```yaml
// !tensor 4xf32=[0 1 2 3]
// ```
static iree_status_t iree_trace_replay_parse_inline_hal_buffer_view(
iree_trace_replay_t* replay, yaml_document_t* document,
yaml_node_t* value_node, iree_vm_list_t* target_list) {
iree_hal_buffer_view_t* buffer_view = NULL;
IREE_RETURN_IF_ERROR(iree_hal_buffer_view_parse(
iree_yaml_node_as_string(value_node),
iree_hal_device_allocator(replay->device), &buffer_view));
iree_vm_ref_t buffer_view_ref = iree_hal_buffer_view_move_ref(buffer_view);
iree_status_t status =
iree_vm_list_push_ref_move(target_list, &buffer_view_ref);
iree_vm_ref_release(&buffer_view_ref);
return status;
}
// Parses a typed item from |value_node| and appends it to |target_list|.
//
// ```yaml
// type: vm.list
// items:
// - type: value
// i8: 7
// ```
// or
// ```yaml
// !hal.buffer_view 4xf32=[0 1 2 3]
// ```
static iree_status_t iree_trace_replay_parse_item(iree_trace_replay_t* replay,
yaml_document_t* document,
yaml_node_t* value_node,
iree_vm_list_t* target_list) {
if (strcmp(value_node->tag, "!hal.buffer_view") == 0) {
return iree_trace_replay_parse_inline_hal_buffer_view(
replay, document, value_node, target_list);
}
yaml_node_t* type_node = NULL;
IREE_RETURN_IF_ERROR(iree_yaml_mapping_find(
document, value_node, iree_make_cstring_view("type"), &type_node));
iree_string_view_t type = iree_yaml_node_as_string(type_node);
if (iree_string_view_equal(type, iree_make_cstring_view("null"))) {
iree_vm_variant_t null_value = iree_vm_variant_empty();
return iree_vm_list_push_variant(target_list, &null_value);
} else if (iree_string_view_equal(type, iree_make_cstring_view("value"))) {
return iree_trace_replay_parse_scalar(replay, document, value_node,
target_list);
} else if (iree_string_view_equal(type, iree_make_cstring_view("vm.list"))) {
return iree_trace_replay_parse_vm_list(replay, document, value_node,
target_list);
} else if (iree_string_view_equal(
type, iree_make_cstring_view("hal.buffer_view"))) {
return iree_trace_replay_parse_hal_buffer_view(replay, document, value_node,
target_list);
}
return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
"unimplemented type parser: '%.*s'", (int)type.size,
type.data);
}
// Parses a sequence of items appending each to |target_list|.
static iree_status_t iree_trace_replay_parse_item_sequence(
iree_trace_replay_t* replay, yaml_document_t* document,
yaml_node_t* sequence_node, iree_vm_list_t* target_list) {
for (yaml_node_item_t* item = sequence_node->data.sequence.items.start;
item != sequence_node->data.sequence.items.top; ++item) {
yaml_node_t* item_node = yaml_document_get_node(document, *item);
IREE_RETURN_IF_ERROR(
iree_trace_replay_parse_item(replay, document, item_node, target_list));
}
return iree_ok_status();
}
static iree_status_t iree_trace_replay_print_item(iree_vm_variant_t* value);
static iree_status_t iree_trace_replay_print_scalar(iree_vm_variant_t* value) {
switch (value->type.value_type) {
case IREE_VM_VALUE_TYPE_I8:
fprintf(stdout, "i8=%" PRIi8, value->i8);
break;
case IREE_VM_VALUE_TYPE_I16:
fprintf(stdout, "i16=%" PRIi16, value->i16);
break;
case IREE_VM_VALUE_TYPE_I32:
fprintf(stdout, "i32=%" PRIi32, value->i32);
break;
case IREE_VM_VALUE_TYPE_I64:
fprintf(stdout, "i64=%" PRIi64, value->i64);
break;
case IREE_VM_VALUE_TYPE_F32:
fprintf(stdout, "f32=%G", value->f32);
break;
case IREE_VM_VALUE_TYPE_F64:
fprintf(stdout, "f64=%G", value->f64);
break;
default:
fprintf(stdout, "?");
break;
}
return iree_ok_status();
}
static iree_status_t iree_trace_replay_print_vm_list(iree_vm_list_t* list) {
for (iree_host_size_t i = 0; i < iree_vm_list_size(list); ++i) {
iree_vm_variant_t variant = iree_vm_variant_empty();
IREE_RETURN_IF_ERROR(iree_vm_list_get_variant(list, i, &variant),
"variant %zu not present", i);
IREE_RETURN_IF_ERROR(iree_trace_replay_print_item(&variant));
fprintf(stdout, "\n");
}
return iree_ok_status();
}
static iree_status_t iree_trace_replay_print_hal_buffer_view(
iree_hal_buffer_view_t* buffer_view) {
return iree_hal_buffer_view_fprint(stdout, buffer_view,
/*max_element_count=*/1024);
}
static iree_status_t iree_trace_replay_print_item(iree_vm_variant_t* value) {
if (iree_vm_variant_is_value(*value)) {
IREE_RETURN_IF_ERROR(iree_trace_replay_print_scalar(value));
} else if (iree_vm_variant_is_ref(*value)) {
if (iree_hal_buffer_view_isa(value->ref)) {
iree_hal_buffer_view_t* buffer_view =
iree_hal_buffer_view_deref(value->ref);
IREE_RETURN_IF_ERROR(
iree_trace_replay_print_hal_buffer_view(buffer_view));
} else if (iree_vm_list_isa(value->ref)) {
iree_vm_list_t* list = iree_vm_list_deref(value->ref);
IREE_RETURN_IF_ERROR(iree_trace_replay_print_vm_list(list));
} else {
// TODO(benvanik): a way for ref types to describe themselves.
fprintf(stdout, "(no printer)");
}
} else {
fprintf(stdout, "(null)");
}
return iree_ok_status();
}
iree_status_t iree_trace_replay_event_call_prepare(
iree_trace_replay_t* replay, yaml_document_t* document,
yaml_node_t* event_node, iree_vm_function_t* out_function,
iree_vm_list_t** out_input_list) {
memset(out_function, 0, sizeof(*out_function));
*out_input_list = NULL;
// Resolve the function ('module.function') within the context.
yaml_node_t* function_node = NULL;
IREE_RETURN_IF_ERROR(iree_yaml_mapping_find(
document, event_node, iree_make_cstring_view("function"),
&function_node));
iree_vm_function_t function;
IREE_RETURN_IF_ERROR(iree_vm_context_resolve_function(
replay->context, iree_yaml_node_as_string(function_node), &function));
// Parse function inputs.
yaml_node_t* args_node = NULL;
IREE_RETURN_IF_ERROR(iree_yaml_mapping_try_find(
document, event_node, iree_make_cstring_view("args"), &args_node));
iree_vm_list_t* input_list = NULL;
IREE_RETURN_IF_ERROR(
iree_vm_list_create(/*element_type=*/NULL, /*initial_capacity=*/8,
replay->host_allocator, &input_list));
iree_status_t status = iree_trace_replay_parse_item_sequence(
replay, document, args_node, input_list);
if (iree_status_is_ok(status)) {
*out_function = function;
*out_input_list = input_list;
} else {
iree_vm_list_release(input_list);
}
return status;
}
iree_status_t iree_trace_replay_event_call(iree_trace_replay_t* replay,
yaml_document_t* document,
yaml_node_t* event_node,
iree_vm_list_t** out_output_list) {
if (out_output_list) *out_output_list = NULL;
iree_vm_function_t function;
iree_vm_list_t* input_list = NULL;
IREE_RETURN_IF_ERROR(iree_trace_replay_event_call_prepare(
replay, document, event_node, &function, &input_list));
// Invoke the function to produce outputs.
iree_vm_list_t* output_list = NULL;
iree_status_t status =
iree_vm_list_create(/*element_type=*/NULL, /*initial_capacity=*/8,
replay->host_allocator, &output_list);
if (iree_status_is_ok(status)) {
status = iree_vm_invoke(replay->context, function,
IREE_VM_INVOCATION_FLAG_NONE, /*policy=*/NULL,
input_list, output_list, replay->host_allocator);
}
iree_vm_list_release(input_list);
if (iree_status_is_ok(status) && out_output_list) {
*out_output_list = output_list;
} else {
iree_vm_list_release(output_list);
}
return status;
}
static iree_status_t iree_trace_replay_event_call_stdout(
iree_trace_replay_t* replay, yaml_document_t* document,
yaml_node_t* event_node) {
yaml_node_t* function_node = NULL;
IREE_RETURN_IF_ERROR(iree_yaml_mapping_find(
document, event_node, iree_make_cstring_view("function"),
&function_node));
iree_string_view_t function_name = iree_yaml_node_as_string(function_node);
fprintf(stdout, "--- CALL[%.*s] ---\n", (int)function_name.size,
function_name.data);
// Prepare to call the function.
iree_vm_function_t function;
iree_vm_list_t* input_list = NULL;
IREE_RETURN_IF_ERROR(iree_trace_replay_event_call_prepare(
replay, document, event_node, &function, &input_list));
// Invoke the function to produce outputs.
iree_vm_list_t* output_list = NULL;
iree_status_t status =
iree_vm_list_create(/*element_type=*/NULL, /*initial_capacity=*/8,
replay->host_allocator, &output_list);
if (iree_status_is_ok(status)) {
status = iree_vm_invoke(replay->context, function,
IREE_VM_INVOCATION_FLAG_NONE, /*policy=*/NULL,
input_list, output_list, replay->host_allocator);
}
iree_vm_list_release(input_list);
// Print the outputs.
if (iree_status_is_ok(status)) {
status = iree_trace_replay_print_vm_list(output_list);
}
iree_vm_list_release(output_list);
return status;
}
iree_status_t iree_trace_replay_event(iree_trace_replay_t* replay,
yaml_document_t* document,
yaml_node_t* event_node) {
if (event_node->type != YAML_MAPPING_NODE) {
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"(%zu): expected mapping node",
event_node->start_mark.line);
}
yaml_node_t* type_node = NULL;
IREE_RETURN_IF_ERROR(iree_yaml_mapping_find(
document, event_node, iree_make_cstring_view("type"), &type_node));
if (iree_yaml_string_equal(type_node,
iree_make_cstring_view("context_load"))) {
return iree_trace_replay_event_context_load(replay, document, event_node);
} else if (iree_yaml_string_equal(type_node,
iree_make_cstring_view("module_load"))) {
return iree_trace_replay_event_module_load(replay, document, event_node);
} else if (iree_yaml_string_equal(type_node,
iree_make_cstring_view("call"))) {
return iree_trace_replay_event_call_stdout(replay, document, event_node);
}
return iree_make_status(
IREE_STATUS_UNIMPLEMENTED, "(%zu): unhandled type '%.*s'",
event_node->start_mark.line, (int)type_node->data.scalar.length,
type_node->data.scalar.value);
}