Add scalar input support to IREE runners
PiperOrigin-RevId: 293073521
diff --git a/iree/tools/test/scalars.mlir b/iree/tools/test/scalars.mlir
index a63e94e..a6f38e8 100644
--- a/iree/tools/test/scalars.mlir
+++ b/iree/tools/test/scalars.mlir
@@ -1,14 +1,13 @@
-// RUN: (iree-translate --iree-hal-target-backends=interpreter-bytecode -iree-mlir-to-vm-bytecode-module %s | iree-run-module --entry_function=multi_input) | IreeFileCheck %s
+// RUN: (iree-translate --iree-hal-target-backends=interpreter-bytecode -iree-mlir-to-vm-bytecode-module %s | iree-run-module --entry_function=scalar --inputs="i32=42") | IreeFileCheck %s
-// RUN: [[ $IREE_VULKAN_DISABLE == 1 ]] || ((iree-translate --iree-hal-target-backends=vulkan-spirv -iree-mlir-to-vm-bytecode-module %s | iree-run-module --driver=vulkan --entry_function=multi_input) | IreeFileCheck %s)
+// RUN: [[ $IREE_VULKAN_DISABLE == 1 ]] || ((iree-translate --iree-hal-target-backends=vulkan-spirv -iree-mlir-to-vm-bytecode-module %s | iree-run-module --driver=vulkan --entry_function=scalar --inputs="i32=42") | IreeFileCheck %s)
-// RUN: (iree-run-mlir --iree-hal-target-backends=interpreter-bytecode %s) | IreeFileCheck %s
+// RUN: (iree-run-mlir --iree-hal-target-backends=interpreter-bytecode --input-value=i32=42 %s) | IreeFileCheck %s
-// RUN: [[ $IREE_VULKAN_DISABLE == 1 ]] || (iree-run-mlir -iree-hal-target-backends=vulkan-spirv %s | IreeFileCheck %s)
+// RUN: [[ $IREE_VULKAN_DISABLE == 1 ]] || (iree-run-mlir -iree-hal-target-backends=vulkan-spirv --input-value=i32=42 %s | IreeFileCheck %s)
-// CHECK-LABEL: EXEC @multi_input
-func @multi_input() -> i32 attributes { iree.module.export } {
- %c = iree.unfoldable_constant 42 : i32
- return %c : i32
+// CHECK-LABEL: EXEC @scalar
+func @scalar(%arg0 : i32) -> i32 attributes { iree.module.export } {
+ return %arg0 : i32
}
// CHECK: i32=42
diff --git a/iree/tools/vm_util.cc b/iree/tools/vm_util.cc
index e61be5e..ae0669c 100644
--- a/iree/tools/vm_util.cc
+++ b/iree/tools/vm_util.cc
@@ -16,7 +16,9 @@
#include <ostream>
+#include "absl/strings/numbers.h"
#include "absl/strings/string_view.h"
+#include "absl/strings/strip.h"
#include "absl/types/span.h"
#include "iree/base/api_util.h"
#include "iree/base/buffer_string_util.h"
@@ -28,6 +30,7 @@
#include "iree/modules/hal/hal_module.h"
#include "iree/vm/bytecode_module.h"
#include "iree/vm/module.h"
+#include "iree/vm/variant_list.h"
namespace iree {
@@ -80,46 +83,81 @@
StatusOr<iree_vm_variant_list_t*> ParseToVariantList(
absl::Span<const RawSignatureParser::Description> descs,
iree_hal_allocator_t* allocator,
- absl::Span<const std::string> buf_strings) {
- if (buf_strings.size() != descs.size()) {
+ absl::Span<const std::string> input_strings) {
+ if (input_strings.size() != descs.size()) {
return FailedPreconditionErrorBuilder(IREE_LOC)
<< "Signature mismatch; expected " << descs.size()
- << " buffer strings but received " << buf_strings.size();
+ << " buffer strings but received " << input_strings.size();
}
iree_vm_variant_list_t* variant_list = nullptr;
RETURN_IF_ERROR(FromApiStatus(
- iree_vm_variant_list_alloc(buf_strings.size(), IREE_ALLOCATOR_SYSTEM,
+ iree_vm_variant_list_alloc(input_strings.size(), IREE_ALLOCATOR_SYSTEM,
&variant_list),
IREE_LOC));
- for (const auto& buf_string : buf_strings) {
- // TODO(gcmn) Handle scalar variants.
- ASSIGN_OR_RETURN(auto shaped_buffer,
- ParseShapedBufferFromString(buf_string),
- _ << "Parsing value '" << buf_string << "'");
- iree_hal_buffer_t* buf = nullptr;
- // TODO(benvanik): combined function for linear to optimal upload.
- iree_device_size_t allocation_size =
- shaped_buffer.shape().element_count() * shaped_buffer.element_size();
- RETURN_IF_ERROR(FromApiStatus(
- iree_hal_allocator_allocate_buffer(
- allocator,
- static_cast<iree_hal_memory_type_t>(
- IREE_HAL_MEMORY_TYPE_HOST_LOCAL |
- IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE),
- static_cast<iree_hal_buffer_usage_t>(
- IREE_HAL_BUFFER_USAGE_ALL | IREE_HAL_BUFFER_USAGE_CONSTANT),
- allocation_size, &buf),
- IREE_LOC))
- << "Allocating buffer";
- RETURN_IF_ERROR(FromApiStatus(
- iree_hal_buffer_write_data(buf, 0, shaped_buffer.contents().data(),
- shaped_buffer.contents().size()),
- IREE_LOC))
- << "Populating buffer contents ";
- auto buf_ref = iree_hal_buffer_move_ref(buf);
- RETURN_IF_ERROR(FromApiStatus(
- iree_vm_variant_list_append_ref_move(variant_list, &buf_ref),
- IREE_LOC));
+ for (int i = 0; i < input_strings.size(); ++i) {
+ auto input_string = input_strings[i];
+ auto desc = descs[i];
+ std::string desc_str;
+ desc.ToString(desc_str);
+ switch (desc.type) {
+ case RawSignatureParser::Type::kScalar: {
+ if (desc.scalar.type != AbiConstants::ScalarType::kSint32) {
+ return UnimplementedErrorBuilder(IREE_LOC)
+ << "Unsupported signature scalar type: " << desc_str;
+ }
+ absl::string_view input_view = absl::StripAsciiWhitespace(input_string);
+ input_view = absl::StripPrefix(input_view, "\"");
+ input_view = absl::StripSuffix(input_view, "\"");
+ if (!absl::ConsumePrefix(&input_view, "i32=")) {
+ return InvalidArgumentErrorBuilder(IREE_LOC)
+ << "Parsing '" << input_string
+ << "'. Has i32 descriptor but does not start with 'i32='";
+ }
+ int32_t val;
+ if (!absl::SimpleAtoi(input_view, &val)) {
+ return InvalidArgumentErrorBuilder(IREE_LOC)
+ << "Converting '" << input_view << "' to i32 when parsing '"
+ << input_string << "'";
+ }
+ iree_vm_variant_list_append_value(variant_list,
+ IREE_VM_VALUE_MAKE_I32(val));
+ break;
+ }
+ case RawSignatureParser::Type::kBuffer: {
+ ASSIGN_OR_RETURN(auto shaped_buffer,
+ ParseShapedBufferFromString(input_string),
+ _ << "Parsing value '" << input_string << "'");
+ iree_hal_buffer_t* buf = nullptr;
+ // TODO(benvanik): combined function for linear to optimal upload.
+ iree_device_size_t allocation_size =
+ shaped_buffer.shape().element_count() *
+ shaped_buffer.element_size();
+ RETURN_IF_ERROR(FromApiStatus(
+ iree_hal_allocator_allocate_buffer(
+ allocator,
+ static_cast<iree_hal_memory_type_t>(
+ IREE_HAL_MEMORY_TYPE_HOST_LOCAL |
+ IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE),
+ static_cast<iree_hal_buffer_usage_t>(
+ IREE_HAL_BUFFER_USAGE_ALL | IREE_HAL_BUFFER_USAGE_CONSTANT),
+ allocation_size, &buf),
+ IREE_LOC))
+ << "Allocating buffer";
+ RETURN_IF_ERROR(FromApiStatus(
+ iree_hal_buffer_write_data(buf, 0, shaped_buffer.contents().data(),
+ shaped_buffer.contents().size()),
+ IREE_LOC))
+ << "Populating buffer contents ";
+ auto buf_ref = iree_hal_buffer_move_ref(buf);
+ RETURN_IF_ERROR(FromApiStatus(
+ iree_vm_variant_list_append_ref_move(variant_list, &buf_ref),
+ IREE_LOC));
+ break;
+ }
+ default:
+ return UnimplementedErrorBuilder(IREE_LOC)
+ << "Unsupported signature type: " << desc_str;
+ }
}
return variant_list;
}
@@ -147,12 +185,12 @@
<< static_cast<int>(variant->value_type)
<< " but descriptor information " << desc_str;
}
- if (desc.scalar.type == AbiConstants::ScalarType::kSint32) {
- *os << "i32=" << variant->i32 << "\n";
- break;
+ if (desc.scalar.type != AbiConstants::ScalarType::kSint32) {
+ return UnimplementedErrorBuilder(IREE_LOC)
+ << "Unsupported signature scalar type: " << desc_str;
}
- return UnimplementedErrorBuilder(IREE_LOC)
- << "Unsupported signature scalar type: " << desc_str;
+ *os << "i32=" << variant->i32 << "\n";
+ break;
}
case RawSignatureParser::Type::kBuffer: {
if (variant->value_type != IREE_VM_VALUE_TYPE_NONE) {
diff --git a/iree/tools/vm_util.h b/iree/tools/vm_util.h
index d8323f1..64023db 100644
--- a/iree/tools/vm_util.h
+++ b/iree/tools/vm_util.h
@@ -40,24 +40,29 @@
StatusOr<std::vector<RawSignatureParser::Description>> ParseOutputSignature(
const iree_vm_function_t& function);
-// Parses a list of shapes and values into VM buffers.
-// Expects strings in the IREE standard shaped buffer format:
+// Parses |input_strings| into a variant list of VM scalars and buffers.
+// Scalars should be in the format:
+// type=value
+// Buffers should be in the IREE standard shaped buffer format:
// [shape]xtype=[value]
// described in
// https://github.com/google/iree/tree/master/iree/base/buffer_string_util.h
-// Uses |allocator| to allocate the buffers, validating them against the type
-// descriptors in |descs|. The returned variant list must be freed by the
-// caller.
+// Uses |allocator| to allocate the buffers.
+// Uses descriptors in |descs| for type information and validation.
+// The returned variant list must be freed by the caller.
StatusOr<iree_vm_variant_list_t*> ParseToVariantList(
absl::Span<const RawSignatureParser::Description> descs,
- iree_hal_allocator_t* allocator, absl::Span<const std::string> buf_strings);
+ iree_hal_allocator_t* allocator,
+ absl::Span<const std::string> input_strings);
// Prints a variant list of VM scalars and buffers to |os|.
-// Uses the IREE standard shaped buffer format:
+// Prints scalars in the format:
+// type=value
+// Prints buffers in the IREE standard shaped buffer format:
// [shape]xtype=[value]
// described in
// https://github.com/google/iree/tree/master/iree/base/buffer_string_util.h
-// Uses |descs| for type information and validation.
+// Uses descriptors in |descs| for type information and validation.
Status PrintVariantList(absl::Span<const RawSignatureParser::Description> descs,
iree_vm_variant_list_t* variant_list,
std::ostream* os = &std::cout);
diff --git a/iree/tools/vm_util_test.cc b/iree/tools/vm_util_test.cc
index db01496..a54f96d 100644
--- a/iree/tools/vm_util_test.cc
+++ b/iree/tools/vm_util_test.cc
@@ -35,130 +35,13 @@
allocator_ = iree_hal_device_allocator(device_);
}
- virtual void TearDown() {
- IREE_ASSERT_OK(iree_hal_device_release(device_));
- if (outputs_ != nullptr) {
- IREE_ASSERT_OK(iree_vm_variant_list_free(outputs_));
- }
- }
+ virtual void TearDown() { IREE_ASSERT_OK(iree_hal_device_release(device_)); }
iree_hal_device_t* device_ = nullptr;
- iree_vm_variant_list_t* outputs_ = nullptr;
iree_hal_allocator_t* allocator_ = nullptr;
};
-TEST_F(VmUtilTest, PrintVariantListScalar) {
- IREE_ASSERT_OK(
- iree_vm_variant_list_alloc(1, IREE_ALLOCATOR_SYSTEM, &outputs_));
- IREE_ASSERT_OK(
- iree_vm_variant_list_append_value(outputs_, IREE_VM_VALUE_MAKE_I32(42)));
- RawSignatureParser::Description desc;
- desc.type = RawSignatureParser::Type::kScalar;
- desc.scalar.type = AbiConstants::ScalarType::kSint32;
- std::stringstream os;
- ASSERT_OK(PrintVariantList({desc}, outputs_, &os));
- EXPECT_EQ(os.str(), "i32=42\n");
-}
-
-TEST_F(VmUtilTest, PrintVariantListMultiple) {
- IREE_ASSERT_OK(
- iree_vm_variant_list_alloc(2, IREE_ALLOCATOR_SYSTEM, &outputs_));
- IREE_ASSERT_OK(
- iree_vm_variant_list_append_value(outputs_, IREE_VM_VALUE_MAKE_I32(42)));
- IREE_ASSERT_OK(
- iree_vm_variant_list_append_value(outputs_, IREE_VM_VALUE_MAKE_I32(13)));
- RawSignatureParser::Description desc;
- desc.type = RawSignatureParser::Type::kScalar;
- desc.scalar.type = AbiConstants::ScalarType::kSint32;
- std::stringstream os;
- ASSERT_OK(PrintVariantList({desc, desc}, outputs_, &os));
- EXPECT_EQ(os.str(),
- "i32=42\n"
- "i32=13\n");
-}
-
-TEST_F(VmUtilTest, PrintVariantListBuffer) {
- IREE_ASSERT_OK(
- iree_vm_variant_list_alloc(1, IREE_ALLOCATOR_SYSTEM, &outputs_));
- iree_hal_buffer_t* buf = nullptr;
- std::array<int32_t, 4> buf_data{42, -43, 44, 45};
- iree_device_size_t allocation_size = sizeof(buf_data);
- IREE_ASSERT_OK(iree_hal_allocator_allocate_buffer(
- allocator_,
- static_cast<iree_hal_memory_type_t>(IREE_HAL_MEMORY_TYPE_HOST_LOCAL |
- IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE),
- static_cast<iree_hal_buffer_usage_t>(IREE_HAL_BUFFER_USAGE_ALL |
- IREE_HAL_BUFFER_USAGE_CONSTANT),
- allocation_size, &buf));
- IREE_ASSERT_OK(iree_hal_buffer_write_data(
- buf, 0, reinterpret_cast<uint8_t*>(buf_data.data()), allocation_size));
- auto buf_ref = iree_hal_buffer_move_ref(buf);
- IREE_ASSERT_OK(iree_vm_variant_list_append_ref_move(outputs_, &buf_ref));
- RawSignatureParser::Description desc;
- desc.type = RawSignatureParser::Type::kBuffer;
- desc.buffer.scalar_type = AbiConstants::ScalarType::kSint32;
- desc.dims = {2, 2};
- std::stringstream os;
- ASSERT_OK(PrintVariantList({desc}, outputs_, &os));
- EXPECT_EQ(os.str(), "2x2xi32=[42 -43][44 45]\n");
-}
-
-TEST_F(VmUtilTest, PrintVariantListMultiBuffer) {
- IREE_ASSERT_OK(
- iree_vm_variant_list_alloc(2, IREE_ALLOCATOR_SYSTEM, &outputs_));
-
- // Buffer 1
- iree_hal_buffer_t* buf1 = nullptr;
- int element_count1 = 4;
- int element_size1 = sizeof(int32_t);
- iree_device_size_t allocation_size1 = element_count1 * element_size1;
- IREE_ASSERT_OK(iree_hal_allocator_allocate_buffer(
- allocator_,
- static_cast<iree_hal_memory_type_t>(IREE_HAL_MEMORY_TYPE_HOST_LOCAL |
- IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE),
- static_cast<iree_hal_buffer_usage_t>(IREE_HAL_BUFFER_USAGE_ALL |
- IREE_HAL_BUFFER_USAGE_CONSTANT),
- allocation_size1, &buf1));
- std::array<int32_t, 4> buf1_data{42, 43, 44, 45};
- IREE_ASSERT_OK(iree_hal_buffer_write_data(
- buf1, 0, reinterpret_cast<uint8_t*>(buf1_data.data()), allocation_size1));
- auto buf1_ref = iree_hal_buffer_move_ref(buf1);
- IREE_ASSERT_OK(iree_vm_variant_list_append_ref_move(outputs_, &buf1_ref));
- RawSignatureParser::Description desc1;
- desc1.type = RawSignatureParser::Type::kBuffer;
- desc1.buffer.scalar_type = AbiConstants::ScalarType::kSint32;
- desc1.dims = {2, 2};
-
- // Buffer 2
- iree_hal_buffer_t* buf2 = nullptr;
- int element_count2 = 6;
- int element_size2 = sizeof(double);
- iree_device_size_t allocation_size2 = element_count2 * element_size2;
- IREE_ASSERT_OK(iree_hal_allocator_allocate_buffer(
- allocator_,
- static_cast<iree_hal_memory_type_t>(IREE_HAL_MEMORY_TYPE_HOST_LOCAL |
- IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE),
- static_cast<iree_hal_buffer_usage_t>(IREE_HAL_BUFFER_USAGE_ALL |
- IREE_HAL_BUFFER_USAGE_CONSTANT),
- allocation_size2, &buf2));
- std::array<double, 6> buf2_data{1, 2, 3, 4, 5, 6};
- IREE_ASSERT_OK(iree_hal_buffer_write_data(
- buf2, 0, reinterpret_cast<uint8_t*>(buf2_data.data()), allocation_size2));
- auto buf2_ref = iree_hal_buffer_move_ref(buf2);
- IREE_ASSERT_OK(iree_vm_variant_list_append_ref_move(outputs_, &buf2_ref));
- RawSignatureParser::Description desc2;
- desc2.type = RawSignatureParser::Type::kBuffer;
- desc2.buffer.scalar_type = AbiConstants::ScalarType::kIeeeFloat64;
- desc2.dims = {2, 3};
-
- std::stringstream os;
- ASSERT_OK(PrintVariantList({desc1, desc2}, outputs_, &os));
- EXPECT_EQ(os.str(),
- "2x2xi32=[42 43][44 45]\n"
- "2x3xf64=[1 2 3][4 5 6]\n");
-}
-
-TEST_F(VmUtilTest, ParsePrint) {
+TEST_F(VmUtilTest, ParsePrintBuffer) {
auto buf_string = "2x2xi32=[42 43][44 45]";
RawSignatureParser::Description desc;
desc.type = RawSignatureParser::Type::kBuffer;
@@ -174,6 +57,21 @@
IREE_ASSERT_OK(iree_vm_variant_list_free(variant_list));
}
+TEST_F(VmUtilTest, ParsePrintScalar) {
+ auto input_string = "i32=42";
+ RawSignatureParser::Description desc;
+ desc.type = RawSignatureParser::Type::kScalar;
+ desc.scalar.type = AbiConstants::ScalarType::kSint32;
+
+ ASSERT_OK_AND_ASSIGN(auto* variant_list,
+ ParseToVariantList({desc}, allocator_, {input_string}));
+ std::stringstream os;
+ ASSERT_OK(PrintVariantList({desc}, variant_list, &os));
+ EXPECT_EQ(os.str(), absl::StrCat(input_string, "\n"));
+
+ IREE_ASSERT_OK(iree_vm_variant_list_free(variant_list));
+}
+
TEST_F(VmUtilTest, ParsePrintRank0Buffer) {
auto buf_string = "i32=42";
RawSignatureParser::Description desc;