Adding `numpy_load`/`numpy_save` trace events.
These allow for easily loading/storing from/to the blackboard/outputs
within the trace as opposed to the command line flags.
diff --git a/runtime/src/iree/tooling/BUILD b/runtime/src/iree/tooling/BUILD
index 18182a8..84e1a12 100644
--- a/runtime/src/iree/tooling/BUILD
+++ b/runtime/src/iree/tooling/BUILD
@@ -193,6 +193,7 @@
srcs = ["trace_replay.c"],
hdrs = ["trace_replay.h"],
deps = [
+ ":numpy_io",
":yaml_util",
"//runtime/src/iree/base",
"//runtime/src/iree/base:tracing",
diff --git a/runtime/src/iree/tooling/CMakeLists.txt b/runtime/src/iree/tooling/CMakeLists.txt
index 5e41789..b3f7284 100644
--- a/runtime/src/iree/tooling/CMakeLists.txt
+++ b/runtime/src/iree/tooling/CMakeLists.txt
@@ -214,6 +214,7 @@
SRCS
"trace_replay.c"
DEPS
+ ::numpy_io
::yaml_util
iree::base
iree::base::internal
diff --git a/runtime/src/iree/tooling/trace_replay.c b/runtime/src/iree/tooling/trace_replay.c
index 93eadba..ea5a921 100644
--- a/runtime/src/iree/tooling/trace_replay.c
+++ b/runtime/src/iree/tooling/trace_replay.c
@@ -17,6 +17,7 @@
#include "iree/base/internal/path.h"
#include "iree/base/tracing.h"
#include "iree/modules/hal/module.h"
+#include "iree/tooling/numpy_io.h"
#include "iree/vm/bytecode/module.h"
//===----------------------------------------------------------------------===//
@@ -1307,6 +1308,177 @@
}
//===----------------------------------------------------------------------===//
+// Numpy ndarray management
+//===----------------------------------------------------------------------===//
+
+// Loads one or more ndarrays from a .npy file.
+//
+// Example:
+// ```yaml
+// type: numpy_load
+// path: three_ndarrays.npy
+// arrays:
+// - !blackboard.set 2
+// - !blackboard.set 3
+// - !output.set 4
+// ```
+//
+// NOTE: this currently reads things into new buffers; we could try mapping from
+// disk and other fancy things where possible.
+static iree_status_t iree_trace_replay_event_numpy_load(
+ iree_trace_replay_t* replay, yaml_document_t* document,
+ yaml_node_t* event_node) {
+ if (!replay->device) {
+ return iree_make_status(
+ IREE_STATUS_FAILED_PRECONDITION,
+ "HAL module must be loaded before loading numpy arrays");
+ }
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ yaml_node_t* path_node = NULL;
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_yaml_mapping_find(document, event_node, IREE_SV("path"),
+ &path_node));
+ iree_string_view_t path_str = iree_yaml_node_as_string(path_node);
+
+ yaml_node_t* arrays_node = NULL;
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_yaml_mapping_find(document, event_node, IREE_SV("arrays"),
+ &arrays_node));
+
+ char* full_path = NULL;
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_file_path_join(replay->root_path, path_str,
+ replay->host_allocator, &full_path));
+ FILE* file = fopen(full_path, "rb");
+ iree_allocator_free(replay->host_allocator, full_path);
+ if (!file) {
+ IREE_TRACE_ZONE_END(z0);
+ return iree_make_status(iree_status_code_from_errno(errno),
+ "failed to open file `%.*s` for read",
+ (int)path_str.size, path_str.data);
+ }
+
+ uint64_t file_length = 0;
+ iree_status_t status = iree_file_query_length(file, &file_length);
+
+ iree_hal_buffer_params_t buffer_params = {0};
+ buffer_params.usage = IREE_HAL_BUFFER_USAGE_DEFAULT;
+ buffer_params.access = IREE_HAL_MEMORY_ACCESS_READ;
+ buffer_params.type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL;
+ iree_hal_allocator_t* device_allocator =
+ iree_hal_device_allocator(replay->device);
+
+ if (iree_status_is_ok(status)) {
+ for (yaml_node_item_t* item = arrays_node->data.sequence.items.start;
+ item != arrays_node->data.sequence.items.top; ++item) {
+ // Parse the next array in the file. Note that we may be at the end!
+ if (iree_file_is_at(file, file_length)) {
+ status = iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+ "file ended before all arrays were decoded");
+ break;
+ }
+ iree_hal_buffer_view_t* buffer_view = NULL;
+ status = iree_numpy_npy_load_ndarray(
+ file, IREE_NUMPY_NPY_LOAD_OPTION_DEFAULT, buffer_params,
+ device_allocator, &buffer_view);
+ if (!iree_status_is_ok(status)) break;
+
+ // Route the loaded value to its destination.
+ iree_vm_variant_t variant = iree_vm_make_variant_ref_assign(
+ iree_hal_buffer_view_move_ref(buffer_view));
+ yaml_node_t* item_node = yaml_document_get_node(document, *item);
+ status = iree_trace_replay_parse_result_item(replay, document, item_node,
+ variant);
+ iree_vm_variant_reset(&variant);
+ if (!iree_status_is_ok(status)) break;
+ }
+ }
+
+ fclose(file);
+
+ IREE_TRACE_ZONE_END(z0);
+ return status;
+}
+
+// Saves one or more ndarrays to a .npy file.
+//
+// Example:
+// ```yaml
+// type: numpy_save
+// path: three_ndarrays.npy
+// append: true
+// arrays:
+// - !blackboard.get 2
+// - !blackboard.take 3
+// - !input.get 4
+// - !hal.buffer_view 4xf32=0,1,2,3
+// ```
+static iree_status_t iree_trace_replay_event_numpy_save(
+ iree_trace_replay_t* replay, yaml_document_t* document,
+ yaml_node_t* event_node) {
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ yaml_node_t* path_node = NULL;
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_yaml_mapping_find(document, event_node, IREE_SV("path"),
+ &path_node));
+ iree_string_view_t path_str = iree_yaml_node_as_string(path_node);
+
+ yaml_node_t* append_node = NULL;
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_yaml_mapping_try_find(document, event_node, IREE_SV("append"),
+ &append_node));
+
+ yaml_node_t* arrays_node = NULL;
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_yaml_mapping_find(document, event_node, IREE_SV("arrays"),
+ &arrays_node));
+
+ char* full_path = NULL;
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_file_path_join(replay->root_path, path_str,
+ replay->host_allocator, &full_path));
+ const char* mode = append_node ? "ab" : "wb";
+ FILE* file = fopen(full_path, mode);
+ iree_allocator_free(replay->host_allocator, full_path);
+ if (!file) {
+ IREE_TRACE_ZONE_END(z0);
+ return iree_make_status(iree_status_code_from_errno(errno),
+ "failed to open file `%.*s` for write",
+ (int)path_str.size, path_str.data);
+ }
+
+ iree_status_t status = iree_ok_status();
+ for (yaml_node_item_t* item = arrays_node->data.sequence.items.start;
+ item != arrays_node->data.sequence.items.top; ++item) {
+ yaml_node_t* item_node = yaml_document_get_node(document, *item);
+ iree_vm_variant_t variant = iree_vm_variant_empty();
+ status =
+ iree_trace_replay_parse_item(replay, document, item_node, &variant);
+ if (!iree_status_is_ok(status)) break;
+ if (!iree_vm_variant_is_ref(variant) ||
+ !iree_hal_buffer_view_isa(variant.ref)) {
+ status =
+ iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+ "only buffer views can be saved to numpy files");
+ break;
+ }
+ iree_hal_buffer_view_t* buffer_view =
+ iree_hal_buffer_view_deref(variant.ref);
+ status =
+ iree_numpy_npy_save_ndarray(file, IREE_NUMPY_NPY_SAVE_OPTION_DEFAULT,
+ buffer_view, replay->host_allocator);
+ iree_vm_variant_reset(&variant);
+ }
+
+ fclose(file);
+
+ IREE_TRACE_ZONE_END(z0);
+ return status;
+}
+
+//===----------------------------------------------------------------------===//
// Event dispatch
//===----------------------------------------------------------------------===//
@@ -1331,6 +1503,10 @@
} else if (iree_yaml_string_equal(type_node, IREE_SV("assign"))) {
return iree_trace_replay_event_blackboard_assign(replay, document,
event_node);
+ } else if (iree_yaml_string_equal(type_node, IREE_SV("numpy_load"))) {
+ return iree_trace_replay_event_numpy_load(replay, document, event_node);
+ } else if (iree_yaml_string_equal(type_node, IREE_SV("numpy_save"))) {
+ return iree_trace_replay_event_numpy_save(replay, document, event_node);
} else if (iree_yaml_string_equal(type_node, IREE_SV("call"))) {
return iree_trace_replay_event_call(replay, document, event_node,
&replay->call_hooks);
diff --git a/tools/iree-run-trace-main.c b/tools/iree-run-trace-main.c
index 56c66ca..9c01f06 100644
--- a/tools/iree-run-trace-main.c
+++ b/tools/iree-run-trace-main.c
@@ -295,6 +295,14 @@
"Equivalent to an identity function call and can be used to move\n"
"between inputs, outputs, and the blackboard.\n"
"\n"
+ "`type: numpy_load`\n"
+ "Loads one or more ndarrays from a .npy value. Each array has a target\n"
+ "where the array will be retained such as `!blackboard.set 2`.\n"
+ "\n"
+ "`type: numpy_save\n"
+ "Saves one or more ndarrays to a .npy value. Each array has a source\n"
+ "where the array will be taken from such as `!blackboard.get 2`.\n"
+ "\n"
"`type: call`\n"
"Invokes a function in the context by fully-qualified `function` name.\n"
"Uses arguments from an `args` sequence and produces results into a\n"