Adding state functionality to iree-run-trace and improving ergonomics. (#12534)
In order to support executing pipelines where outputs of one call are
passed into another the trace replay functionality has grown slightly
closer to turing complete (and loops are definitely coming :) by
obtaining input/output control, numpy npy file access, and a blackboard
for temporary values. A test demonstrating the file format and some
`--help` info has been added to `iree-run-trace` to at least have a
reference not generated by python and ensure it mostly works.
---
The new `!input.get`/`!input.take`/`!output.set`/`!output.push` macros
can be used in any source sequence such as function call arguments.
These will either get (assign semantics) or take (move semantics) a
value from the input list and set or push a value to the output list.
`iree-run-trace` now supports the same `--input=`/`--output=` flags as
`iree-run-module` and they define the input/output handling for the
whole trace pipeline as if calling a single function.
```yaml
type: call
function: module.fn
# pass the first two `--input=` flag values and a constant
args:
- !input.take 0
- !input.take 1
- !hal.buffer_view 4xf32=0,1,2,3
# store the two results into `--output=` 0 and 1 (pushing)
results:
- !output.set 0
- !output.push
```
---
In addition to the input/output lists there's also a user-defined
blackboard that provides storage for the duration of the trace. Slots
can be set by using `!blackboard.set`/`!blackboard.push` on any target
sequence such as function call results and later retrieved in any source
sequence with `!blackboard.get`/`!blackboard.take`.
```yaml
# save call results to the blackboard
type: call
function: module.return_two_things
results:
- !blackboard.push
- !blackboard.push
---
# load prior results from the blackboard
type: call
function: module.consume_three_things
args:
- !input.take 0
- !blackboard.take 0
- !blackboard.take 1
```
---
The `--input=` and `--output=`-style works for pipeline-style traces
while larger traces may need programmatic control over I/O and the
blackboard. The `numpy_load` and `numpy_save` events have been added
which allow for loading or saving one or more `arrays` to a .npy file
`path`. This can be used to stream outputs during processing by using
the `append: true` node when saving or sharding to different files.
```yaml
# load blackboard slot 3 and 4 from a .npy file
type: numpy_load
path: input.npy
arrays:
- !blackboard.set 3
- !blackboard.set 4
---
# save a few arrays to a .npy file
type: numpy_save
path: output.npy
append: false
arrays:
- !blackboard.get 3
- !input.get 0
- !hal.buffer_view 4xf32=0,1,2,3
```
---
There's some helpers that'd be useful to add (enqueue/dequeue, pop, etc)
that could make it easier to write more complex pipelines. The
blackboard could also be changed to using a hash table so that string
keys could be used instead of just ordinals.
Fixes #12525.
Fixes #12526.
diff --git a/build_tools/scripts/run_yamllint.sh b/build_tools/scripts/run_yamllint.sh
index a465af8..979fe0d 100755
--- a/build_tools/scripts/run_yamllint.sh
+++ b/build_tools/scripts/run_yamllint.sh
@@ -21,6 +21,8 @@
declare -a excluded_files_patterns=(
"/third_party/"
"^third_party/"
+ "/tools/test/"
+ "^tools/test/"
)
# Join on |
diff --git a/compiler/src/iree/compiler/ConstEval/Runtime.cpp b/compiler/src/iree/compiler/ConstEval/Runtime.cpp
index d14198a..7ba471a 100644
--- a/compiler/src/iree/compiler/ConstEval/Runtime.cpp
+++ b/compiler/src/iree/compiler/ConstEval/Runtime.cpp
@@ -145,7 +145,8 @@
<< " got " << iree_vm_list_size(outputs);
}
iree_vm_variant_t variant = iree_vm_variant_empty();
- IREE_CHECK_OK(iree_vm_list_get_variant(outputs, 0, &variant));
+ IREE_CHECK_OK(
+ iree_vm_list_get_variant_assign(outputs, 0, &variant));
result = convertVariantToAttribute(loc, variant);
return success(result != nullptr);
}))) {
diff --git a/experimental/web/sample_dynamic/main.c b/experimental/web/sample_dynamic/main.c
index 17bb4be..8f75a40 100644
--- a/experimental/web/sample_dynamic/main.c
+++ b/experimental/web/sample_dynamic/main.c
@@ -283,8 +283,9 @@
iree_vm_list_t* variants_list = iree_runtime_call_outputs(call);
for (iree_host_size_t i = 0; i < iree_vm_list_size(variants_list); ++i) {
iree_vm_variant_t variant = iree_vm_variant_empty();
- IREE_RETURN_IF_ERROR(iree_vm_list_get_variant(variants_list, i, &variant),
- "variant %" PRIhsz " not present", i);
+ IREE_RETURN_IF_ERROR(
+ iree_vm_list_get_variant_assign(variants_list, i, &variant),
+ "variant %" PRIhsz " not present", i);
if (iree_vm_variant_is_value(variant)) {
switch (variant.type.value_type) {
diff --git a/runtime/bindings/python/vm.cc b/runtime/bindings/python/vm.cc
index 3e32df9..70e5f53 100644
--- a/runtime/bindings/python/vm.cc
+++ b/runtime/bindings/python/vm.cc
@@ -255,7 +255,7 @@
py::object VmVariantList::GetVariant(int index) {
iree_vm_variant_t v = iree_vm_variant_empty();
- CheckApiStatus(iree_vm_list_get_variant(raw_ptr(), index, &v),
+ CheckApiStatus(iree_vm_list_get_variant_assign(raw_ptr(), index, &v),
"Could not access list element");
if (iree_vm_type_def_is_value(&v.type)) {
// Convert a value type.
@@ -288,7 +288,7 @@
py::object VmVariantList::GetAsSerializedTraceValue(int index) {
iree_vm_variant_t v = iree_vm_variant_empty();
- CheckApiStatus(iree_vm_list_get_variant(raw_ptr(), index, &v),
+ CheckApiStatus(iree_vm_list_get_variant_assign(raw_ptr(), index, &v),
"Could not access list element");
if (iree_vm_type_def_is_value(&v.type)) {
// Convert a value type.
@@ -367,8 +367,14 @@
// Element type.
iree_hal_element_type_t element_type =
iree_hal_buffer_view_element_type(buffer_view);
- // TODO: Would be nice to output as hex.
- record["element_type"] = element_type;
+ char element_type_str[64] = {0};
+ iree_host_size_t element_type_length = 0;
+ CheckApiStatus(
+ iree_hal_format_element_type(element_type, sizeof(element_type_str),
+ element_type_str, &element_type_length),
+ "Formatting element type");
+ record["element_type"] =
+ std::string(element_type_str, element_type_length);
// Map memory.
iree_device_size_t byte_length = iree_hal_buffer_byte_length(raw_buffer);
@@ -392,7 +398,7 @@
py::object VmVariantList::GetAsRef(int index) {
iree_vm_variant_t v = iree_vm_variant_empty();
- CheckApiStatus(iree_vm_list_get_variant(raw_ptr(), index, &v),
+ CheckApiStatus(iree_vm_list_get_variant_assign(raw_ptr(), index, &v),
"Could not access list element");
if (!iree_vm_variant_is_ref(v)) {
throw std::invalid_argument("list element is not a ref");
@@ -426,7 +432,7 @@
std::unordered_set<iree_vm_list_t*>& visited) {
for (iree_host_size_t i = 0, e = iree_vm_list_size(list); i < e; ++i) {
iree_vm_variant_t variant = iree_vm_variant_empty();
- iree_status_t status = iree_vm_list_get_variant(list, i, &variant);
+ iree_status_t status = iree_vm_list_get_variant_assign(list, i, &variant);
if (!iree_status_is_ok(status)) {
iree_status_ignore(status);
out.append("Error");
diff --git a/runtime/src/iree/tooling/BUILD b/runtime/src/iree/tooling/BUILD
index 18182a8..9fc50c8 100644
--- a/runtime/src/iree/tooling/BUILD
+++ b/runtime/src/iree/tooling/BUILD
@@ -193,6 +193,8 @@
srcs = ["trace_replay.c"],
hdrs = ["trace_replay.h"],
deps = [
+ ":device_util",
+ ":numpy_io",
":yaml_util",
"//runtime/src/iree/base",
"//runtime/src/iree/base:tracing",
diff --git a/runtime/src/iree/tooling/CMakeLists.txt b/runtime/src/iree/tooling/CMakeLists.txt
index 5e41789..7e4c87e 100644
--- a/runtime/src/iree/tooling/CMakeLists.txt
+++ b/runtime/src/iree/tooling/CMakeLists.txt
@@ -214,6 +214,8 @@
SRCS
"trace_replay.c"
DEPS
+ ::device_util
+ ::numpy_io
::yaml_util
iree::base
iree::base::internal
diff --git a/runtime/src/iree/tooling/comparison.cc b/runtime/src/iree/tooling/comparison.cc
index 7cc3c76..b01e3c9 100644
--- a/runtime/src/iree/tooling/comparison.cc
+++ b/runtime/src/iree/tooling/comparison.cc
@@ -244,9 +244,10 @@
for (iree_host_size_t i = 0; i < iree_vm_list_size(expected_list); ++i) {
iree_vm_variant_t expected_variant = iree_vm_variant_empty();
IREE_CHECK_OK(
- iree_vm_list_get_variant(expected_list, i, &expected_variant));
+ iree_vm_list_get_variant_assign(expected_list, i, &expected_variant));
iree_vm_variant_t actual_variant = iree_vm_variant_empty();
- IREE_CHECK_OK(iree_vm_list_get_variant(actual_list, i, &actual_variant));
+ IREE_CHECK_OK(
+ iree_vm_list_get_variant_assign(actual_list, i, &actual_variant));
bool did_match = iree_tooling_compare_variants(
(int)i, expected_variant, actual_variant, host_allocator,
/*max_element_count=*/1024, builder);
diff --git a/runtime/src/iree/tooling/context_util.c b/runtime/src/iree/tooling/context_util.c
index 301819e..9bc9398 100644
--- a/runtime/src/iree/tooling/context_util.c
+++ b/runtime/src/iree/tooling/context_util.c
@@ -111,8 +111,9 @@
}
iree_hal_device_t* device = NULL;
IREE_RETURN_AND_END_ZONE_IF_ERROR(
- z0, iree_hal_create_device_from_flags(default_device_uri, host_allocator,
- &device));
+ z0, iree_hal_create_device_from_flags(
+ iree_hal_available_driver_registry(), default_device_uri,
+ host_allocator, &device));
// Fetch the allocator from the device to pass back to the caller.
iree_hal_allocator_t* device_allocator = iree_hal_device_allocator(device);
diff --git a/runtime/src/iree/tooling/device_util.c b/runtime/src/iree/tooling/device_util.c
index 4d0017f..f738dcd 100644
--- a/runtime/src/iree/tooling/device_util.c
+++ b/runtime/src/iree/tooling/device_util.c
@@ -321,6 +321,7 @@
}
iree_status_t iree_hal_create_device_from_flags(
+ iree_hal_driver_registry_t* driver_registry,
iree_string_view_t default_device, iree_allocator_t host_allocator,
iree_hal_device_t** out_device) {
iree_string_view_t device_uri = default_device;
diff --git a/runtime/src/iree/tooling/device_util.h b/runtime/src/iree/tooling/device_util.h
index 915bb4b..d290fad 100644
--- a/runtime/src/iree/tooling/device_util.h
+++ b/runtime/src/iree/tooling/device_util.h
@@ -30,6 +30,7 @@
// Uses the |default_device| if no flags were specified.
// Fails if more than one device was specified.
iree_status_t iree_hal_create_device_from_flags(
+ iree_hal_driver_registry_t* driver_registry,
iree_string_view_t default_device, iree_allocator_t host_allocator,
iree_hal_device_t** out_device);
diff --git a/runtime/src/iree/tooling/trace_replay.c b/runtime/src/iree/tooling/trace_replay.c
index f05cac8..930ffbe 100644
--- a/runtime/src/iree/tooling/trace_replay.c
+++ b/runtime/src/iree/tooling/trace_replay.c
@@ -17,11 +17,13 @@
#include "iree/base/internal/path.h"
#include "iree/base/tracing.h"
#include "iree/modules/hal/module.h"
+#include "iree/tooling/device_util.h"
+#include "iree/tooling/numpy_io.h"
#include "iree/vm/bytecode/module.h"
-// Parameter for locally defined lcg similar to std::minstd_rand.
-#define IREE_PRNG_MULTIPLIER 48271
-#define IREE_PRNG_MODULUS 2147483647
+//===----------------------------------------------------------------------===//
+// iree_trace_replay_t
+//===----------------------------------------------------------------------===//
iree_status_t iree_trace_replay_initialize(
iree_string_view_t root_path, iree_vm_instance_t* instance,
@@ -41,11 +43,31 @@
out_replay->driver_registry = driver_registry;
- return iree_ok_status();
+ iree_status_t status = iree_ok_status();
+ if (iree_status_is_ok(status)) {
+ status = iree_vm_list_create(NULL, 8u, host_allocator, &out_replay->inputs);
+ }
+ if (iree_status_is_ok(status)) {
+ status =
+ iree_vm_list_create(NULL, 8u, host_allocator, &out_replay->outputs);
+ }
+ if (iree_status_is_ok(status)) {
+ status =
+ iree_vm_list_create(NULL, 8u, host_allocator, &out_replay->blackboard);
+ }
+
+ if (!iree_status_is_ok(status)) {
+ iree_trace_replay_deinitialize(out_replay,
+ IREE_TRACE_REPLAY_SHUTDOWN_QUIET);
+ }
+ return status;
}
void iree_trace_replay_deinitialize(iree_trace_replay_t* replay,
iree_trace_replay_shutdown_flags_t flags) {
+ iree_vm_list_release(replay->inputs);
+ iree_vm_list_release(replay->outputs);
+ iree_vm_list_release(replay->blackboard);
iree_vm_context_release(replay->context);
iree_vm_instance_release(replay->instance);
@@ -65,6 +87,10 @@
replay->device_uris = device_uris;
}
+//===----------------------------------------------------------------------===//
+// type: context_load
+//===----------------------------------------------------------------------===//
+
iree_status_t iree_trace_replay_event_context_load(iree_trace_replay_t* replay,
yaml_document_t* document,
yaml_node_t* event_node) {
@@ -80,6 +106,10 @@
replay->host_allocator, &replay->context);
}
+//===----------------------------------------------------------------------===//
+// type: module_load
+//===----------------------------------------------------------------------===//
+
// TODO(benvanik): rework this to allow for multiple devices from a device set.
static iree_status_t iree_trace_replay_create_device(
iree_trace_replay_t* replay, yaml_node_t* device_node,
@@ -96,8 +126,8 @@
}
// Try to create the device.
- return iree_hal_create_device(replay->driver_registry, device_uri,
- host_allocator, out_device);
+ return iree_hal_create_device_from_flags(replay->driver_registry, device_uri,
+ host_allocator, out_device);
}
static iree_status_t iree_trace_replay_load_builtin_module(
@@ -106,14 +136,14 @@
iree_vm_module_t* module = NULL;
yaml_node_t* name_node = NULL;
- IREE_RETURN_IF_ERROR(iree_yaml_mapping_find(
- document, module_node, iree_make_cstring_view("name"), &name_node));
- if (iree_yaml_string_equal(name_node, iree_make_cstring_view("hal"))) {
- yaml_node_t* driver_node = NULL;
+ IREE_RETURN_IF_ERROR(iree_yaml_mapping_find(document, module_node,
+ IREE_SV("name"), &name_node));
+ if (iree_yaml_string_equal(name_node, IREE_SV("hal"))) {
+ yaml_node_t* device_node = NULL;
IREE_RETURN_IF_ERROR(iree_yaml_mapping_try_find(
- document, module_node, iree_make_cstring_view("driver"), &driver_node));
+ document, module_node, IREE_SV("device"), &device_node));
IREE_RETURN_IF_ERROR(iree_trace_replay_create_device(
- replay, driver_node, replay->host_allocator, &replay->device));
+ replay, device_node, replay->host_allocator, &replay->device));
IREE_RETURN_IF_ERROR(iree_hal_module_create(
replay->instance, replay->device, IREE_HAL_MODULE_FLAG_NONE,
replay->host_allocator, &module));
@@ -134,13 +164,13 @@
iree_trace_replay_t* replay, yaml_document_t* document,
yaml_node_t* module_node) {
yaml_node_t* path_node = NULL;
- IREE_RETURN_IF_ERROR(iree_yaml_mapping_find(
- document, module_node, iree_make_cstring_view("path"), &path_node));
+ IREE_RETURN_IF_ERROR(iree_yaml_mapping_find(document, module_node,
+ IREE_SV("path"), &path_node));
// Load bytecode file (or stdin) contents into memory.
iree_file_contents_t* flatbuffer_contents = NULL;
iree_status_t status = iree_ok_status();
- if (iree_yaml_string_equal(path_node, iree_make_cstring_view("<stdin>"))) {
+ if (iree_yaml_string_equal(path_node, IREE_SV("<stdin>"))) {
fprintf(stdout, "Reading bytecode contents from stdin...\n");
status =
iree_stdin_read_contents(replay->host_allocator, &flatbuffer_contents);
@@ -181,17 +211,17 @@
yaml_document_t* document,
yaml_node_t* event_node) {
yaml_node_t* module_node = NULL;
- IREE_RETURN_IF_ERROR(iree_yaml_mapping_find(
- document, event_node, iree_make_cstring_view("module"), &module_node));
+ IREE_RETURN_IF_ERROR(iree_yaml_mapping_find(document, event_node,
+ IREE_SV("module"), &module_node));
yaml_node_t* type_node = NULL;
- IREE_RETURN_IF_ERROR(iree_yaml_mapping_find(
- document, module_node, iree_make_cstring_view("type"), &type_node));
+ IREE_RETURN_IF_ERROR(iree_yaml_mapping_find(document, module_node,
+ IREE_SV("type"), &type_node));
iree_string_view_t type = iree_yaml_node_as_string(type_node);
- if (iree_string_view_equal(type, iree_make_cstring_view("builtin"))) {
+ if (iree_string_view_equal(type, IREE_SV("builtin"))) {
return iree_trace_replay_load_builtin_module(replay, document, module_node);
- } else if (iree_string_view_equal(type, iree_make_cstring_view("bytecode"))) {
+ } else if (iree_string_view_equal(type, IREE_SV("bytecode"))) {
return iree_trace_replay_load_bytecode_module(replay, document,
module_node);
}
@@ -201,10 +231,268 @@
type.data);
}
-static iree_status_t iree_trace_replay_parse_item(iree_trace_replay_t* replay,
- yaml_document_t* document,
- yaml_node_t* value_node,
- iree_vm_list_t* target_list);
+//===----------------------------------------------------------------------===//
+// RNG utilities
+//===----------------------------------------------------------------------===//
+// TODO(benvanik): move these out to another file.
+
+// Parameter for locally defined lcg similar to std::minstd_rand.
+#define IREE_PRNG_MULTIPLIER 48271
+#define IREE_PRNG_MODULUS 2147483647
+
+// Writes an element of the given |element_type| with the given integral |value|
+// to |dst|.
+static void iree_trace_replay_write_element(
+ iree_hal_element_type_t element_type, int32_t value, void* dst) {
+#define IREE_TRACE_REPLAY_WRITE_ELEMENT_CASE(ETYPE, CTYPE) \
+ case IREE_HAL_ELEMENT_TYPE_##ETYPE: \
+ *(CTYPE*)dst = (CTYPE)value; \
+ break;
+
+ switch (element_type) {
+ IREE_TRACE_REPLAY_WRITE_ELEMENT_CASE(INT_8, int8_t)
+ IREE_TRACE_REPLAY_WRITE_ELEMENT_CASE(INT_16, int16_t)
+ IREE_TRACE_REPLAY_WRITE_ELEMENT_CASE(INT_32, int32_t)
+ IREE_TRACE_REPLAY_WRITE_ELEMENT_CASE(INT_64, int64_t)
+ IREE_TRACE_REPLAY_WRITE_ELEMENT_CASE(SINT_8, int8_t)
+ IREE_TRACE_REPLAY_WRITE_ELEMENT_CASE(SINT_16, int16_t)
+ IREE_TRACE_REPLAY_WRITE_ELEMENT_CASE(SINT_32, int32_t)
+ IREE_TRACE_REPLAY_WRITE_ELEMENT_CASE(SINT_64, int64_t)
+ IREE_TRACE_REPLAY_WRITE_ELEMENT_CASE(UINT_8, uint8_t)
+ IREE_TRACE_REPLAY_WRITE_ELEMENT_CASE(UINT_16, uint16_t)
+ IREE_TRACE_REPLAY_WRITE_ELEMENT_CASE(UINT_32, uint32_t)
+ IREE_TRACE_REPLAY_WRITE_ELEMENT_CASE(UINT_64, uint64_t)
+ // clang-format off
+ case IREE_HAL_ELEMENT_TYPE_FLOAT_16:
+ *(uint16_t*)dst = iree_math_f32_to_f16((float)value);
+ break;
+ IREE_TRACE_REPLAY_WRITE_ELEMENT_CASE(FLOAT_32, float)
+ IREE_TRACE_REPLAY_WRITE_ELEMENT_CASE(FLOAT_64, double)
+ // clang-format on
+ default:
+ IREE_ASSERT(false, "unhandled element type");
+ break;
+ }
+
+#undef IREE_TRACE_REPLAY_WRITE_ELEMENT_CASE
+}
+
+// Simple deterministic pseudorandom generator.
+// This function is same as C++'s std::minstd_rand.
+static uint32_t iree_tree_replay_pseudorandom_uint32(uint32_t* state) {
+ *state = (*state * IREE_PRNG_MULTIPLIER) % IREE_PRNG_MODULUS;
+ return *state;
+}
+
+// Returns a random uint8_t in the range of [0, UCHAR_MAX].
+static uint8_t iree_trace_replay_pseudorandom_uint8(uint32_t* state) {
+ // return the second-least-signicant out of the 4 bytes of state. it avoids
+ // some mild issues with the least-significant and most-significant bytes.
+ return iree_tree_replay_pseudorandom_uint32(state) >> 8;
+}
+
+// Returns a random uint32_t in the range [0, range).
+static inline uint32_t iree_trace_replay_pseudorandom_range(uint32_t* state,
+ uint32_t range) {
+ return iree_tree_replay_pseudorandom_uint32(state) % range;
+}
+
+// Returns a random double in the range of [0, 1.0).
+static double iree_trace_replay_pseudorandom_double(uint32_t* state) {
+ const double inv_modulus = 1.0 / IREE_PRNG_MODULUS;
+ return iree_tree_replay_pseudorandom_uint32(state) * inv_modulus;
+}
+
+// Get minimum and maximum for integer-valued uniform distribution.
+static void iree_trace_replay_get_min_max_for_element_type(
+ iree_hal_element_type_t element_type, int32_t* min, int32_t* max) {
+ switch (element_type) {
+ case IREE_HAL_ELEMENT_TYPE_INT_8:
+ case IREE_HAL_ELEMENT_TYPE_SINT_8:
+ *min = -2;
+ *max = +2;
+ break;
+ case IREE_HAL_ELEMENT_TYPE_UINT_8:
+ *min = 0;
+ *max = +2;
+ break;
+ case IREE_HAL_ELEMENT_TYPE_INT_16:
+ case IREE_HAL_ELEMENT_TYPE_SINT_16:
+ case IREE_HAL_ELEMENT_TYPE_FLOAT_16:
+ *min = -4;
+ *max = +4;
+ break;
+ case IREE_HAL_ELEMENT_TYPE_UINT_16:
+ *min = 0;
+ *max = +4;
+ break;
+ case IREE_HAL_ELEMENT_TYPE_INT_32:
+ case IREE_HAL_ELEMENT_TYPE_SINT_32:
+ case IREE_HAL_ELEMENT_TYPE_FLOAT_32:
+ *min = -8;
+ *max = +8;
+ break;
+ case IREE_HAL_ELEMENT_TYPE_UINT_32:
+ *min = 0;
+ *max = +8;
+ break;
+ case IREE_HAL_ELEMENT_TYPE_INT_64:
+ case IREE_HAL_ELEMENT_TYPE_SINT_64:
+ case IREE_HAL_ELEMENT_TYPE_FLOAT_64:
+ *min = -16;
+ *min = +16;
+ break;
+ case IREE_HAL_ELEMENT_TYPE_UINT_64:
+ *min = 0;
+ *max = +16;
+ break;
+ default:
+ IREE_ASSERT(false, "unhandled element type");
+ break;
+ }
+}
+
+// Fills the destination span with pseudorandom values of the given
+// |element_type|. The given |seed| is passed to the pseudorandom generator.
+// The pseudorandom values are reproducible both across runs and across
+// machines.
+static void iree_trace_replay_generate_fully_specified_pseudorandom_buffer(
+ iree_hal_element_type_t element_type, iree_byte_span_t span,
+ uint32_t seed) {
+ iree_host_size_t element_byte_count =
+ iree_hal_element_dense_byte_count(element_type);
+ uint8_t* data_end = span.data + span.data_length;
+ uint32_t state = seed;
+ uint32_t range;
+ int32_t min, max;
+ iree_trace_replay_get_min_max_for_element_type(element_type, &min, &max);
+ range = (max - min + 1);
+ for (uint8_t* data = span.data; data < data_end; data += element_byte_count) {
+ // Generate "uniform" integer-valued numbers in the range [min, max].
+ int32_t value =
+ (int32_t)iree_trace_replay_pseudorandom_range(&state, range) + min;
+ iree_trace_replay_write_element(element_type, value, data);
+ }
+}
+
+//===----------------------------------------------------------------------===//
+// List I/O macros
+//===----------------------------------------------------------------------===//
+
+// Parses an I/O macro referencing the replay-global inputs/outputs.
+// If |is_move| is true then the value is consumed from |list| and the original
+// value in the list is reset to NULL.
+//
+// ```yaml
+// !input.get 0
+// !input.take 1
+// !output.get 2
+// ```
+static iree_status_t iree_trace_replay_parse_list_get_macro(
+ iree_trace_replay_t* replay, yaml_document_t* document,
+ yaml_node_t* value_node, iree_vm_list_t* list, bool is_move,
+ iree_vm_variant_t* out_result) {
+ iree_string_view_t value_str = iree_yaml_node_as_string(value_node);
+ int32_t ordinal = 0;
+ if (!iree_string_view_atoi_int32(value_str, &ordinal)) {
+ return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+ "failed to parse I/O ordinal from `%.*s`",
+ (int)value_str.size, value_str.data);
+ }
+ iree_vm_variant_t variant = iree_vm_variant_empty();
+ if (is_move) {
+ IREE_RETURN_IF_ERROR(
+ iree_vm_list_get_variant_move(list, ordinal, &variant));
+ } else {
+ IREE_RETURN_IF_ERROR(
+ iree_vm_list_get_variant_retain(list, ordinal, &variant));
+ }
+ *out_result = variant;
+ return iree_ok_status();
+}
+
+// Parses a list load macro referencing a replay-global |list|.
+//
+// ```yaml
+// # gets |variant| at index 2, leaving it in the list for future use
+// [!input.]get 2
+// # takes |variant| at index 2, clearing the entry in the list
+// [!input.]take 2
+// ```
+static iree_status_t iree_trace_replay_parse_list_load_macro(
+ iree_trace_replay_t* replay, yaml_document_t* document,
+ yaml_node_t* value_node, iree_string_view_t name, iree_vm_list_t* list,
+ iree_vm_variant_t* out_result) {
+ if (iree_string_view_equal(name, IREE_SV("get"))) {
+ return iree_trace_replay_parse_list_get_macro(
+ replay, document, value_node, list,
+ /*is_move=*/false, out_result);
+ } else if (iree_string_view_equal(name, IREE_SV("take"))) {
+ return iree_trace_replay_parse_list_get_macro(replay, document, value_node,
+ list,
+ /*is_move=*/true, out_result);
+ } else if (iree_string_view_equal(name, IREE_SV("pop"))) {
+ iree_host_size_t i = iree_vm_list_size(list) - 1;
+ IREE_RETURN_IF_ERROR(iree_vm_list_get_variant_move(list, i, out_result));
+ return iree_vm_list_resize(list, i);
+ }
+ return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
+ "unsupported list load macro: `%.*s`", (int)name.size,
+ name.data);
+}
+
+// Parses an output-set macro referencing the replay-global outputs.
+// The provided |variant| value is set at the specified index.
+//
+// ```yaml
+// !output.set 2
+// ```
+static iree_status_t iree_trace_replay_parse_list_set_macro(
+ iree_trace_replay_t* replay, yaml_document_t* document,
+ yaml_node_t* value_node, iree_vm_list_t* list, iree_vm_variant_t variant) {
+ iree_string_view_t value_str = iree_yaml_node_as_string(value_node);
+ int32_t ordinal = 0;
+ if (!iree_string_view_atoi_int32(value_str, &ordinal)) {
+ return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+ "failed to parse I/O ordinal from `%.*s`",
+ (int)value_str.size, value_str.data);
+ }
+ if (iree_vm_list_size(list) <= ordinal) {
+ IREE_RETURN_IF_ERROR(iree_vm_list_resize(list, ordinal + 1));
+ }
+ return iree_vm_list_set_variant_retain(list, ordinal, &variant);
+}
+
+// Parses a list store macro referencing a replay-global |list|.
+//
+// ```yaml
+// # sets |variant| at index 2 in the output list
+// [!output.]set 2
+// # pushes |variant| to the end of the output list
+// [!output.]push
+// ```
+static iree_status_t iree_trace_replay_parse_list_store_macro(
+ iree_trace_replay_t* replay, yaml_document_t* document,
+ yaml_node_t* value_node, iree_string_view_t name, iree_vm_list_t* list,
+ iree_vm_variant_t variant) {
+ if (iree_string_view_equal(name, IREE_SV("set"))) {
+ return iree_trace_replay_parse_list_set_macro(replay, document, value_node,
+ list, variant);
+ } else if (iree_string_view_equal(name, IREE_SV("push"))) {
+ return iree_vm_list_push_variant_retain(list, &variant);
+ }
+ return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
+ "unsupported list store macro: `%.*s`",
+ (int)name.size, name.data);
+}
+
+//===----------------------------------------------------------------------===//
+// YAML value parsing
+//===----------------------------------------------------------------------===//
+
+static iree_status_t iree_trace_replay_parse_item(
+ iree_trace_replay_t* replay, yaml_document_t* document,
+ yaml_node_t* value_node, iree_vm_variant_t* out_result);
static iree_status_t iree_trace_replay_parse_item_sequence(
iree_trace_replay_t* replay, yaml_document_t* document,
yaml_node_t* sequence_node, iree_vm_list_t* target_list);
@@ -216,10 +504,10 @@
// ```
static iree_status_t iree_trace_replay_parse_scalar(
iree_trace_replay_t* replay, yaml_document_t* document,
- yaml_node_t* value_node, iree_vm_list_t* target_list) {
+ yaml_node_t* value_node, iree_vm_variant_t* out_result) {
yaml_node_t* data_node = NULL;
- IREE_RETURN_IF_ERROR(iree_yaml_mapping_try_find(
- document, value_node, iree_make_cstring_view("i8"), &data_node));
+ IREE_RETURN_IF_ERROR(iree_yaml_mapping_try_find(document, value_node,
+ IREE_SV("i8"), &data_node));
if (data_node) {
int32_t value = 0;
if (!iree_string_view_atoi_int32(iree_yaml_node_as_string(data_node),
@@ -228,13 +516,12 @@
IREE_STATUS_INVALID_ARGUMENT, "failed to parse i8 value: '%.*s'",
(int)data_node->data.scalar.length, data_node->data.scalar.value);
}
- iree_vm_variant_t variant = iree_vm_variant_empty();
- variant.type.value_type = IREE_VM_VALUE_TYPE_I8;
- variant.i8 = (int8_t)value;
- return iree_vm_list_push_variant(target_list, &variant);
+ *out_result =
+ iree_vm_make_variant_value(iree_vm_value_make_i8((int8_t)value));
+ return iree_ok_status();
}
- IREE_RETURN_IF_ERROR(iree_yaml_mapping_try_find(
- document, value_node, iree_make_cstring_view("i16"), &data_node));
+ IREE_RETURN_IF_ERROR(iree_yaml_mapping_try_find(document, value_node,
+ IREE_SV("i16"), &data_node));
if (data_node) {
int32_t value = 0;
if (!iree_string_view_atoi_int32(iree_yaml_node_as_string(data_node),
@@ -243,69 +530,66 @@
IREE_STATUS_INVALID_ARGUMENT, "failed to parse i16 value: '%.*s'",
(int)data_node->data.scalar.length, data_node->data.scalar.value);
}
- iree_vm_variant_t variant = iree_vm_variant_empty();
- variant.type.value_type = IREE_VM_VALUE_TYPE_I16;
- variant.i16 = (int16_t)value;
- return iree_vm_list_push_variant(target_list, &variant);
+ *out_result =
+ iree_vm_make_variant_value(iree_vm_value_make_i16((int16_t)value));
+ return iree_ok_status();
}
- IREE_RETURN_IF_ERROR(iree_yaml_mapping_try_find(
- document, value_node, iree_make_cstring_view("i32"), &data_node));
+ IREE_RETURN_IF_ERROR(iree_yaml_mapping_try_find(document, value_node,
+ IREE_SV("i32"), &data_node));
if (data_node) {
- iree_vm_variant_t variant = iree_vm_variant_empty();
- variant.type.value_type = IREE_VM_VALUE_TYPE_I32;
+ int32_t value = 0;
if (!iree_string_view_atoi_int32(iree_yaml_node_as_string(data_node),
- &variant.i32)) {
+ &value)) {
return iree_make_status(
IREE_STATUS_INVALID_ARGUMENT, "failed to parse i32 value: '%.*s'",
(int)data_node->data.scalar.length, data_node->data.scalar.value);
}
- return iree_vm_list_push_variant(target_list, &variant);
+ *out_result = iree_vm_make_variant_value(iree_vm_value_make_i32(value));
+ return iree_ok_status();
}
- IREE_RETURN_IF_ERROR(iree_yaml_mapping_try_find(
- document, value_node, iree_make_cstring_view("i64"), &data_node));
+ IREE_RETURN_IF_ERROR(iree_yaml_mapping_try_find(document, value_node,
+ IREE_SV("i64"), &data_node));
if (data_node) {
- iree_vm_variant_t variant = iree_vm_variant_empty();
- variant.type.value_type = IREE_VM_VALUE_TYPE_I64;
+ int64_t value = 0;
if (!iree_string_view_atoi_int64(iree_yaml_node_as_string(data_node),
- &variant.i64)) {
+ &value)) {
return iree_make_status(
IREE_STATUS_INVALID_ARGUMENT, "failed to parse i64 value: '%.*s'",
(int)data_node->data.scalar.length, data_node->data.scalar.value);
}
- return iree_vm_list_push_variant(target_list, &variant);
+ *out_result = iree_vm_make_variant_value(iree_vm_value_make_i64(value));
+ return iree_ok_status();
}
- IREE_RETURN_IF_ERROR(iree_yaml_mapping_try_find(
- document, value_node, iree_make_cstring_view("f32"), &data_node));
+ IREE_RETURN_IF_ERROR(iree_yaml_mapping_try_find(document, value_node,
+ IREE_SV("f32"), &data_node));
if (data_node) {
- iree_vm_variant_t variant = iree_vm_variant_empty();
- variant.type.value_type = IREE_VM_VALUE_TYPE_F32;
- if (!iree_string_view_atof(iree_yaml_node_as_string(data_node),
- &variant.f32)) {
+ float value = 0.0f;
+ if (!iree_string_view_atof(iree_yaml_node_as_string(data_node), &value)) {
return iree_make_status(
IREE_STATUS_INVALID_ARGUMENT, "failed to parse f32 value: '%.*s'",
(int)data_node->data.scalar.length, data_node->data.scalar.value);
}
- return iree_vm_list_push_variant(target_list, &variant);
+ *out_result = iree_vm_make_variant_value(iree_vm_value_make_f32(value));
+ return iree_ok_status();
}
- IREE_RETURN_IF_ERROR(iree_yaml_mapping_try_find(
- document, value_node, iree_make_cstring_view("f64"), &data_node));
+ IREE_RETURN_IF_ERROR(iree_yaml_mapping_try_find(document, value_node,
+ IREE_SV("f64"), &data_node));
if (data_node) {
- iree_vm_variant_t variant = iree_vm_variant_empty();
- variant.type.value_type = IREE_VM_VALUE_TYPE_F64;
- if (!iree_string_view_atod(iree_yaml_node_as_string(data_node),
- &variant.f64)) {
+ double value = 0.0;
+ if (!iree_string_view_atod(iree_yaml_node_as_string(data_node), &value)) {
return iree_make_status(
IREE_STATUS_INVALID_ARGUMENT, "failed to parse f64 value: '%.*s'",
(int)data_node->data.scalar.length, data_node->data.scalar.value);
}
- return iree_vm_list_push_variant(target_list, &variant);
+ *out_result = iree_vm_make_variant_value(iree_vm_value_make_f64(value));
+ return iree_ok_status();
}
return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
"(%zu): unimplemented scalar type parser",
value_node->start_mark.line);
}
-// Parses a !vm.list and appends it to |target_list|.
+// Parses a !vm.list into |out_result|.
//
// ```yaml
// items:
@@ -316,7 +600,7 @@
// ```
static iree_status_t iree_trace_replay_parse_vm_list(
iree_trace_replay_t* replay, yaml_document_t* document,
- yaml_node_t* value_node, iree_vm_list_t* target_list) {
+ yaml_node_t* value_node, iree_vm_variant_t* out_result) {
if (value_node->type != YAML_MAPPING_NODE) {
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"(%zu): expected sequence node for type",
@@ -324,7 +608,7 @@
}
yaml_node_t* items_node = NULL;
IREE_RETURN_IF_ERROR(iree_yaml_mapping_try_find(
- document, value_node, iree_make_cstring_view("items"), &items_node));
+ document, value_node, IREE_SV("items"), &items_node));
iree_vm_list_t* list = NULL;
IREE_RETURN_IF_ERROR(iree_vm_list_create(/*element_type=*/NULL,
@@ -338,10 +622,8 @@
}
if (iree_status_is_ok(status)) {
- iree_vm_ref_t list_ref = iree_vm_list_move_ref(list);
- status = iree_vm_list_push_ref_move(target_list, &list_ref);
- }
- if (!iree_status_is_ok(status)) {
+ *out_result = iree_vm_make_variant_ref_assign(iree_vm_list_move_ref(list));
+ } else {
iree_vm_list_release(list);
}
return status;
@@ -515,141 +797,6 @@
}
}
-// Writes an element of the given |element_type| with the given integral |value|
-// to |dst|.
-static void iree_trace_replay_write_element(
- iree_hal_element_type_t element_type, int32_t value, void* dst) {
-#define IREE_TRACE_REPLAY_WRITE_ELEMENT_CASE(ETYPE, CTYPE) \
- case IREE_HAL_ELEMENT_TYPE_##ETYPE: \
- *(CTYPE*)dst = (CTYPE)value; \
- break;
-
- switch (element_type) {
- IREE_TRACE_REPLAY_WRITE_ELEMENT_CASE(INT_8, int8_t)
- IREE_TRACE_REPLAY_WRITE_ELEMENT_CASE(INT_16, int16_t)
- IREE_TRACE_REPLAY_WRITE_ELEMENT_CASE(INT_32, int32_t)
- IREE_TRACE_REPLAY_WRITE_ELEMENT_CASE(INT_64, int64_t)
- IREE_TRACE_REPLAY_WRITE_ELEMENT_CASE(SINT_8, int8_t)
- IREE_TRACE_REPLAY_WRITE_ELEMENT_CASE(SINT_16, int16_t)
- IREE_TRACE_REPLAY_WRITE_ELEMENT_CASE(SINT_32, int32_t)
- IREE_TRACE_REPLAY_WRITE_ELEMENT_CASE(SINT_64, int64_t)
- IREE_TRACE_REPLAY_WRITE_ELEMENT_CASE(UINT_8, uint8_t)
- IREE_TRACE_REPLAY_WRITE_ELEMENT_CASE(UINT_16, uint16_t)
- IREE_TRACE_REPLAY_WRITE_ELEMENT_CASE(UINT_32, uint32_t)
- IREE_TRACE_REPLAY_WRITE_ELEMENT_CASE(UINT_64, uint64_t)
- // clang-format off
- case IREE_HAL_ELEMENT_TYPE_FLOAT_16:
- *(uint16_t*)dst = iree_math_f32_to_f16((float)value);
- break;
- IREE_TRACE_REPLAY_WRITE_ELEMENT_CASE(FLOAT_32, float)
- IREE_TRACE_REPLAY_WRITE_ELEMENT_CASE(FLOAT_64, double)
- // clang-format on
- default:
- IREE_ASSERT(false, "unhandled element type");
- break;
- }
-
-#undef IREE_TRACE_REPLAY_WRITE_ELEMENT_CASE
-}
-
-// Simple deterministic pseudorandom generator.
-// This function is same as C++'s std::minstd_rand.
-static uint32_t iree_tree_replay_pseudorandom_uint32(uint32_t* state) {
- *state = (*state * IREE_PRNG_MULTIPLIER) % IREE_PRNG_MODULUS;
- return *state;
-}
-
-// Returns a random uint8_t in the range of [0, UCHAR_MAX].
-static uint8_t iree_trace_replay_pseudorandom_uint8(uint32_t* state) {
- // return the second-least-signicant out of the 4 bytes of state. it avoids
- // some mild issues with the least-significant and most-significant bytes.
- return iree_tree_replay_pseudorandom_uint32(state) >> 8;
-}
-
-// Returns a random uint32_t in the range [0, range).
-static inline uint32_t iree_trace_replay_pseudorandom_range(uint32_t* state,
- uint32_t range) {
- return iree_tree_replay_pseudorandom_uint32(state) % range;
-}
-
-// Returns a random double in the range of [0, 1.0).
-static double iree_trace_replay_pseudorandom_double(uint32_t* state) {
- const double inv_modulus = 1.0 / IREE_PRNG_MODULUS;
- return iree_tree_replay_pseudorandom_uint32(state) * inv_modulus;
-}
-
-// Get minimum and maximum for integer-valued uniform distribution.
-static void iree_trace_replay_get_min_max_for_element_type(
- iree_hal_element_type_t element_type, int32_t* min, int32_t* max) {
- switch (element_type) {
- case IREE_HAL_ELEMENT_TYPE_INT_8:
- case IREE_HAL_ELEMENT_TYPE_SINT_8:
- *min = -2;
- *max = +2;
- break;
- case IREE_HAL_ELEMENT_TYPE_UINT_8:
- *min = 0;
- *max = +2;
- break;
- case IREE_HAL_ELEMENT_TYPE_INT_16:
- case IREE_HAL_ELEMENT_TYPE_SINT_16:
- case IREE_HAL_ELEMENT_TYPE_FLOAT_16:
- *min = -4;
- *max = +4;
- break;
- case IREE_HAL_ELEMENT_TYPE_UINT_16:
- *min = 0;
- *max = +4;
- break;
- case IREE_HAL_ELEMENT_TYPE_INT_32:
- case IREE_HAL_ELEMENT_TYPE_SINT_32:
- case IREE_HAL_ELEMENT_TYPE_FLOAT_32:
- *min = -8;
- *max = +8;
- break;
- case IREE_HAL_ELEMENT_TYPE_UINT_32:
- *min = 0;
- *max = +8;
- break;
- case IREE_HAL_ELEMENT_TYPE_INT_64:
- case IREE_HAL_ELEMENT_TYPE_SINT_64:
- case IREE_HAL_ELEMENT_TYPE_FLOAT_64:
- *min = -16;
- *min = +16;
- break;
- case IREE_HAL_ELEMENT_TYPE_UINT_64:
- *min = 0;
- *max = +16;
- break;
- default:
- IREE_ASSERT(false, "unhandled element type");
- break;
- }
-}
-
-// Fills the destination span with pseudorandom values of the given
-// |element_type|. The given |seed| is passed to the pseudorandom generator.
-// The pseudorandom values are reproducible both across runs and across
-// machines.
-static void iree_trace_replay_generate_fully_specified_pseudorandom_buffer(
- iree_hal_element_type_t element_type, iree_byte_span_t span,
- uint32_t seed) {
- iree_host_size_t element_byte_count =
- iree_hal_element_dense_byte_count(element_type);
- uint8_t* data_end = span.data + span.data_length;
- uint32_t state = seed;
- uint32_t range;
- int32_t min, max;
- iree_trace_replay_get_min_max_for_element_type(element_type, &min, &max);
- range = (max - min + 1);
- for (uint8_t* data = span.data; data < data_end; data += element_byte_count) {
- // Generate "uniform" integer-valued numbers in the range [min, max].
- int32_t value =
- (int32_t)iree_trace_replay_pseudorandom_range(&state, range) + min;
- iree_trace_replay_write_element(element_type, value, data);
- }
-}
-
// Generates the destination |buffer| using the generator specified by
// |generator_node|.
static iree_status_t iree_trace_replay_generate_hal_buffer(
@@ -712,7 +859,7 @@
}
}
-// Parses a !hal.buffer and appends it to |target_list|.
+// Parses a !hal.buffer into |out_result|.
//
// ```yaml
// shape:
@@ -721,10 +868,10 @@
// ```
static iree_status_t iree_trace_replay_parse_hal_buffer(
iree_trace_replay_t* replay, yaml_document_t* document,
- yaml_node_t* value_node, iree_vm_list_t* target_list) {
+ yaml_node_t* value_node, iree_vm_variant_t* out_result) {
yaml_node_t* shape_node = NULL;
IREE_RETURN_IF_ERROR(iree_yaml_mapping_try_find(
- document, value_node, iree_make_cstring_view("shape"), &shape_node));
+ document, value_node, IREE_SV("shape"), &shape_node));
iree_hal_dim_t shape[16];
iree_host_size_t shape_rank = 0;
IREE_RETURN_IF_ERROR(iree_trace_replay_parse_hal_shape(
@@ -732,16 +879,14 @@
yaml_node_t* element_type_node = NULL;
IREE_RETURN_IF_ERROR(iree_yaml_mapping_find(
- document, value_node, iree_make_cstring_view("element_type"),
- &element_type_node));
+ document, value_node, IREE_SV("element_type"), &element_type_node));
iree_hal_element_type_t element_type = IREE_HAL_ELEMENT_TYPE_NONE;
IREE_RETURN_IF_ERROR(iree_trace_replay_parse_hal_element_type(
replay, document, element_type_node, &element_type));
yaml_node_t* encoding_type_node = NULL;
IREE_RETURN_IF_ERROR(iree_yaml_mapping_try_find(
- document, value_node, iree_make_cstring_view("encoding_type"),
- &encoding_type_node));
+ document, value_node, IREE_SV("encoding_type"), &encoding_type_node));
iree_hal_encoding_type_t encoding_type =
IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR;
IREE_RETURN_IF_ERROR(iree_trace_replay_parse_hal_encoding_type(
@@ -760,13 +905,12 @@
},
allocation_size, iree_const_byte_span_empty(), &buffer));
- iree_vm_ref_t buffer_ref = iree_hal_buffer_move_ref(buffer);
- iree_status_t status = iree_vm_list_push_ref_move(target_list, &buffer_ref);
- iree_vm_ref_release(&buffer_ref);
- return status;
+ *out_result =
+ iree_vm_make_variant_ref_assign(iree_hal_buffer_move_ref(buffer));
+ return iree_ok_status();
}
-// Parses a !hal.buffer_view and appends it to |target_list|.
+// Parses a !hal.buffer_view into |out_result|.
//
// ```yaml
// shape:
@@ -777,10 +921,10 @@
// ```
static iree_status_t iree_trace_replay_parse_hal_buffer_view(
iree_trace_replay_t* replay, yaml_document_t* document,
- yaml_node_t* value_node, iree_vm_list_t* target_list) {
+ yaml_node_t* value_node, iree_vm_variant_t* out_result) {
yaml_node_t* shape_node = NULL;
IREE_RETURN_IF_ERROR(iree_yaml_mapping_try_find(
- document, value_node, iree_make_cstring_view("shape"), &shape_node));
+ document, value_node, IREE_SV("shape"), &shape_node));
iree_hal_dim_t shape[16];
iree_host_size_t shape_rank = 0;
IREE_RETURN_IF_ERROR(iree_trace_replay_parse_hal_shape(
@@ -788,16 +932,14 @@
yaml_node_t* element_type_node = NULL;
IREE_RETURN_IF_ERROR(iree_yaml_mapping_find(
- document, value_node, iree_make_cstring_view("element_type"),
- &element_type_node));
+ document, value_node, IREE_SV("element_type"), &element_type_node));
iree_hal_element_type_t element_type = IREE_HAL_ELEMENT_TYPE_NONE;
IREE_RETURN_IF_ERROR(iree_trace_replay_parse_hal_element_type(
replay, document, element_type_node, &element_type));
yaml_node_t* encoding_type_node = NULL;
IREE_RETURN_IF_ERROR(iree_yaml_mapping_try_find(
- document, value_node, iree_make_cstring_view("encoding_type"),
- &encoding_type_node));
+ document, value_node, IREE_SV("encoding_type"), &encoding_type_node));
iree_hal_encoding_type_t encoding_type =
IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR;
IREE_RETURN_IF_ERROR(iree_trace_replay_parse_hal_encoding_type(
@@ -805,13 +947,11 @@
yaml_node_t* contents_node = NULL;
IREE_RETURN_IF_ERROR(iree_yaml_mapping_try_find(
- document, value_node, iree_make_cstring_view("contents"),
- &contents_node));
+ document, value_node, IREE_SV("contents"), &contents_node));
yaml_node_t* generator_node = NULL;
IREE_RETURN_IF_ERROR(iree_yaml_mapping_try_find(
- document, value_node, iree_make_cstring_view("contents_generator"),
- &generator_node));
+ document, value_node, IREE_SV("contents_generator"), &generator_node));
iree_hal_buffer_view_t* buffer_view = NULL;
if (contents_node && generator_node) {
@@ -848,51 +988,49 @@
iree_const_byte_span_empty(), &buffer_view));
}
- iree_vm_ref_t buffer_view_ref = iree_hal_buffer_view_move_ref(buffer_view);
- iree_status_t status =
- iree_vm_list_push_ref_move(target_list, &buffer_view_ref);
- iree_vm_ref_release(&buffer_view_ref);
- return status;
+ *out_result = iree_vm_make_variant_ref_assign(
+ iree_hal_buffer_view_move_ref(buffer_view));
+ return iree_ok_status();
}
-// Parses a !hal.buffer in tensor form and appends it to |target_list|.
+// Parses a !hal.buffer in tensor form into |out_result|.
// The tensor form is used to size and initialize the buffer but then the
// metadata is thrown away.
//
// ```yaml
-// !!hal.buffer 4xf32=[0 1 2 3]
+// !hal.buffer 4xf32=[0 1 2 3]
// ```
static iree_status_t iree_trace_replay_parse_inline_hal_buffer(
iree_trace_replay_t* replay, yaml_document_t* document,
- yaml_node_t* value_node, iree_vm_list_t* target_list) {
+ yaml_node_t* value_node, iree_vm_variant_t* out_result) {
iree_hal_buffer_view_t* buffer_view = NULL;
IREE_RETURN_IF_ERROR(iree_hal_buffer_view_parse(
iree_yaml_node_as_string(value_node),
iree_hal_device_allocator(replay->device), &buffer_view));
- iree_vm_ref_t buffer_ref =
- iree_hal_buffer_retain_ref(iree_hal_buffer_view_buffer(buffer_view));
- iree_status_t status = iree_vm_list_push_ref_move(target_list, &buffer_ref);
+ *out_result = iree_vm_make_variant_ref_assign(
+ iree_hal_buffer_retain_ref(iree_hal_buffer_view_buffer(buffer_view)));
iree_hal_buffer_view_release(buffer_view);
- return status;
+ return iree_ok_status();
}
-// Parses a !hal.buffer_view in tensor form and appends it to |target_list|.
+// Parses a !hal.buffer_view in tensor form into |out_result|.
//
// ```yaml
// !hal.buffer_view 4xf32=[0 1 2 3]
// ```
static iree_status_t iree_trace_replay_parse_inline_hal_buffer_view(
iree_trace_replay_t* replay, yaml_document_t* document,
- yaml_node_t* value_node, iree_vm_list_t* target_list) {
+ yaml_node_t* value_node, iree_vm_variant_t* out_result) {
iree_hal_buffer_view_t* buffer_view = NULL;
IREE_RETURN_IF_ERROR(iree_hal_buffer_view_parse(
iree_yaml_node_as_string(value_node),
iree_hal_device_allocator(replay->device), &buffer_view));
- iree_vm_ref_t buffer_view_ref = iree_hal_buffer_view_move_ref(buffer_view);
- return iree_vm_list_push_ref_move(target_list, &buffer_view_ref);
+ *out_result = iree_vm_make_variant_ref_assign(
+ iree_hal_buffer_view_move_ref(buffer_view));
+ return iree_ok_status();
}
-// Parses a typed item from |value_node| and appends it to |target_list|.
+// Parses a typed item from |value_node| into |out_result|.
//
// ```yaml
// type: vm.list
@@ -904,39 +1042,46 @@
// ```yaml
// !hal.buffer_view 4xf32=[0 1 2 3]
// ```
-static iree_status_t iree_trace_replay_parse_item(iree_trace_replay_t* replay,
- yaml_document_t* document,
- yaml_node_t* value_node,
- iree_vm_list_t* target_list) {
- if (strcmp(value_node->tag, "!hal.buffer") == 0) {
+static iree_status_t iree_trace_replay_parse_item(
+ iree_trace_replay_t* replay, yaml_document_t* document,
+ yaml_node_t* value_node, iree_vm_variant_t* out_result) {
+ iree_string_view_t tag = iree_make_cstring_view(value_node->tag);
+ if (iree_string_view_consume_prefix(&tag, IREE_SV("!input."))) {
+ return iree_trace_replay_parse_list_load_macro(
+ replay, document, value_node, tag, replay->inputs, out_result);
+ } else if (iree_string_view_consume_prefix(&tag, IREE_SV("!output."))) {
+ return iree_trace_replay_parse_list_load_macro(
+ replay, document, value_node, tag, replay->outputs, out_result);
+ } else if (iree_string_view_consume_prefix(&tag, IREE_SV("!blackboard."))) {
+ return iree_trace_replay_parse_list_load_macro(
+ replay, document, value_node, tag, replay->blackboard, out_result);
+ } else if (strcmp(value_node->tag, "!hal.buffer") == 0) {
return iree_trace_replay_parse_inline_hal_buffer(replay, document,
- value_node, target_list);
+ value_node, out_result);
} else if (strcmp(value_node->tag, "!hal.buffer_view") == 0) {
return iree_trace_replay_parse_inline_hal_buffer_view(
- replay, document, value_node, target_list);
+ replay, document, value_node, out_result);
}
yaml_node_t* type_node = NULL;
- IREE_RETURN_IF_ERROR(iree_yaml_mapping_find(
- document, value_node, iree_make_cstring_view("type"), &type_node));
+ IREE_RETURN_IF_ERROR(iree_yaml_mapping_find(document, value_node,
+ IREE_SV("type"), &type_node));
iree_string_view_t type = iree_yaml_node_as_string(type_node);
- if (iree_string_view_equal(type, iree_make_cstring_view("null"))) {
- iree_vm_variant_t null_value = iree_vm_variant_empty();
- return iree_vm_list_push_variant(target_list, &null_value);
- } else if (iree_string_view_equal(type, iree_make_cstring_view("value"))) {
+ if (iree_string_view_equal(type, IREE_SV("null"))) {
+ *out_result = iree_vm_variant_empty();
+ return iree_ok_status();
+ } else if (iree_string_view_equal(type, IREE_SV("value"))) {
return iree_trace_replay_parse_scalar(replay, document, value_node,
- target_list);
- } else if (iree_string_view_equal(type, iree_make_cstring_view("vm.list"))) {
+ out_result);
+ } else if (iree_string_view_equal(type, IREE_SV("vm.list"))) {
return iree_trace_replay_parse_vm_list(replay, document, value_node,
- target_list);
- } else if (iree_string_view_equal(type,
- iree_make_cstring_view("hal.buffer"))) {
+ out_result);
+ } else if (iree_string_view_equal(type, IREE_SV("hal.buffer"))) {
return iree_trace_replay_parse_hal_buffer(replay, document, value_node,
- target_list);
- } else if (iree_string_view_equal(
- type, iree_make_cstring_view("hal.buffer_view"))) {
+ out_result);
+ } else if (iree_string_view_equal(type, IREE_SV("hal.buffer_view"))) {
return iree_trace_replay_parse_hal_buffer_view(replay, document, value_node,
- target_list);
+ out_result);
}
return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
"unimplemented type parser: '%.*s'", (int)type.size,
@@ -950,80 +1095,59 @@
for (yaml_node_item_t* item = sequence_node->data.sequence.items.start;
item != sequence_node->data.sequence.items.top; ++item) {
yaml_node_t* item_node = yaml_document_get_node(document, *item);
- IREE_RETURN_IF_ERROR(
- iree_trace_replay_parse_item(replay, document, item_node, target_list));
- }
- return iree_ok_status();
-}
-
-static iree_status_t iree_trace_replay_print_item(
- iree_vm_variant_t* value, iree_allocator_t host_allocator);
-
-static iree_status_t iree_trace_replay_print_scalar(iree_vm_variant_t* value) {
- switch (value->type.value_type) {
- case IREE_VM_VALUE_TYPE_I8:
- fprintf(stdout, "i8=%" PRIi8, value->i8);
- break;
- case IREE_VM_VALUE_TYPE_I16:
- fprintf(stdout, "i16=%" PRIi16, value->i16);
- break;
- case IREE_VM_VALUE_TYPE_I32:
- fprintf(stdout, "i32=%" PRIi32, value->i32);
- break;
- case IREE_VM_VALUE_TYPE_I64:
- fprintf(stdout, "i64=%" PRIi64, value->i64);
- break;
- case IREE_VM_VALUE_TYPE_F32:
- fprintf(stdout, "f32=%G", value->f32);
- break;
- case IREE_VM_VALUE_TYPE_F64:
- fprintf(stdout, "f64=%G", value->f64);
- break;
- default:
- fprintf(stdout, "?");
- break;
- }
- return iree_ok_status();
-}
-
-static iree_status_t iree_trace_replay_print_vm_list(
- iree_vm_list_t* list, iree_allocator_t host_allocator) {
- for (iree_host_size_t i = 0; i < iree_vm_list_size(list); ++i) {
iree_vm_variant_t variant = iree_vm_variant_empty();
- IREE_RETURN_IF_ERROR(iree_vm_list_get_variant(list, i, &variant),
- "variant %zu not present", i);
IREE_RETURN_IF_ERROR(
- iree_trace_replay_print_item(&variant, host_allocator));
- fprintf(stdout, "\n");
+ iree_trace_replay_parse_item(replay, document, item_node, &variant));
+ iree_status_t status =
+ iree_vm_list_push_variant_move(target_list, &variant);
+ iree_vm_variant_reset(&variant);
+ IREE_RETURN_IF_ERROR(status);
}
return iree_ok_status();
}
-static iree_status_t iree_trace_replay_print_item(
- iree_vm_variant_t* value, iree_allocator_t host_allocator) {
- if (iree_vm_variant_is_value(*value)) {
- IREE_RETURN_IF_ERROR(iree_trace_replay_print_scalar(value));
- } else if (iree_vm_variant_is_ref(*value)) {
- if (iree_hal_buffer_view_isa(value->ref)) {
- iree_hal_buffer_view_t* buffer_view =
- iree_hal_buffer_view_deref(value->ref);
- IREE_RETURN_IF_ERROR(iree_hal_buffer_view_fprint(
- stdout, buffer_view,
- /*max_element_count=*/1024, host_allocator));
- } else if (iree_vm_list_isa(value->ref)) {
- iree_vm_list_t* list = iree_vm_list_deref(value->ref);
- IREE_RETURN_IF_ERROR(
- iree_trace_replay_print_vm_list(list, host_allocator));
- } else {
- // TODO(benvanik): a way for ref types to describe themselves.
- fprintf(stdout, "(no printer)");
- }
- } else {
- fprintf(stdout, "(null)");
+//===----------------------------------------------------------------------===//
+// Output
+//===----------------------------------------------------------------------===//
+
+// Parses a single item.
+static iree_status_t iree_trace_replay_parse_result_item(
+ iree_trace_replay_t* replay, yaml_document_t* document,
+ yaml_node_t* item_node, iree_vm_variant_t variant) {
+ iree_string_view_t tag = iree_make_cstring_view(item_node->tag);
+ if (iree_string_view_consume_prefix(&tag, IREE_SV("!output."))) {
+ return iree_trace_replay_parse_list_store_macro(
+ replay, document, item_node, tag, replay->outputs, variant);
+ } else if (iree_string_view_consume_prefix(&tag, IREE_SV("!blackboard."))) {
+ return iree_trace_replay_parse_list_store_macro(
+ replay, document, item_node, tag, replay->blackboard, variant);
+ }
+ // NOTE: we ignore other types currently; we could parse them and compare
+ // against the |source_list| values or something.
+ return iree_ok_status();
+}
+
+// Parses a sequence of items and checks each against |source_list|.
+static iree_status_t iree_trace_replay_parse_result_item_sequence(
+ iree_trace_replay_t* replay, yaml_document_t* document,
+ yaml_node_t* sequence_node, iree_vm_list_t* source_list) {
+ iree_host_size_t i = 0;
+ for (yaml_node_item_t* item = sequence_node->data.sequence.items.start;
+ item != sequence_node->data.sequence.items.top; ++item, ++i) {
+ iree_vm_variant_t variant = iree_vm_variant_empty();
+ IREE_RETURN_IF_ERROR(
+ iree_vm_list_get_variant_assign(source_list, i, &variant));
+ yaml_node_t* item_node = yaml_document_get_node(document, *item);
+ IREE_RETURN_IF_ERROR(iree_trace_replay_parse_result_item(
+ replay, document, item_node, variant));
}
return iree_ok_status();
}
+//===----------------------------------------------------------------------===//
+// type: call
+//===----------------------------------------------------------------------===//
+
iree_status_t iree_trace_replay_event_call_prepare(
iree_trace_replay_t* replay, yaml_document_t* document,
yaml_node_t* event_node, iree_vm_function_t* out_function,
@@ -1035,8 +1159,7 @@
// Resolve the function ('module.function') within the context.
yaml_node_t* function_node = NULL;
IREE_RETURN_AND_END_ZONE_IF_ERROR(
- z0, iree_yaml_mapping_find(document, event_node,
- iree_make_cstring_view("function"),
+ z0, iree_yaml_mapping_find(document, event_node, IREE_SV("function"),
&function_node));
iree_vm_function_t function;
IREE_RETURN_AND_END_ZONE_IF_ERROR(
@@ -1047,9 +1170,8 @@
// Parse function inputs.
yaml_node_t* args_node = NULL;
IREE_RETURN_AND_END_ZONE_IF_ERROR(
- z0,
- iree_yaml_mapping_try_find(document, event_node,
- iree_make_cstring_view("args"), &args_node));
+ z0, iree_yaml_mapping_try_find(document, event_node, IREE_SV("args"),
+ &args_node));
iree_vm_list_t* input_list = NULL;
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_vm_list_create(/*element_type=*/NULL, /*initial_capacity=*/8,
@@ -1066,12 +1188,31 @@
return status;
}
-iree_status_t iree_trace_replay_event_call(iree_trace_replay_t* replay,
- yaml_document_t* document,
- yaml_node_t* event_node,
- iree_vm_list_t** out_output_list) {
+iree_status_t iree_trace_replay_event_call_finish(iree_trace_replay_t* replay,
+ yaml_document_t* document,
+ yaml_node_t* event_node,
+ iree_vm_function_t function,
+ iree_vm_list_t* output_list) {
IREE_TRACE_ZONE_BEGIN(z0);
- if (out_output_list) *out_output_list = NULL;
+
+ yaml_node_t* results_node = NULL;
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_yaml_mapping_try_find(document, event_node, IREE_SV("results"),
+ &results_node));
+ if (results_node) {
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_trace_replay_parse_result_item_sequence(
+ replay, document, results_node, output_list));
+ }
+
+ IREE_TRACE_ZONE_END(z0);
+ return iree_ok_status();
+}
+
+iree_status_t iree_trace_replay_event_call(
+ iree_trace_replay_t* replay, yaml_document_t* document,
+ yaml_node_t* event_node, const iree_trace_replay_call_hooks_t* hooks) {
+ IREE_TRACE_ZONE_BEGIN(z0);
iree_vm_function_t function;
iree_vm_list_t* input_list = NULL;
@@ -1079,66 +1220,269 @@
z0, iree_trace_replay_event_call_prepare(replay, document, event_node,
&function, &input_list));
- // Invoke the function to produce outputs.
+ iree_status_t status = iree_ok_status();
+ if (hooks && hooks->before) {
+ status = hooks->before(hooks->user_data, replay, document, event_node,
+ function, input_list);
+ }
+
iree_vm_list_t* output_list = NULL;
- iree_status_t status =
- iree_vm_list_create(/*element_type=*/NULL, /*initial_capacity=*/8,
- replay->host_allocator, &output_list);
if (iree_status_is_ok(status)) {
- status = iree_vm_invoke(replay->context, function,
- IREE_VM_INVOCATION_FLAG_NONE, /*policy=*/NULL,
- input_list, output_list, replay->host_allocator);
+ status = iree_vm_list_create(/*element_type=*/NULL, /*initial_capacity=*/8,
+ replay->host_allocator, &output_list);
+ }
+
+ // Invoke the function to produce outputs.
+ iree_status_t call_status = iree_ok_status();
+ if (iree_status_is_ok(status)) {
+ call_status = iree_vm_invoke(
+ replay->context, function, IREE_VM_INVOCATION_FLAG_NONE,
+ /*policy=*/NULL, input_list, output_list, replay->host_allocator);
}
iree_vm_list_release(input_list);
- if (iree_status_is_ok(status) && out_output_list) {
- *out_output_list = output_list;
- } else {
- iree_vm_list_release(output_list);
+ if (!iree_status_is_ok(call_status)) {
+ if (hooks && hooks->error) {
+ status = hooks->error(hooks->user_data, replay, document, event_node,
+ function, call_status);
+ } else {
+ status = call_status;
+ }
+ } else if (hooks && hooks->after) {
+ status = hooks->after(hooks->user_data, replay, document, event_node,
+ function, output_list);
}
+
+ if (iree_status_is_ok(status)) {
+ status = iree_trace_replay_event_call_finish(replay, document, event_node,
+ function, output_list);
+ }
+ iree_vm_list_release(output_list);
+
IREE_TRACE_ZONE_END(z0);
return status;
}
-static iree_status_t iree_trace_replay_event_call_stdout(
+//===----------------------------------------------------------------------===//
+// Blackboard management
+//===----------------------------------------------------------------------===//
+
+static iree_status_t iree_trace_replay_event_blackboard_clear(
iree_trace_replay_t* replay, yaml_document_t* document,
yaml_node_t* event_node) {
- yaml_node_t* function_node = NULL;
- IREE_RETURN_IF_ERROR(iree_yaml_mapping_find(
- document, event_node, iree_make_cstring_view("function"),
- &function_node));
- iree_string_view_t function_name = iree_yaml_node_as_string(function_node);
- fprintf(stdout, "--- CALL[%.*s] ---\n", (int)function_name.size,
- function_name.data);
+ IREE_TRACE_ZONE_BEGIN(z0);
+ iree_vm_list_clear(replay->blackboard);
+ IREE_TRACE_ZONE_END(z0);
+ return iree_ok_status();
+}
- // Prepare to call the function.
- iree_vm_function_t function;
- iree_vm_list_t* input_list = NULL;
- IREE_RETURN_IF_ERROR(iree_trace_replay_event_call_prepare(
- replay, document, event_node, &function, &input_list));
+static iree_status_t iree_trace_replay_event_blackboard_assign(
+ iree_trace_replay_t* replay, yaml_document_t* document,
+ yaml_node_t* event_node) {
+ IREE_TRACE_ZONE_BEGIN(z0);
- // Invoke the function to produce outputs.
- iree_vm_list_t* output_list = NULL;
+ yaml_node_t* from_node = NULL;
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_yaml_mapping_find(document, event_node, IREE_SV("from"),
+ &from_node));
+ yaml_node_t* to_node = NULL;
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0,
+ iree_yaml_mapping_find(document, event_node, IREE_SV("to"), &to_node));
+
+ iree_vm_list_t* list = NULL;
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_vm_list_create(/*element_type=*/NULL, 8u, replay->host_allocator,
+ &list));
+
iree_status_t status =
- iree_vm_list_create(/*element_type=*/NULL, /*initial_capacity=*/8,
- replay->host_allocator, &output_list);
+ iree_trace_replay_parse_item_sequence(replay, document, from_node, list);
if (iree_status_is_ok(status)) {
- status = iree_vm_invoke(replay->context, function,
- IREE_VM_INVOCATION_FLAG_NONE, /*policy=*/NULL,
- input_list, output_list, replay->host_allocator);
+ status = iree_trace_replay_parse_result_item_sequence(replay, document,
+ to_node, list);
}
- iree_vm_list_release(input_list);
- // Print the outputs.
- if (iree_status_is_ok(status)) {
- status =
- iree_trace_replay_print_vm_list(output_list, replay->host_allocator);
- }
- iree_vm_list_release(output_list);
+ iree_vm_list_release(list);
+ IREE_TRACE_ZONE_END(z0);
return status;
}
+//===----------------------------------------------------------------------===//
+// Numpy ndarray management
+//===----------------------------------------------------------------------===//
+
+// Loads one or more ndarrays from a .npy file.
+//
+// Example:
+// ```yaml
+// type: numpy_load
+// path: three_ndarrays.npy
+// arrays:
+// - !blackboard.set 2
+// - !blackboard.set 3
+// - !output.set 4
+// ```
+//
+// NOTE: this currently reads things into new buffers; we could try mapping from
+// disk and other fancy things where possible.
+static iree_status_t iree_trace_replay_event_numpy_load(
+ iree_trace_replay_t* replay, yaml_document_t* document,
+ yaml_node_t* event_node) {
+ if (!replay->device) {
+ return iree_make_status(
+ IREE_STATUS_FAILED_PRECONDITION,
+ "HAL module must be loaded before loading numpy arrays");
+ }
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ yaml_node_t* path_node = NULL;
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_yaml_mapping_find(document, event_node, IREE_SV("path"),
+ &path_node));
+ iree_string_view_t path_str = iree_yaml_node_as_string(path_node);
+
+ yaml_node_t* arrays_node = NULL;
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_yaml_mapping_find(document, event_node, IREE_SV("arrays"),
+ &arrays_node));
+
+ char* full_path = NULL;
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_file_path_join(replay->root_path, path_str,
+ replay->host_allocator, &full_path));
+ FILE* file = fopen(full_path, "rb");
+ iree_allocator_free(replay->host_allocator, full_path);
+ if (!file) {
+ IREE_TRACE_ZONE_END(z0);
+ return iree_make_status(iree_status_code_from_errno(errno),
+ "failed to open file `%.*s` for read",
+ (int)path_str.size, path_str.data);
+ }
+
+ uint64_t file_length = 0;
+ iree_status_t status = iree_file_query_length(file, &file_length);
+
+ iree_hal_buffer_params_t buffer_params = {0};
+ buffer_params.usage = IREE_HAL_BUFFER_USAGE_DEFAULT;
+ buffer_params.access = IREE_HAL_MEMORY_ACCESS_READ;
+ buffer_params.type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL;
+ iree_hal_allocator_t* device_allocator =
+ iree_hal_device_allocator(replay->device);
+
+ if (iree_status_is_ok(status)) {
+ for (yaml_node_item_t* item = arrays_node->data.sequence.items.start;
+ item != arrays_node->data.sequence.items.top; ++item) {
+ // Parse the next array in the file. Note that we may be at the end!
+ if (iree_file_is_at(file, file_length)) {
+ status = iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+ "file ended before all arrays were decoded");
+ break;
+ }
+ iree_hal_buffer_view_t* buffer_view = NULL;
+ status = iree_numpy_npy_load_ndarray(
+ file, IREE_NUMPY_NPY_LOAD_OPTION_DEFAULT, buffer_params,
+ device_allocator, &buffer_view);
+ if (!iree_status_is_ok(status)) break;
+
+ // Route the loaded value to its destination.
+ iree_vm_variant_t variant = iree_vm_make_variant_ref_assign(
+ iree_hal_buffer_view_move_ref(buffer_view));
+ yaml_node_t* item_node = yaml_document_get_node(document, *item);
+ status = iree_trace_replay_parse_result_item(replay, document, item_node,
+ variant);
+ iree_vm_variant_reset(&variant);
+ if (!iree_status_is_ok(status)) break;
+ }
+ }
+
+ fclose(file);
+
+ IREE_TRACE_ZONE_END(z0);
+ return status;
+}
+
+// Saves one or more ndarrays to a .npy file.
+//
+// Example:
+// ```yaml
+// type: numpy_save
+// path: three_ndarrays.npy
+// append: true
+// arrays:
+// - !blackboard.get 2
+// - !blackboard.take 3
+// - !input.get 4
+// - !hal.buffer_view 4xf32=0,1,2,3
+// ```
+static iree_status_t iree_trace_replay_event_numpy_save(
+ iree_trace_replay_t* replay, yaml_document_t* document,
+ yaml_node_t* event_node) {
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ yaml_node_t* path_node = NULL;
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_yaml_mapping_find(document, event_node, IREE_SV("path"),
+ &path_node));
+ iree_string_view_t path_str = iree_yaml_node_as_string(path_node);
+
+ yaml_node_t* append_node = NULL;
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_yaml_mapping_try_find(document, event_node, IREE_SV("append"),
+ &append_node));
+
+ yaml_node_t* arrays_node = NULL;
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_yaml_mapping_find(document, event_node, IREE_SV("arrays"),
+ &arrays_node));
+
+ char* full_path = NULL;
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_file_path_join(replay->root_path, path_str,
+ replay->host_allocator, &full_path));
+ const char* mode = append_node ? "ab" : "wb";
+ FILE* file = fopen(full_path, mode);
+ iree_allocator_free(replay->host_allocator, full_path);
+ if (!file) {
+ IREE_TRACE_ZONE_END(z0);
+ return iree_make_status(iree_status_code_from_errno(errno),
+ "failed to open file `%.*s` for write",
+ (int)path_str.size, path_str.data);
+ }
+
+ iree_status_t status = iree_ok_status();
+ for (yaml_node_item_t* item = arrays_node->data.sequence.items.start;
+ item != arrays_node->data.sequence.items.top; ++item) {
+ yaml_node_t* item_node = yaml_document_get_node(document, *item);
+ iree_vm_variant_t variant = iree_vm_variant_empty();
+ status =
+ iree_trace_replay_parse_item(replay, document, item_node, &variant);
+ if (!iree_status_is_ok(status)) break;
+ if (!iree_vm_variant_is_ref(variant) ||
+ !iree_hal_buffer_view_isa(variant.ref)) {
+ status =
+ iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+ "only buffer views can be saved to numpy files");
+ break;
+ }
+ iree_hal_buffer_view_t* buffer_view =
+ iree_hal_buffer_view_deref(variant.ref);
+ status =
+ iree_numpy_npy_save_ndarray(file, IREE_NUMPY_NPY_SAVE_OPTION_DEFAULT,
+ buffer_view, replay->host_allocator);
+ iree_vm_variant_reset(&variant);
+ }
+
+ fclose(file);
+
+ IREE_TRACE_ZONE_END(z0);
+ return status;
+}
+
+//===----------------------------------------------------------------------===//
+// Event dispatch
+//===----------------------------------------------------------------------===//
+
iree_status_t iree_trace_replay_event(iree_trace_replay_t* replay,
yaml_document_t* document,
yaml_node_t* event_node) {
@@ -1148,17 +1492,25 @@
event_node->start_mark.line);
}
yaml_node_t* type_node = NULL;
- IREE_RETURN_IF_ERROR(iree_yaml_mapping_find(
- document, event_node, iree_make_cstring_view("type"), &type_node));
- if (iree_yaml_string_equal(type_node,
- iree_make_cstring_view("context_load"))) {
+ IREE_RETURN_IF_ERROR(iree_yaml_mapping_find(document, event_node,
+ IREE_SV("type"), &type_node));
+ if (iree_yaml_string_equal(type_node, IREE_SV("context_load"))) {
return iree_trace_replay_event_context_load(replay, document, event_node);
- } else if (iree_yaml_string_equal(type_node,
- iree_make_cstring_view("module_load"))) {
+ } else if (iree_yaml_string_equal(type_node, IREE_SV("module_load"))) {
return iree_trace_replay_event_module_load(replay, document, event_node);
- } else if (iree_yaml_string_equal(type_node,
- iree_make_cstring_view("call"))) {
- return iree_trace_replay_event_call_stdout(replay, document, event_node);
+ } else if (iree_yaml_string_equal(type_node, IREE_SV("blackboard_clear"))) {
+ return iree_trace_replay_event_blackboard_clear(replay, document,
+ event_node);
+ } else if (iree_yaml_string_equal(type_node, IREE_SV("assign"))) {
+ return iree_trace_replay_event_blackboard_assign(replay, document,
+ event_node);
+ } else if (iree_yaml_string_equal(type_node, IREE_SV("numpy_load"))) {
+ return iree_trace_replay_event_numpy_load(replay, document, event_node);
+ } else if (iree_yaml_string_equal(type_node, IREE_SV("numpy_save"))) {
+ return iree_trace_replay_event_numpy_save(replay, document, event_node);
+ } else if (iree_yaml_string_equal(type_node, IREE_SV("call"))) {
+ return iree_trace_replay_event_call(replay, document, event_node,
+ &replay->call_hooks);
}
return iree_make_status(
IREE_STATUS_UNIMPLEMENTED, "(%zu): unhandled type '%.*s'",
diff --git a/runtime/src/iree/tooling/trace_replay.h b/runtime/src/iree/tooling/trace_replay.h
index db2de14..ddc240c 100644
--- a/runtime/src/iree/tooling/trace_replay.h
+++ b/runtime/src/iree/tooling/trace_replay.h
@@ -16,12 +16,37 @@
extern "C" {
#endif // __cplusplus
+typedef struct iree_trace_replay_t iree_trace_replay_t;
+
enum iree_trace_replay_shutdown_flag_bits_e {
IREE_TRACE_REPLAY_SHUTDOWN_QUIET = 0u,
IREE_TRACE_REPLAY_SHUTDOWN_PRINT_STATISTICS = 1 << 0u,
};
typedef uint32_t iree_trace_replay_shutdown_flags_t;
+// Optional set of callbacks around a replay event function call.
+// Functions not required by the caller may be omitted.
+typedef struct iree_trace_replay_call_hooks_t {
+ // User context passed to each callback.
+ void* user_data;
+ // Issued before the call begins with the call inputs.
+ iree_status_t (*before)(void* user_data, iree_trace_replay_t* replay,
+ yaml_document_t* document, yaml_node_t* event_node,
+ iree_vm_function_t function,
+ iree_vm_list_t* input_list);
+ // Issued after the call completes successfully with the call outputs.
+ iree_status_t (*after)(void* user_data, iree_trace_replay_t* replay,
+ yaml_document_t* document, yaml_node_t* event_node,
+ iree_vm_function_t function,
+ iree_vm_list_t* output_list);
+ // Issued only when the call fails and not the replay operation itself.
+ // |status| is as returned from the call and ownership is transferred to the
+ // hook.
+ iree_status_t (*error)(void* user_data, iree_trace_replay_t* replay,
+ yaml_document_t* document, yaml_node_t* event_node,
+ iree_vm_function_t function, iree_status_t status);
+} iree_trace_replay_call_hooks_t;
+
typedef struct iree_trace_replay_t {
iree_allocator_t host_allocator;
iree_string_view_t root_path;
@@ -33,8 +58,21 @@
iree_host_size_t device_uri_count;
const iree_string_view_t* device_uris;
+ // Context used within the replay, modules registered on-demand.
iree_vm_context_t* context;
+
+ // Active HAL device if any. Will be initialized on the first HAL module load.
iree_hal_device_t* device;
+
+ // Optional inputs available via `!input.get`/`!input.take`.
+ iree_vm_list_t* inputs;
+ // Optional outputs populated via `!output.set`/`!output.push`.
+ iree_vm_list_t* outputs;
+ // Blackboard used to track state within the trace.
+ iree_vm_list_t* blackboard;
+
+ // Optional call hooks allowing reflection of calls and their I/O.
+ iree_trace_replay_call_hooks_t call_hooks;
} iree_trace_replay_t;
// Initializes a trace replay context.
@@ -82,12 +120,11 @@
iree_vm_list_t** out_input_list);
// Replays a `call` event against the replay context.
-// Optionally |out_output_list| can be populated with a caller-owned set of
-// outputs from the call.
-iree_status_t iree_trace_replay_event_call(iree_trace_replay_t* replay,
- yaml_document_t* document,
- yaml_node_t* event_node,
- iree_vm_list_t** out_output_list);
+// Optionally |hooks| may be specified to inspect the inputs and outputs of the
+// call operation.
+iree_status_t iree_trace_replay_event_call(
+ iree_trace_replay_t* replay, yaml_document_t* document,
+ yaml_node_t* event_node, const iree_trace_replay_call_hooks_t* hooks);
#ifdef __cplusplus
} // extern "C"
diff --git a/runtime/src/iree/tooling/vm_util.c b/runtime/src/iree/tooling/vm_util.c
index 3052b9b..c3b6fa3 100644
--- a/runtime/src/iree/tooling/vm_util.c
+++ b/runtime/src/iree/tooling/vm_util.c
@@ -159,6 +159,31 @@
z0,
iree_vm_list_create(
/*element_type=*/NULL, input_strings_count, host_allocator, &list));
+
+ iree_status_t status = iree_tooling_parse_into_variant_list(
+ device_allocator, input_strings, input_strings_count, host_allocator,
+ list);
+ if (iree_status_is_ok(status)) {
+ *out_list = list;
+ } else {
+ iree_vm_list_release(list);
+ }
+ IREE_TRACE_ZONE_END(z0);
+ return status;
+}
+
+iree_status_t iree_tooling_parse_into_variant_list(
+ iree_hal_allocator_t* device_allocator,
+ const iree_string_view_t* input_strings,
+ iree_host_size_t input_strings_count, iree_allocator_t host_allocator,
+ iree_vm_list_t* list) {
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ // Reset the list and prepare for pushing items.
+ iree_vm_list_clear(list);
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_vm_list_reserve(list, input_strings_count));
+
iree_status_t status = iree_ok_status();
for (size_t i = 0; i < input_strings_count; ++i) {
if (!iree_status_is_ok(status)) break;
@@ -250,11 +275,7 @@
if (!iree_status_is_ok(status)) break;
}
}
- if (iree_status_is_ok(status)) {
- *out_list = list;
- } else {
- iree_vm_list_release(list);
- }
+
IREE_TRACE_ZONE_END(z0);
return status;
}
@@ -334,7 +355,12 @@
IREE_RETURN_IF_ERROR(iree_string_builder_append_string(builder, type_name));
IREE_RETURN_IF_ERROR(
iree_string_builder_append_string(builder, IREE_SV("\n")));
- if (iree_hal_buffer_view_isa(variant.ref)) {
+ if (iree_vm_list_isa(variant.ref)) {
+ iree_vm_list_t* child_list = iree_vm_list_deref(variant.ref);
+ IREE_RETURN_IF_ERROR(iree_tooling_append_variant_list_lines(
+ IREE_SV("child_list"), child_list, max_element_count, builder));
+ return iree_string_builder_append_string(builder, IREE_SV("\n"));
+ } else if (iree_hal_buffer_view_isa(variant.ref)) {
iree_hal_buffer_view_t* buffer_view =
iree_hal_buffer_view_deref(variant.ref);
IREE_RETURN_IF_ERROR(iree_hal_buffer_view_append_to_builder(
@@ -370,15 +396,16 @@
}
iree_status_t iree_tooling_append_variant_list_lines(
- iree_vm_list_t* list, iree_host_size_t max_element_count,
- iree_string_builder_t* builder) {
+ iree_string_view_t list_name, iree_vm_list_t* list,
+ iree_host_size_t max_element_count, iree_string_builder_t* builder) {
IREE_TRACE_ZONE_BEGIN(z0);
for (iree_host_size_t i = 0; i < iree_vm_list_size(list); ++i) {
iree_vm_variant_t variant = iree_vm_variant_empty();
IREE_RETURN_AND_END_ZONE_IF_ERROR(
- z0, iree_vm_list_get_variant(list, i, &variant),
+ z0, iree_vm_list_get_variant_assign(list, i, &variant),
"variant %zu not present", i);
- iree_string_builder_append_format(builder, "result[%zu]: ", i);
+ iree_string_builder_append_format(
+ builder, "%.*s[%" PRIhsz "]: ", (int)list_name.size, list_name.data, i);
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_variant_format(variant, max_element_count, builder));
}
@@ -387,11 +414,12 @@
}
iree_status_t iree_tooling_variant_list_fprint(
- iree_vm_list_t* list, iree_host_size_t max_element_count, FILE* file) {
+ iree_string_view_t list_name, iree_vm_list_t* list,
+ iree_host_size_t max_element_count, FILE* file) {
iree_string_builder_t builder;
iree_string_builder_initialize(iree_allocator_system(), &builder);
- iree_status_t status =
- iree_tooling_append_variant_list_lines(list, max_element_count, &builder);
+ iree_status_t status = iree_tooling_append_variant_list_lines(
+ list_name, list, max_element_count, &builder);
if (iree_status_is_ok(status)) {
size_t written = fwrite(iree_string_builder_buffer(&builder), 1,
iree_string_builder_size(&builder), file);
@@ -482,7 +510,7 @@
for (iree_host_size_t i = 0; i < output_strings_count; ++i) {
iree_vm_variant_t variant = iree_vm_variant_empty();
IREE_RETURN_AND_END_ZONE_IF_ERROR(
- z0, iree_vm_list_get_variant(list, i, &variant));
+ z0, iree_vm_list_get_variant_assign(list, i, &variant));
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_tooling_output_variant(variant, output_strings[i],
max_element_count, file));
diff --git a/runtime/src/iree/tooling/vm_util.h b/runtime/src/iree/tooling/vm_util.h
index b72d4ed..0375099 100644
--- a/runtime/src/iree/tooling/vm_util.h
+++ b/runtime/src/iree/tooling/vm_util.h
@@ -32,6 +32,19 @@
iree_host_size_t input_strings_count, iree_allocator_t host_allocator,
iree_vm_list_t** out_list);
+// Parses |input_strings| into a variant list of VM scalars and buffers.
+// Scalars should be in the format:
+// type=value
+// Buffers should be in the IREE standard shaped buffer format:
+// [shape]xtype=[value]
+// described in iree/hal/api.h
+// Uses |device_allocator| to allocate the buffers.
+iree_status_t iree_tooling_parse_into_variant_list(
+ iree_hal_allocator_t* device_allocator,
+ const iree_string_view_t* input_strings,
+ iree_host_size_t input_strings_count, iree_allocator_t host_allocator,
+ iree_vm_list_t* list);
+
// Appends fences to |list| if the invocation model of |function| requires them.
// If no |wait_fence| is provided then the invocation will begin immediately.
// The caller must wait on the returned |out_signal_fence| before accessing the
@@ -42,6 +55,8 @@
iree_hal_fence_t** out_signal_fence);
// Appends a variant list of VM scalars and buffers to |builder|.
+// |list_name| will be printed alongside each element ordinal.
+//
// Prints scalars in the format:
// value
// Prints buffers in the IREE standard shaped buffer format:
@@ -49,12 +64,14 @@
// described in
// https://github.com/openxla/iree/tree/main/iree/hal/api.h
iree_status_t iree_tooling_append_variant_list_lines(
- iree_vm_list_t* list, iree_host_size_t max_element_count,
- iree_string_builder_t* builder);
+ iree_string_view_t list_name, iree_vm_list_t* list,
+ iree_host_size_t max_element_count, iree_string_builder_t* builder);
-// Prints a variant list to a file.
+// Prints a variant list to a |file|.
+// |list_name| will be printed alongside each element ordinal.
iree_status_t iree_tooling_variant_list_fprint(
- iree_vm_list_t* list, iree_host_size_t max_element_count, FILE* file);
+ iree_string_view_t list_name, iree_vm_list_t* list,
+ iree_host_size_t max_element_count, FILE* file);
// Prints a variant |list| to targets based on the provided |output_strings|.
//
diff --git a/runtime/src/iree/tooling/vm_util_test.cc b/runtime/src/iree/tooling/vm_util_test.cc
index 5d33a68..a42cfa7 100644
--- a/runtime/src/iree/tooling/vm_util_test.cc
+++ b/runtime/src/iree/tooling/vm_util_test.cc
@@ -37,7 +37,7 @@
iree_string_builder_t builder;
iree_string_builder_initialize(iree_allocator_system(), &builder);
IREE_RETURN_IF_ERROR(iree_tooling_append_variant_list_lines(
- variant_list, /*max_element_count=*/1024, &builder));
+ IREE_SV("result"), variant_list, /*max_element_count=*/1024, &builder));
out_string->assign(iree_string_builder_buffer(&builder),
iree_string_builder_size(&builder));
iree_string_builder_deinitialize(&builder);
diff --git a/runtime/src/iree/vm/BUILD b/runtime/src/iree/vm/BUILD
index 1fac929..3969e8d 100644
--- a/runtime/src/iree/vm/BUILD
+++ b/runtime/src/iree/vm/BUILD
@@ -75,6 +75,7 @@
"stack.h",
"type_def.h",
"value.h",
+ "variant.h",
],
deps = [
"//runtime/src/iree/base",
diff --git a/runtime/src/iree/vm/CMakeLists.txt b/runtime/src/iree/vm/CMakeLists.txt
index b0447c3..48013b5 100644
--- a/runtime/src/iree/vm/CMakeLists.txt
+++ b/runtime/src/iree/vm/CMakeLists.txt
@@ -51,6 +51,7 @@
"stack.h"
"type_def.h"
"value.h"
+ "variant.h"
SRCS
"buffer.c"
"context.c"
diff --git a/runtime/src/iree/vm/api.h b/runtime/src/iree/vm/api.h
index 87e5750..45a0d39 100644
--- a/runtime/src/iree/vm/api.h
+++ b/runtime/src/iree/vm/api.h
@@ -20,5 +20,6 @@
#include "iree/vm/stack.h" // IWYU pragma: export
#include "iree/vm/type_def.h" // IWYU pragma: export
#include "iree/vm/value.h" // IWYU pragma: export
+#include "iree/vm/variant.h" // IWYU pragma: export
#endif // IREE_VM_API_H_
diff --git a/runtime/src/iree/vm/list.c b/runtime/src/iree/vm/list.c
index 0dafa55..cc4b703 100644
--- a/runtime/src/iree/vm/list.c
+++ b/runtime/src/iree/vm/list.c
@@ -925,34 +925,60 @@
return iree_ok_status();
}
-IREE_API_EXPORT iree_status_t
-iree_vm_list_get_variant(const iree_vm_list_t* list, iree_host_size_t i,
- iree_vm_variant_t* out_value) {
+typedef enum {
+ IREE_VM_LIST_REF_ASSIGN = 0,
+ IREE_VM_LIST_REF_RETAIN,
+ IREE_VM_LIST_REF_MOVE,
+} iree_vm_list_ref_mode_t;
+
+static void iree_vm_list_ref_op(iree_vm_list_ref_mode_t mode,
+ iree_vm_ref_t* ref, iree_vm_ref_t* out_ref) {
+ switch (mode) {
+ case IREE_VM_LIST_REF_ASSIGN:
+ iree_vm_ref_assign(ref, out_ref);
+ break;
+ case IREE_VM_LIST_REF_RETAIN:
+ iree_vm_ref_retain(ref, out_ref);
+ break;
+ case IREE_VM_LIST_REF_MOVE:
+ iree_vm_ref_move(ref, out_ref);
+ break;
+ }
+}
+
+static iree_status_t iree_vm_list_get_variant(const iree_vm_list_t* list,
+ iree_host_size_t i,
+ iree_vm_list_ref_mode_t ref_mode,
+ iree_vm_variant_t* out_variant) {
+ IREE_ASSERT_ARGUMENT(list);
+ IREE_ASSERT_ARGUMENT(out_variant);
if (i >= list->count) {
return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
"index %zu out of bounds (%zu)", i, list->count);
}
+ iree_vm_variant_reset(out_variant);
uintptr_t element_ptr = (uintptr_t)list->storage + i * list->element_size;
switch (list->storage_mode) {
case IREE_VM_LIST_STORAGE_MODE_VALUE: {
- out_value->type = list->element_type;
- memcpy(out_value->value_storage, (void*)element_ptr, list->element_size);
+ out_variant->type = list->element_type;
+ memcpy(out_variant->value_storage, (void*)element_ptr,
+ list->element_size);
break;
}
case IREE_VM_LIST_STORAGE_MODE_REF: {
iree_vm_ref_t* element_ref = (iree_vm_ref_t*)element_ptr;
- out_value->type.ref_type = element_ref->type;
- out_value->type.value_type = IREE_VM_VALUE_TYPE_NONE;
- iree_vm_ref_assign(element_ref, &out_value->ref);
+ out_variant->type.ref_type = element_ref->type;
+ out_variant->type.value_type = IREE_VM_VALUE_TYPE_NONE;
+ iree_vm_list_ref_op(ref_mode, element_ref, &out_variant->ref);
break;
}
case IREE_VM_LIST_STORAGE_MODE_VARIANT: {
iree_vm_variant_t* variant = (iree_vm_variant_t*)element_ptr;
- out_value->type = variant->type;
+ out_variant->type = variant->type;
if (iree_vm_type_def_is_ref(&variant->type)) {
- iree_vm_ref_assign(&variant->ref, &out_value->ref);
+ iree_vm_list_ref_op(ref_mode, &variant->ref, &out_variant->ref);
} else {
- memcpy(out_value->value_storage, variant->value_storage,
+ memcpy(out_variant->value_storage, variant->value_storage,
sizeof(variant->value_storage));
}
break;
@@ -963,17 +989,80 @@
return iree_ok_status();
}
-IREE_API_EXPORT iree_status_t iree_vm_list_set_variant(
- iree_vm_list_t* list, iree_host_size_t i, const iree_vm_variant_t* value) {
- return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
- "iree_vm_list_set_variant unimplemented");
+IREE_API_EXPORT iree_status_t
+iree_vm_list_get_variant_assign(const iree_vm_list_t* list, iree_host_size_t i,
+ iree_vm_variant_t* out_variant) {
+ return iree_vm_list_get_variant(list, i, IREE_VM_LIST_REF_ASSIGN,
+ out_variant);
}
-IREE_API_EXPORT iree_status_t iree_vm_list_push_variant(
- iree_vm_list_t* list, const iree_vm_variant_t* value) {
+IREE_API_EXPORT iree_status_t
+iree_vm_list_get_variant_retain(const iree_vm_list_t* list, iree_host_size_t i,
+ iree_vm_variant_t* out_variant) {
+ return iree_vm_list_get_variant(list, i, IREE_VM_LIST_REF_RETAIN,
+ out_variant);
+}
+
+IREE_API_EXPORT iree_status_t
+iree_vm_list_get_variant_move(const iree_vm_list_t* list, iree_host_size_t i,
+ iree_vm_variant_t* out_variant) {
+ return iree_vm_list_get_variant(list, i, IREE_VM_LIST_REF_MOVE, out_variant);
+}
+
+static iree_status_t iree_vm_list_set_variant(iree_vm_list_t* list,
+ iree_host_size_t i, bool is_move,
+ iree_vm_variant_t* variant) {
+ if (iree_vm_type_def_is_variant(&variant->type)) {
+ iree_vm_value_t value = iree_vm_variant_value(*variant);
+ return iree_vm_list_set_value(list, i, &value);
+ } else if (iree_vm_type_def_is_value(&variant->type)) {
+ iree_vm_value_t value = {
+ .type = variant->type.value_type,
+ };
+ memcpy(value.value_storage, variant->value_storage,
+ sizeof(value.value_storage));
+ return iree_vm_list_set_value(list, i, &value);
+ } else if (iree_vm_type_def_is_ref(&variant->type)) {
+ iree_status_t status =
+ iree_vm_list_set_ref(list, i, is_move, &variant->ref);
+ if (iree_status_is_ok(status) && is_move) {
+ variant->type.ref_type = IREE_VM_REF_TYPE_NULL;
+ }
+ return status;
+ } else {
+ return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
+ "unhandled variant value type");
+ }
+}
+
+IREE_API_EXPORT iree_status_t
+iree_vm_list_set_variant_retain(iree_vm_list_t* list, iree_host_size_t i,
+ const iree_vm_variant_t* variant) {
+ return iree_vm_list_set_variant(list, i, /*is_move=*/false,
+ (iree_vm_variant_t*)variant);
+}
+
+IREE_API_EXPORT iree_status_t iree_vm_list_set_variant_move(
+ iree_vm_list_t* list, iree_host_size_t i, iree_vm_variant_t* variant) {
+ return iree_vm_list_set_variant(list, i, /*is_move=*/true, variant);
+}
+
+static iree_status_t iree_vm_list_push_variant(
+ iree_vm_list_t* list, bool is_move, const iree_vm_variant_t* variant) {
iree_host_size_t i = iree_vm_list_size(list);
IREE_RETURN_IF_ERROR(iree_vm_list_resize(list, i + 1));
- return iree_vm_list_set_variant(list, i, value);
+ return iree_vm_list_set_variant(list, i, is_move,
+ (iree_vm_variant_t*)variant);
+}
+
+IREE_API_EXPORT iree_status_t iree_vm_list_push_variant_retain(
+ iree_vm_list_t* list, const iree_vm_variant_t* variant) {
+ return iree_vm_list_push_variant(list, /*is_move=*/false, variant);
+}
+
+IREE_API_EXPORT iree_status_t iree_vm_list_push_variant_move(
+ iree_vm_list_t* list, iree_vm_variant_t* variant) {
+ return iree_vm_list_push_variant(list, /*is_move=*/true, variant);
}
iree_status_t iree_vm_list_register_types(iree_vm_instance_t* instance) {
diff --git a/runtime/src/iree/vm/list.h b/runtime/src/iree/vm/list.h
index b7d2ba3..3b9b176 100644
--- a/runtime/src/iree/vm/list.h
+++ b/runtime/src/iree/vm/list.h
@@ -13,6 +13,7 @@
#include "iree/vm/ref.h"
#include "iree/vm/type_def.h"
#include "iree/vm/value.h"
+#include "iree/vm/variant.h"
#ifdef __cplusplus
extern "C" {
@@ -202,22 +203,48 @@
// a ref it will *not* be retained and the caller must retain it to extend its
// lifetime.
IREE_API_EXPORT iree_status_t
-iree_vm_list_get_variant(const iree_vm_list_t* list, iree_host_size_t i,
- iree_vm_variant_t* out_value);
+iree_vm_list_get_variant_assign(const iree_vm_list_t* list, iree_host_size_t i,
+ iree_vm_variant_t* out_variant);
+
+// Returns the value of the element at the given index.
+// If the variant is a ref then it will be retained.
+IREE_API_EXPORT iree_status_t
+iree_vm_list_get_variant_retain(const iree_vm_list_t* list, iree_host_size_t i,
+ iree_vm_variant_t* out_variant);
+
+// Returns the value of the element at the given index.
+// If the variant is a ref then it will be moved.
+IREE_API_EXPORT iree_status_t
+iree_vm_list_get_variant_move(const iree_vm_list_t* list, iree_host_size_t i,
+ iree_vm_variant_t* out_variant);
// Sets the value of the element at the given index. If the specified |value|
// type differs from the list storage type the value will be converted using the
// value type semantics (such as sign/zero extend, etc). If the variant is a ref
// then it will be retained.
-IREE_API_EXPORT iree_status_t iree_vm_list_set_variant(
- iree_vm_list_t* list, iree_host_size_t i, const iree_vm_variant_t* value);
+IREE_API_EXPORT iree_status_t iree_vm_list_set_variant_retain(
+ iree_vm_list_t* list, iree_host_size_t i, const iree_vm_variant_t* variant);
+
+// Sets the value of the element at the given index. If the specified |value|
+// type differs from the list storage type the value will be converted using the
+// value type semantics (such as sign/zero extend, etc). If the variant is a ref
+// then it will be moved.
+IREE_API_EXPORT iree_status_t iree_vm_list_set_variant_move(
+ iree_vm_list_t* list, iree_host_size_t i, iree_vm_variant_t* variant);
// Pushes the value of the element to the end of the list. If the specified
-// |value| type differs from the list storage type the value will be converted
+// |variant| type differs from the list storage type the value will be converted
// using the value type semantics (such as sign/zero extend, etc). If the
// variant is a ref then it will be retained.
-IREE_API_EXPORT iree_status_t
-iree_vm_list_push_variant(iree_vm_list_t* list, const iree_vm_variant_t* value);
+IREE_API_EXPORT iree_status_t iree_vm_list_push_variant_retain(
+ iree_vm_list_t* list, const iree_vm_variant_t* variant);
+
+// Pushes the value of the element to the end of the list. If the specified
+// |variant| type differs from the list storage type the value will be converted
+// using the value type semantics (such as sign/zero extend, etc). If the
+// variant is a ref then it will be moved.
+IREE_API_EXPORT iree_status_t iree_vm_list_push_variant_move(
+ iree_vm_list_t* list, iree_vm_variant_t* variant);
#ifdef __cplusplus
} // extern "C"
diff --git a/runtime/src/iree/vm/list_test.cc b/runtime/src/iree/vm/list_test.cc
index f20bc06..a18c4f5 100644
--- a/runtime/src/iree/vm/list_test.cc
+++ b/runtime/src/iree/vm/list_test.cc
@@ -101,7 +101,7 @@
result.resize(iree_vm_list_size(list));
for (iree_host_size_t i = 0; i < result.size(); ++i) {
iree_vm_variant_t variant = iree_vm_variant_empty();
- IREE_CHECK_OK(iree_vm_list_get_variant(list, i, &variant));
+ IREE_CHECK_OK(iree_vm_list_get_variant_assign(list, i, &variant));
if (iree_vm_type_def_is_value(&variant.type)) {
result[i].type = variant.type.value_type;
memcpy(result[i].value_storage, variant.value_storage,
@@ -622,7 +622,7 @@
IREE_ASSERT_OK(iree_vm_list_resize(list, 5));
for (iree_host_size_t i = 0; i < 5; ++i) {
iree_vm_variant_t value = iree_vm_variant_empty();
- IREE_ASSERT_OK(iree_vm_list_get_variant(list, i, &value));
+ IREE_ASSERT_OK(iree_vm_list_get_variant_assign(list, i, &value));
EXPECT_TRUE(iree_vm_variant_is_empty(value));
}
@@ -655,7 +655,7 @@
}
for (iree_host_size_t i = 2; i < 5; ++i) {
iree_vm_variant_t value = iree_vm_variant_empty();
- IREE_ASSERT_OK(iree_vm_list_get_variant(list, i, &value));
+ IREE_ASSERT_OK(iree_vm_list_get_variant_assign(list, i, &value));
EXPECT_TRUE(iree_vm_variant_is_empty(value));
}
diff --git a/runtime/src/iree/vm/type_def.h b/runtime/src/iree/vm/type_def.h
index d8cc8b6..3de996c 100644
--- a/runtime/src/iree/vm/type_def.h
+++ b/runtime/src/iree/vm/type_def.h
@@ -58,32 +58,6 @@
((v)->value_type == IREE_VM_VALUE_TYPE_NONE && \
(v)->ref_type == IREE_VM_REF_TYPE_NULL)
-// An variant value that can be either a primitive value type or a ref type.
-// Each variant value stores its type but users are required to check the type
-// prior to accessing any of the data.
-typedef struct iree_vm_variant_t {
- iree_vm_type_def_t type;
- union {
- // TODO(benvanik): replace with iree_vm_value_t.
- int8_t i8;
- int16_t i16;
- int32_t i32;
- int64_t i64;
- float f32;
- double f64;
- iree_vm_ref_t ref;
-
- uint8_t value_storage[IREE_VM_VALUE_STORAGE_SIZE]; // max size of all value
- // types
- };
-} iree_vm_variant_t;
-
-#define iree_vm_variant_empty() \
- { {IREE_VM_VALUE_TYPE_NONE, IREE_VM_REF_TYPE_NULL}, {0}, }
-#define iree_vm_variant_is_value(v) iree_vm_type_def_is_value(&(v).type)
-#define iree_vm_variant_is_ref(v) iree_vm_type_def_is_ref(&(v).type)
-#define iree_vm_variant_is_empty(v) iree_vm_type_def_is_variant(&(v).type)
-
#ifdef __cplusplus
} // extern "C"
#endif // __cplusplus
diff --git a/runtime/src/iree/vm/variant.h b/runtime/src/iree/vm/variant.h
new file mode 100644
index 0000000..dfff9ed
--- /dev/null
+++ b/runtime/src/iree/vm/variant.h
@@ -0,0 +1,103 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_VM_VARIANT_H_
+#define IREE_VM_VARIANT_H_
+
+#include "iree/vm/ref.h"
+#include "iree/vm/value.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif // __cplusplus
+
+// An variant value that can be either a primitive value type or a ref type.
+// Each variant value stores its type but users are required to check the type
+// prior to accessing any of the data.
+typedef struct iree_vm_variant_t {
+ iree_vm_type_def_t type;
+ union {
+ // TODO(benvanik): replace with iree_vm_value_t. Don't want to pay for 2x
+ // the type storage, though.
+ int8_t i8;
+ int16_t i16;
+ int32_t i32;
+ int64_t i64;
+ float f32;
+ double f64;
+ iree_vm_ref_t ref;
+ uint8_t value_storage[IREE_VM_VALUE_STORAGE_SIZE]; // max size of all value
+ // types
+ };
+} iree_vm_variant_t;
+
+// Returns an empty variant.
+static inline iree_vm_variant_t iree_vm_variant_empty(void) {
+ iree_vm_variant_t result;
+ result.type = iree_vm_type_def_make_variant_type();
+ result.ref = iree_vm_ref_null();
+ return result;
+}
+
+// Returns true if |variant| is empty (no value/NULL ref).
+static inline bool iree_vm_variant_is_empty(iree_vm_variant_t variant) {
+ return iree_vm_type_def_is_variant(&variant.type);
+}
+
+// Returns true if |variant| represents a primitive value.
+static inline bool iree_vm_variant_is_value(iree_vm_variant_t variant) {
+ return iree_vm_type_def_is_value(&variant.type);
+}
+
+// Returns true if |variant| represents a non-NULL ref type.
+static inline bool iree_vm_variant_is_ref(iree_vm_variant_t variant) {
+ return iree_vm_type_def_is_ref(&variant.type);
+}
+
+// Makes a variant containing the given primitive |value|.
+static inline iree_vm_variant_t iree_vm_make_variant_value(
+ iree_vm_value_t value) {
+ iree_vm_variant_t result = iree_vm_variant_empty();
+ result.type.value_type = value.type;
+ memcpy(result.value_storage, value.value_storage,
+ sizeof(result.value_storage));
+ return result;
+}
+
+// Makes a variant containing the given |ref| type with assignment semantics.
+static inline iree_vm_variant_t iree_vm_make_variant_ref_assign(
+ iree_vm_ref_t ref) {
+ iree_vm_variant_t result = iree_vm_variant_empty();
+ result.type.ref_type = ref.type;
+ result.ref = ref;
+ return result;
+}
+
+// Returns the primitive value contained within |variant|, if any.
+// If the variant is not a value type the return will be the same as
+// iree_vm_value_make_none.
+static inline iree_vm_value_t iree_vm_variant_value(iree_vm_variant_t variant) {
+ iree_vm_value_t value;
+ value.type = variant.type.value_type;
+ memcpy(value.value_storage, variant.value_storage,
+ sizeof(value.value_storage));
+ return value;
+}
+
+// Resets |variant| to empty in-place and releases the contained ref, if set.
+static inline void iree_vm_variant_reset(iree_vm_variant_t* variant) {
+ if (!variant) return;
+ if (iree_vm_variant_is_ref(*variant)) {
+ iree_vm_ref_release(&variant->ref);
+ }
+ *variant = iree_vm_variant_empty();
+}
+
+#ifdef __cplusplus
+} // extern "C"
+#endif // __cplusplus
+
+#endif // IREE_VM_VARIANT_H_
diff --git a/tools/BUILD b/tools/BUILD
index c21ee68..3a37662 100644
--- a/tools/BUILD
+++ b/tools/BUILD
@@ -183,6 +183,7 @@
"//runtime/src/iree/hal",
"//runtime/src/iree/tooling:device_util",
"//runtime/src/iree/tooling:trace_replay",
+ "//runtime/src/iree/tooling:vm_util",
"//runtime/src/iree/tooling:yaml_util",
"//runtime/src/iree/vm",
"@com_github_yaml_libyaml//:yaml",
diff --git a/tools/CMakeLists.txt b/tools/CMakeLists.txt
index aac88a3..50e56d1 100644
--- a/tools/CMakeLists.txt
+++ b/tools/CMakeLists.txt
@@ -161,6 +161,7 @@
iree::modules::hal
iree::tooling::device_util
iree::tooling::trace_replay
+ iree::tooling::vm_util
iree::tooling::yaml_util
iree::vm
yaml
diff --git a/tools/android/run_module_app/src/main.cc b/tools/android/run_module_app/src/main.cc
index be97e89..739c3fe 100644
--- a/tools/android/run_module_app/src/main.cc
+++ b/tools/android/run_module_app/src/main.cc
@@ -157,10 +157,10 @@
iree_string_builder_t result_str;
iree_string_builder_initialize(iree_allocator_system(), &result_str);
- IREE_RETURN_IF_ERROR(
- iree_tooling_append_variant_list_lines(
- outputs.get(), /*max_element_count=*/1024, &result_str),
- "printing results");
+ IREE_RETURN_IF_ERROR(iree_tooling_append_variant_list_lines(
+ IREE_SV("result"), outputs.get(),
+ /*max_element_count=*/1024, &result_str),
+ "printing results");
LOGI("Execution Result:");
LOGI("%.*s", (int)iree_string_builder_size(&result_str),
iree_string_builder_buffer(&result_str));
diff --git a/tools/iree-e2e-matmul-test.c b/tools/iree-e2e-matmul-test.c
index e77a89c..3615249 100644
--- a/tools/iree-e2e-matmul-test.c
+++ b/tools/iree-e2e-matmul-test.c
@@ -166,7 +166,7 @@
iree_vm_list_t* list, iree_host_size_t i,
iree_hal_buffer_view_t** out_value) {
iree_vm_variant_t variant = iree_vm_variant_empty();
- IREE_RETURN_IF_ERROR(iree_vm_list_get_variant(list, i, &variant));
+ IREE_RETURN_IF_ERROR(iree_vm_list_get_variant_assign(list, i, &variant));
if (!iree_vm_variant_is_ref(variant)) {
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"expected list item %zu to be a ref", i);
diff --git a/tools/iree-run-mlir-main.cc b/tools/iree-run-mlir-main.cc
index d7830fb..9aa2e37 100644
--- a/tools/iree-run-mlir-main.cc
+++ b/tools/iree-run-mlir-main.cc
@@ -385,8 +385,8 @@
if (FLAG_output_list().count == 0) {
IREE_RETURN_IF_ERROR(
iree_tooling_variant_list_fprint(
- outputs.get(), (iree_host_size_t)FLAG_output_max_element_count,
- stdout),
+ IREE_SV("result"), outputs.get(),
+ (iree_host_size_t)FLAG_output_max_element_count, stdout),
"printing results");
} else {
IREE_RETURN_IF_ERROR(
diff --git a/tools/iree-run-module-main.cc b/tools/iree-run-module-main.cc
index 6063e86..7f5a439 100644
--- a/tools/iree-run-module-main.cc
+++ b/tools/iree-run-module-main.cc
@@ -158,8 +158,8 @@
if (FLAG_output_list().count == 0) {
IREE_RETURN_IF_ERROR(
iree_tooling_variant_list_fprint(
- outputs.get(), (iree_host_size_t)FLAG_output_max_element_count,
- stdout),
+ IREE_SV("result"), outputs.get(),
+ (iree_host_size_t)FLAG_output_max_element_count, stdout),
"printing results");
} else {
IREE_RETURN_IF_ERROR(
diff --git a/tools/iree-run-trace-main.c b/tools/iree-run-trace-main.c
index 57a2053..9c01f06 100644
--- a/tools/iree-run-trace-main.c
+++ b/tools/iree-run-trace-main.c
@@ -14,6 +14,7 @@
#include "iree/hal/api.h"
#include "iree/tooling/device_util.h"
#include "iree/tooling/trace_replay.h"
+#include "iree/tooling/vm_util.h"
#include "iree/tooling/yaml_util.h"
#include "iree/vm/api.h"
@@ -22,6 +23,102 @@
IREE_FLAG(bool, print_statistics, false,
"Prints runtime statistics to stderr on exit.");
+IREE_FLAG(bool, print_calls, false, "Prints all I/O for each call to stdout.");
+IREE_FLAG(bool, print_call_inputs, false,
+ "Prints all inputs for each call before they are made to stdout.");
+IREE_FLAG(bool, print_call_outputs, false,
+ "Prints all outputs for each call after they are made to stdout.");
+
+IREE_FLAG_LIST(
+ string, input,
+ "An input (a) value or (b) buffer of the format:\n"
+ " (a) scalar value\n"
+ " value\n"
+ " e.g.: --input=\"3.14\"\n"
+ " (b) buffer:\n"
+ " [shape]xtype=[value]\n"
+ " e.g.: --input=\"2x2xi32=1 2 3 4\"\n"
+ "Optionally, brackets may be used to separate the element values:\n"
+ " 2x2xi32=[[1 2][3 4]]\n"
+ "Raw binary files can be read to provide buffer contents:\n"
+ " 2x2xi32=@some/file.bin\n"
+ "\n"
+ "Numpy npy files from numpy.save can be read to provide 1+ values:\n"
+ " @some.npy\n"
+ "\n"
+ "Each occurrence of the flag indicates an input in the order they were\n"
+ "specified on the command line.");
+
+IREE_FLAG_LIST(
+ string, output,
+ "Specifies how to handle an output from the invocation:\n"
+ " `` (empty): ignore output\n"
+ " e.g.: --output=\n"
+ " `-`: print textual form to stdout\n"
+ " e.g.: --output=-\n"
+ " `@file.npy`: create/overwrite a numpy npy file and write buffer view\n"
+ " e.g.: --output=@file.npy\n"
+ " `+file.npy`: create/append a numpy npy file and write buffer view\n"
+ " e.g.: --output=+file.npy\n"
+ "\n"
+ "Numpy npy files can be read in Python using numpy.load, for example an\n"
+ "invocation producing two outputs can be concatenated as:\n"
+ " --output=@file.npy --output=+file.npy\n"
+ "And then loaded in Python by reading from the same file:\n"
+ " with open('file.npy', 'rb') as f:\n"
+ " print(numpy.load(f))\n"
+ " print(numpy.load(f))\n"
+ "\n"
+ "Each occurrence of the flag indicates an output in the order they were\n"
+ "specified on the command line.");
+
+IREE_FLAG_LIST(string, expected_output,
+ "An expected function output following the same format as "
+ "--input. When present the results of the "
+ "invocation will be compared against these values and the "
+ "tool will return non-zero if any differ. If the value of a "
+ "particular output is not of interest provide `(ignored)`.");
+
+IREE_FLAG(int32_t, output_max_element_count, 1024,
+ "Prints up to the maximum number of elements of output tensors, "
+ "eliding the remainder.");
+
+static iree_status_t iree_trace_replay_call_before(void* user_data,
+ iree_trace_replay_t* replay,
+ yaml_document_t* document,
+ yaml_node_t* event_node,
+ iree_vm_function_t function,
+ iree_vm_list_t* input_list) {
+ if (FLAG_print_calls || FLAG_print_call_inputs) {
+ iree_string_view_t function_name = iree_vm_function_name(&function);
+ fprintf(stdout, "--- CALL[%.*s] ---\n", (int)function_name.size,
+ function_name.data);
+ IREE_RETURN_IF_ERROR(iree_tooling_variant_list_fprint(
+ IREE_SV("arg"), input_list,
+ (iree_host_size_t)FLAG_output_max_element_count, stdout));
+ }
+ return iree_ok_status();
+}
+
+static iree_status_t iree_trace_replay_call_after(void* user_data,
+ iree_trace_replay_t* replay,
+ yaml_document_t* document,
+ yaml_node_t* event_node,
+ iree_vm_function_t function,
+ iree_vm_list_t* output_list) {
+ if (FLAG_print_calls || FLAG_print_call_outputs) {
+ if (!FLAG_print_calls && !FLAG_print_call_inputs) {
+ iree_string_view_t function_name = iree_vm_function_name(&function);
+ fprintf(stdout, "--- CALL[%.*s] ---\n", (int)function_name.size,
+ function_name.data);
+ }
+ IREE_RETURN_IF_ERROR(iree_tooling_variant_list_fprint(
+ IREE_SV("result"), output_list,
+ (iree_host_size_t)FLAG_output_max_element_count, stdout));
+ }
+ return iree_ok_status();
+}
+
// Runs the trace in |file| using |root_path| as the base for any path lookups
// required for external files referenced in |file|.
static iree_status_t iree_run_trace_file(iree_string_view_t root_path,
@@ -34,6 +131,11 @@
: IREE_VM_CONTEXT_FLAG_NONE,
iree_hal_available_driver_registry(), iree_allocator_system(), &replay));
+ // Hook into all calls processed during the trace.
+ replay.call_hooks.user_data = NULL;
+ replay.call_hooks.before = iree_trace_replay_call_before;
+ replay.call_hooks.after = iree_trace_replay_call_after;
+
// Query device overrides, if any. When omitted the devices from the trace
// file will be used.
// TODO(#5724): remove this and instead provide a device set on initialize.
@@ -51,24 +153,59 @@
}
yaml_parser_set_input_file(&parser, file);
+ bool have_parsed_inputs = false;
iree_status_t status = iree_ok_status();
for (bool document_eof = false; !document_eof;) {
+ // Parse the subdocument event.
yaml_document_t document;
if (!yaml_parser_load(&parser, &document)) {
status = iree_status_from_yaml_parser_error(&parser);
break;
}
+
+ // Execute the event or handle EOF (empty document).
yaml_node_t* event_node = yaml_document_get_root_node(&document);
if (event_node) {
status = iree_trace_replay_event(&replay, &document, event_node);
} else {
document_eof = true;
}
+
+ // Reclaim subdocument resources before moving on to the next.
yaml_document_delete(&document);
if (!iree_status_is_ok(status)) break;
+
+ // If the event created a device and we haven't yet performed our input
+ // loading we can do that now before processing subsequent events.
+ if (!have_parsed_inputs && replay.device) {
+ status = iree_tooling_parse_into_variant_list(
+ iree_hal_device_allocator(replay.device), FLAG_input_list().values,
+ FLAG_input_list().count, replay.host_allocator, replay.inputs);
+ have_parsed_inputs = true;
+ }
+ if (!iree_status_is_ok(status)) break;
}
yaml_parser_delete(&parser);
+
+ // Optionally process outputs from the replay session.
+ if (iree_status_is_ok(status)) {
+ if (FLAG_output_list().count == 0) {
+ IREE_RETURN_IF_ERROR(
+ iree_tooling_variant_list_fprint(
+ IREE_SV("output"), replay.outputs,
+ (iree_host_size_t)FLAG_output_max_element_count, stdout),
+ "printing results");
+ } else {
+ IREE_RETURN_IF_ERROR(
+ iree_tooling_output_variant_list(
+ replay.outputs, FLAG_output_list().values,
+ FLAG_output_list().count,
+ (iree_host_size_t)FLAG_output_max_element_count, stdout),
+ "outputting results");
+ }
+ }
+
iree_trace_replay_deinitialize(
&replay, FLAG_print_statistics
? IREE_TRACE_REPLAY_SHUTDOWN_PRINT_STATISTICS
@@ -97,6 +234,122 @@
}
int main(int argc, char** argv) {
+ iree_flags_set_usage(
+ "iree-run-trace",
+ "Executes a YAML trace file containing a sequence of context operations\n"
+ "and calls represented as subdocuments.\n"
+ "\n"
+ "Example loading a bytecode module and calling a function:\n"
+ "\n"
+ "```yaml\n"
+ "type: context_load\n"
+ "---\n"
+ "type: module_load\n"
+ "module:\n"
+ " type: buildin\n"
+ " name: hal\n"
+ "---\n"
+ "type: module_load\n"
+ "module:\n"
+ " type: bytecode\n"
+ " path: ../build/some_module.vmfb\n"
+ "---\n"
+ "type: call\n"
+ "function: module.mul\n"
+ "args:\n"
+ "- !input.take 0\n"
+ "- !input.take 1\n"
+ "results:\n"
+ "- !output.push\n"
+ "- !output.push\n"
+ "```\n"
+ "\n"
+ "This can be invoked like iree-run-module specifying inputs/outputs:\n"
+ " iree-run-trace trace.yml \\\n"
+ " --device=local-sync \\\n"
+ " --input=4xf32=0,1,2,3,4 \\\n"
+ " --input=@input1.npy \\\n"
+ " --output=@outputs.npy \\\n"
+ " --output=+outputs.npy\n"
+ "\n"
+ "In addition to `--input=`/`--output=` flag access a user-defined\n"
+ "blackboard exists for preserving temporary values used within the\n"
+ "trace. Blackboard slots are defined by ordinal and they can be used\n"
+ "in any context and input/output can be, `!blackboard.get` instead of\n"
+ "`!input.get` and `!blackboard.set` instead of `!output.set`.\n"
+ "\n"
+ "--- Events ---\n"
+ "\n"
+ "`type: context_load`\n"
+ "Loads an empty VM context with no modules registered.\n"
+ "\n"
+ "`type: module_load`\n"
+ "Loads a module into the current context. Modules may either be\n"
+ "`builtin` (compiled into the binary) or dynamically-loaded `bytecode`.\n"
+ "\n"
+ "`type: blackboard_clear`\n"
+ "Clears the contents of the blackboard and resets it to 0 elements.\n"
+ "\n"
+ "`type: assign`\n"
+ "Assigns sources from a `from` sequence to targets in a `to` sequence.\n"
+ "Equivalent to an identity function call and can be used to move\n"
+ "between inputs, outputs, and the blackboard.\n"
+ "\n"
+ "`type: numpy_load`\n"
+ "Loads one or more ndarrays from a .npy value. Each array has a target\n"
+ "where the array will be retained such as `!blackboard.set 2`.\n"
+ "\n"
+ "`type: numpy_save\n"
+ "Saves one or more ndarrays to a .npy value. Each array has a source\n"
+ "where the array will be taken from such as `!blackboard.get 2`.\n"
+ "\n"
+ "`type: call`\n"
+ "Invokes a function in the context by fully-qualified `function` name.\n"
+ "Uses arguments from an `args` sequence and produces results into a\n"
+ "`results` sequence.\n"
+ "\n"
+ "--- Sources ---\n"
+ "\n"
+ "`type: null`\n"
+ "A null ref value.\n"
+ "\n"
+ "`!hal.buffer_view 4xf32=0,1,2,3`\n"
+ "A constant iree_hal_buffer_view_t/!hal.buffer_view value using the\n"
+ "same formatting as iree-run-module's `--input=` flag.\n"
+ "\n"
+ "`!hal.buffer 4xf32=0,1,2,3`\n"
+ "An initialized iree_hal_buffer_t/!hal.buffer without the wrapping view\n"
+ "metadata.\n"
+ "\n"
+ "`!input.get ORDINAL` / `!input.take ORDINAL`\n"
+ "Returns a reference to `--input=` flag at ORDINAL. Note that a single\n"
+ "npy file may expand to multiple inputs. The `take` variant transfers\n"
+ "ownership and clears the slot in the list and is recommended to avoid\n"
+ "keeping unneeded inputs around for the duration of the trace.\n"
+ "\n"
+ "`!output.get ORDINAL` / `!output.take ORDINAL`\n"
+ "Returns a reference to the `--output=` flag at ORDINAL. These are\n"
+ "initially empty until assigned by the trace.\n"
+ "\n"
+ "`!blackboard.get ORDINAL` / `!blackboard.take ORDINAL`\n"
+ "Returns a reference to the blackboard slot ORDINAL. The blackboard is\n"
+ "initially empty and slots must be assigned in order to define them.\n"
+ "The `take` variant transfers ownership and clears the slot in the\n"
+ "blackboard and is recommended to avoid keeping large resources live\n"
+ "in the blackboard longer than they need to be.\n"
+ "\n"
+ "--- Targets ---\n"
+ "\n"
+ "`!output.set ORDINAL` / `!output.push`\n"
+ "Sets the `--output=` flag result value at ORDINAL or pushes it to the\n"
+ "back of the output list. Outputs can either be dumped to files or by\n"
+ "default printed to stdout.\n"
+ "\n"
+ "`!blackboard.set ORDINAL` / `blackboard.push`\n"
+ "Sets the value of the blackboard slot ORDINAL or pushes it to the back\n"
+ "of the blackboard list. Blackboard values will be retained until they\n"
+ "are consumed via `!blackboard.take` or the blackboard is cleared.\n"
+ "\n");
iree_flags_parse_checked(IREE_FLAGS_PARSE_MODE_DEFAULT, &argc, &argv);
if (argc <= 1) {
fprintf(stderr,
diff --git a/tools/test/BUILD b/tools/test/BUILD
index f7bf4f8..af0a150 100644
--- a/tools/test/BUILD
+++ b/tools/test/BUILD
@@ -28,6 +28,7 @@
"iree-run-module.mlir",
"iree-run-module-expected.mlir",
"iree-run-module-outputs.mlir",
+ "iree-run-trace.mlir",
"multiple_args.mlir",
"multiple_exported_functions.mlir",
"null_values.mlir",
@@ -39,6 +40,7 @@
cfg = "//tools:lit.cfg.py",
data = [
"echo_npy.py",
+ "iree-run-trace.yml",
],
tags = [
"driver=local-task",
@@ -51,6 +53,7 @@
"//tools:iree-opt",
"//tools:iree-run-mlir",
"//tools:iree-run-module",
+ "//tools:iree-run-trace",
"@llvm-project//lld",
"@llvm-project//llvm:FileCheck",
"@llvm-project//llvm:not",
diff --git a/tools/test/CMakeLists.txt b/tools/test/CMakeLists.txt
index 7c92d9d..305345a 100644
--- a/tools/test/CMakeLists.txt
+++ b/tools/test/CMakeLists.txt
@@ -24,6 +24,7 @@
"iree-run-module-expected.mlir"
"iree-run-module-outputs.mlir"
"iree-run-module.mlir"
+ "iree-run-trace.mlir"
"multiple_args.mlir"
"multiple_exported_functions.mlir"
"null_values.mlir"
@@ -37,9 +38,11 @@
iree-opt
iree-run-mlir
iree-run-module
+ iree-run-trace
not
DATA
echo_npy.py
+ iree-run-trace.yml
LABELS
"driver=local-task"
"driver=vulkan"
diff --git a/tools/test/iree-run-trace.mlir b/tools/test/iree-run-trace.mlir
new file mode 100644
index 0000000..ad898e7
--- /dev/null
+++ b/tools/test/iree-run-trace.mlir
@@ -0,0 +1,22 @@
+// Tests iree-run-trace usage by running two calls of @mul and passing the
+// result between them. The outputs of both calls are produced as outputs from
+// the trace and both are written to a .npy file for processing. Inputs can
+// also come from an .npy file. See iree-run-module usage for more information
+// on the `--input=` and `--output=` flags.
+
+// RUN: (iree-compile --iree-hal-target-backends=vmvx %s | \
+// RUN: iree-run-trace %S/iree-run-trace.yml \
+// RUN: --device=local-sync \
+// RUN: --input=4xf32=4,4,4,4 \
+// RUN: --output=@%t \
+// RUN: --output=+%t) && \
+// RUN: python3 %S/echo_npy.py %t | \
+// RUN: FileCheck %s
+
+// CHECK{LITERAL}: [ 0. 4. 8. 12.]
+// CHECK-NEXT{LITERAL}: [ 0. 12. 24. 36.]
+
+func.func @mul(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
+ %0 = arith.mulf %arg0, %arg1 : tensor<4xf32>
+ return %0 : tensor<4xf32>
+}
diff --git a/tools/test/iree-run-trace.yml b/tools/test/iree-run-trace.yml
new file mode 100644
index 0000000..b267dc2
--- /dev/null
+++ b/tools/test/iree-run-trace.yml
@@ -0,0 +1,80 @@
+# Copyright 2023 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+# Tests loading and executing a bytecode module and issuing a few calls showing
+# how to take input, produce output, and support temporary values within the
+# trace session. See iree-run-trace.mlir for how to compile the module and
+# invoke the iree-run-trace tool.
+
+# Prepare the VM context for use; effectively a reset.
+# API: iree_vm_context_create
+type: context_load
+
+---
+
+# Load the builtin HAL module used to execute the program.
+# API: iree_hal_module_create
+type: module_load
+module:
+ type: builtin
+ name: hal
+
+---
+
+# Load the compiled bytecode module.
+# API: iree_vm_bytecode_module_create
+type: module_load
+module:
+ type: bytecode
+ name: module
+ # The test pulls the .vmfb from stdin but you can also reference relative or
+ # absolute file paths:
+ # path: ../iree-tmp/iree-run-trace.vmfb
+ path: <stdin>
+
+---
+
+# Call #0 of @mul.
+# API: iree_vm_invoke
+type: call
+function: module.mul
+args:
+# arg[0]: the first `--input=` buffer. !input.get would retain the input for
+# other calls to use but otherwise prefer taking ownership.
+- !input.take 0
+# arg[1]: constant value defined inline.
+- !hal.buffer_view 4xf32=0,1,2,3
+results:
+# result[0]: store in blackboard slot 4 for later use.
+- !blackboard.set 4
+
+---
+
+# Assigns one or more source values to a set of target values.
+# Effectively: outputs.push(retain(blackboard[4]))
+type: assign
+from:
+# from[0]: retain blackboard slot 4, leaving it for later use.
+- !blackboard.get 4
+to:
+# to[0]: push on to the trace output list. --output= can save off the results
+# and otherwise they are printed to stdout.
+- !output.push
+
+---
+
+# Call #1 of @mul.
+# API: iree_vm_invoke
+type: call
+function: module.mul
+args:
+# arg[0]: take the previously-stored value in blackboard slot 4.
+- !blackboard.take 4
+# arg[1]: another constant.
+- !hal.buffer_view 4xf32=3,3,3,3
+results:
+# result[0]: push on to the trace output list.
+- !output.push