Adding IREE_FLAG_LIST utility for repeated string flags. (#11806)
Currently only supports strings but we can extend it in the future with
some macro magic.
Many of the tools were using C++ to implement repeated flags for things
like `--function_input=` and now they can just use the common base C
APIs instead of needing extra C++ goo. Though passing giant inputs in
flags is never going to be great we now avoid the alloc+copy for each
flag that was previously required to do this.
diff --git a/runtime/src/iree/base/internal/flags.c b/runtime/src/iree/base/internal/flags.c
index 0ec1b29..86ffcef 100644
--- a/runtime/src/iree/base/internal/flags.c
+++ b/runtime/src/iree/base/internal/flags.c
@@ -147,6 +147,60 @@
}
//===----------------------------------------------------------------------===//
+// List parsing/printing
+//===----------------------------------------------------------------------===//
+
+iree_status_t iree_flag_string_list_parse(iree_string_view_t flag_name,
+ void* storage,
+ iree_string_view_t value) {
+ iree_flag_string_list_storage_t* flag =
+ (iree_flag_string_list_storage_t*)storage;
+ if (flag->count == 0) {
+ // Inline storage (common case).
+ flag->count = 1;
+ flag->inline_value = value;
+ } else if (flag->count == 1) {
+ // Switching from inline storage to external storage.
+ iree_host_size_t new_capacity = 4;
+ iree_string_view_t* values = NULL;
+ IREE_RETURN_IF_ERROR(iree_allocator_malloc(
+ iree_allocator_system(), sizeof(iree_string_view_t*) * new_capacity,
+ (void**)&values));
+ values[0] = flag->inline_value;
+ flag->capacity = new_capacity;
+ flag->values = values;
+ flag->values[flag->count++] = value;
+ } else {
+ // Growing external storage list.
+ iree_host_size_t new_capacity = iree_max(4, flag->capacity * 2);
+ IREE_RETURN_IF_ERROR(iree_allocator_realloc(
+ iree_allocator_system(), sizeof(iree_string_view_t*) * new_capacity,
+ (void**)&flag->values));
+ flag->capacity = new_capacity;
+ flag->values[flag->count++] = value;
+ }
+ return iree_ok_status();
+}
+
+void iree_flag_string_list_print(iree_string_view_t flag_name, void* storage,
+ FILE* file) {
+ iree_flag_string_list_storage_t* flag =
+ (iree_flag_string_list_storage_t*)storage;
+ if (flag->count == 0) {
+ fprintf(file, "# --%.*s=...\n", (int)flag_name.size, flag_name.data);
+ } else if (flag->count == 1) {
+ fprintf(file, "--%.*s=%.*s\n", (int)flag_name.size, flag_name.data,
+ (int)flag->inline_value.size, flag->inline_value.data);
+ } else {
+ for (iree_host_size_t i = 0; i < flag->count; ++i) {
+ const iree_string_view_t value = flag->values[i];
+ fprintf(file, "--%.*s=%.*s\n", (int)flag_name.size, flag_name.data,
+ (int)value.size, value.data);
+ }
+ }
+}
+
+//===----------------------------------------------------------------------===//
// Flag parsing/printing
//===----------------------------------------------------------------------===//
diff --git a/runtime/src/iree/base/internal/flags.h b/runtime/src/iree/base/internal/flags.h
index 213c1f3..cae4ef8 100644
--- a/runtime/src/iree/base/internal/flags.h
+++ b/runtime/src/iree/base/internal/flags.h
@@ -220,6 +220,75 @@
#endif // IREE_FLAGS_ENABLE_CLI
//===----------------------------------------------------------------------===//
+// List flag utilities
+//===----------------------------------------------------------------------===//
+
+// A list of string views referencing flag storage.
+typedef struct iree_flag_string_list_t {
+ // Total number of values in the list.
+ iree_host_size_t count;
+ // Value list or NULL if no values.
+ const iree_string_view_t* values;
+} iree_flag_string_list_t;
+
+#if IREE_FLAGS_ENABLE_CLI == 1
+
+// Internal storage; do not use.
+typedef struct iree_flag_string_list_storage_t {
+ iree_host_size_t capacity;
+ iree_host_size_t count;
+ union {
+ iree_string_view_t inline_value; // only if count == 1
+ iree_string_view_t* values; // only if count > 1
+ };
+} iree_flag_string_list_storage_t;
+iree_status_t iree_flag_string_list_parse(iree_string_view_t flag_name,
+ void* storage,
+ iree_string_view_t value);
+void iree_flag_string_list_print(iree_string_view_t flag_name, void* storage,
+ FILE* file);
+
+// Defines a repeated flag representing a dynamically sized list of values.
+//
+// Usage:
+// IREE_FLAG_LIST(string, foo, "hello");
+// ...
+// const iree_flag_string_list_t list = FLAG_foo_list();
+// for (iree_host_size_t i = 0; i < list.count; ++i) {
+// printf("value: %.*s", (int)list.values[i].size, list.values[i].data);
+// }
+// ...
+// ./binary --foo=a --foo=b
+// > value: a
+// > value: b
+#define IREE_FLAG_LIST(type, name, description) \
+ static iree_flag_##type##_list_storage_t FLAG_##name##_storage = { \
+ /*.capacity=*/1 /* inline by default */, \
+ /*.count=*/0, \
+ }; \
+ IREE_FLAG_CALLBACK(iree_flag_##type##_list_parse, \
+ iree_flag_##type##_list_print, &FLAG_##name##_storage, \
+ name, description); \
+ static const iree_flag_##type##_list_t FLAG_##name##_list(void) { \
+ const iree_flag_##type##_list_t list = { \
+ /*.count=*/FLAG_##name##_storage.count, \
+ /*.values=*/FLAG_##name##_storage.count == 1 \
+ ? &FLAG_##name##_storage.inline_value \
+ : FLAG_##name##_storage.values, \
+ }; \
+ return list; \
+ }
+
+#else
+
+#define IREE_FLAG_LIST(type, name, description) \
+ static const iree_flag_##type##_list_t FLAG_##name##_list(void) { \
+ return (iree_flag_##type##_list_t){0, NULL}; \
+ }
+
+#endif // IREE_FLAGS_ENABLE_CLI
+
+//===----------------------------------------------------------------------===//
// Flag parsing
//===----------------------------------------------------------------------===//
diff --git a/runtime/src/iree/base/internal/flags_demo.c b/runtime/src/iree/base/internal/flags_demo.c
index 82f1213..f199a4c 100644
--- a/runtime/src/iree/base/internal/flags_demo.c
+++ b/runtime/src/iree/base/internal/flags_demo.c
@@ -38,6 +38,8 @@
IREE_FLAG_CALLBACK(parse_callback, print_callback, &callback_count,
test_callback, "Callback!");
+IREE_FLAG_LIST(string, test_strings, "repeated");
+
int main(int argc, char** argv) {
// Parse flags, updating argc/argv with position arguments.
iree_flags_parse_checked(IREE_FLAGS_PARSE_MODE_DEFAULT, &argc, &argv);
@@ -50,6 +52,14 @@
printf("FLAG[test_string] = %s\n", FLAG_test_string);
printf("FLAG[test_callback] = %d\n", callback_count);
+ iree_flag_string_list_t strings = FLAG_test_strings_list();
+ printf("FLAG[test_strings] = %" PRIhsz ": ", strings.count);
+ for (iree_host_size_t i = 0; i < strings.count; ++i) {
+ if (i > 0) printf(", ");
+ printf("%.*s", (int)strings.values[i].size, strings.values[i].data);
+ }
+ printf("\n");
+
// Report positional arguments:
for (int i = 0; i < argc; ++i) {
printf("ARG(%d) = %s\n", i, argv[i]);
diff --git a/runtime/src/iree/base/internal/flags_test.txt b/runtime/src/iree/base/internal/flags_test.txt
index 74f14ed..d292d5b 100644
--- a/runtime/src/iree/base/internal/flags_test.txt
+++ b/runtime/src/iree/base/internal/flags_test.txt
@@ -68,6 +68,13 @@
// RUN: ( flags_demo --test_callback=FORCE_FAILURE 2>&1 || [[ $? == 1 ]] ) | FileCheck --check-prefix=FLAG-CALLBACK-ERROR %s
// FLAG-CALLBACK-ERROR: INTERNAL; callbacks can do verification
+// RUN: ( flags_demo ) | FileCheck --check-prefix=FLAG-LIST-0 %s
+// FLAG-LIST-0: FLAG[test_strings] = 0
+// RUN: ( flags_demo --test_strings=a ) | FileCheck --check-prefix=FLAG-LIST-1 %s
+// FLAG-LIST-1: FLAG[test_strings] = 1: a
+// RUN: ( flags_demo --test_strings=a --test_strings=b ) | FileCheck --check-prefix=FLAG-LIST-2 %s
+// FLAG-LIST-2: FLAG[test_strings] = 2: a, b
+
// RUN: ( flags_demo arg1 ) | FileCheck --check-prefix=FLAG-POSITIONAL-1 %s
// FLAG-POSITIONAL-1: ARG(1) = arg1
// RUN: ( flags_demo arg1 arg2 arg3 ) | FileCheck --check-prefix=FLAG-POSITIONAL-3 %s
diff --git a/runtime/src/iree/tooling/BUILD b/runtime/src/iree/tooling/BUILD
index 374a63e..21a752e 100644
--- a/runtime/src/iree/tooling/BUILD
+++ b/runtime/src/iree/tooling/BUILD
@@ -60,8 +60,9 @@
srcs = ["comparison_test.cc"],
deps = [
":comparison",
- ":vm_util_cc",
+ ":vm_util",
"//runtime/src/iree/base",
+ "//runtime/src/iree/base/internal:span",
"//runtime/src/iree/hal",
"//runtime/src/iree/modules/hal",
"//runtime/src/iree/testing:gtest",
@@ -147,27 +148,14 @@
],
)
-cc_library(
- name = "vm_util_cc",
- srcs = ["vm_util_cc.cc"],
- hdrs = ["vm_util_cc.h"],
- deps = [
- ":vm_util",
- "//runtime/src/iree/base",
- "//runtime/src/iree/base:tracing",
- "//runtime/src/iree/base/internal:span",
- "//runtime/src/iree/hal",
- "//runtime/src/iree/vm",
- ],
-)
-
cc_test(
name = "vm_util_test",
srcs = ["vm_util_test.cc"],
deps = [
":device_util",
- ":vm_util_cc",
+ ":vm_util",
"//runtime/src/iree/base",
+ "//runtime/src/iree/base/internal:span",
"//runtime/src/iree/hal",
"//runtime/src/iree/modules/hal",
"//runtime/src/iree/testing:gtest",
diff --git a/runtime/src/iree/tooling/CMakeLists.txt b/runtime/src/iree/tooling/CMakeLists.txt
index cf99a51..c030f22 100644
--- a/runtime/src/iree/tooling/CMakeLists.txt
+++ b/runtime/src/iree/tooling/CMakeLists.txt
@@ -67,8 +67,9 @@
"comparison_test.cc"
DEPS
::comparison
- ::vm_util_cc
+ ::vm_util
iree::base
+ iree::base::internal::span
iree::hal
iree::modules::hal
iree::testing::gtest
@@ -166,23 +167,6 @@
PUBLIC
)
-iree_cc_library(
- NAME
- vm_util_cc
- HDRS
- "vm_util_cc.h"
- SRCS
- "vm_util_cc.cc"
- DEPS
- ::vm_util
- iree::base
- iree::base::internal::span
- iree::base::tracing
- iree::hal
- iree::vm
- PUBLIC
-)
-
iree_cc_test(
NAME
vm_util_test
@@ -190,8 +174,9 @@
"vm_util_test.cc"
DEPS
::device_util
- ::vm_util_cc
+ ::vm_util
iree::base
+ iree::base::internal::span
iree::hal
iree::modules::hal
iree::testing::gtest
diff --git a/runtime/src/iree/tooling/comparison_test.cc b/runtime/src/iree/tooling/comparison_test.cc
index 5fcc966..b8e55b8 100644
--- a/runtime/src/iree/tooling/comparison_test.cc
+++ b/runtime/src/iree/tooling/comparison_test.cc
@@ -7,11 +7,12 @@
#include "iree/tooling/comparison.h"
#include "iree/base/api.h"
+#include "iree/base/internal/span.h"
#include "iree/hal/api.h"
#include "iree/modules/hal/module.h"
#include "iree/testing/gtest.h"
#include "iree/testing/status_matchers.h"
-#include "iree/tooling/vm_util_cc.h"
+#include "iree/tooling/vm_util.h"
#include "iree/vm/api.h"
namespace iree {
@@ -19,6 +20,20 @@
using ::testing::HasSubstr;
+static void ParseToVariantList(iree_hal_allocator_t* device_allocator,
+ iree::span<const std::string> input_strings,
+ iree_allocator_t host_allocator,
+ iree_vm_list_t** out_list) {
+ std::vector<iree_string_view_t> input_string_views(input_strings.size());
+ for (size_t i = 0; i < input_strings.size(); ++i) {
+ input_string_views[i].data = input_strings[i].data();
+ input_string_views[i].size = input_strings[i].size();
+ }
+ IREE_CHECK_OK(iree_tooling_parse_to_variant_list(
+ device_allocator, input_string_views.data(), input_string_views.size(),
+ host_allocator, out_list));
+}
+
class ComparisonTest : public ::testing::Test {
protected:
virtual void SetUp() {
@@ -37,12 +52,13 @@
iree::span<const std::string> expected_strings,
iree::span<const std::string> actual_strings, std::string* out_string) {
vm::ref<iree_vm_list_t> expected_list;
- IREE_CHECK_OK(ParseToVariantList(device_allocator_, expected_strings,
- host_allocator_, &expected_list));
+ ParseToVariantList(device_allocator_, expected_strings, host_allocator_,
+ &expected_list);
vm::ref<iree_vm_list_t> actual_list;
- IREE_CHECK_OK(ParseToVariantList(device_allocator_, actual_strings,
- host_allocator_, &actual_list));
+ ParseToVariantList(device_allocator_, actual_strings, host_allocator_,
+ &actual_list);
+
iree_string_builder_t builder;
iree_string_builder_initialize(host_allocator_, &builder);
bool all_match = iree_tooling_compare_variant_lists_and_append(
diff --git a/runtime/src/iree/tooling/device_util.c b/runtime/src/iree/tooling/device_util.c
index 3472da5..146f35e 100644
--- a/runtime/src/iree/tooling/device_util.c
+++ b/runtime/src/iree/tooling/device_util.c
@@ -284,96 +284,25 @@
// Device selection
//===----------------------------------------------------------------------===//
-// Common case is for 1 device, so to keep our memory profiles simpler we allow
-// for storing a single device inline in the struct. As soon as more than one
-// device is used we grow it.
-typedef struct iree_hal_devices_flag_t {
- iree_host_size_t capacity;
- iree_host_size_t count;
- union {
- iree_string_view_t inline_uri; // only if count == 1
- iree_string_view_t* uris; // only if count > 1
- };
-} iree_hal_devices_flag_t;
-iree_hal_devices_flag_t iree_hal_devices_flag = {
- .capacity = 1, // inline
- .count = 0,
- .uris = NULL,
-};
-
-static iree_status_t iree_hal_flags_parse_device(iree_string_view_t flag_name,
- void* storage,
- iree_string_view_t value) {
- iree_hal_devices_flag_t* flag = (iree_hal_devices_flag_t*)storage;
- if (flag->count == 0) {
- // Inline storage (common case).
- flag->count = 1;
- flag->inline_uri = value;
- } else if (flag->count == 1) {
- // Switching from inline storage to external storage.
- iree_host_size_t new_capacity = 4;
- iree_string_view_t* uris = NULL;
- IREE_RETURN_IF_ERROR(iree_allocator_malloc(
- iree_allocator_system(), sizeof(iree_string_view_t*) * new_capacity,
- (void**)&uris));
- uris[0] = flag->inline_uri;
- flag->capacity = new_capacity;
- flag->uris = uris;
- flag->uris[flag->count++] = value;
- } else {
- // Growing external storage list.
- iree_host_size_t new_capacity = iree_max(4, flag->capacity * 2);
- IREE_RETURN_IF_ERROR(iree_allocator_realloc(
- iree_allocator_system(), sizeof(iree_string_view_t*) * new_capacity,
- (void**)&flag->uris));
- flag->capacity = new_capacity;
- flag->uris[flag->count++] = value;
- }
- return iree_ok_status();
-}
-
-static void iree_hal_flags_print_devices(iree_string_view_t flag_name,
- void* storage, FILE* file) {
- iree_hal_devices_flag_t* flag = (iree_hal_devices_flag_t*)storage;
- if (flag->count == 0) {
- fprintf(file, "# --%.*s=driver://path?params\n", (int)flag_name.size,
- flag_name.data);
- } else if (flag->count == 1) {
- fprintf(file, "--%.*s=%.*s\n", (int)flag_name.size, flag_name.data,
- (int)flag->inline_uri.size, flag->inline_uri.data);
- } else {
- for (iree_host_size_t i = 0; i < flag->count; ++i) {
- const iree_string_view_t device_uri = flag->uris[i];
- fprintf(file, "--%.*s=%.*s\n", (int)flag_name.size, flag_name.data,
- (int)device_uri.size, device_uri.data);
- }
- }
-}
-
-IREE_FLAG_CALLBACK(
- iree_hal_flags_parse_device, iree_hal_flags_print_devices,
- &iree_hal_devices_flag, device,
+IREE_FLAG_LIST(
+ string, device,
"Specifies one or more HAL devices to use for execution.\n"
"Use --list_devices/--dump_devices to see available devices and their\n"
"canonical URI used with this flag.");
// TODO(#5724): remove this and replace with an iree_hal_device_set_t.
void iree_hal_get_devices_flag_list(iree_host_size_t* out_count,
- iree_string_view_t** out_list) {
- *out_count = iree_hal_devices_flag.count;
- if (iree_hal_devices_flag.count == 1) {
- *out_list = &iree_hal_devices_flag.inline_uri;
- } else {
- *out_list = iree_hal_devices_flag.uris;
- }
+ const iree_string_view_t** out_list) {
+ *out_count = FLAG_device_list().count;
+ *out_list = FLAG_device_list().values;
}
iree_status_t iree_hal_create_device_from_flags(
iree_string_view_t default_device, iree_allocator_t host_allocator,
iree_hal_device_t** out_device) {
iree_string_view_t device_uri = default_device;
- const iree_hal_devices_flag_t* flag = &iree_hal_devices_flag;
- if (flag->count == 0) {
+ const iree_flag_string_list_t list = FLAG_device_list();
+ if (list.count == 0) {
// No devices specified. Use default if provided.
if (iree_string_view_is_empty(default_device)) {
return iree_make_status(
@@ -381,17 +310,28 @@
"no device specified; use --list_devices to see the "
"available devices and specify one with --device=");
}
- } else if (flag->count > 1) {
+ } else if (list.count > 1) {
// Too many devices for the single device creation function.
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"too many devices specified; only one --device= "
"flag may be provided with this API");
} else {
// Exactly one device specified.
- device_uri = flag->inline_uri;
+ device_uri = list.values[0];
}
- return iree_hal_create_device(iree_hal_available_driver_registry(),
- device_uri, host_allocator, out_device);
+
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ // Create the device, which may be slow and dynamically load big dependencies
+ // (CUDA, Vulkan, etc).
+ iree_hal_device_t* device = NULL;
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_hal_create_device(iree_hal_available_driver_registry(),
+ device_uri, host_allocator, &device));
+
+ *out_device = device;
+ IREE_TRACE_ZONE_END(z0);
+ return iree_ok_status();
}
//===----------------------------------------------------------------------===//
diff --git a/runtime/src/iree/tooling/device_util.h b/runtime/src/iree/tooling/device_util.h
index 2891547..915bb4b 100644
--- a/runtime/src/iree/tooling/device_util.h
+++ b/runtime/src/iree/tooling/device_util.h
@@ -24,7 +24,7 @@
// TODO(#5724): remove this and replace with an iree_hal_device_set_t.
void iree_hal_get_devices_flag_list(iree_host_size_t* out_count,
- iree_string_view_t** out_list);
+ const iree_string_view_t** out_list);
// Creates a single device from the --device= flag.
// Uses the |default_device| if no flags were specified.
diff --git a/runtime/src/iree/tooling/vm_util.c b/runtime/src/iree/tooling/vm_util.c
index 32b454b..506836a 100644
--- a/runtime/src/iree/tooling/vm_util.c
+++ b/runtime/src/iree/tooling/vm_util.c
@@ -146,7 +146,8 @@
}
iree_status_t iree_tooling_parse_to_variant_list(
- iree_hal_allocator_t* device_allocator, iree_string_view_t* input_strings,
+ iree_hal_allocator_t* device_allocator,
+ const iree_string_view_t* input_strings,
iree_host_size_t input_strings_count, iree_allocator_t host_allocator,
iree_vm_list_t** out_list) {
IREE_TRACE_ZONE_BEGIN(z0);
@@ -161,8 +162,7 @@
iree_status_t status = iree_ok_status();
for (size_t i = 0; i < input_strings_count; ++i) {
if (!iree_status_is_ok(status)) break;
- iree_string_view_t input_view = iree_string_view_trim(
- iree_make_string_view(input_strings[i].data, input_strings[i].size));
+ iree_string_view_t input_view = iree_string_view_trim(input_strings[i]);
if (iree_string_view_consume_prefix(&input_view, IREE_SV("@"))) {
status = iree_tooling_load_ndarrays_from_file(
input_view, device_allocator, variant_list);
diff --git a/runtime/src/iree/tooling/vm_util.h b/runtime/src/iree/tooling/vm_util.h
index 8264254..ab51cfe 100644
--- a/runtime/src/iree/tooling/vm_util.h
+++ b/runtime/src/iree/tooling/vm_util.h
@@ -27,7 +27,8 @@
// Uses |device_allocator| to allocate the buffers.
// The returned variant list must be freed by the caller.
iree_status_t iree_tooling_parse_to_variant_list(
- iree_hal_allocator_t* device_allocator, iree_string_view_t* input_strings,
+ iree_hal_allocator_t* device_allocator,
+ const iree_string_view_t* input_strings,
iree_host_size_t input_strings_count, iree_allocator_t host_allocator,
iree_vm_list_t** out_list);
diff --git a/runtime/src/iree/tooling/vm_util_cc.cc b/runtime/src/iree/tooling/vm_util_cc.cc
deleted file mode 100644
index 57c020e..0000000
--- a/runtime/src/iree/tooling/vm_util_cc.cc
+++ /dev/null
@@ -1,41 +0,0 @@
-// Copyright 2020 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#include "iree/tooling/vm_util_cc.h"
-
-#include <vector>
-
-#include "iree/vm/api.h"
-
-namespace iree {
-
-Status ParseToVariantList(iree_hal_allocator_t* device_allocator,
- iree::span<const std::string> input_strings,
- iree_allocator_t host_allocator,
- iree_vm_list_t** out_list) {
- std::vector<iree_string_view_t> input_string_views(input_strings.size());
- for (size_t i = 0; i < input_strings.size(); ++i) {
- input_string_views[i].data = input_strings[i].data();
- input_string_views[i].size = input_strings[i].size();
- }
- return iree_tooling_parse_to_variant_list(
- device_allocator, input_string_views.data(), input_string_views.size(),
- host_allocator, out_list);
-}
-
-Status PrintVariantList(iree_vm_list_t* variant_list, size_t max_element_count,
- std::string* out_string) {
- iree_string_builder_t builder;
- iree_string_builder_initialize(iree_allocator_system(), &builder);
- IREE_RETURN_IF_ERROR(iree_tooling_append_variant_list_lines(
- variant_list, max_element_count, &builder));
- out_string->assign(iree_string_builder_buffer(&builder),
- iree_string_builder_size(&builder));
- iree_string_builder_deinitialize(&builder);
- return iree_ok_status();
-}
-
-} // namespace iree
diff --git a/runtime/src/iree/tooling/vm_util_cc.h b/runtime/src/iree/tooling/vm_util_cc.h
deleted file mode 100644
index fce5099..0000000
--- a/runtime/src/iree/tooling/vm_util_cc.h
+++ /dev/null
@@ -1,55 +0,0 @@
-// Copyright 2020 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#ifndef IREE_TOOLING_VM_UTIL_CC_H_
-#define IREE_TOOLING_VM_UTIL_CC_H_
-
-#include <string>
-#include <vector>
-
-#include "iree/base/api.h"
-#include "iree/base/internal/span.h"
-#include "iree/hal/api.h"
-#include "iree/tooling/vm_util.h"
-#include "iree/vm/api.h"
-
-namespace iree {
-
-// NOTE: this file is not best-practice and needs to be rewritten; consider this
-// appropriate only for test code.
-
-// Parses |input_strings| into a variant list of VM scalars and buffers.
-// Scalars should be in the format:
-// type=value
-// Buffers should be in the IREE standard shaped buffer format:
-// [shape]xtype=[value]
-// described in iree/hal/api.h
-// Uses |device_allocator| to allocate the buffers.
-// The returned variant list must be freed by the caller.
-Status ParseToVariantList(iree_hal_allocator_t* device_allocator,
- iree::span<const std::string> input_strings,
- iree_allocator_t host_allocator,
- iree_vm_list_t** out_list);
-
-// Prints a variant list to |out_string|.
-Status PrintVariantList(iree_vm_list_t* variant_list, size_t max_element_count,
- std::string* out_string);
-
-inline Status PrintVariantList(iree_vm_list_t* variant_list,
- std::string* out_string) {
- return PrintVariantList(variant_list, 1024, out_string);
-}
-
-// Prints a variant list to stdout.
-inline Status PrintVariantList(iree_vm_list_t* variant_list,
- size_t max_element_count = 1024) {
- return iree_tooling_variant_list_fprint(variant_list, max_element_count,
- stdout);
-}
-
-} // namespace iree
-
-#endif // IREE_TOOLING_VM_UTIL_CC_H_
diff --git a/runtime/src/iree/tooling/vm_util_test.cc b/runtime/src/iree/tooling/vm_util_test.cc
index 41cc3b2..5d33a68 100644
--- a/runtime/src/iree/tooling/vm_util_test.cc
+++ b/runtime/src/iree/tooling/vm_util_test.cc
@@ -4,18 +4,46 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+#include "iree/tooling/vm_util.h"
+
#include "iree/base/api.h"
+#include "iree/base/internal/span.h"
#include "iree/hal/api.h"
#include "iree/modules/hal/module.h"
#include "iree/testing/gtest.h"
#include "iree/testing/status_matchers.h"
#include "iree/tooling/device_util.h"
-#include "iree/tooling/vm_util_cc.h"
#include "iree/vm/api.h"
namespace iree {
namespace {
+static Status ParseToVariantList(iree_hal_allocator_t* device_allocator,
+ iree::span<const std::string> input_strings,
+ iree_allocator_t host_allocator,
+ iree_vm_list_t** out_list) {
+ std::vector<iree_string_view_t> input_string_views(input_strings.size());
+ for (size_t i = 0; i < input_strings.size(); ++i) {
+ input_string_views[i].data = input_strings[i].data();
+ input_string_views[i].size = input_strings[i].size();
+ }
+ return iree_tooling_parse_to_variant_list(
+ device_allocator, input_string_views.data(), input_string_views.size(),
+ host_allocator, out_list);
+}
+
+static Status PrintVariantList(iree_vm_list_t* variant_list,
+ std::string* out_string) {
+ iree_string_builder_t builder;
+ iree_string_builder_initialize(iree_allocator_system(), &builder);
+ IREE_RETURN_IF_ERROR(iree_tooling_append_variant_list_lines(
+ variant_list, /*max_element_count=*/1024, &builder));
+ out_string->assign(iree_string_builder_buffer(&builder),
+ iree_string_builder_size(&builder));
+ iree_string_builder_deinitialize(&builder);
+ return iree_ok_status();
+}
+
class VmUtilTest : public ::testing::Test {
protected:
virtual void SetUp() {
diff --git a/tools/BUILD b/tools/BUILD
index 2fd3a71..e4af67c 100644
--- a/tools/BUILD
+++ b/tools/BUILD
@@ -33,7 +33,7 @@
"//runtime/src/iree/modules/hal:types",
"//runtime/src/iree/tooling:context_util",
"//runtime/src/iree/tooling:device_util",
- "//runtime/src/iree/tooling:vm_util_cc",
+ "//runtime/src/iree/tooling:vm_util",
"//runtime/src/iree/vm",
"@com_google_benchmark//:benchmark",
],
@@ -73,7 +73,7 @@
"//runtime/src/iree/testing:gtest",
"//runtime/src/iree/tooling:context_util",
"//runtime/src/iree/tooling:device_util",
- "//runtime/src/iree/tooling:vm_util_cc",
+ "//runtime/src/iree/tooling:vm_util",
"//runtime/src/iree/vm",
"//runtime/src/iree/vm:bytecode_module",
],
@@ -129,7 +129,7 @@
"//runtime/src/iree/modules/hal:types",
"//runtime/src/iree/tooling:context_util",
"//runtime/src/iree/tooling:device_util",
- "//runtime/src/iree/tooling:vm_util_cc",
+ "//runtime/src/iree/tooling:vm_util",
"//runtime/src/iree/vm",
"//runtime/src/iree/vm:bytecode_module",
"@llvm-project//llvm:Support",
@@ -154,7 +154,7 @@
"//runtime/src/iree/tooling:comparison",
"//runtime/src/iree/tooling:context_util",
"//runtime/src/iree/tooling:device_util",
- "//runtime/src/iree/tooling:vm_util_cc",
+ "//runtime/src/iree/tooling:vm_util",
"//runtime/src/iree/vm",
],
)
diff --git a/tools/CMakeLists.txt b/tools/CMakeLists.txt
index e94fa4e..82f188e 100644
--- a/tools/CMakeLists.txt
+++ b/tools/CMakeLists.txt
@@ -53,7 +53,7 @@
iree::modules::hal::types
iree::tooling::context_util
iree::tooling::device_util
- iree::tooling::vm_util_cc
+ iree::tooling::vm_util
iree::vm
)
@@ -94,7 +94,7 @@
iree::testing::gtest
iree::tooling::context_util
iree::tooling::device_util
- iree::tooling::vm_util_cc
+ iree::tooling::vm_util
iree::vm
iree::vm::bytecode_module
TESTONLY
@@ -128,7 +128,7 @@
iree::tooling::comparison
iree::tooling::context_util
iree::tooling::device_util
- iree::tooling::vm_util_cc
+ iree::tooling::vm_util
iree::vm
)
@@ -254,7 +254,7 @@
iree::modules::hal::types
iree::tooling::context_util
iree::tooling::device_util
- iree::tooling::vm_util_cc
+ iree::tooling::vm_util
iree::vm
iree::vm::bytecode_module
DATA
diff --git a/tools/android/run_module_app/CMakeLists.txt b/tools/android/run_module_app/CMakeLists.txt
index dfb229d..4d36e80 100644
--- a/tools/android/run_module_app/CMakeLists.txt
+++ b/tools/android/run_module_app/CMakeLists.txt
@@ -34,7 +34,7 @@
iree::base
iree::modules::hal
iree::tooling::device_util
- iree::tooling::vm_util_cc
+ iree::tooling::vm_util
iree::vm
LINKOPTS
"-landroid"
diff --git a/tools/android/run_module_app/src/main.cc b/tools/android/run_module_app/src/main.cc
index dfdb135..193595e 100644
--- a/tools/android/run_module_app/src/main.cc
+++ b/tools/android/run_module_app/src/main.cc
@@ -12,11 +12,12 @@
#include <sstream>
#include <string>
#include <thread>
+#include <vector>
#include "iree/base/api.h"
#include "iree/modules/hal/module.h"
#include "iree/tooling/device_util.h"
-#include "iree/tooling/vm_util_cc.h"
+#include "iree/tooling/vm_util.h"
#include "iree/vm/api.h"
#include "iree/vm/bytecode_module.h"
@@ -130,17 +131,18 @@
&function),
"looking up function '%s'", function_name.c_str());
- size_t pos = 0;
- std::string inputs_str = invocation.inputs;
- std::vector<std::string> input_views;
- while ((pos = inputs_str.find('\n')) != std::string::npos) {
- input_views.push_back(inputs_str.substr(0, pos));
- inputs_str.erase(0, pos + 1);
+ std::vector<iree_string_view_t> input_views;
+ iree_string_view_t inputs_view =
+ iree_make_string_view(invocation.inputs.data(), invocation.inputs.size());
+ while (!iree_string_view_is_empty(inputs_view)) {
+ iree_string_view_t input_view = iree_string_view_empty();
+ iree_string_view_split(inputs_view, '\n', &input_view, &inputs_view);
+ input_views.push_back(input_view);
}
vm::ref<iree_vm_list_t> inputs;
- IREE_RETURN_IF_ERROR(ParseToVariantList(iree_hal_device_allocator(device),
- input_views, iree_allocator_system(),
- &inputs));
+ IREE_RETURN_IF_ERROR(iree_tooling_parse_to_variant_list(
+ iree_hal_device_allocator(device), input_views.data(), input_views.size(),
+ iree_allocator_system(), &inputs));
vm::ref<iree_vm_list_t> outputs;
IREE_RETURN_IF_ERROR(iree_vm_list_create(/*element_type=*/nullptr, 16,
@@ -153,11 +155,16 @@
iree_allocator_system()),
"invoking function '%s'", function_name.c_str());
- std::string result;
- IREE_RETURN_IF_ERROR(PrintVariantList(outputs.get(), &result),
- "printing results");
+ iree_string_builder_t result_str;
+ iree_string_builder_initialize(iree_allocator_system(), &result_str);
+ IREE_RETURN_IF_ERROR(
+ iree_tooling_append_variant_list_lines(
+ outputs.get(), /*max_element_count=*/1024, &result_str),
+ "printing results");
LOGI("Execution Result:");
- LOGI("%s", result.c_str());
+ LOGI("%.*s", (int)iree_string_builder_size(&result_str),
+ iree_string_builder_buffer(&result_str));
+ iree_string_builder_deinitialize(&result_str);
inputs.reset();
outputs.reset();
diff --git a/tools/iree-benchmark-module-main.cc b/tools/iree-benchmark-module-main.cc
index a7fc39b..e67bd5c 100644
--- a/tools/iree-benchmark-module-main.cc
+++ b/tools/iree-benchmark-module-main.cc
@@ -68,7 +68,7 @@
#include "iree/modules/hal/types.h"
#include "iree/tooling/context_util.h"
#include "iree/tooling/device_util.h"
-#include "iree/tooling/vm_util_cc.h"
+#include "iree/tooling/vm_util.h"
#include "iree/vm/api.h"
constexpr char kNanosecondsUnitString[] = "ns";
@@ -92,30 +92,8 @@
IREE_FLAG(bool, print_statistics, false,
"Prints runtime statistics to stderr on exit.");
-// TODO(benvanik): move --function_input= flag into a util.
-static iree_status_t parse_function_input(iree_string_view_t flag_name,
- void* storage,
- iree_string_view_t value) {
- auto* list = (std::vector<std::string>*)storage;
- list->push_back(std::string(value.data, value.size));
- return iree_ok_status();
-}
-static void print_function_input(iree_string_view_t flag_name, void* storage,
- FILE* file) {
- auto* list = (std::vector<std::string>*)storage;
- if (list->empty()) {
- fprintf(file, "# --%.*s=\n", (int)flag_name.size, flag_name.data);
- } else {
- for (size_t i = 0; i < list->size(); ++i) {
- fprintf(file, "--%.*s=\"%s\"\n", (int)flag_name.size, flag_name.data,
- list->at(i).c_str());
- }
- }
-}
-static std::vector<std::string> FLAG_function_inputs;
-IREE_FLAG_CALLBACK(
- parse_function_input, print_function_input, &FLAG_function_inputs,
- function_input,
+IREE_FLAG_LIST(
+ string, function_input,
"An input value or buffer of the format:\n"
" [shape]xtype=[value]\n"
" 2x2xi32=1 2 3 4\n"
@@ -499,10 +477,9 @@
iree_string_view_t{function_name.data(), function_name.size()},
&function));
- IREE_CHECK_OK(ParseToVariantList(
- device_allocator_.get(),
- iree::span<const std::string>{FLAG_function_inputs.data(),
- FLAG_function_inputs.size()},
+ IREE_CHECK_OK(iree_tooling_parse_to_variant_list(
+ device_allocator_.get(), FLAG_function_input_list().values,
+ FLAG_function_input_list().count,
iree_vm_instance_allocator(instance_.get()), &inputs_));
iree_string_view_t invocation_model = iree_vm_function_lookup_attr_by_name(
diff --git a/tools/iree-benchmark-trace-main.c b/tools/iree-benchmark-trace-main.c
index 266df43..1a34580 100644
--- a/tools/iree-benchmark-trace-main.c
+++ b/tools/iree-benchmark-trace-main.c
@@ -206,7 +206,7 @@
// file will be used.
// TODO(#5724): remove this and instead provide a device set on initialize.
iree_host_size_t device_uri_count = 0;
- iree_string_view_t* device_uris = NULL;
+ const iree_string_view_t* device_uris = NULL;
iree_hal_get_devices_flag_list(&device_uri_count, &device_uris);
iree_trace_replay_set_hal_devices_override(&replay, device_uri_count,
device_uris);
diff --git a/tools/iree-check-module-main.cc b/tools/iree-check-module-main.cc
index 38f53f6..2d5ba6a 100644
--- a/tools/iree-check-module-main.cc
+++ b/tools/iree-check-module-main.cc
@@ -22,7 +22,7 @@
#include "iree/testing/status_matchers.h"
#include "iree/tooling/context_util.h"
#include "iree/tooling/device_util.h"
-#include "iree/tooling/vm_util_cc.h"
+#include "iree/tooling/vm_util.h"
#include "iree/vm/api.h"
#include "iree/vm/bytecode_module.h"
diff --git a/tools/iree-e2e-matmul-test.c b/tools/iree-e2e-matmul-test.c
index a00f1ef..217d4e6 100644
--- a/tools/iree-e2e-matmul-test.c
+++ b/tools/iree-e2e-matmul-test.c
@@ -1165,7 +1165,7 @@
// file will be used.
// TODO(#5724): remove this and instead provide a device set on initialize.
iree_host_size_t device_uri_count = 0;
- iree_string_view_t* device_uris = NULL;
+ const iree_string_view_t* device_uris = NULL;
iree_hal_get_devices_flag_list(&device_uri_count, &device_uris);
iree_trace_replay_set_hal_devices_override(&replay, device_uri_count,
device_uris);
diff --git a/tools/iree-run-mlir-main.cc b/tools/iree-run-mlir-main.cc
index ba7ffee..aaee8ca 100644
--- a/tools/iree-run-mlir-main.cc
+++ b/tools/iree-run-mlir-main.cc
@@ -52,7 +52,7 @@
#include "iree/modules/hal/types.h"
#include "iree/tooling/context_util.h"
#include "iree/tooling/device_util.h"
-#include "iree/tooling/vm_util_cc.h"
+#include "iree/tooling/vm_util.h"
#include "iree/vm/api.h"
#include "iree/vm/bytecode_module.h"
#include "llvm/ADT/STLExtras.h"
@@ -138,30 +138,8 @@
llvm::cl::ConsumeAfter,
};
-// TODO(benvanik): move --function_input= flag into a util.
-static iree_status_t parse_function_input(iree_string_view_t flag_name,
- void* storage,
- iree_string_view_t value) {
- auto* list = (std::vector<std::string>*)storage;
- list->push_back(std::string(value.data, value.size));
- return iree_ok_status();
-}
-static void print_function_input(iree_string_view_t flag_name, void* storage,
- FILE* file) {
- auto* list = (std::vector<std::string>*)storage;
- if (list->empty()) {
- fprintf(file, "# --%.*s=\n", (int)flag_name.size, flag_name.data);
- } else {
- for (size_t i = 0; i < list->size(); ++i) {
- fprintf(file, "--%.*s=\"%s\"\n", (int)flag_name.size, flag_name.data,
- list->at(i).c_str());
- }
- }
-}
-static std::vector<std::string> FLAG_function_inputs;
-IREE_FLAG_CALLBACK(
- parse_function_input, print_function_input, &FLAG_function_inputs,
- function_input,
+IREE_FLAG_LIST(
+ string, function_input,
"An input value or buffer of the format:\n"
" [shape]xtype=[value]\n"
" 2x2xi32=1 2 3 4\n"
@@ -342,11 +320,9 @@
// Parse input values from the flags.
vm::ref<iree_vm_list_t> inputs;
- IREE_RETURN_IF_ERROR(ParseToVariantList(
- device_allocator,
- iree::span<const std::string>{FLAG_function_inputs.data(),
- FLAG_function_inputs.size()},
- host_allocator, &inputs));
+ IREE_RETURN_IF_ERROR(iree_tooling_parse_to_variant_list(
+ device_allocator, FLAG_function_input_list().values,
+ FLAG_function_input_list().count, host_allocator, &inputs));
// If the function is async add fences so we can invoke it synchronously.
vm::ref<iree_hal_fence_t> finish_fence;
@@ -370,7 +346,8 @@
}
// Print outputs.
- IREE_RETURN_IF_ERROR(PrintVariantList(outputs.get()));
+ IREE_RETURN_IF_ERROR(iree_tooling_variant_list_fprint(
+ outputs.get(), /*max_element_count=*/1024, stdout));
return OkStatus();
}
diff --git a/tools/iree-run-module-main.cc b/tools/iree-run-module-main.cc
index ccbca90..46cadda 100644
--- a/tools/iree-run-module-main.cc
+++ b/tools/iree-run-module-main.cc
@@ -20,7 +20,7 @@
#include "iree/tooling/comparison.h"
#include "iree/tooling/context_util.h"
#include "iree/tooling/device_util.h"
-#include "iree/tooling/vm_util_cc.h"
+#include "iree/tooling/vm_util.h"
#include "iree/vm/api.h"
IREE_FLAG(string, entry_function, "",
@@ -34,29 +34,8 @@
IREE_FLAG(bool, print_statistics, false,
"Prints runtime statistics to stderr on exit.");
-// TODO(benvanik): move --function_input= flag into a util.
-static iree_status_t parse_function_io(iree_string_view_t flag_name,
- void* storage,
- iree_string_view_t value) {
- auto* list = (std::vector<std::string>*)storage;
- list->push_back(std::string(value.data, value.size));
- return iree_ok_status();
-}
-static void print_function_io(iree_string_view_t flag_name, void* storage,
- FILE* file) {
- auto* list = (std::vector<std::string>*)storage;
- if (list->empty()) {
- fprintf(file, "# --%.*s=\n", (int)flag_name.size, flag_name.data);
- } else {
- for (size_t i = 0; i < list->size(); ++i) {
- fprintf(file, "--%.*s=\"%s\"\n", (int)flag_name.size, flag_name.data,
- list->at(i).c_str());
- }
- }
-}
-static std::vector<std::string> FLAG_function_inputs;
-IREE_FLAG_CALLBACK(
- parse_function_io, print_function_io, &FLAG_function_inputs, function_input,
+IREE_FLAG_LIST(
+ string, function_input,
"An input (a) value or (b) buffer of the format:\n"
" (a) scalar value\n"
" value\n"
@@ -73,14 +52,12 @@
"Each occurrence of the flag indicates an input in the order they were\n"
"specified on the command line.");
-static std::vector<std::string> FLAG_expected_outputs;
-IREE_FLAG_CALLBACK(parse_function_io, print_function_io, &FLAG_expected_outputs,
- expected_output,
- "An expected function output following the same format as "
- "--function_input. When present the results of the "
- "invocation will be compared against these values and the "
- "tool will return non-zero if any differ. If the value of a "
- "particular output is not of interest provide `(ignored)`.");
+IREE_FLAG_LIST(string, expected_output,
+ "An expected function output following the same format as "
+ "--function_input. When present the results of the "
+ "invocation will be compared against these values and the "
+ "tool will return non-zero if any differ. If the value of a "
+ "particular output is not of interest provide `(ignored)`.");
namespace iree {
namespace {
@@ -122,11 +99,9 @@
IREE_RETURN_IF_ERROR(iree_hal_begin_profiling_from_flags(device.get()));
vm::ref<iree_vm_list_t> inputs;
- IREE_RETURN_IF_ERROR(ParseToVariantList(
- device_allocator.get(),
- iree::span<const std::string>{FLAG_function_inputs.data(),
- FLAG_function_inputs.size()},
- host_allocator, &inputs));
+ IREE_RETURN_IF_ERROR(iree_tooling_parse_to_variant_list(
+ device_allocator.get(), FLAG_function_input_list().values,
+ FLAG_function_input_list().count, host_allocator, &inputs));
// If the function is async add fences so we can invoke it synchronously.
vm::ref<iree_hal_fence_t> finish_fence;
@@ -153,9 +128,10 @@
IREE_RETURN_IF_ERROR(iree_hal_end_profiling_from_flags(device.get()));
- if (FLAG_expected_outputs.empty()) {
+ if (FLAG_expected_output_list().count == 0) {
IREE_RETURN_IF_ERROR(
- PrintVariantList(outputs.get(), (size_t)FLAG_print_max_element_count),
+ iree_tooling_variant_list_fprint(
+ outputs.get(), (size_t)FLAG_print_max_element_count, stdout),
"printing results");
} else {
// Parse expected list into host-local memory that we can easily access.
@@ -164,9 +140,9 @@
IREE_RETURN_IF_ERROR(iree_hal_allocator_create_heap(
IREE_SV("heap"), host_allocator, host_allocator, &heap_allocator));
vm::ref<iree_vm_list_t> expected_list;
- IREE_RETURN_IF_ERROR(ParseToVariantList(heap_allocator.get(),
- FLAG_expected_outputs,
- host_allocator, &expected_list));
+ IREE_RETURN_IF_ERROR(iree_tooling_parse_to_variant_list(
+ heap_allocator.get(), FLAG_expected_output_list().values,
+ FLAG_expected_output_list().count, host_allocator, &expected_list));
// Compare expected vs actual lists and output diffs.
bool did_match = iree_tooling_compare_variant_lists(
diff --git a/tools/iree-run-trace-main.c b/tools/iree-run-trace-main.c
index 577e0b6..57a2053 100644
--- a/tools/iree-run-trace-main.c
+++ b/tools/iree-run-trace-main.c
@@ -38,7 +38,7 @@
// file will be used.
// TODO(#5724): remove this and instead provide a device set on initialize.
iree_host_size_t device_uri_count = 0;
- iree_string_view_t* device_uris = NULL;
+ const iree_string_view_t* device_uris = NULL;
iree_hal_get_devices_flag_list(&device_uri_count, &device_uris);
iree_trace_replay_set_hal_devices_override(&replay, device_uri_count,
device_uris);