Adding support for reading binary files in tool --function_input=. (#9328)

Raw binary data read from a file can now be used to initialize buffer
view contents. No interpretation is done on the data. The shape and
element type information is still required.

Example:
```
iree-benchmark-module ... --function_input=4x2xi32=@some/file.bin
```
diff --git a/runtime/src/iree/hal/string_util.c b/runtime/src/iree/hal/string_util.c
index 6e794fa..d9a662c 100644
--- a/runtime/src/iree/hal/string_util.c
+++ b/runtime/src/iree/hal/string_util.c
@@ -183,6 +183,54 @@
                               : iree_ok_status();
 }
 
+IREE_API_EXPORT iree_status_t iree_hal_parse_shape_and_element_type(
+    iree_string_view_t value, iree_host_size_t shape_capacity,
+    iree_hal_dim_t* out_shape, iree_host_size_t* out_shape_rank,
+    iree_hal_element_type_t* out_element_type) {
+  *out_shape_rank = 0;
+  *out_element_type = IREE_HAL_ELEMENT_TYPE_NONE;
+
+  // Strip whitespace that may come along (linefeeds/etc).
+  value = iree_string_view_trim(value);
+  value = iree_string_view_strip_prefix(value, IREE_SV("\""));
+  value = iree_string_view_strip_suffix(value, IREE_SV("\""));
+  if (iree_string_view_is_empty(value)) {
+    // Empty lines are invalid; need at least the shape/type information.
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, "empty string input");
+  }
+
+  // The part of the string corresponding to the shape, e.g. 1x2x3.
+  iree_string_view_t shape_str = iree_string_view_empty();
+  // The part of the string corresponding to the type, e.g. f32
+  iree_string_view_t type_str = iree_string_view_empty();
+  // The part of the string corresponding to the buffer data, e.g. 1 2 3 4 5 6
+  // We ignore this.
+  iree_string_view_t data_str = iree_string_view_empty();
+
+  iree_string_view_t shape_and_type_str = value;
+  iree_string_view_split(value, '=', &shape_and_type_str, &data_str);
+  iree_host_size_t last_x_index = iree_string_view_find_last_of(
+      shape_and_type_str, IREE_SV("x"), IREE_STRING_VIEW_NPOS);
+  if (last_x_index == IREE_STRING_VIEW_NPOS) {
+    // Scalar.
+    type_str = shape_and_type_str;
+  } else {
+    // Has a shape.
+    shape_str = iree_string_view_substr(shape_and_type_str, 0, last_x_index);
+    type_str = iree_string_view_substr(shape_and_type_str, last_x_index + 1,
+                                       IREE_STRING_VIEW_NPOS);
+  }
+
+  // AxBxC...
+  IREE_RETURN_IF_ERROR(iree_hal_parse_shape(shape_str, shape_capacity,
+                                            out_shape, out_shape_rank));
+
+  // f32, i32, etc
+  IREE_RETURN_IF_ERROR(iree_hal_parse_element_type(type_str, out_element_type));
+
+  return iree_ok_status();
+}
+
 // Parses a string of two character pairs representing hex numbers into bytes.
 static void iree_hal_hex_string_to_bytes(const char* from, uint8_t* to,
                                          ptrdiff_t num) {
diff --git a/runtime/src/iree/hal/string_util.h b/runtime/src/iree/hal/string_util.h
index 3e8b1bf..4750339 100644
--- a/runtime/src/iree/hal/string_util.h
+++ b/runtime/src/iree/hal/string_util.h
@@ -48,6 +48,14 @@
     iree_hal_element_type_t element_type, iree_host_size_t buffer_capacity,
     char* buffer, iree_host_size_t* out_buffer_length);
 
+// Parses a shape and type from a `[shape]x[type]` string |value|.
+// Behaves the same as calling iree_hal_parse_shape and
+// iree_hal_parse_element_type. Ignores any training `=`.
+IREE_API_EXPORT iree_status_t iree_hal_parse_shape_and_element_type(
+    iree_string_view_t value, iree_host_size_t shape_capacity,
+    iree_hal_dim_t* out_shape, iree_host_size_t* out_shape_rank,
+    iree_hal_element_type_t* out_element_type);
+
 // Parses a serialized element of |element_type| to its in-memory form.
 // |data_ptr| must be at least large enough to contain the bytes of the element.
 // For example, "1.2" of type IREE_HAL_ELEMENT_TYPE_FLOAT32 will write the 4
diff --git a/runtime/src/iree/hal/string_util_test.cc b/runtime/src/iree/hal/string_util_test.cc
index 453f99c..b8034f0 100644
--- a/runtime/src/iree/hal/string_util_test.cc
+++ b/runtime/src/iree/hal/string_util_test.cc
@@ -35,7 +35,7 @@
 StatusOr<Shape> ParseShape(const std::string& value) {
   Shape shape(6);
   iree_host_size_t actual_rank = 0;
-  iree_status_t status;
+  iree_status_t status = iree_ok_status();
   do {
     status =
         iree_hal_parse_shape(iree_string_view_t{value.data(), value.size()},
@@ -50,7 +50,7 @@
 StatusOr<std::string> FormatShape(iree::span<const iree_hal_dim_t> value) {
   std::string buffer(16, '\0');
   iree_host_size_t actual_length = 0;
-  iree_status_t status;
+  iree_status_t status = iree_ok_status();
   do {
     status =
         iree_hal_format_shape(value.data(), value.size(), buffer.size() + 1,
@@ -77,7 +77,7 @@
 StatusOr<std::string> FormatElementType(iree_hal_element_type_t value) {
   std::string buffer(16, '\0');
   iree_host_size_t actual_length = 0;
-  iree_status_t status;
+  iree_status_t status = iree_ok_status();
   do {
     status = iree_hal_format_element_type(value, buffer.size() + 1, &buffer[0],
                                           &actual_length);
@@ -87,6 +87,34 @@
   return std::move(buffer);
 }
 
+struct ShapeAndType {
+  Shape shape;
+  iree_hal_element_type_t element_type = IREE_HAL_ELEMENT_TYPE_NONE;
+  ShapeAndType() = default;
+  ShapeAndType(Shape shape, iree_hal_element_type_t element_type)
+      : shape(std::move(shape)), element_type(element_type) {}
+};
+static bool operator==(const ShapeAndType& lhs,
+                       const ShapeAndType& rhs) noexcept {
+  return lhs.shape == rhs.shape && lhs.element_type == rhs.element_type;
+}
+
+// Parses a serialized set of shape dimensions and an element type.
+StatusOr<ShapeAndType> ParseShapeAndElementType(const std::string& value) {
+  Shape shape(6);
+  iree_host_size_t actual_rank = 0;
+  iree_hal_element_type_t element_type = IREE_HAL_ELEMENT_TYPE_NONE;
+  iree_status_t status = iree_ok_status();
+  do {
+    status = iree_hal_parse_shape_and_element_type(
+        iree_string_view_t{value.data(), value.size()}, shape.size(),
+        shape.data(), &actual_rank, &element_type);
+    shape.resize(actual_rank);
+  } while (iree_status_is_out_of_range(status));
+  IREE_RETURN_IF_ERROR(std::move(status));
+  return ShapeAndType(std::move(shape), element_type);
+}
+
 // Parses a serialized element of |element_type| to its in-memory form.
 // |buffer| be at least large enough to contain the bytes of the element.
 // For example, "1.2" of type IREE_HAL_ELEMENT_TYPE_FLOAT32 will write the 4
@@ -593,6 +621,50 @@
               IsOkAndHolds(Eq("f4")));
 }
 
+TEST(StringUtilTest, ParseShapeAndElementType) {
+  EXPECT_THAT(
+      ParseShapeAndElementType("1xi8"),
+      IsOkAndHolds(Eq(ShapeAndType(Shape{1}, IREE_HAL_ELEMENT_TYPE_INT_8))));
+  EXPECT_THAT(ParseShapeAndElementType("1x2xi16"),
+              IsOkAndHolds(
+                  Eq(ShapeAndType(Shape{1, 2}, IREE_HAL_ELEMENT_TYPE_INT_16))));
+  EXPECT_THAT(
+      ParseShapeAndElementType("1x2x3x4x5x6x7x8x9xi32=invalid stuff here"),
+      IsOkAndHolds(Eq(ShapeAndType(Shape{1, 2, 3, 4, 5, 6, 7, 8, 9},
+                                   IREE_HAL_ELEMENT_TYPE_INT_32))));
+}
+
+TEST(StringUtilTest, ParseShapeAndElementTypeInvalid) {
+  EXPECT_THAT(ParseShapeAndElementType(""),
+              StatusIs(StatusCode::kInvalidArgument));
+  EXPECT_THAT(ParseShapeAndElementType("0"),
+              StatusIs(StatusCode::kInvalidArgument));
+  EXPECT_THAT(ParseShapeAndElementType("="),
+              StatusIs(StatusCode::kInvalidArgument));
+  EXPECT_THAT(ParseShapeAndElementType("abc"),
+              StatusIs(StatusCode::kInvalidArgument));
+  EXPECT_THAT(ParseShapeAndElementType("1xf"),
+              StatusIs(StatusCode::kInvalidArgument));
+  EXPECT_THAT(ParseShapeAndElementType("1xff23"),
+              StatusIs(StatusCode::kInvalidArgument));
+  EXPECT_THAT(ParseShapeAndElementType("1xn3"),
+              StatusIs(StatusCode::kInvalidArgument));
+  EXPECT_THAT(ParseShapeAndElementType("x"),
+              StatusIs(StatusCode::kInvalidArgument));
+  EXPECT_THAT(ParseShapeAndElementType("x1"),
+              StatusIs(StatusCode::kInvalidArgument));
+  EXPECT_THAT(ParseShapeAndElementType("1x"),
+              StatusIs(StatusCode::kInvalidArgument));
+  EXPECT_THAT(ParseShapeAndElementType("x1x2="),
+              StatusIs(StatusCode::kInvalidArgument));
+  EXPECT_THAT(ParseShapeAndElementType("1xx2"),
+              StatusIs(StatusCode::kInvalidArgument));
+  EXPECT_THAT(ParseShapeAndElementType("1x2x"),
+              StatusIs(StatusCode::kInvalidArgument));
+  EXPECT_THAT(ParseShapeAndElementType("0x-1"),
+              StatusIs(StatusCode::kInvalidArgument));
+}
+
 TEST(ElementStringUtilTest, ParseElement) {
   EXPECT_THAT(ParseElement<int8_t>("-128"), IsOkAndHolds(Eq(INT8_MIN)));
   EXPECT_THAT(ParseElement<int8_t>("127"), IsOkAndHolds(Eq(INT8_MAX)));
diff --git a/runtime/src/iree/tools/utils/vm_util.cc b/runtime/src/iree/tools/utils/vm_util.cc
index 12564d3..27f6037 100644
--- a/runtime/src/iree/tools/utils/vm_util.cc
+++ b/runtime/src/iree/tools/utils/vm_util.cc
@@ -23,14 +23,88 @@
 
 namespace iree {
 
-Status ParseToVariantList(iree_hal_allocator_t* allocator,
+// Creates a HAL buffer view with the given |metadata| and reads the contents
+// from the file at |file_path|.
+//
+// The file contents are directly read in to memory with no processing.
+static iree_status_t CreateBufferViewFromFile(
+    iree_string_view_t metadata, iree_string_view_t file_path,
+    iree_hal_allocator_t* device_allocator,
+    iree_hal_buffer_view_t** out_buffer_view) {
+  *out_buffer_view = NULL;
+
+  // Parse shape and element type used to allocate the buffer view.
+  iree_hal_element_type_t element_type = IREE_HAL_ELEMENT_TYPE_NONE;
+  iree_host_size_t shape_rank = 0;
+  iree_status_t shape_result = iree_hal_parse_shape_and_element_type(
+      metadata, 0, NULL, &shape_rank, &element_type);
+  if (!iree_status_is_ok(shape_result) &&
+      !iree_status_is_out_of_range(shape_result)) {
+    return shape_result;
+  } else if (shape_rank > 128) {
+    return iree_make_status(
+        IREE_STATUS_RESOURCE_EXHAUSTED,
+        "a shape rank of %zu is just a little bit excessive, eh?", shape_rank);
+  }
+  shape_result = iree_status_ignore(shape_result);
+  iree_hal_dim_t* shape =
+      (iree_hal_dim_t*)iree_alloca(shape_rank * sizeof(iree_hal_dim_t));
+  IREE_RETURN_IF_ERROR(iree_hal_parse_shape_and_element_type(
+      metadata, shape_rank, shape, &shape_rank, &element_type));
+
+  // TODO(benvanik): allow specifying the encoding.
+  iree_hal_encoding_type_t encoding_type =
+      IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR;
+
+  // Open the file for reading.
+  std::string file_path_str(file_path.data, file_path.size);
+  FILE* file = std::fopen(file_path_str.c_str(), "rb");
+  if (!file) {
+    return iree_make_status(iree_status_code_from_errno(errno),
+                            "failed to open file '%.*s'", (int)file_path.size,
+                            file_path.data);
+  }
+
+  iree_hal_buffer_params_t buffer_params = {0};
+  buffer_params.type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL;
+  buffer_params.usage =
+      IREE_HAL_BUFFER_USAGE_DISPATCH | IREE_HAL_BUFFER_USAGE_TRANSFER;
+  struct read_params_t {
+    FILE* file;
+  } read_params = {
+      file,
+  };
+  iree_status_t status = iree_hal_buffer_view_generate_buffer(
+      device_allocator, shape, shape_rank, element_type, encoding_type,
+      buffer_params,
+      +[](iree_hal_buffer_mapping_t* mapping, void* user_data) {
+        auto* read_params = reinterpret_cast<read_params_t*>(user_data);
+        size_t bytes_read =
+            std::fread(mapping->contents.data, 1, mapping->contents.data_length,
+                       read_params->file);
+        if (bytes_read != mapping->contents.data_length) {
+          return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                                  "file contents truncated; expected %zu bytes "
+                                  "based on buffer view size",
+                                  mapping->contents.data_length);
+        }
+        return iree_ok_status();
+      },
+      &read_params, out_buffer_view);
+
+  std::fclose(file);
+
+  return status;
+}
+
+Status ParseToVariantList(iree_hal_allocator_t* device_allocator,
                           iree::span<const std::string> input_strings,
                           iree_vm_list_t** out_list) {
   *out_list = NULL;
   vm::ref<iree_vm_list_t> variant_list;
-  IREE_RETURN_IF_ERROR(
-      iree_vm_list_create(/*element_type=*/nullptr, input_strings.size(),
-                          iree_allocator_system(), &variant_list));
+  IREE_RETURN_IF_ERROR(iree_vm_list_create(
+      /*element_type=*/nullptr, input_strings.size(),
+      iree_hal_allocator_host_allocator(device_allocator), &variant_list));
   for (size_t i = 0; i < input_strings.size(); ++i) {
     iree_string_view_t input_view = iree_string_view_trim(iree_make_string_view(
         input_strings[i].data(), input_strings[i].size()));
@@ -43,9 +117,22 @@
       bool is_storage_reference = iree_string_view_consume_prefix(
           &input_view, iree_make_cstring_view("&"));
       iree_hal_buffer_view_t* buffer_view = nullptr;
-      IREE_RETURN_IF_ERROR(
-          iree_hal_buffer_view_parse(input_view, allocator, &buffer_view),
-          "parsing value '%.*s'", (int)input_view.size, input_view.data);
+      bool has_at = iree_string_view_find_char(input_view, '@', 0) !=
+                    IREE_STRING_VIEW_NPOS;
+      if (has_at) {
+        // Referencing an external file; split into the portion used to
+        // initialize the buffer view and the file contents.
+        iree_string_view_t metadata, file_path;
+        iree_string_view_split(input_view, '@', &metadata, &file_path);
+        iree_string_view_consume_suffix(&metadata, iree_make_cstring_view("="));
+        IREE_RETURN_IF_ERROR(CreateBufferViewFromFile(
+            metadata, file_path, device_allocator, &buffer_view));
+      } else {
+        IREE_RETURN_IF_ERROR(iree_hal_buffer_view_parse(
+                                 input_view, device_allocator, &buffer_view),
+                             "parsing value '%.*s'", (int)input_view.size,
+                             input_view.data);
+      }
       if (is_storage_reference) {
         // Storage buffer reference; just take the storage for the buffer view -
         // it'll still have whatever contents were specified (or 0) but we'll
diff --git a/runtime/src/iree/tools/utils/vm_util.h b/runtime/src/iree/tools/utils/vm_util.h
index f9ae892..b363e34 100644
--- a/runtime/src/iree/tools/utils/vm_util.h
+++ b/runtime/src/iree/tools/utils/vm_util.h
@@ -29,10 +29,9 @@
 // Buffers should be in the IREE standard shaped buffer format:
 //   [shape]xtype=[value]
 // described in iree/hal/api.h
-// Uses |allocator| to allocate the buffers.
-// Uses descriptors in |descs| for type information and validation.
+// Uses |device_allocator| to allocate the buffers.
 // The returned variant list must be freed by the caller.
-Status ParseToVariantList(iree_hal_allocator_t* allocator,
+Status ParseToVariantList(iree_hal_allocator_t* device_allocator,
                           iree::span<const std::string> input_strings,
                           iree_vm_list_t** out_list);
 
@@ -43,7 +42,6 @@
 //   [shape]xtype=[value]
 // described in
 // https://github.com/google/iree/tree/main/iree/hal/api.h
-// Uses descriptors in |descs| for type information and validation.
 Status PrintVariantList(iree_vm_list_t* variant_list, size_t max_element_count,
                         std::ostream* os);
 inline Status PrintVariantList(iree_vm_list_t* variant_list, std::ostream* os) {
diff --git a/tools/iree-benchmark-module-main.cc b/tools/iree-benchmark-module-main.cc
index d17d8e9..37619f8 100644
--- a/tools/iree-benchmark-module-main.cc
+++ b/tools/iree-benchmark-module-main.cc
@@ -124,6 +124,8 @@
     "  2x2xi32=1 2 3 4\n"
     "Optionally, brackets may be used to separate the element values:\n"
     "  2x2xi32=[[1 2][3 4]]\n"
+    "Raw binary files can be read to provide buffer contents:\n"
+    "  2x2xi32=@some/file.bin\n"
     "Each occurrence of the flag indicates an input in the order they were\n"
     "specified on the command line.");
 
diff --git a/tools/iree-run-module-main.cc b/tools/iree-run-module-main.cc
index 7155d1b..7f8f6f9 100644
--- a/tools/iree-run-module-main.cc
+++ b/tools/iree-run-module-main.cc
@@ -72,6 +72,8 @@
     "  2x2xi32=1 2 3 4\n"
     "Optionally, brackets may be used to separate the element values:\n"
     "  2x2xi32=[[1 2][3 4]]\n"
+    "Raw binary files can be read to provide buffer contents:\n"
+    "  2x2xi32=@some/file.bin\n"
     "Each occurrence of the flag indicates an input in the order they were\n"
     "specified on the command line.");