[metal] Fix staging buffer alignment calculation (#15272)
We need to align the offset to the required alignment when grabbing
space from the staging buffer for argument buffer encoding. This fixes
the following API validation error:
```
validateComputeFunctionArguments:1077: failed assertion `Compute Function(_abs):
the offset into the buffer spvDescriptorSet0 that is bound at buffer index 0
must be a multiple of 8 but was set to 4.'
```
diff --git a/runtime/src/iree/hal/drivers/metal/staging_buffer.m b/runtime/src/iree/hal/drivers/metal/staging_buffer.m
index 7e868ac..ca0128f 100644
--- a/runtime/src/iree/hal/drivers/metal/staging_buffer.m
+++ b/runtime/src/iree/hal/drivers/metal/staging_buffer.m
@@ -54,8 +54,7 @@
iree_host_size_t alignment,
iree_byte_span_t* out_reservation,
uint32_t* out_offset) {
- iree_host_size_t aligned_length = iree_host_align(length, alignment);
- if (aligned_length > staging_buffer->capacity) {
+ if (length > staging_buffer->capacity) {
// This will never fit in the staging buffer.
return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
"reservation (%" PRIhsz " bytes) exceeds the maximum capacity of "
@@ -64,17 +63,17 @@
}
iree_slim_mutex_lock(&staging_buffer->offset_mutex);
- uint32_t offset = staging_buffer->offset;
- if (offset + aligned_length > staging_buffer->capacity) {
+ uint32_t aligned_offset = iree_host_align(staging_buffer->offset, alignment);
+ if (aligned_offset + length > staging_buffer->capacity) {
iree_slim_mutex_unlock(&staging_buffer->offset_mutex);
return iree_make_status(IREE_STATUS_RESOURCE_EXHAUSTED,
"failed to reserve %" PRIhsz " bytes in staging buffer", length);
}
- staging_buffer->offset += aligned_length;
+ staging_buffer->offset = aligned_offset + length;
iree_slim_mutex_unlock(&staging_buffer->offset_mutex);
- *out_reservation = iree_make_byte_span(staging_buffer->host_buffer + offset, aligned_length);
- *out_offset = offset;
+ *out_reservation = iree_make_byte_span(staging_buffer->host_buffer + aligned_offset, length);
+ *out_offset = aligned_offset;
return iree_ok_status();
}