Adding IREE_FLAG_LIST utility for repeated string flags. (#11806)
Currently only supports strings but we can extend it in the future with
some macro magic.
Many of the tools were using C++ to implement repeated flags for things
like `--function_input=` and now they can just use the common base C
APIs instead of needing extra C++ goo. Though passing giant inputs in
flags is never going to be great we now avoid the alloc+copy for each
flag that was previously required to do this.
diff --git a/tools/BUILD b/tools/BUILD
index 2fd3a71..e4af67c 100644
--- a/tools/BUILD
+++ b/tools/BUILD
@@ -33,7 +33,7 @@
"//runtime/src/iree/modules/hal:types",
"//runtime/src/iree/tooling:context_util",
"//runtime/src/iree/tooling:device_util",
- "//runtime/src/iree/tooling:vm_util_cc",
+ "//runtime/src/iree/tooling:vm_util",
"//runtime/src/iree/vm",
"@com_google_benchmark//:benchmark",
],
@@ -73,7 +73,7 @@
"//runtime/src/iree/testing:gtest",
"//runtime/src/iree/tooling:context_util",
"//runtime/src/iree/tooling:device_util",
- "//runtime/src/iree/tooling:vm_util_cc",
+ "//runtime/src/iree/tooling:vm_util",
"//runtime/src/iree/vm",
"//runtime/src/iree/vm:bytecode_module",
],
@@ -129,7 +129,7 @@
"//runtime/src/iree/modules/hal:types",
"//runtime/src/iree/tooling:context_util",
"//runtime/src/iree/tooling:device_util",
- "//runtime/src/iree/tooling:vm_util_cc",
+ "//runtime/src/iree/tooling:vm_util",
"//runtime/src/iree/vm",
"//runtime/src/iree/vm:bytecode_module",
"@llvm-project//llvm:Support",
@@ -154,7 +154,7 @@
"//runtime/src/iree/tooling:comparison",
"//runtime/src/iree/tooling:context_util",
"//runtime/src/iree/tooling:device_util",
- "//runtime/src/iree/tooling:vm_util_cc",
+ "//runtime/src/iree/tooling:vm_util",
"//runtime/src/iree/vm",
],
)
diff --git a/tools/CMakeLists.txt b/tools/CMakeLists.txt
index e94fa4e..82f188e 100644
--- a/tools/CMakeLists.txt
+++ b/tools/CMakeLists.txt
@@ -53,7 +53,7 @@
iree::modules::hal::types
iree::tooling::context_util
iree::tooling::device_util
- iree::tooling::vm_util_cc
+ iree::tooling::vm_util
iree::vm
)
@@ -94,7 +94,7 @@
iree::testing::gtest
iree::tooling::context_util
iree::tooling::device_util
- iree::tooling::vm_util_cc
+ iree::tooling::vm_util
iree::vm
iree::vm::bytecode_module
TESTONLY
@@ -128,7 +128,7 @@
iree::tooling::comparison
iree::tooling::context_util
iree::tooling::device_util
- iree::tooling::vm_util_cc
+ iree::tooling::vm_util
iree::vm
)
@@ -254,7 +254,7 @@
iree::modules::hal::types
iree::tooling::context_util
iree::tooling::device_util
- iree::tooling::vm_util_cc
+ iree::tooling::vm_util
iree::vm
iree::vm::bytecode_module
DATA
diff --git a/tools/android/run_module_app/CMakeLists.txt b/tools/android/run_module_app/CMakeLists.txt
index dfb229d..4d36e80 100644
--- a/tools/android/run_module_app/CMakeLists.txt
+++ b/tools/android/run_module_app/CMakeLists.txt
@@ -34,7 +34,7 @@
iree::base
iree::modules::hal
iree::tooling::device_util
- iree::tooling::vm_util_cc
+ iree::tooling::vm_util
iree::vm
LINKOPTS
"-landroid"
diff --git a/tools/android/run_module_app/src/main.cc b/tools/android/run_module_app/src/main.cc
index dfdb135..193595e 100644
--- a/tools/android/run_module_app/src/main.cc
+++ b/tools/android/run_module_app/src/main.cc
@@ -12,11 +12,12 @@
#include <sstream>
#include <string>
#include <thread>
+#include <vector>
#include "iree/base/api.h"
#include "iree/modules/hal/module.h"
#include "iree/tooling/device_util.h"
-#include "iree/tooling/vm_util_cc.h"
+#include "iree/tooling/vm_util.h"
#include "iree/vm/api.h"
#include "iree/vm/bytecode_module.h"
@@ -130,17 +131,18 @@
&function),
"looking up function '%s'", function_name.c_str());
- size_t pos = 0;
- std::string inputs_str = invocation.inputs;
- std::vector<std::string> input_views;
- while ((pos = inputs_str.find('\n')) != std::string::npos) {
- input_views.push_back(inputs_str.substr(0, pos));
- inputs_str.erase(0, pos + 1);
+ std::vector<iree_string_view_t> input_views;
+ iree_string_view_t inputs_view =
+ iree_make_string_view(invocation.inputs.data(), invocation.inputs.size());
+ while (!iree_string_view_is_empty(inputs_view)) {
+ iree_string_view_t input_view = iree_string_view_empty();
+ iree_string_view_split(inputs_view, '\n', &input_view, &inputs_view);
+ input_views.push_back(input_view);
}
vm::ref<iree_vm_list_t> inputs;
- IREE_RETURN_IF_ERROR(ParseToVariantList(iree_hal_device_allocator(device),
- input_views, iree_allocator_system(),
- &inputs));
+ IREE_RETURN_IF_ERROR(iree_tooling_parse_to_variant_list(
+ iree_hal_device_allocator(device), input_views.data(), input_views.size(),
+ iree_allocator_system(), &inputs));
vm::ref<iree_vm_list_t> outputs;
IREE_RETURN_IF_ERROR(iree_vm_list_create(/*element_type=*/nullptr, 16,
@@ -153,11 +155,16 @@
iree_allocator_system()),
"invoking function '%s'", function_name.c_str());
- std::string result;
- IREE_RETURN_IF_ERROR(PrintVariantList(outputs.get(), &result),
- "printing results");
+ iree_string_builder_t result_str;
+ iree_string_builder_initialize(iree_allocator_system(), &result_str);
+ IREE_RETURN_IF_ERROR(
+ iree_tooling_append_variant_list_lines(
+ outputs.get(), /*max_element_count=*/1024, &result_str),
+ "printing results");
LOGI("Execution Result:");
- LOGI("%s", result.c_str());
+ LOGI("%.*s", (int)iree_string_builder_size(&result_str),
+ iree_string_builder_buffer(&result_str));
+ iree_string_builder_deinitialize(&result_str);
inputs.reset();
outputs.reset();
diff --git a/tools/iree-benchmark-module-main.cc b/tools/iree-benchmark-module-main.cc
index a7fc39b..e67bd5c 100644
--- a/tools/iree-benchmark-module-main.cc
+++ b/tools/iree-benchmark-module-main.cc
@@ -68,7 +68,7 @@
#include "iree/modules/hal/types.h"
#include "iree/tooling/context_util.h"
#include "iree/tooling/device_util.h"
-#include "iree/tooling/vm_util_cc.h"
+#include "iree/tooling/vm_util.h"
#include "iree/vm/api.h"
constexpr char kNanosecondsUnitString[] = "ns";
@@ -92,30 +92,8 @@
IREE_FLAG(bool, print_statistics, false,
"Prints runtime statistics to stderr on exit.");
-// TODO(benvanik): move --function_input= flag into a util.
-static iree_status_t parse_function_input(iree_string_view_t flag_name,
- void* storage,
- iree_string_view_t value) {
- auto* list = (std::vector<std::string>*)storage;
- list->push_back(std::string(value.data, value.size));
- return iree_ok_status();
-}
-static void print_function_input(iree_string_view_t flag_name, void* storage,
- FILE* file) {
- auto* list = (std::vector<std::string>*)storage;
- if (list->empty()) {
- fprintf(file, "# --%.*s=\n", (int)flag_name.size, flag_name.data);
- } else {
- for (size_t i = 0; i < list->size(); ++i) {
- fprintf(file, "--%.*s=\"%s\"\n", (int)flag_name.size, flag_name.data,
- list->at(i).c_str());
- }
- }
-}
-static std::vector<std::string> FLAG_function_inputs;
-IREE_FLAG_CALLBACK(
- parse_function_input, print_function_input, &FLAG_function_inputs,
- function_input,
+IREE_FLAG_LIST(
+ string, function_input,
"An input value or buffer of the format:\n"
" [shape]xtype=[value]\n"
" 2x2xi32=1 2 3 4\n"
@@ -499,10 +477,9 @@
iree_string_view_t{function_name.data(), function_name.size()},
&function));
- IREE_CHECK_OK(ParseToVariantList(
- device_allocator_.get(),
- iree::span<const std::string>{FLAG_function_inputs.data(),
- FLAG_function_inputs.size()},
+ IREE_CHECK_OK(iree_tooling_parse_to_variant_list(
+ device_allocator_.get(), FLAG_function_input_list().values,
+ FLAG_function_input_list().count,
iree_vm_instance_allocator(instance_.get()), &inputs_));
iree_string_view_t invocation_model = iree_vm_function_lookup_attr_by_name(
diff --git a/tools/iree-benchmark-trace-main.c b/tools/iree-benchmark-trace-main.c
index 266df43..1a34580 100644
--- a/tools/iree-benchmark-trace-main.c
+++ b/tools/iree-benchmark-trace-main.c
@@ -206,7 +206,7 @@
// file will be used.
// TODO(#5724): remove this and instead provide a device set on initialize.
iree_host_size_t device_uri_count = 0;
- iree_string_view_t* device_uris = NULL;
+ const iree_string_view_t* device_uris = NULL;
iree_hal_get_devices_flag_list(&device_uri_count, &device_uris);
iree_trace_replay_set_hal_devices_override(&replay, device_uri_count,
device_uris);
diff --git a/tools/iree-check-module-main.cc b/tools/iree-check-module-main.cc
index 38f53f6..2d5ba6a 100644
--- a/tools/iree-check-module-main.cc
+++ b/tools/iree-check-module-main.cc
@@ -22,7 +22,7 @@
#include "iree/testing/status_matchers.h"
#include "iree/tooling/context_util.h"
#include "iree/tooling/device_util.h"
-#include "iree/tooling/vm_util_cc.h"
+#include "iree/tooling/vm_util.h"
#include "iree/vm/api.h"
#include "iree/vm/bytecode_module.h"
diff --git a/tools/iree-e2e-matmul-test.c b/tools/iree-e2e-matmul-test.c
index a00f1ef..217d4e6 100644
--- a/tools/iree-e2e-matmul-test.c
+++ b/tools/iree-e2e-matmul-test.c
@@ -1165,7 +1165,7 @@
// file will be used.
// TODO(#5724): remove this and instead provide a device set on initialize.
iree_host_size_t device_uri_count = 0;
- iree_string_view_t* device_uris = NULL;
+ const iree_string_view_t* device_uris = NULL;
iree_hal_get_devices_flag_list(&device_uri_count, &device_uris);
iree_trace_replay_set_hal_devices_override(&replay, device_uri_count,
device_uris);
diff --git a/tools/iree-run-mlir-main.cc b/tools/iree-run-mlir-main.cc
index ba7ffee..aaee8ca 100644
--- a/tools/iree-run-mlir-main.cc
+++ b/tools/iree-run-mlir-main.cc
@@ -52,7 +52,7 @@
#include "iree/modules/hal/types.h"
#include "iree/tooling/context_util.h"
#include "iree/tooling/device_util.h"
-#include "iree/tooling/vm_util_cc.h"
+#include "iree/tooling/vm_util.h"
#include "iree/vm/api.h"
#include "iree/vm/bytecode_module.h"
#include "llvm/ADT/STLExtras.h"
@@ -138,30 +138,8 @@
llvm::cl::ConsumeAfter,
};
-// TODO(benvanik): move --function_input= flag into a util.
-static iree_status_t parse_function_input(iree_string_view_t flag_name,
- void* storage,
- iree_string_view_t value) {
- auto* list = (std::vector<std::string>*)storage;
- list->push_back(std::string(value.data, value.size));
- return iree_ok_status();
-}
-static void print_function_input(iree_string_view_t flag_name, void* storage,
- FILE* file) {
- auto* list = (std::vector<std::string>*)storage;
- if (list->empty()) {
- fprintf(file, "# --%.*s=\n", (int)flag_name.size, flag_name.data);
- } else {
- for (size_t i = 0; i < list->size(); ++i) {
- fprintf(file, "--%.*s=\"%s\"\n", (int)flag_name.size, flag_name.data,
- list->at(i).c_str());
- }
- }
-}
-static std::vector<std::string> FLAG_function_inputs;
-IREE_FLAG_CALLBACK(
- parse_function_input, print_function_input, &FLAG_function_inputs,
- function_input,
+IREE_FLAG_LIST(
+ string, function_input,
"An input value or buffer of the format:\n"
" [shape]xtype=[value]\n"
" 2x2xi32=1 2 3 4\n"
@@ -342,11 +320,9 @@
// Parse input values from the flags.
vm::ref<iree_vm_list_t> inputs;
- IREE_RETURN_IF_ERROR(ParseToVariantList(
- device_allocator,
- iree::span<const std::string>{FLAG_function_inputs.data(),
- FLAG_function_inputs.size()},
- host_allocator, &inputs));
+ IREE_RETURN_IF_ERROR(iree_tooling_parse_to_variant_list(
+ device_allocator, FLAG_function_input_list().values,
+ FLAG_function_input_list().count, host_allocator, &inputs));
// If the function is async add fences so we can invoke it synchronously.
vm::ref<iree_hal_fence_t> finish_fence;
@@ -370,7 +346,8 @@
}
// Print outputs.
- IREE_RETURN_IF_ERROR(PrintVariantList(outputs.get()));
+ IREE_RETURN_IF_ERROR(iree_tooling_variant_list_fprint(
+ outputs.get(), /*max_element_count=*/1024, stdout));
return OkStatus();
}
diff --git a/tools/iree-run-module-main.cc b/tools/iree-run-module-main.cc
index ccbca90..46cadda 100644
--- a/tools/iree-run-module-main.cc
+++ b/tools/iree-run-module-main.cc
@@ -20,7 +20,7 @@
#include "iree/tooling/comparison.h"
#include "iree/tooling/context_util.h"
#include "iree/tooling/device_util.h"
-#include "iree/tooling/vm_util_cc.h"
+#include "iree/tooling/vm_util.h"
#include "iree/vm/api.h"
IREE_FLAG(string, entry_function, "",
@@ -34,29 +34,8 @@
IREE_FLAG(bool, print_statistics, false,
"Prints runtime statistics to stderr on exit.");
-// TODO(benvanik): move --function_input= flag into a util.
-static iree_status_t parse_function_io(iree_string_view_t flag_name,
- void* storage,
- iree_string_view_t value) {
- auto* list = (std::vector<std::string>*)storage;
- list->push_back(std::string(value.data, value.size));
- return iree_ok_status();
-}
-static void print_function_io(iree_string_view_t flag_name, void* storage,
- FILE* file) {
- auto* list = (std::vector<std::string>*)storage;
- if (list->empty()) {
- fprintf(file, "# --%.*s=\n", (int)flag_name.size, flag_name.data);
- } else {
- for (size_t i = 0; i < list->size(); ++i) {
- fprintf(file, "--%.*s=\"%s\"\n", (int)flag_name.size, flag_name.data,
- list->at(i).c_str());
- }
- }
-}
-static std::vector<std::string> FLAG_function_inputs;
-IREE_FLAG_CALLBACK(
- parse_function_io, print_function_io, &FLAG_function_inputs, function_input,
+IREE_FLAG_LIST(
+ string, function_input,
"An input (a) value or (b) buffer of the format:\n"
" (a) scalar value\n"
" value\n"
@@ -73,14 +52,12 @@
"Each occurrence of the flag indicates an input in the order they were\n"
"specified on the command line.");
-static std::vector<std::string> FLAG_expected_outputs;
-IREE_FLAG_CALLBACK(parse_function_io, print_function_io, &FLAG_expected_outputs,
- expected_output,
- "An expected function output following the same format as "
- "--function_input. When present the results of the "
- "invocation will be compared against these values and the "
- "tool will return non-zero if any differ. If the value of a "
- "particular output is not of interest provide `(ignored)`.");
+IREE_FLAG_LIST(string, expected_output,
+ "An expected function output following the same format as "
+ "--function_input. When present the results of the "
+ "invocation will be compared against these values and the "
+ "tool will return non-zero if any differ. If the value of a "
+ "particular output is not of interest provide `(ignored)`.");
namespace iree {
namespace {
@@ -122,11 +99,9 @@
IREE_RETURN_IF_ERROR(iree_hal_begin_profiling_from_flags(device.get()));
vm::ref<iree_vm_list_t> inputs;
- IREE_RETURN_IF_ERROR(ParseToVariantList(
- device_allocator.get(),
- iree::span<const std::string>{FLAG_function_inputs.data(),
- FLAG_function_inputs.size()},
- host_allocator, &inputs));
+ IREE_RETURN_IF_ERROR(iree_tooling_parse_to_variant_list(
+ device_allocator.get(), FLAG_function_input_list().values,
+ FLAG_function_input_list().count, host_allocator, &inputs));
// If the function is async add fences so we can invoke it synchronously.
vm::ref<iree_hal_fence_t> finish_fence;
@@ -153,9 +128,10 @@
IREE_RETURN_IF_ERROR(iree_hal_end_profiling_from_flags(device.get()));
- if (FLAG_expected_outputs.empty()) {
+ if (FLAG_expected_output_list().count == 0) {
IREE_RETURN_IF_ERROR(
- PrintVariantList(outputs.get(), (size_t)FLAG_print_max_element_count),
+ iree_tooling_variant_list_fprint(
+ outputs.get(), (size_t)FLAG_print_max_element_count, stdout),
"printing results");
} else {
// Parse expected list into host-local memory that we can easily access.
@@ -164,9 +140,9 @@
IREE_RETURN_IF_ERROR(iree_hal_allocator_create_heap(
IREE_SV("heap"), host_allocator, host_allocator, &heap_allocator));
vm::ref<iree_vm_list_t> expected_list;
- IREE_RETURN_IF_ERROR(ParseToVariantList(heap_allocator.get(),
- FLAG_expected_outputs,
- host_allocator, &expected_list));
+ IREE_RETURN_IF_ERROR(iree_tooling_parse_to_variant_list(
+ heap_allocator.get(), FLAG_expected_output_list().values,
+ FLAG_expected_output_list().count, host_allocator, &expected_list));
// Compare expected vs actual lists and output diffs.
bool did_match = iree_tooling_compare_variant_lists(
diff --git a/tools/iree-run-trace-main.c b/tools/iree-run-trace-main.c
index 577e0b6..57a2053 100644
--- a/tools/iree-run-trace-main.c
+++ b/tools/iree-run-trace-main.c
@@ -38,7 +38,7 @@
// file will be used.
// TODO(#5724): remove this and instead provide a device set on initialize.
iree_host_size_t device_uri_count = 0;
- iree_string_view_t* device_uris = NULL;
+ const iree_string_view_t* device_uris = NULL;
iree_hal_get_devices_flag_list(&device_uri_count, &device_uris);
iree_trace_replay_set_hal_devices_override(&replay, device_uri_count,
device_uris);