Add scalar input support to IREE runners

PiperOrigin-RevId: 293073521
diff --git a/iree/tools/test/scalars.mlir b/iree/tools/test/scalars.mlir
index a63e94e..a6f38e8 100644
--- a/iree/tools/test/scalars.mlir
+++ b/iree/tools/test/scalars.mlir
@@ -1,14 +1,13 @@
-// RUN: (iree-translate --iree-hal-target-backends=interpreter-bytecode -iree-mlir-to-vm-bytecode-module %s | iree-run-module --entry_function=multi_input) | IreeFileCheck %s
+// RUN: (iree-translate --iree-hal-target-backends=interpreter-bytecode -iree-mlir-to-vm-bytecode-module %s | iree-run-module --entry_function=scalar --inputs="i32=42") | IreeFileCheck %s
 
-// RUN: [[ $IREE_VULKAN_DISABLE == 1 ]] || ((iree-translate --iree-hal-target-backends=vulkan-spirv -iree-mlir-to-vm-bytecode-module %s | iree-run-module --driver=vulkan --entry_function=multi_input) | IreeFileCheck %s)
+// RUN: [[ $IREE_VULKAN_DISABLE == 1 ]] || ((iree-translate --iree-hal-target-backends=vulkan-spirv -iree-mlir-to-vm-bytecode-module %s | iree-run-module --driver=vulkan --entry_function=scalar --inputs="i32=42") | IreeFileCheck %s)
 
-// RUN: (iree-run-mlir --iree-hal-target-backends=interpreter-bytecode %s) | IreeFileCheck %s
+// RUN: (iree-run-mlir --iree-hal-target-backends=interpreter-bytecode --input-value=i32=42 %s) | IreeFileCheck %s
 
-// RUN: [[ $IREE_VULKAN_DISABLE == 1 ]] || (iree-run-mlir -iree-hal-target-backends=vulkan-spirv  %s | IreeFileCheck %s)
+// RUN: [[ $IREE_VULKAN_DISABLE == 1 ]] || (iree-run-mlir -iree-hal-target-backends=vulkan-spirv --input-value=i32=42 %s | IreeFileCheck %s)
 
-// CHECK-LABEL: EXEC @multi_input
-func @multi_input() -> i32 attributes { iree.module.export } {
-  %c = iree.unfoldable_constant 42 : i32
-  return %c : i32
+// CHECK-LABEL: EXEC @scalar
+func @scalar(%arg0 : i32) -> i32 attributes { iree.module.export } {
+  return %arg0 : i32
 }
 // CHECK: i32=42
diff --git a/iree/tools/vm_util.cc b/iree/tools/vm_util.cc
index e61be5e..ae0669c 100644
--- a/iree/tools/vm_util.cc
+++ b/iree/tools/vm_util.cc
@@ -16,7 +16,9 @@
 
 #include <ostream>
 
+#include "absl/strings/numbers.h"
 #include "absl/strings/string_view.h"
+#include "absl/strings/strip.h"
 #include "absl/types/span.h"
 #include "iree/base/api_util.h"
 #include "iree/base/buffer_string_util.h"
@@ -28,6 +30,7 @@
 #include "iree/modules/hal/hal_module.h"
 #include "iree/vm/bytecode_module.h"
 #include "iree/vm/module.h"
+#include "iree/vm/variant_list.h"
 
 namespace iree {
 
@@ -80,46 +83,81 @@
 StatusOr<iree_vm_variant_list_t*> ParseToVariantList(
     absl::Span<const RawSignatureParser::Description> descs,
     iree_hal_allocator_t* allocator,
-    absl::Span<const std::string> buf_strings) {
-  if (buf_strings.size() != descs.size()) {
+    absl::Span<const std::string> input_strings) {
+  if (input_strings.size() != descs.size()) {
     return FailedPreconditionErrorBuilder(IREE_LOC)
            << "Signature mismatch; expected " << descs.size()
-           << " buffer strings but received " << buf_strings.size();
+           << " buffer strings but received " << input_strings.size();
   }
   iree_vm_variant_list_t* variant_list = nullptr;
   RETURN_IF_ERROR(FromApiStatus(
-      iree_vm_variant_list_alloc(buf_strings.size(), IREE_ALLOCATOR_SYSTEM,
+      iree_vm_variant_list_alloc(input_strings.size(), IREE_ALLOCATOR_SYSTEM,
                                  &variant_list),
       IREE_LOC));
-  for (const auto& buf_string : buf_strings) {
-    // TODO(gcmn) Handle scalar variants.
-    ASSIGN_OR_RETURN(auto shaped_buffer,
-                     ParseShapedBufferFromString(buf_string),
-                     _ << "Parsing value '" << buf_string << "'");
-    iree_hal_buffer_t* buf = nullptr;
-    // TODO(benvanik): combined function for linear to optimal upload.
-    iree_device_size_t allocation_size =
-        shaped_buffer.shape().element_count() * shaped_buffer.element_size();
-    RETURN_IF_ERROR(FromApiStatus(
-        iree_hal_allocator_allocate_buffer(
-            allocator,
-            static_cast<iree_hal_memory_type_t>(
-                IREE_HAL_MEMORY_TYPE_HOST_LOCAL |
-                IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE),
-            static_cast<iree_hal_buffer_usage_t>(
-                IREE_HAL_BUFFER_USAGE_ALL | IREE_HAL_BUFFER_USAGE_CONSTANT),
-            allocation_size, &buf),
-        IREE_LOC))
-        << "Allocating buffer";
-    RETURN_IF_ERROR(FromApiStatus(
-        iree_hal_buffer_write_data(buf, 0, shaped_buffer.contents().data(),
-                                   shaped_buffer.contents().size()),
-        IREE_LOC))
-        << "Populating buffer contents ";
-    auto buf_ref = iree_hal_buffer_move_ref(buf);
-    RETURN_IF_ERROR(FromApiStatus(
-        iree_vm_variant_list_append_ref_move(variant_list, &buf_ref),
-        IREE_LOC));
+  for (int i = 0; i < input_strings.size(); ++i) {
+    auto input_string = input_strings[i];
+    auto desc = descs[i];
+    std::string desc_str;
+    desc.ToString(desc_str);
+    switch (desc.type) {
+      case RawSignatureParser::Type::kScalar: {
+        if (desc.scalar.type != AbiConstants::ScalarType::kSint32) {
+          return UnimplementedErrorBuilder(IREE_LOC)
+                 << "Unsupported signature scalar type: " << desc_str;
+        }
+        absl::string_view input_view = absl::StripAsciiWhitespace(input_string);
+        input_view = absl::StripPrefix(input_view, "\"");
+        input_view = absl::StripSuffix(input_view, "\"");
+        if (!absl::ConsumePrefix(&input_view, "i32=")) {
+          return InvalidArgumentErrorBuilder(IREE_LOC)
+                 << "Parsing '" << input_string
+                 << "'. Has i32 descriptor but does not start with 'i32='";
+        }
+        int32_t val;
+        if (!absl::SimpleAtoi(input_view, &val)) {
+          return InvalidArgumentErrorBuilder(IREE_LOC)
+                 << "Converting '" << input_view << "' to i32 when parsing '"
+                 << input_string << "'";
+        }
+        iree_vm_variant_list_append_value(variant_list,
+                                          IREE_VM_VALUE_MAKE_I32(val));
+        break;
+      }
+      case RawSignatureParser::Type::kBuffer: {
+        ASSIGN_OR_RETURN(auto shaped_buffer,
+                         ParseShapedBufferFromString(input_string),
+                         _ << "Parsing value '" << input_string << "'");
+        iree_hal_buffer_t* buf = nullptr;
+        // TODO(benvanik): combined function for linear to optimal upload.
+        iree_device_size_t allocation_size =
+            shaped_buffer.shape().element_count() *
+            shaped_buffer.element_size();
+        RETURN_IF_ERROR(FromApiStatus(
+            iree_hal_allocator_allocate_buffer(
+                allocator,
+                static_cast<iree_hal_memory_type_t>(
+                    IREE_HAL_MEMORY_TYPE_HOST_LOCAL |
+                    IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE),
+                static_cast<iree_hal_buffer_usage_t>(
+                    IREE_HAL_BUFFER_USAGE_ALL | IREE_HAL_BUFFER_USAGE_CONSTANT),
+                allocation_size, &buf),
+            IREE_LOC))
+            << "Allocating buffer";
+        RETURN_IF_ERROR(FromApiStatus(
+            iree_hal_buffer_write_data(buf, 0, shaped_buffer.contents().data(),
+                                       shaped_buffer.contents().size()),
+            IREE_LOC))
+            << "Populating buffer contents ";
+        auto buf_ref = iree_hal_buffer_move_ref(buf);
+        RETURN_IF_ERROR(FromApiStatus(
+            iree_vm_variant_list_append_ref_move(variant_list, &buf_ref),
+            IREE_LOC));
+        break;
+      }
+      default:
+        return UnimplementedErrorBuilder(IREE_LOC)
+               << "Unsupported signature type: " << desc_str;
+    }
   }
   return variant_list;
 }
@@ -147,12 +185,12 @@
                  << static_cast<int>(variant->value_type)
                  << " but descriptor information " << desc_str;
         }
-        if (desc.scalar.type == AbiConstants::ScalarType::kSint32) {
-          *os << "i32=" << variant->i32 << "\n";
-          break;
+        if (desc.scalar.type != AbiConstants::ScalarType::kSint32) {
+          return UnimplementedErrorBuilder(IREE_LOC)
+                 << "Unsupported signature scalar type: " << desc_str;
         }
-        return UnimplementedErrorBuilder(IREE_LOC)
-               << "Unsupported signature scalar type: " << desc_str;
+        *os << "i32=" << variant->i32 << "\n";
+        break;
       }
       case RawSignatureParser::Type::kBuffer: {
         if (variant->value_type != IREE_VM_VALUE_TYPE_NONE) {
diff --git a/iree/tools/vm_util.h b/iree/tools/vm_util.h
index d8323f1..64023db 100644
--- a/iree/tools/vm_util.h
+++ b/iree/tools/vm_util.h
@@ -40,24 +40,29 @@
 StatusOr<std::vector<RawSignatureParser::Description>> ParseOutputSignature(
     const iree_vm_function_t& function);
 
-// Parses a list of shapes and values into VM buffers.
-// Expects strings in the IREE standard shaped buffer format:
+// Parses |input_strings| into a variant list of VM scalars and buffers.
+// Scalars should be in the format:
+//   type=value
+// Buffers should be in the IREE standard shaped buffer format:
 //   [shape]xtype=[value]
 // described in
 // https://github.com/google/iree/tree/master/iree/base/buffer_string_util.h
-// Uses |allocator| to allocate the buffers, validating them against the type
-// descriptors in |descs|. The returned variant list must be freed by the
-// caller.
+// Uses |allocator| to allocate the buffers.
+// Uses descriptors in |descs| for type information and validation.
+// The returned variant list must be freed by the caller.
 StatusOr<iree_vm_variant_list_t*> ParseToVariantList(
     absl::Span<const RawSignatureParser::Description> descs,
-    iree_hal_allocator_t* allocator, absl::Span<const std::string> buf_strings);
+    iree_hal_allocator_t* allocator,
+    absl::Span<const std::string> input_strings);
 
 // Prints a variant list of VM scalars and buffers to |os|.
-// Uses the IREE standard shaped buffer format:
+// Prints scalars in the format:
+//   type=value
+// Prints buffers in the IREE standard shaped buffer format:
 //   [shape]xtype=[value]
 // described in
 // https://github.com/google/iree/tree/master/iree/base/buffer_string_util.h
-// Uses |descs| for type information and validation.
+// Uses descriptors in |descs| for type information and validation.
 Status PrintVariantList(absl::Span<const RawSignatureParser::Description> descs,
                         iree_vm_variant_list_t* variant_list,
                         std::ostream* os = &std::cout);
diff --git a/iree/tools/vm_util_test.cc b/iree/tools/vm_util_test.cc
index db01496..a54f96d 100644
--- a/iree/tools/vm_util_test.cc
+++ b/iree/tools/vm_util_test.cc
@@ -35,130 +35,13 @@
     allocator_ = iree_hal_device_allocator(device_);
   }
 
-  virtual void TearDown() {
-    IREE_ASSERT_OK(iree_hal_device_release(device_));
-    if (outputs_ != nullptr) {
-      IREE_ASSERT_OK(iree_vm_variant_list_free(outputs_));
-    }
-  }
+  virtual void TearDown() { IREE_ASSERT_OK(iree_hal_device_release(device_)); }
 
   iree_hal_device_t* device_ = nullptr;
-  iree_vm_variant_list_t* outputs_ = nullptr;
   iree_hal_allocator_t* allocator_ = nullptr;
 };
 
-TEST_F(VmUtilTest, PrintVariantListScalar) {
-  IREE_ASSERT_OK(
-      iree_vm_variant_list_alloc(1, IREE_ALLOCATOR_SYSTEM, &outputs_));
-  IREE_ASSERT_OK(
-      iree_vm_variant_list_append_value(outputs_, IREE_VM_VALUE_MAKE_I32(42)));
-  RawSignatureParser::Description desc;
-  desc.type = RawSignatureParser::Type::kScalar;
-  desc.scalar.type = AbiConstants::ScalarType::kSint32;
-  std::stringstream os;
-  ASSERT_OK(PrintVariantList({desc}, outputs_, &os));
-  EXPECT_EQ(os.str(), "i32=42\n");
-}
-
-TEST_F(VmUtilTest, PrintVariantListMultiple) {
-  IREE_ASSERT_OK(
-      iree_vm_variant_list_alloc(2, IREE_ALLOCATOR_SYSTEM, &outputs_));
-  IREE_ASSERT_OK(
-      iree_vm_variant_list_append_value(outputs_, IREE_VM_VALUE_MAKE_I32(42)));
-  IREE_ASSERT_OK(
-      iree_vm_variant_list_append_value(outputs_, IREE_VM_VALUE_MAKE_I32(13)));
-  RawSignatureParser::Description desc;
-  desc.type = RawSignatureParser::Type::kScalar;
-  desc.scalar.type = AbiConstants::ScalarType::kSint32;
-  std::stringstream os;
-  ASSERT_OK(PrintVariantList({desc, desc}, outputs_, &os));
-  EXPECT_EQ(os.str(),
-            "i32=42\n"
-            "i32=13\n");
-}
-
-TEST_F(VmUtilTest, PrintVariantListBuffer) {
-  IREE_ASSERT_OK(
-      iree_vm_variant_list_alloc(1, IREE_ALLOCATOR_SYSTEM, &outputs_));
-  iree_hal_buffer_t* buf = nullptr;
-  std::array<int32_t, 4> buf_data{42, -43, 44, 45};
-  iree_device_size_t allocation_size = sizeof(buf_data);
-  IREE_ASSERT_OK(iree_hal_allocator_allocate_buffer(
-      allocator_,
-      static_cast<iree_hal_memory_type_t>(IREE_HAL_MEMORY_TYPE_HOST_LOCAL |
-                                          IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE),
-      static_cast<iree_hal_buffer_usage_t>(IREE_HAL_BUFFER_USAGE_ALL |
-                                           IREE_HAL_BUFFER_USAGE_CONSTANT),
-      allocation_size, &buf));
-  IREE_ASSERT_OK(iree_hal_buffer_write_data(
-      buf, 0, reinterpret_cast<uint8_t*>(buf_data.data()), allocation_size));
-  auto buf_ref = iree_hal_buffer_move_ref(buf);
-  IREE_ASSERT_OK(iree_vm_variant_list_append_ref_move(outputs_, &buf_ref));
-  RawSignatureParser::Description desc;
-  desc.type = RawSignatureParser::Type::kBuffer;
-  desc.buffer.scalar_type = AbiConstants::ScalarType::kSint32;
-  desc.dims = {2, 2};
-  std::stringstream os;
-  ASSERT_OK(PrintVariantList({desc}, outputs_, &os));
-  EXPECT_EQ(os.str(), "2x2xi32=[42 -43][44 45]\n");
-}
-
-TEST_F(VmUtilTest, PrintVariantListMultiBuffer) {
-  IREE_ASSERT_OK(
-      iree_vm_variant_list_alloc(2, IREE_ALLOCATOR_SYSTEM, &outputs_));
-
-  // Buffer 1
-  iree_hal_buffer_t* buf1 = nullptr;
-  int element_count1 = 4;
-  int element_size1 = sizeof(int32_t);
-  iree_device_size_t allocation_size1 = element_count1 * element_size1;
-  IREE_ASSERT_OK(iree_hal_allocator_allocate_buffer(
-      allocator_,
-      static_cast<iree_hal_memory_type_t>(IREE_HAL_MEMORY_TYPE_HOST_LOCAL |
-                                          IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE),
-      static_cast<iree_hal_buffer_usage_t>(IREE_HAL_BUFFER_USAGE_ALL |
-                                           IREE_HAL_BUFFER_USAGE_CONSTANT),
-      allocation_size1, &buf1));
-  std::array<int32_t, 4> buf1_data{42, 43, 44, 45};
-  IREE_ASSERT_OK(iree_hal_buffer_write_data(
-      buf1, 0, reinterpret_cast<uint8_t*>(buf1_data.data()), allocation_size1));
-  auto buf1_ref = iree_hal_buffer_move_ref(buf1);
-  IREE_ASSERT_OK(iree_vm_variant_list_append_ref_move(outputs_, &buf1_ref));
-  RawSignatureParser::Description desc1;
-  desc1.type = RawSignatureParser::Type::kBuffer;
-  desc1.buffer.scalar_type = AbiConstants::ScalarType::kSint32;
-  desc1.dims = {2, 2};
-
-  // Buffer 2
-  iree_hal_buffer_t* buf2 = nullptr;
-  int element_count2 = 6;
-  int element_size2 = sizeof(double);
-  iree_device_size_t allocation_size2 = element_count2 * element_size2;
-  IREE_ASSERT_OK(iree_hal_allocator_allocate_buffer(
-      allocator_,
-      static_cast<iree_hal_memory_type_t>(IREE_HAL_MEMORY_TYPE_HOST_LOCAL |
-                                          IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE),
-      static_cast<iree_hal_buffer_usage_t>(IREE_HAL_BUFFER_USAGE_ALL |
-                                           IREE_HAL_BUFFER_USAGE_CONSTANT),
-      allocation_size2, &buf2));
-  std::array<double, 6> buf2_data{1, 2, 3, 4, 5, 6};
-  IREE_ASSERT_OK(iree_hal_buffer_write_data(
-      buf2, 0, reinterpret_cast<uint8_t*>(buf2_data.data()), allocation_size2));
-  auto buf2_ref = iree_hal_buffer_move_ref(buf2);
-  IREE_ASSERT_OK(iree_vm_variant_list_append_ref_move(outputs_, &buf2_ref));
-  RawSignatureParser::Description desc2;
-  desc2.type = RawSignatureParser::Type::kBuffer;
-  desc2.buffer.scalar_type = AbiConstants::ScalarType::kIeeeFloat64;
-  desc2.dims = {2, 3};
-
-  std::stringstream os;
-  ASSERT_OK(PrintVariantList({desc1, desc2}, outputs_, &os));
-  EXPECT_EQ(os.str(),
-            "2x2xi32=[42 43][44 45]\n"
-            "2x3xf64=[1 2 3][4 5 6]\n");
-}
-
-TEST_F(VmUtilTest, ParsePrint) {
+TEST_F(VmUtilTest, ParsePrintBuffer) {
   auto buf_string = "2x2xi32=[42 43][44 45]";
   RawSignatureParser::Description desc;
   desc.type = RawSignatureParser::Type::kBuffer;
@@ -174,6 +57,21 @@
   IREE_ASSERT_OK(iree_vm_variant_list_free(variant_list));
 }
 
+TEST_F(VmUtilTest, ParsePrintScalar) {
+  auto input_string = "i32=42";
+  RawSignatureParser::Description desc;
+  desc.type = RawSignatureParser::Type::kScalar;
+  desc.scalar.type = AbiConstants::ScalarType::kSint32;
+
+  ASSERT_OK_AND_ASSIGN(auto* variant_list,
+                       ParseToVariantList({desc}, allocator_, {input_string}));
+  std::stringstream os;
+  ASSERT_OK(PrintVariantList({desc}, variant_list, &os));
+  EXPECT_EQ(os.str(), absl::StrCat(input_string, "\n"));
+
+  IREE_ASSERT_OK(iree_vm_variant_list_free(variant_list));
+}
+
 TEST_F(VmUtilTest, ParsePrintRank0Buffer) {
   auto buf_string = "i32=42";
   RawSignatureParser::Description desc;