Adding support for reading binary files in tool --function_input=. (#9328)
Raw binary data read from a file can now be used to initialize buffer
view contents. No interpretation is done on the data. The shape and
element type information is still required.
Example:
```
iree-benchmark-module ... --function_input=4x2xi32=@some/file.bin
```
diff --git a/runtime/src/iree/hal/string_util.c b/runtime/src/iree/hal/string_util.c
index 6e794fa..d9a662c 100644
--- a/runtime/src/iree/hal/string_util.c
+++ b/runtime/src/iree/hal/string_util.c
@@ -183,6 +183,54 @@
: iree_ok_status();
}
+IREE_API_EXPORT iree_status_t iree_hal_parse_shape_and_element_type(
+ iree_string_view_t value, iree_host_size_t shape_capacity,
+ iree_hal_dim_t* out_shape, iree_host_size_t* out_shape_rank,
+ iree_hal_element_type_t* out_element_type) {
+ *out_shape_rank = 0;
+ *out_element_type = IREE_HAL_ELEMENT_TYPE_NONE;
+
+ // Strip whitespace that may come along (linefeeds/etc).
+ value = iree_string_view_trim(value);
+ value = iree_string_view_strip_prefix(value, IREE_SV("\""));
+ value = iree_string_view_strip_suffix(value, IREE_SV("\""));
+ if (iree_string_view_is_empty(value)) {
+ // Empty lines are invalid; need at least the shape/type information.
+ return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, "empty string input");
+ }
+
+ // The part of the string corresponding to the shape, e.g. 1x2x3.
+ iree_string_view_t shape_str = iree_string_view_empty();
+ // The part of the string corresponding to the type, e.g. f32
+ iree_string_view_t type_str = iree_string_view_empty();
+ // The part of the string corresponding to the buffer data, e.g. 1 2 3 4 5 6
+ // We ignore this.
+ iree_string_view_t data_str = iree_string_view_empty();
+
+ iree_string_view_t shape_and_type_str = value;
+ iree_string_view_split(value, '=', &shape_and_type_str, &data_str);
+ iree_host_size_t last_x_index = iree_string_view_find_last_of(
+ shape_and_type_str, IREE_SV("x"), IREE_STRING_VIEW_NPOS);
+ if (last_x_index == IREE_STRING_VIEW_NPOS) {
+ // Scalar.
+ type_str = shape_and_type_str;
+ } else {
+ // Has a shape.
+ shape_str = iree_string_view_substr(shape_and_type_str, 0, last_x_index);
+ type_str = iree_string_view_substr(shape_and_type_str, last_x_index + 1,
+ IREE_STRING_VIEW_NPOS);
+ }
+
+ // AxBxC...
+ IREE_RETURN_IF_ERROR(iree_hal_parse_shape(shape_str, shape_capacity,
+ out_shape, out_shape_rank));
+
+ // f32, i32, etc
+ IREE_RETURN_IF_ERROR(iree_hal_parse_element_type(type_str, out_element_type));
+
+ return iree_ok_status();
+}
+
// Parses a string of two character pairs representing hex numbers into bytes.
static void iree_hal_hex_string_to_bytes(const char* from, uint8_t* to,
ptrdiff_t num) {
diff --git a/runtime/src/iree/hal/string_util.h b/runtime/src/iree/hal/string_util.h
index 3e8b1bf..4750339 100644
--- a/runtime/src/iree/hal/string_util.h
+++ b/runtime/src/iree/hal/string_util.h
@@ -48,6 +48,14 @@
iree_hal_element_type_t element_type, iree_host_size_t buffer_capacity,
char* buffer, iree_host_size_t* out_buffer_length);
+// Parses a shape and type from a `[shape]x[type]` string |value|.
+// Behaves the same as calling iree_hal_parse_shape and
+// iree_hal_parse_element_type. Ignores any training `=`.
+IREE_API_EXPORT iree_status_t iree_hal_parse_shape_and_element_type(
+ iree_string_view_t value, iree_host_size_t shape_capacity,
+ iree_hal_dim_t* out_shape, iree_host_size_t* out_shape_rank,
+ iree_hal_element_type_t* out_element_type);
+
// Parses a serialized element of |element_type| to its in-memory form.
// |data_ptr| must be at least large enough to contain the bytes of the element.
// For example, "1.2" of type IREE_HAL_ELEMENT_TYPE_FLOAT32 will write the 4
diff --git a/runtime/src/iree/hal/string_util_test.cc b/runtime/src/iree/hal/string_util_test.cc
index 453f99c..b8034f0 100644
--- a/runtime/src/iree/hal/string_util_test.cc
+++ b/runtime/src/iree/hal/string_util_test.cc
@@ -35,7 +35,7 @@
StatusOr<Shape> ParseShape(const std::string& value) {
Shape shape(6);
iree_host_size_t actual_rank = 0;
- iree_status_t status;
+ iree_status_t status = iree_ok_status();
do {
status =
iree_hal_parse_shape(iree_string_view_t{value.data(), value.size()},
@@ -50,7 +50,7 @@
StatusOr<std::string> FormatShape(iree::span<const iree_hal_dim_t> value) {
std::string buffer(16, '\0');
iree_host_size_t actual_length = 0;
- iree_status_t status;
+ iree_status_t status = iree_ok_status();
do {
status =
iree_hal_format_shape(value.data(), value.size(), buffer.size() + 1,
@@ -77,7 +77,7 @@
StatusOr<std::string> FormatElementType(iree_hal_element_type_t value) {
std::string buffer(16, '\0');
iree_host_size_t actual_length = 0;
- iree_status_t status;
+ iree_status_t status = iree_ok_status();
do {
status = iree_hal_format_element_type(value, buffer.size() + 1, &buffer[0],
&actual_length);
@@ -87,6 +87,34 @@
return std::move(buffer);
}
+struct ShapeAndType {
+ Shape shape;
+ iree_hal_element_type_t element_type = IREE_HAL_ELEMENT_TYPE_NONE;
+ ShapeAndType() = default;
+ ShapeAndType(Shape shape, iree_hal_element_type_t element_type)
+ : shape(std::move(shape)), element_type(element_type) {}
+};
+static bool operator==(const ShapeAndType& lhs,
+ const ShapeAndType& rhs) noexcept {
+ return lhs.shape == rhs.shape && lhs.element_type == rhs.element_type;
+}
+
+// Parses a serialized set of shape dimensions and an element type.
+StatusOr<ShapeAndType> ParseShapeAndElementType(const std::string& value) {
+ Shape shape(6);
+ iree_host_size_t actual_rank = 0;
+ iree_hal_element_type_t element_type = IREE_HAL_ELEMENT_TYPE_NONE;
+ iree_status_t status = iree_ok_status();
+ do {
+ status = iree_hal_parse_shape_and_element_type(
+ iree_string_view_t{value.data(), value.size()}, shape.size(),
+ shape.data(), &actual_rank, &element_type);
+ shape.resize(actual_rank);
+ } while (iree_status_is_out_of_range(status));
+ IREE_RETURN_IF_ERROR(std::move(status));
+ return ShapeAndType(std::move(shape), element_type);
+}
+
// Parses a serialized element of |element_type| to its in-memory form.
// |buffer| be at least large enough to contain the bytes of the element.
// For example, "1.2" of type IREE_HAL_ELEMENT_TYPE_FLOAT32 will write the 4
@@ -593,6 +621,50 @@
IsOkAndHolds(Eq("f4")));
}
+TEST(StringUtilTest, ParseShapeAndElementType) {
+ EXPECT_THAT(
+ ParseShapeAndElementType("1xi8"),
+ IsOkAndHolds(Eq(ShapeAndType(Shape{1}, IREE_HAL_ELEMENT_TYPE_INT_8))));
+ EXPECT_THAT(ParseShapeAndElementType("1x2xi16"),
+ IsOkAndHolds(
+ Eq(ShapeAndType(Shape{1, 2}, IREE_HAL_ELEMENT_TYPE_INT_16))));
+ EXPECT_THAT(
+ ParseShapeAndElementType("1x2x3x4x5x6x7x8x9xi32=invalid stuff here"),
+ IsOkAndHolds(Eq(ShapeAndType(Shape{1, 2, 3, 4, 5, 6, 7, 8, 9},
+ IREE_HAL_ELEMENT_TYPE_INT_32))));
+}
+
+TEST(StringUtilTest, ParseShapeAndElementTypeInvalid) {
+ EXPECT_THAT(ParseShapeAndElementType(""),
+ StatusIs(StatusCode::kInvalidArgument));
+ EXPECT_THAT(ParseShapeAndElementType("0"),
+ StatusIs(StatusCode::kInvalidArgument));
+ EXPECT_THAT(ParseShapeAndElementType("="),
+ StatusIs(StatusCode::kInvalidArgument));
+ EXPECT_THAT(ParseShapeAndElementType("abc"),
+ StatusIs(StatusCode::kInvalidArgument));
+ EXPECT_THAT(ParseShapeAndElementType("1xf"),
+ StatusIs(StatusCode::kInvalidArgument));
+ EXPECT_THAT(ParseShapeAndElementType("1xff23"),
+ StatusIs(StatusCode::kInvalidArgument));
+ EXPECT_THAT(ParseShapeAndElementType("1xn3"),
+ StatusIs(StatusCode::kInvalidArgument));
+ EXPECT_THAT(ParseShapeAndElementType("x"),
+ StatusIs(StatusCode::kInvalidArgument));
+ EXPECT_THAT(ParseShapeAndElementType("x1"),
+ StatusIs(StatusCode::kInvalidArgument));
+ EXPECT_THAT(ParseShapeAndElementType("1x"),
+ StatusIs(StatusCode::kInvalidArgument));
+ EXPECT_THAT(ParseShapeAndElementType("x1x2="),
+ StatusIs(StatusCode::kInvalidArgument));
+ EXPECT_THAT(ParseShapeAndElementType("1xx2"),
+ StatusIs(StatusCode::kInvalidArgument));
+ EXPECT_THAT(ParseShapeAndElementType("1x2x"),
+ StatusIs(StatusCode::kInvalidArgument));
+ EXPECT_THAT(ParseShapeAndElementType("0x-1"),
+ StatusIs(StatusCode::kInvalidArgument));
+}
+
TEST(ElementStringUtilTest, ParseElement) {
EXPECT_THAT(ParseElement<int8_t>("-128"), IsOkAndHolds(Eq(INT8_MIN)));
EXPECT_THAT(ParseElement<int8_t>("127"), IsOkAndHolds(Eq(INT8_MAX)));
diff --git a/runtime/src/iree/tools/utils/vm_util.cc b/runtime/src/iree/tools/utils/vm_util.cc
index 12564d3..27f6037 100644
--- a/runtime/src/iree/tools/utils/vm_util.cc
+++ b/runtime/src/iree/tools/utils/vm_util.cc
@@ -23,14 +23,88 @@
namespace iree {
-Status ParseToVariantList(iree_hal_allocator_t* allocator,
+// Creates a HAL buffer view with the given |metadata| and reads the contents
+// from the file at |file_path|.
+//
+// The file contents are directly read in to memory with no processing.
+static iree_status_t CreateBufferViewFromFile(
+ iree_string_view_t metadata, iree_string_view_t file_path,
+ iree_hal_allocator_t* device_allocator,
+ iree_hal_buffer_view_t** out_buffer_view) {
+ *out_buffer_view = NULL;
+
+ // Parse shape and element type used to allocate the buffer view.
+ iree_hal_element_type_t element_type = IREE_HAL_ELEMENT_TYPE_NONE;
+ iree_host_size_t shape_rank = 0;
+ iree_status_t shape_result = iree_hal_parse_shape_and_element_type(
+ metadata, 0, NULL, &shape_rank, &element_type);
+ if (!iree_status_is_ok(shape_result) &&
+ !iree_status_is_out_of_range(shape_result)) {
+ return shape_result;
+ } else if (shape_rank > 128) {
+ return iree_make_status(
+ IREE_STATUS_RESOURCE_EXHAUSTED,
+ "a shape rank of %zu is just a little bit excessive, eh?", shape_rank);
+ }
+ shape_result = iree_status_ignore(shape_result);
+ iree_hal_dim_t* shape =
+ (iree_hal_dim_t*)iree_alloca(shape_rank * sizeof(iree_hal_dim_t));
+ IREE_RETURN_IF_ERROR(iree_hal_parse_shape_and_element_type(
+ metadata, shape_rank, shape, &shape_rank, &element_type));
+
+ // TODO(benvanik): allow specifying the encoding.
+ iree_hal_encoding_type_t encoding_type =
+ IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR;
+
+ // Open the file for reading.
+ std::string file_path_str(file_path.data, file_path.size);
+ FILE* file = std::fopen(file_path_str.c_str(), "rb");
+ if (!file) {
+ return iree_make_status(iree_status_code_from_errno(errno),
+ "failed to open file '%.*s'", (int)file_path.size,
+ file_path.data);
+ }
+
+ iree_hal_buffer_params_t buffer_params = {0};
+ buffer_params.type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL;
+ buffer_params.usage =
+ IREE_HAL_BUFFER_USAGE_DISPATCH | IREE_HAL_BUFFER_USAGE_TRANSFER;
+ struct read_params_t {
+ FILE* file;
+ } read_params = {
+ file,
+ };
+ iree_status_t status = iree_hal_buffer_view_generate_buffer(
+ device_allocator, shape, shape_rank, element_type, encoding_type,
+ buffer_params,
+ +[](iree_hal_buffer_mapping_t* mapping, void* user_data) {
+ auto* read_params = reinterpret_cast<read_params_t*>(user_data);
+ size_t bytes_read =
+ std::fread(mapping->contents.data, 1, mapping->contents.data_length,
+ read_params->file);
+ if (bytes_read != mapping->contents.data_length) {
+ return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+ "file contents truncated; expected %zu bytes "
+ "based on buffer view size",
+ mapping->contents.data_length);
+ }
+ return iree_ok_status();
+ },
+ &read_params, out_buffer_view);
+
+ std::fclose(file);
+
+ return status;
+}
+
+Status ParseToVariantList(iree_hal_allocator_t* device_allocator,
iree::span<const std::string> input_strings,
iree_vm_list_t** out_list) {
*out_list = NULL;
vm::ref<iree_vm_list_t> variant_list;
- IREE_RETURN_IF_ERROR(
- iree_vm_list_create(/*element_type=*/nullptr, input_strings.size(),
- iree_allocator_system(), &variant_list));
+ IREE_RETURN_IF_ERROR(iree_vm_list_create(
+ /*element_type=*/nullptr, input_strings.size(),
+ iree_hal_allocator_host_allocator(device_allocator), &variant_list));
for (size_t i = 0; i < input_strings.size(); ++i) {
iree_string_view_t input_view = iree_string_view_trim(iree_make_string_view(
input_strings[i].data(), input_strings[i].size()));
@@ -43,9 +117,22 @@
bool is_storage_reference = iree_string_view_consume_prefix(
&input_view, iree_make_cstring_view("&"));
iree_hal_buffer_view_t* buffer_view = nullptr;
- IREE_RETURN_IF_ERROR(
- iree_hal_buffer_view_parse(input_view, allocator, &buffer_view),
- "parsing value '%.*s'", (int)input_view.size, input_view.data);
+ bool has_at = iree_string_view_find_char(input_view, '@', 0) !=
+ IREE_STRING_VIEW_NPOS;
+ if (has_at) {
+ // Referencing an external file; split into the portion used to
+ // initialize the buffer view and the file contents.
+ iree_string_view_t metadata, file_path;
+ iree_string_view_split(input_view, '@', &metadata, &file_path);
+ iree_string_view_consume_suffix(&metadata, iree_make_cstring_view("="));
+ IREE_RETURN_IF_ERROR(CreateBufferViewFromFile(
+ metadata, file_path, device_allocator, &buffer_view));
+ } else {
+ IREE_RETURN_IF_ERROR(iree_hal_buffer_view_parse(
+ input_view, device_allocator, &buffer_view),
+ "parsing value '%.*s'", (int)input_view.size,
+ input_view.data);
+ }
if (is_storage_reference) {
// Storage buffer reference; just take the storage for the buffer view -
// it'll still have whatever contents were specified (or 0) but we'll
diff --git a/runtime/src/iree/tools/utils/vm_util.h b/runtime/src/iree/tools/utils/vm_util.h
index f9ae892..b363e34 100644
--- a/runtime/src/iree/tools/utils/vm_util.h
+++ b/runtime/src/iree/tools/utils/vm_util.h
@@ -29,10 +29,9 @@
// Buffers should be in the IREE standard shaped buffer format:
// [shape]xtype=[value]
// described in iree/hal/api.h
-// Uses |allocator| to allocate the buffers.
-// Uses descriptors in |descs| for type information and validation.
+// Uses |device_allocator| to allocate the buffers.
// The returned variant list must be freed by the caller.
-Status ParseToVariantList(iree_hal_allocator_t* allocator,
+Status ParseToVariantList(iree_hal_allocator_t* device_allocator,
iree::span<const std::string> input_strings,
iree_vm_list_t** out_list);
@@ -43,7 +42,6 @@
// [shape]xtype=[value]
// described in
// https://github.com/google/iree/tree/main/iree/hal/api.h
-// Uses descriptors in |descs| for type information and validation.
Status PrintVariantList(iree_vm_list_t* variant_list, size_t max_element_count,
std::ostream* os);
inline Status PrintVariantList(iree_vm_list_t* variant_list, std::ostream* os) {
diff --git a/tools/iree-benchmark-module-main.cc b/tools/iree-benchmark-module-main.cc
index d17d8e9..37619f8 100644
--- a/tools/iree-benchmark-module-main.cc
+++ b/tools/iree-benchmark-module-main.cc
@@ -124,6 +124,8 @@
" 2x2xi32=1 2 3 4\n"
"Optionally, brackets may be used to separate the element values:\n"
" 2x2xi32=[[1 2][3 4]]\n"
+ "Raw binary files can be read to provide buffer contents:\n"
+ " 2x2xi32=@some/file.bin\n"
"Each occurrence of the flag indicates an input in the order they were\n"
"specified on the command line.");
diff --git a/tools/iree-run-module-main.cc b/tools/iree-run-module-main.cc
index 7155d1b..7f8f6f9 100644
--- a/tools/iree-run-module-main.cc
+++ b/tools/iree-run-module-main.cc
@@ -72,6 +72,8 @@
" 2x2xi32=1 2 3 4\n"
"Optionally, brackets may be used to separate the element values:\n"
" 2x2xi32=[[1 2][3 4]]\n"
+ "Raw binary files can be read to provide buffer contents:\n"
+ " 2x2xi32=@some/file.bin\n"
"Each occurrence of the flag indicates an input in the order they were\n"
"specified on the command line.");