Fixing check test methods to use the buffer view byte length.
Previously they were assuming the entire underlying storage buffer was
equal to the valid byte length which caused problems now that we are
packing multiple buffers together.
diff --git a/iree/hal/buffer_view.h b/iree/hal/buffer_view.h
index fefe383..6d849fe 100644
--- a/iree/hal/buffer_view.h
+++ b/iree/hal/buffer_view.h
@@ -147,6 +147,11 @@
// Returns the buffer underlying the buffer view.
// The caller must retain the returned buffer if they want to continue using it.
+//
+// NOTE: the returned buffer length will almost always be larger than the valid
+// bytes representing this buffer view due to padding. Always query the actual
+// valid length with iree_hal_buffer_view_byte_length instead of assuming the
+// buffer is already clamped.
IREE_API_EXPORT iree_hal_buffer_t* IREE_API_CALL
iree_hal_buffer_view_buffer(const iree_hal_buffer_view_t* buffer_view);
diff --git a/iree/modules/check/native_module.cc b/iree/modules/check/native_module.cc
index 64f9ac2..c684296 100644
--- a/iree/modules/check/native_module.cc
+++ b/iree/modules/check/native_module.cc
@@ -201,10 +201,11 @@
iree_hal_element_type_t element_type =
iree_hal_buffer_view_element_type(view);
iree_hal_buffer_t* buf = iree_hal_buffer_view_buffer(view);
+ iree_device_size_t size = iree_hal_buffer_view_byte_length(view);
iree_hal_buffer_mapping_t mapped_memory;
- IREE_RETURN_IF_ERROR(iree_hal_buffer_map_range(
- buf, IREE_HAL_MEMORY_ACCESS_READ,
- /*byte_offset=*/0, IREE_WHOLE_BUFFER, &mapped_memory));
+ IREE_RETURN_IF_ERROR(
+ iree_hal_buffer_map_range(buf, IREE_HAL_MEMORY_ACCESS_READ,
+ /*byte_offset=*/0, size, &mapped_memory));
IREE_RETURN_IF_ERROR(
::iree::ExpectAllTrue(mapped_memory.contents, element_type));
iree_hal_buffer_unmap_range(&mapped_memory);
@@ -215,6 +216,8 @@
vm::ref<iree_hal_buffer_view_t> rhs_ref) {
auto* lhs = lhs_ref.get();
auto* rhs = rhs_ref.get();
+
+ iree_device_size_t lhs_size = iree_hal_buffer_view_byte_length(lhs);
size_t lhs_rank = iree_hal_buffer_view_shape_rank(lhs);
std::vector<iree_hal_dim_t> lhs_shape(lhs_rank);
if (lhs_rank > 0) {
@@ -222,6 +225,7 @@
iree_hal_buffer_view_shape(lhs, lhs_rank, lhs_shape.data(), nullptr));
}
+ iree_device_size_t rhs_size = iree_hal_buffer_view_byte_length(rhs);
size_t rhs_rank = iree_hal_buffer_view_shape_rank(rhs);
std::vector<iree_hal_dim_t> rhs_shape(rhs_rank);
if (rhs_rank > 0) {
@@ -238,12 +242,12 @@
iree_hal_buffer_mapping_t lhs_mapped_memory;
IREE_RETURN_IF_ERROR(iree_hal_buffer_map_range(
lhs_buf, IREE_HAL_MEMORY_ACCESS_READ,
- /*byte_offset=*/0, IREE_WHOLE_BUFFER, &lhs_mapped_memory));
+ /*byte_offset=*/0, lhs_size, &lhs_mapped_memory));
iree_hal_buffer_t* rhs_buf = iree_hal_buffer_view_buffer(rhs);
iree_hal_buffer_mapping_t rhs_mapped_memory;
IREE_RETURN_IF_ERROR(iree_hal_buffer_map_range(
rhs_buf, IREE_HAL_MEMORY_ACCESS_READ,
- /*byte_offset=*/0, IREE_WHOLE_BUFFER, &rhs_mapped_memory));
+ /*byte_offset=*/0, rhs_size, &rhs_mapped_memory));
bool element_types_eq = lhs_element_type == rhs_element_type;
bool shape_eq = lhs_shape == rhs_shape;
@@ -288,6 +292,8 @@
vm::ref<iree_hal_buffer_view_t> rhs_ref) {
auto* lhs = lhs_ref.get();
auto* rhs = rhs_ref.get();
+
+ iree_device_size_t lhs_size = iree_hal_buffer_view_byte_length(lhs);
size_t lhs_rank = iree_hal_buffer_view_shape_rank(lhs);
std::vector<iree_hal_dim_t> lhs_shape(lhs_rank);
if (lhs_rank > 0) {
@@ -295,6 +301,7 @@
iree_hal_buffer_view_shape(lhs, lhs_rank, lhs_shape.data(), nullptr));
}
+ iree_device_size_t rhs_size = iree_hal_buffer_view_byte_length(rhs);
size_t rhs_rank = iree_hal_buffer_view_shape_rank(rhs);
std::vector<iree_hal_dim_t> rhs_shape(rhs_rank);
if (rhs_rank > 0) {
@@ -311,12 +318,12 @@
iree_hal_buffer_mapping_t lhs_mapped_memory;
IREE_RETURN_IF_ERROR(iree_hal_buffer_map_range(
lhs_buf, IREE_HAL_MEMORY_ACCESS_READ,
- /*byte_offset=*/0, IREE_WHOLE_BUFFER, &lhs_mapped_memory));
+ /*byte_offset=*/0, lhs_size, &lhs_mapped_memory));
iree_hal_buffer_t* rhs_buf = iree_hal_buffer_view_buffer(rhs);
iree_hal_buffer_mapping_t rhs_mapped_memory;
IREE_RETURN_IF_ERROR(iree_hal_buffer_map_range(
rhs_buf, IREE_HAL_MEMORY_ACCESS_READ,
- /*byte_offset=*/0, IREE_WHOLE_BUFFER, &rhs_mapped_memory));
+ /*byte_offset=*/0, rhs_size, &rhs_mapped_memory));
bool element_types_eq = lhs_element_type == rhs_element_type;
bool shape_eq = lhs_shape == rhs_shape;