Converting check+custom_modules to use the HAL buffer print/parse API.
diff --git a/iree/modules/check/BUILD b/iree/modules/check/BUILD
index 87f83a0..863b189 100644
--- a/iree/modules/check/BUILD
+++ b/iree/modules/check/BUILD
@@ -80,7 +80,6 @@
deps = [
"//iree/base:api",
"//iree/base:api_util",
- "//iree/base:buffer_string_util",
"//iree/base:status",
"//iree/hal:api",
"//iree/modules/hal",
diff --git a/iree/modules/check/CMakeLists.txt b/iree/modules/check/CMakeLists.txt
index fb6d66e..804b844 100644
--- a/iree/modules/check/CMakeLists.txt
+++ b/iree/modules/check/CMakeLists.txt
@@ -80,7 +80,6 @@
absl::strings
iree::base::api
iree::base::api_util
- iree::base::buffer_string_util
iree::base::status
iree::hal::api
iree::modules::hal
diff --git a/iree/modules/check/native_module.cc b/iree/modules/check/native_module.cc
index d90d521..4647643 100644
--- a/iree/modules/check/native_module.cc
+++ b/iree/modules/check/native_module.cc
@@ -24,7 +24,6 @@
#include "absl/strings/str_cat.h"
#include "iree/base/api.h"
#include "iree/base/api_util.h"
-#include "iree/base/buffer_string_util.h"
#include "iree/base/status.h"
#include "iree/hal/api.h"
#include "iree/modules/hal/hal_module.h"
@@ -47,6 +46,22 @@
bytes.data_length / sizeof(T));
}
+StatusOr<std::string> BufferViewToString(iree_hal_buffer_view_t* buffer_view) {
+ std::string result_str(4096, '\0');
+ iree_status_t status;
+ do {
+ iree_host_size_t actual_length = 0;
+ status = iree_hal_buffer_view_format(
+ buffer_view, /*max_element_count=*/1024, result_str.size() + 1,
+ &result_str[0], &actual_length);
+ result_str.resize(actual_length);
+ } while (iree_status_is_out_of_range(status));
+ if (!iree_status_is_ok(status)) {
+ return FromApiStatus(status, IREE_LOC);
+ }
+ return std::move(result_str);
+}
+
template <typename T>
Status ExpectAllTrue(iree_byte_span_t bytes) {
EXPECT_THAT(AbslSpan<T>(bytes), Each(Not(0)));
@@ -225,6 +240,9 @@
bool shape_eq = lhs_shape == rhs_shape;
bool contents_eq =
EqByteSpan(lhs_mapped_memory.contents, rhs_mapped_memory.contents);
+ iree_hal_buffer_unmap(lhs_buf, &lhs_mapped_memory);
+ iree_hal_buffer_unmap(rhs_buf, &rhs_mapped_memory);
+
if (!element_types_eq || !shape_eq || !contents_eq) {
std::ostringstream os;
os << "Expected equality of these values.";
@@ -241,43 +259,18 @@
os << "\n"
" lhs:\n"
" ";
- char lhs_element_type_str[16];
- RETURN_IF_ERROR(FromApiStatus(
- iree_hal_format_element_type(iree_hal_buffer_view_element_type(lhs),
- sizeof(lhs_element_type_str),
- lhs_element_type_str, nullptr),
- IREE_LOC));
- // TODO(b/146898896): Remove dependence on Shape.
- PrintShapedTypeToStream(Shape{lhs_shape}, lhs_element_type_str, &os);
- os << "=";
- RETURN_IF_ERROR(
- PrintNumericalDataToStream(Shape{lhs_shape}, lhs_element_type_str,
- {lhs_mapped_memory.contents.data,
- lhs_mapped_memory.contents.data_length},
- /*max_entries=*/1024, &os));
+ ASSIGN_OR_RETURN(auto lhs_str, BufferViewToString(lhs));
+ os << lhs_str;
os << "\n"
" rhs:\n"
" ";
- char rhs_element_type_str[16];
- RETURN_IF_ERROR(FromApiStatus(
- iree_hal_format_element_type(iree_hal_buffer_view_element_type(rhs),
- sizeof(rhs_element_type_str),
- rhs_element_type_str, nullptr),
- IREE_LOC));
- PrintShapedTypeToStream(Shape{rhs_shape}, rhs_element_type_str, &os);
- os << "=";
- RETURN_IF_ERROR(
- PrintNumericalDataToStream(Shape{rhs_shape}, rhs_element_type_str,
- {rhs_mapped_memory.contents.data,
- rhs_mapped_memory.contents.data_length},
- /*max_entries=*/1024, &os));
+ ASSIGN_OR_RETURN(auto rhs_str, BufferViewToString(rhs));
+ os << rhs_str;
// TODO(b/146898896): Use ADD_FAILURE_AT to propagate source location.
ADD_FAILURE() << os.str();
}
- iree_hal_buffer_unmap(lhs_buf, &lhs_mapped_memory);
- iree_hal_buffer_unmap(rhs_buf, &rhs_mapped_memory);
return OkStatus();
}
@@ -328,6 +321,9 @@
AlmostEqByteSpan(lhs_mapped_memory.contents,
rhs_mapped_memory.contents, lhs_element_type));
}
+ iree_hal_buffer_unmap(lhs_buf, &lhs_mapped_memory);
+ iree_hal_buffer_unmap(rhs_buf, &rhs_mapped_memory);
+
if (!element_types_eq || !shape_eq || !contents_could_be_almost_eq) {
std::ostringstream os;
os << "Expected near equality of these values.";
@@ -344,43 +340,18 @@
os << "\n"
" lhs:\n"
" ";
- char lhs_element_type_str[16];
- RETURN_IF_ERROR(FromApiStatus(
- iree_hal_format_element_type(iree_hal_buffer_view_element_type(lhs),
- sizeof(lhs_element_type_str),
- lhs_element_type_str, nullptr),
- IREE_LOC));
- // TODO(b/146898896): Remove dependence on Shape.
- PrintShapedTypeToStream(Shape{lhs_shape}, lhs_element_type_str, &os);
- os << "=";
- RETURN_IF_ERROR(
- PrintNumericalDataToStream(Shape{lhs_shape}, lhs_element_type_str,
- {lhs_mapped_memory.contents.data,
- lhs_mapped_memory.contents.data_length},
- /*max_entries=*/1024, &os));
+ ASSIGN_OR_RETURN(auto lhs_str, BufferViewToString(lhs));
+ os << lhs_str;
os << "\n"
" rhs:\n"
" ";
- char rhs_element_type_str[16];
- RETURN_IF_ERROR(FromApiStatus(
- iree_hal_format_element_type(iree_hal_buffer_view_element_type(rhs),
- sizeof(rhs_element_type_str),
- rhs_element_type_str, nullptr),
- IREE_LOC));
- PrintShapedTypeToStream(Shape{rhs_shape}, rhs_element_type_str, &os);
- os << "=";
- RETURN_IF_ERROR(
- PrintNumericalDataToStream(Shape{rhs_shape}, rhs_element_type_str,
- {rhs_mapped_memory.contents.data,
- rhs_mapped_memory.contents.data_length},
- /*max_entries=*/1024, &os));
+ ASSIGN_OR_RETURN(auto rhs_str, BufferViewToString(rhs));
+ os << rhs_str;
// TODO(b/146898896): Use ADD_FAILURE_AT to propagate source location.
ADD_FAILURE() << os.str();
}
- iree_hal_buffer_unmap(lhs_buf, &lhs_mapped_memory);
- iree_hal_buffer_unmap(rhs_buf, &rhs_mapped_memory);
return OkStatus();
}
diff --git a/iree/samples/custom_modules/BUILD b/iree/samples/custom_modules/BUILD
index 411e69c..7a65e51 100644
--- a/iree/samples/custom_modules/BUILD
+++ b/iree/samples/custom_modules/BUILD
@@ -55,9 +55,6 @@
deps = [
"//iree/base:api",
"//iree/base:api_util",
- "//iree/base:buffer_string_util",
- "//iree/base:shape",
- "//iree/base:shaped_buffer_string_util",
"//iree/hal:api",
"//iree/modules/hal",
"//iree/vm",
diff --git a/iree/samples/custom_modules/CMakeLists.txt b/iree/samples/custom_modules/CMakeLists.txt
index 0998857..44224a4 100644
--- a/iree/samples/custom_modules/CMakeLists.txt
+++ b/iree/samples/custom_modules/CMakeLists.txt
@@ -59,9 +59,6 @@
DEPS
iree::base::api
iree::base::api_util
- iree::base::buffer_string_util
- iree::base::shape
- iree::base::shaped_buffer_string_util
iree::hal::api
iree::modules::hal
iree::vm
diff --git a/iree/samples/custom_modules/native_module.cc b/iree/samples/custom_modules/native_module.cc
index 93ffcf9..ade2f9c 100644
--- a/iree/samples/custom_modules/native_module.cc
+++ b/iree/samples/custom_modules/native_module.cc
@@ -19,9 +19,6 @@
#include "iree/base/api.h"
#include "iree/base/api_util.h"
-#include "iree/base/buffer_string_util.h"
-#include "iree/base/shape.h"
-#include "iree/base/shaped_buffer_string_util.h"
#include "iree/hal/api.h"
#include "iree/modules/hal/hal_module.h"
#include "iree/vm/module_abi_cc.h"
@@ -133,56 +130,36 @@
Status Initialize(int32_t unique_id) {
// Allocate a unique ID to demonstrate per-context state.
auto str_buffer = "ctx_" + std::to_string(unique_id);
- return FromApiStatus(
+ RETURN_IF_ERROR(FromApiStatus(
iree_custom_message_create(iree_make_cstring_view(str_buffer.c_str()),
allocator_, &unique_message_),
- IREE_LOC);
+ IREE_LOC));
+
+ // Setup a host-local allocator we can use because this sample doesn't have
+ // a real device allocator.
+ RETURN_IF_ERROR(FromApiStatus(iree_hal_allocator_create_host_local(
+ allocator_, &host_local_allocator_),
+ IREE_LOC));
+
+ return OkStatus();
}
// custom.buffer_to_message(%buffer_view) -> %result
StatusOr<vm::ref<iree_custom_message_t>> BufferToMessage(
vm::ref<iree_hal_buffer_view_t> buffer_view) {
- iree_hal_buffer_t* buffer = iree_hal_buffer_view_buffer(buffer_view.get());
-
- // Map the buffer memory so we can read it back.
- iree_hal_mapped_memory_t mapped_memory;
- RETURN_IF_ERROR(
- FromApiStatus(iree_hal_buffer_map(buffer, IREE_HAL_MEMORY_ACCESS_READ,
- 0, IREE_WHOLE_BUFFER, &mapped_memory),
- IREE_LOC));
-
- // NOTE: these string methods take the old Shape type and as such have a
- // rank limit. That limit is just an artifact of those APIs, not the
- // buffer view shape type.
- absl::InlinedVector<int32_t, kMaxRank> shape(kMaxRank);
- size_t rank = 0;
- RETURN_IF_ERROR(FromApiStatus(
- iree_hal_buffer_view_shape(buffer_view.get(), shape.capacity(),
- shape.data(), &rank),
- IREE_LOC));
- shape.resize(rank);
- char element_type_str[16];
- RETURN_IF_ERROR(
- FromApiStatus(iree_hal_format_element_type(
- iree_hal_buffer_view_element_type(buffer_view.get()),
- sizeof(element_type_str), element_type_str, nullptr),
- IREE_LOC));
-
- // Print the buffer contents using our helpers.
- std::string string_value;
- RETURN_IF_ERROR(PrintNumericalDataToString(
- Shape{shape}, element_type_str,
- {mapped_memory.contents.data, mapped_memory.contents.data_length},
- /*max_entries=*/1024, &string_value));
-
- // Prefix shape/type.
- string_value =
- absl::StrCat(PrintShapedTypeToString(Shape{shape}, element_type_str),
- "=", string_value);
-
- // Unmap the buffer when we are done with it.
- RETURN_IF_ERROR(
- FromApiStatus(iree_hal_buffer_unmap(buffer, &mapped_memory), IREE_LOC));
+ // Convert the buffer view to a [shape]x[type]=[contents] string.
+ std::string string_value(4096, '\0');
+ iree_status_t status;
+ do {
+ iree_host_size_t actual_length = 0;
+ status = iree_hal_buffer_view_format(
+ buffer_view.get(), /*max_element_count=*/1024,
+ string_value.size() + 1, &string_value[0], &actual_length);
+ string_value.resize(actual_length);
+ } while (iree_status_is_out_of_range(status));
+ if (!iree_status_is_ok(status)) {
+ return FromApiStatus(status, IREE_LOC);
+ }
// Pack the string contents into a message.
vm::ref<iree_custom_message_t> message;
@@ -197,57 +174,17 @@
// custom.message_to_buffer(%message) -> %buffer_view
StatusOr<vm::ref<iree_hal_buffer_view_t>> MessageToBuffer(
vm::ref<iree_custom_message_t> message) {
- // NOTE: these old-style parsing routines need to be updated for the new
- // type system. They use different types, different shapes, etc.
- auto str_parts = BufferStringParts::ExtractFrom(
- absl::string_view(message->value.data, message->value.size));
- iree_hal_element_type_t element_type = IREE_HAL_ELEMENT_TYPE_NONE;
- RETURN_IF_ERROR(FromApiStatus(
- iree_hal_parse_element_type(
- {str_parts.type_str.data(), str_parts.type_str.size()},
- &element_type),
- IREE_LOC));
- ASSIGN_OR_RETURN(auto shape, ParseShape(str_parts.shape_str));
-
- // TODO(benvanik): plumb through an allocator we can use.
- size_t allocation_size =
- shape.element_count() * iree_hal_element_byte_count(element_type);
- vm::ref<iree_hal_buffer_t> buffer;
- RETURN_IF_ERROR(FromApiStatus(
- iree_hal_heap_buffer_allocate(
- IREE_HAL_MEMORY_TYPE_HOST_LOCAL,
- static_cast<iree_hal_buffer_usage_t>(
- IREE_HAL_BUFFER_USAGE_ALL | IREE_HAL_BUFFER_USAGE_CONSTANT),
- allocation_size, IREE_ALLOCATOR_SYSTEM, IREE_ALLOCATOR_SYSTEM,
- &buffer),
- IREE_LOC));
- if (!str_parts.data_str.empty()) {
- // Map the buffer memory so we can write it with the data contents.
- iree_hal_mapped_memory_t mapped_memory;
- RETURN_IF_ERROR(FromApiStatus(
- iree_hal_buffer_map(buffer.get(),
- IREE_HAL_MEMORY_ACCESS_DISCARD_WRITE, 0,
- IREE_WHOLE_BUFFER, &mapped_memory),
- IREE_LOC));
-
- // Parse the data from the string right into the buffer.
- RETURN_IF_ERROR(ParseBufferDataAsType(
- str_parts.data_str, str_parts.type_str,
- absl::MakeSpan(mapped_memory.contents.data,
- mapped_memory.contents.data_length)));
-
- // Unmap the buffer when we are done with it.
- RETURN_IF_ERROR(FromApiStatus(
- iree_hal_buffer_unmap(buffer.get(), &mapped_memory), IREE_LOC));
- }
-
- // Wrap in a buffer view to pass back into the VM.
+ // Convert the [shape]x[type]=[contents] string to a buffer view.
+ auto input_string =
+ absl::string_view(message->value.data, message->value.size);
vm::ref<iree_hal_buffer_view_t> buffer_view;
- RETURN_IF_ERROR(FromApiStatus(
- iree_hal_buffer_view_create(buffer.get(), shape.data().data(),
- shape.size(), element_type,
- IREE_ALLOCATOR_SYSTEM, &buffer_view),
- IREE_LOC));
+ iree_status_t status = iree_hal_buffer_view_parse(
+ iree_string_view_t{input_string.data(), input_string.size()},
+ host_local_allocator_.get(), allocator_, &buffer_view);
+ if (!iree_status_is_ok(status)) {
+ return FromApiStatus(status, IREE_LOC)
+ << "Parsing value '" << input_string << "'";
+ }
return std::move(buffer_view);
}
@@ -289,6 +226,11 @@
// perform during operation.
iree_allocator_t allocator_ = IREE_ALLOCATOR_SYSTEM;
+ // HAL buffer allocator that uses host-local memory. This is just for this
+ // test as we don't actually use a HAL device and don't have a real device
+ // allocator to use.
+ vm::ref<iree_hal_allocator_t> host_local_allocator_;
+
// A unique message owned by the state and returned to the VM.
// This demonstrates any arbitrary per-context state one may want to store.
vm::ref<iree_custom_message_t> unique_message_;