[metal] Make staging buffer impl thread safe (#14624)
This commit uses mutexes and atomics to properly guard access to staging
buffer fields.
Progress towards https://github.com/openxla/iree/issues/14049
diff --git a/runtime/src/iree/hal/drivers/metal/direct_command_buffer.m b/runtime/src/iree/hal/drivers/metal/direct_command_buffer.m
index a833ed9..d3b52c6 100644
--- a/runtime/src/iree/hal/drivers/metal/direct_command_buffer.m
+++ b/runtime/src/iree/hal/drivers/metal/direct_command_buffer.m
@@ -395,7 +395,7 @@
// Increase command buffer refcount in the shared staging buffer. We tie this to the command
// buffer's lifetime to avoid resource leak.
- iree_hal_metal_staging_buffer_increase_refcount(staging_buffer);
+ iree_hal_metal_staging_buffer_increase_command_buffer_refcount(staging_buffer);
// Retain the device given that we refer to builtin executables and staging buffers whose
// lifetime is associated with the device.
iree_hal_resource_retain(device);
@@ -432,7 +432,7 @@
// Decrease command buffer refcount in the shared staging buffer, and potentially reclaim
// resources. We tie this to the command buffer's lifetime to avoid resource leak.
if (command_buffer->staging_buffer) {
- iree_hal_metal_staging_buffer_decrease_refcount(command_buffer->staging_buffer);
+ iree_hal_metal_staging_buffer_decrease_command_buffer_refcount(command_buffer->staging_buffer);
}
iree_hal_metal_command_buffer_destroy_internal(base_command_buffer);
diff --git a/runtime/src/iree/hal/drivers/metal/staging_buffer.h b/runtime/src/iree/hal/drivers/metal/staging_buffer.h
index 63c6ea1..8217a3b 100644
--- a/runtime/src/iree/hal/drivers/metal/staging_buffer.h
+++ b/runtime/src/iree/hal/drivers/metal/staging_buffer.h
@@ -10,6 +10,7 @@
#import <Metal/Metal.h>
#include "iree/base/api.h"
+#include "iree/base/internal/synchronization.h"
#include "iree/hal/api.h"
#ifdef __cplusplus
@@ -40,8 +41,7 @@
// * Argument buffers for descriptor sets
// * Source buffer for buffer update commands
//
-// TODO(#14049): Use proper atomics/mutexes for concurrent command buffer
-// recording and execution.
+// Thread safe; multiple threads can reserve spaces concurrently.
typedef struct iree_hal_metal_staging_buffer_t {
// Maximum number of bytes in the buffer.
uint32_t capacity;
@@ -51,14 +51,17 @@
// Host pointer to the buffer.
uint8_t* host_buffer;
+ // Non-recursive mutex guarding access to the offset field.
+ iree_slim_mutex_t offset_mutex;
+
// Current write offset of the device buffer.
- uint32_t offset;
+ uint32_t offset IREE_GUARDED_BY(offset_mutex);
// The number of command buffers that are being recorded or executed on
// device. If this reaches zero, we know that there are no users of the
// staging buffer so we can discard the contents and reset the offset to
// zero.
- uint32_t pending_command_buffers;
+ iree_atomic_int32_t pending_command_buffers;
} iree_hal_metal_staging_buffer_t;
// Initializes |out_staging_buffer| with the given |buffer_capacity|.
@@ -87,12 +90,12 @@
iree_hal_metal_staging_buffer_t* staging_buffer);
// Increases the command buffer using this staging buffer by one.
-void iree_hal_metal_staging_buffer_increase_refcount(
+void iree_hal_metal_staging_buffer_increase_command_buffer_refcount(
iree_hal_metal_staging_buffer_t* staging_buffer);
// Decreases the command buffer using this staging buffer by one, which may
// trigger reclaiming of resources.
-void iree_hal_metal_staging_buffer_decrease_refcount(
+void iree_hal_metal_staging_buffer_decrease_command_buffer_refcount(
iree_hal_metal_staging_buffer_t* staging_buffer);
#ifdef __cplusplus
diff --git a/runtime/src/iree/hal/drivers/metal/staging_buffer.m b/runtime/src/iree/hal/drivers/metal/staging_buffer.m
index a7eb806..7e868ac 100644
--- a/runtime/src/iree/hal/drivers/metal/staging_buffer.m
+++ b/runtime/src/iree/hal/drivers/metal/staging_buffer.m
@@ -35,14 +35,17 @@
out_staging_buffer->capacity = (uint32_t)buffer_capacity;
out_staging_buffer->metal_buffer = metal_buffer;
out_staging_buffer->host_buffer = metal_buffer.contents;
+ iree_slim_mutex_initialize(&out_staging_buffer->offset_mutex);
out_staging_buffer->offset = 0;
- out_staging_buffer->pending_command_buffers = 0;
+ iree_atomic_store_int32(&out_staging_buffer->pending_command_buffers, 0,
+ iree_memory_order_relaxed);
IREE_TRACE_ZONE_END(z0);
return iree_ok_status();
}
void iree_hal_metal_staging_buffer_deinitialize(iree_hal_metal_staging_buffer_t* staging_buffer) {
+ iree_slim_mutex_deinitialize(&staging_buffer->offset_mutex);
[staging_buffer->metal_buffer release]; // -1
}
@@ -58,14 +61,21 @@
"reservation (%" PRIhsz " bytes) exceeds the maximum capacity of "
"the staging buffer (%" PRIu32 " bytes)",
length, staging_buffer->capacity);
- } else if (staging_buffer->offset + aligned_length > staging_buffer->capacity) {
+ }
+
+ iree_slim_mutex_lock(&staging_buffer->offset_mutex);
+ uint32_t offset = staging_buffer->offset;
+ if (offset + aligned_length > staging_buffer->capacity) {
+ iree_slim_mutex_unlock(&staging_buffer->offset_mutex);
return iree_make_status(IREE_STATUS_RESOURCE_EXHAUSTED,
"failed to reserve %" PRIhsz " bytes in staging buffer", length);
}
- *out_reservation =
- iree_make_byte_span(staging_buffer->host_buffer + staging_buffer->offset, aligned_length);
- *out_offset = staging_buffer->offset;
staging_buffer->offset += aligned_length;
+ iree_slim_mutex_unlock(&staging_buffer->offset_mutex);
+
+ *out_reservation = iree_make_byte_span(staging_buffer->host_buffer + offset, aligned_length);
+ *out_offset = offset;
+
return iree_ok_status();
}
@@ -81,18 +91,21 @@
}
void iree_hal_metal_staging_buffer_reset(iree_hal_metal_staging_buffer_t* staging_buffer) {
+ iree_slim_mutex_lock(&staging_buffer->offset_mutex);
staging_buffer->offset = 0;
+ iree_slim_mutex_unlock(&staging_buffer->offset_mutex);
}
-void iree_hal_metal_staging_buffer_increase_refcount(
+void iree_hal_metal_staging_buffer_increase_command_buffer_refcount(
iree_hal_metal_staging_buffer_t* staging_buffer) {
- ++staging_buffer->pending_command_buffers;
+ iree_atomic_fetch_add_int32(&staging_buffer->pending_command_buffers, 1,
+ iree_memory_order_relaxed);
}
-void iree_hal_metal_staging_buffer_decrease_refcount(
+void iree_hal_metal_staging_buffer_decrease_command_buffer_refcount(
iree_hal_metal_staging_buffer_t* staging_buffer) {
- IREE_ASSERT(staging_buffer->pending_command_buffers > 0);
- if (--staging_buffer->pending_command_buffers == 0) {
+ if (iree_atomic_fetch_sub_int32(&staging_buffer->pending_command_buffers, 1,
+ iree_memory_order_acq_rel) == 1) {
iree_hal_metal_staging_buffer_reset(staging_buffer);
}
}