[metal] Make staging buffer impl thread safe (#14624)

This commit uses mutexes and atomics to properly guard access to staging
buffer fields.

Progress towards https://github.com/openxla/iree/issues/14049
diff --git a/runtime/src/iree/hal/drivers/metal/direct_command_buffer.m b/runtime/src/iree/hal/drivers/metal/direct_command_buffer.m
index a833ed9..d3b52c6 100644
--- a/runtime/src/iree/hal/drivers/metal/direct_command_buffer.m
+++ b/runtime/src/iree/hal/drivers/metal/direct_command_buffer.m
@@ -395,7 +395,7 @@
 
     // Increase command buffer refcount in the shared staging buffer. We tie this to the command
     // buffer's lifetime to avoid resource leak.
-    iree_hal_metal_staging_buffer_increase_refcount(staging_buffer);
+    iree_hal_metal_staging_buffer_increase_command_buffer_refcount(staging_buffer);
     // Retain the device given that we refer to builtin executables and staging buffers whose
     // lifetime is associated with the device.
     iree_hal_resource_retain(device);
@@ -432,7 +432,7 @@
   // Decrease command buffer refcount in the shared staging buffer, and potentially reclaim
   // resources. We tie this to the command buffer's lifetime to avoid resource leak.
   if (command_buffer->staging_buffer) {
-    iree_hal_metal_staging_buffer_decrease_refcount(command_buffer->staging_buffer);
+    iree_hal_metal_staging_buffer_decrease_command_buffer_refcount(command_buffer->staging_buffer);
   }
 
   iree_hal_metal_command_buffer_destroy_internal(base_command_buffer);
diff --git a/runtime/src/iree/hal/drivers/metal/staging_buffer.h b/runtime/src/iree/hal/drivers/metal/staging_buffer.h
index 63c6ea1..8217a3b 100644
--- a/runtime/src/iree/hal/drivers/metal/staging_buffer.h
+++ b/runtime/src/iree/hal/drivers/metal/staging_buffer.h
@@ -10,6 +10,7 @@
 #import <Metal/Metal.h>
 
 #include "iree/base/api.h"
+#include "iree/base/internal/synchronization.h"
 #include "iree/hal/api.h"
 
 #ifdef __cplusplus
@@ -40,8 +41,7 @@
 // * Argument buffers for descriptor sets
 // * Source buffer for buffer update commands
 //
-// TODO(#14049): Use proper atomics/mutexes for concurrent command buffer
-// recording and execution.
+// Thread safe; multiple threads can reserve spaces concurrently.
 typedef struct iree_hal_metal_staging_buffer_t {
   // Maximum number of bytes in the buffer.
   uint32_t capacity;
@@ -51,14 +51,17 @@
   // Host pointer to the buffer.
   uint8_t* host_buffer;
 
+  // Non-recursive mutex guarding access to the offset field.
+  iree_slim_mutex_t offset_mutex;
+
   // Current write offset of the device buffer.
-  uint32_t offset;
+  uint32_t offset IREE_GUARDED_BY(offset_mutex);
 
   // The number of command buffers that are being recorded or executed on
   // device. If this reaches zero, we know that there are no users of the
   // staging buffer so we can discard the contents and reset the offset to
   // zero.
-  uint32_t pending_command_buffers;
+  iree_atomic_int32_t pending_command_buffers;
 } iree_hal_metal_staging_buffer_t;
 
 // Initializes |out_staging_buffer| with the given |buffer_capacity|.
@@ -87,12 +90,12 @@
     iree_hal_metal_staging_buffer_t* staging_buffer);
 
 // Increases the command buffer using this staging buffer by one.
-void iree_hal_metal_staging_buffer_increase_refcount(
+void iree_hal_metal_staging_buffer_increase_command_buffer_refcount(
     iree_hal_metal_staging_buffer_t* staging_buffer);
 
 // Decreases the command buffer using this staging buffer by one, which may
 // trigger reclaiming of resources.
-void iree_hal_metal_staging_buffer_decrease_refcount(
+void iree_hal_metal_staging_buffer_decrease_command_buffer_refcount(
     iree_hal_metal_staging_buffer_t* staging_buffer);
 
 #ifdef __cplusplus
diff --git a/runtime/src/iree/hal/drivers/metal/staging_buffer.m b/runtime/src/iree/hal/drivers/metal/staging_buffer.m
index a7eb806..7e868ac 100644
--- a/runtime/src/iree/hal/drivers/metal/staging_buffer.m
+++ b/runtime/src/iree/hal/drivers/metal/staging_buffer.m
@@ -35,14 +35,17 @@
   out_staging_buffer->capacity = (uint32_t)buffer_capacity;
   out_staging_buffer->metal_buffer = metal_buffer;
   out_staging_buffer->host_buffer = metal_buffer.contents;
+  iree_slim_mutex_initialize(&out_staging_buffer->offset_mutex);
   out_staging_buffer->offset = 0;
-  out_staging_buffer->pending_command_buffers = 0;
+  iree_atomic_store_int32(&out_staging_buffer->pending_command_buffers, 0,
+                          iree_memory_order_relaxed);
 
   IREE_TRACE_ZONE_END(z0);
   return iree_ok_status();
 }
 
 void iree_hal_metal_staging_buffer_deinitialize(iree_hal_metal_staging_buffer_t* staging_buffer) {
+  iree_slim_mutex_deinitialize(&staging_buffer->offset_mutex);
   [staging_buffer->metal_buffer release];  // -1
 }
 
@@ -58,14 +61,21 @@
                             "reservation (%" PRIhsz " bytes) exceeds the maximum capacity of "
                             "the staging buffer (%" PRIu32 " bytes)",
                             length, staging_buffer->capacity);
-  } else if (staging_buffer->offset + aligned_length > staging_buffer->capacity) {
+  }
+
+  iree_slim_mutex_lock(&staging_buffer->offset_mutex);
+  uint32_t offset = staging_buffer->offset;
+  if (offset + aligned_length > staging_buffer->capacity) {
+    iree_slim_mutex_unlock(&staging_buffer->offset_mutex);
     return iree_make_status(IREE_STATUS_RESOURCE_EXHAUSTED,
                             "failed to reserve %" PRIhsz " bytes in staging buffer", length);
   }
-  *out_reservation =
-      iree_make_byte_span(staging_buffer->host_buffer + staging_buffer->offset, aligned_length);
-  *out_offset = staging_buffer->offset;
   staging_buffer->offset += aligned_length;
+  iree_slim_mutex_unlock(&staging_buffer->offset_mutex);
+
+  *out_reservation = iree_make_byte_span(staging_buffer->host_buffer + offset, aligned_length);
+  *out_offset = offset;
+
   return iree_ok_status();
 }
 
@@ -81,18 +91,21 @@
 }
 
 void iree_hal_metal_staging_buffer_reset(iree_hal_metal_staging_buffer_t* staging_buffer) {
+  iree_slim_mutex_lock(&staging_buffer->offset_mutex);
   staging_buffer->offset = 0;
+  iree_slim_mutex_unlock(&staging_buffer->offset_mutex);
 }
 
-void iree_hal_metal_staging_buffer_increase_refcount(
+void iree_hal_metal_staging_buffer_increase_command_buffer_refcount(
     iree_hal_metal_staging_buffer_t* staging_buffer) {
-  ++staging_buffer->pending_command_buffers;
+  iree_atomic_fetch_add_int32(&staging_buffer->pending_command_buffers, 1,
+                              iree_memory_order_relaxed);
 }
 
-void iree_hal_metal_staging_buffer_decrease_refcount(
+void iree_hal_metal_staging_buffer_decrease_command_buffer_refcount(
     iree_hal_metal_staging_buffer_t* staging_buffer) {
-  IREE_ASSERT(staging_buffer->pending_command_buffers > 0);
-  if (--staging_buffer->pending_command_buffers == 0) {
+  if (iree_atomic_fetch_sub_int32(&staging_buffer->pending_command_buffers, 1,
+                                  iree_memory_order_acq_rel) == 1) {
     iree_hal_metal_staging_buffer_reset(staging_buffer);
   }
 }