[metal] Use staging buffer for argument buffers and update sources

This commit switches the Metal HAL driver to use a staging buffer
for recording argument buffers and uploading buffer update source
data. This avoids creating lots of small-sized buffers like we
did previously, and avoids using command buffer completion callback
to handle their lifetime.
diff --git a/experimental/metal/CMakeLists.txt b/experimental/metal/CMakeLists.txt
index 0b0fc22..9552b58 100644
--- a/experimental/metal/CMakeLists.txt
+++ b/experimental/metal/CMakeLists.txt
@@ -35,6 +35,8 @@
     "nop_executable_cache.m"
     "pipeline_layout.h"
     "pipeline_layout.m"
+    "staging_buffer.h"
+    "staging_buffer.m"
   DEPS
     iree::base
     iree::base::core_headers
diff --git a/experimental/metal/api.h b/experimental/metal/api.h
index 0a5112b..b2c9728 100644
--- a/experimental/metal/api.h
+++ b/experimental/metal/api.h
@@ -46,11 +46,16 @@
 // Must be initialized with iree_hal_metal_device_params_initialize prior to
 // use.
 typedef struct iree_hal_metal_device_params_t {
-  // Total size of each block in the device shared block pool.
+  // Total size in bytes of each block in the device shared block pool.
   // Larger sizes will lower overhead and ensure the heap isn't hit for
   // transient allocations while also increasing memory consumption.
   iree_host_size_t arena_block_size;
 
+  // Total size in bytes of per-queue uniform buffers for uploading parameters
+  // to the GPU (including argument buffers and update source buffers).
+  // Larger sizes better support more concurrent/complex command buffers.
+  iree_host_size_t queue_uniform_buffer_size;
+
   // Command dispatch type in command buffers.
   // Normally we want to dispatch commands in command buffers in parallel, given
   // that IREE performs explicit dependency tracking and synchronization by
diff --git a/experimental/metal/direct_command_buffer.h b/experimental/metal/direct_command_buffer.h
index 09c1478..3c929e5 100644
--- a/experimental/metal/direct_command_buffer.h
+++ b/experimental/metal/direct_command_buffer.h
@@ -11,6 +11,7 @@
 
 #include "experimental/metal/api.h"
 #include "experimental/metal/builtin_executables.h"
+#include "experimental/metal/staging_buffer.h"
 #include "iree/base/internal/arena.h"
 #include "iree/hal/api.h"
 
@@ -43,6 +44,7 @@
         resource_reference_mode,
     id<MTLCommandQueue> queue, iree_allocator_t host_allocator,
     iree_arena_block_pool_t* block_pool,
+    iree_hal_metal_staging_buffer_t* staging_buffer,
     iree_hal_metal_builtin_executable_t* builtin_executable,
     iree_hal_command_buffer_t** out_command_buffer);
 
diff --git a/experimental/metal/direct_command_buffer.m b/experimental/metal/direct_command_buffer.m
index 7117969..a6f3fbf 100644
--- a/experimental/metal/direct_command_buffer.m
+++ b/experimental/metal/direct_command_buffer.m
@@ -13,6 +13,7 @@
 #include "experimental/metal/metal_device.h"
 #include "experimental/metal/metal_kernel_library.h"
 #include "experimental/metal/pipeline_layout.h"
+#include "experimental/metal/staging_buffer.h"
 #include "iree/base/api.h"
 #include "iree/base/target_platform.h"
 #include "iree/base/tracing.h"
@@ -168,6 +169,10 @@
   // Arena used for all allocations; references the shared device block pool.
   iree_arena_allocator_t arena;
 
+  // Per-queue shared uniform staging buffer for uploading parameters to the GPU, including argument
+  // buffers and buffer update source buffers.
+  iree_hal_metal_staging_buffer_t* staging_buffer;
+
   // Linked list of command segments to be recorded into a command buffer.
   iree_hal_metal_command_segment_list_t segments;
 
@@ -319,6 +324,7 @@
     iree_hal_command_category_t command_categories, iree_host_size_t binding_capacity,
     iree_hal_metal_command_buffer_resource_reference_mode_t resource_reference_mode,
     id<MTLCommandQueue> queue, iree_allocator_t host_allocator, iree_arena_block_pool_t* block_pool,
+    iree_hal_metal_staging_buffer_t* staging_buffer,
     iree_hal_metal_builtin_executable_t* builtin_executable,
     iree_hal_command_buffer_t** out_command_buffer) {
   IREE_ASSERT_ARGUMENT(device);
@@ -344,6 +350,7 @@
     command_buffer->queue = [queue retain];  // +1
     command_buffer->builtin_executable = builtin_executable;
     iree_arena_initialize(block_pool, &command_buffer->arena);
+    command_buffer->staging_buffer = staging_buffer;
     iree_hal_metal_command_segment_list_reset(&command_buffer->segments);
     @autoreleasepool {  // Use @autoreleasepool to trigger the autorelease within encoder creation.
       // We track resource lifetime by ourselves in IREE; so just do unretained references to
@@ -737,6 +744,7 @@
         segment->target_buffer, segment->target_offset, segment->length);
   }
 
+  IREE_TRACE_ZONE_END(z0);
   return status;
 }
 
@@ -744,20 +752,19 @@
     iree_hal_command_buffer_t* base_command_buffer, const void* source_buffer,
     iree_host_size_t source_offset, iree_hal_buffer_t* target_buffer,
     iree_device_size_t target_offset, iree_device_size_t length) {
-  // There are no direct corresponding APIs in Metal. We emulate it by creating a buffer with the
-  // content and then copy it over.
   iree_hal_metal_command_buffer_t* command_buffer =
       iree_hal_metal_command_buffer_cast(base_command_buffer);
   IREE_TRACE_ZONE_BEGIN(z0);
 
-  id<MTLDevice> device = command_buffer->command_buffer.device;
-  MTLResourceOptions options = MTLResourceStorageModeShared | MTLResourceCPUCacheModeWriteCombined;
-  id<MTLBuffer> data_buffer = [device newBufferWithBytes:((uint8_t*)source_buffer + source_offset)
-                                                  length:length
-                                                 options:options];  // +1
-  [command_buffer->command_buffer addCompletedHandler:^(id<MTLCommandBuffer> cmdbuf) {
-    [data_buffer release];  // -1
-  }];
+  // There are no direct corresponding APIs in Metal. We update the source buffer data to the
+  // staging buffer and then copy over.
+
+  iree_const_byte_span_t source_data_span =
+      iree_make_const_byte_span((uint8_t*)source_buffer + source_offset, length);
+  uint32_t offset = 0;
+  IREE_RETURN_AND_END_ZONE_IF_ERROR(
+      z0, iree_hal_metal_staging_buffer_append(command_buffer->staging_buffer, source_data_span,
+                                               /*alignment=*/4, &offset));
 
   IREE_RETURN_AND_END_ZONE_IF_ERROR(
       z0, iree_hal_resource_set_insert(command_buffer->resource_set, 1, &target_buffer));
@@ -767,8 +774,8 @@
   target_offset += iree_hal_buffer_byte_offset(target_buffer);
 
   iree_status_t status = iree_hal_metal_command_segment_create_copy_buffer(
-      command_buffer, data_buffer, /*source_offset=*/0, target_device_buffer, target_offset,
-      length);
+      command_buffer, command_buffer->staging_buffer->metal_buffer, offset, target_device_buffer,
+      target_offset, length);
 
   IREE_TRACE_ZONE_END(z0);
   return status;
@@ -931,37 +938,6 @@
   return iree_ok_status();
 }
 
-// Creates an argument encoder and its backing argument buffer for the given kernel |function|'s
-// |buffer_index|. The argument encoder will be set to encode into the newly created argument
-// buffer. Callers are expected to release both the argument encoder and buffer.
-static iree_status_t iree_hal_metal_create_argument_encoder(
-    id<MTLDevice> device, id<MTLCommandBuffer> command_buffer, id<MTLFunction> function,
-    uint32_t buffer_index, id<MTLArgumentEncoder>* out_encoder, id<MTLBuffer>* out_buffer) {
-  id<MTLArgumentEncoder> argument_encoder =
-      [function newArgumentEncoderWithBufferIndex:buffer_index];  // +1
-  IREE_ASSERT(argument_encoder != nil);
-
-  __block id<MTLBuffer> argument_buffer =
-      [device newBufferWithLength:argument_encoder.encodedLength
-                          options:MTLResourceStorageModeShared];  // +1
-  if (!argument_buffer) {
-    return iree_make_status(IREE_STATUS_RESOURCE_EXHAUSTED,
-                            "failed to create argument buffer with size = %ld bytes",
-                            argument_encoder.encodedLength);
-  }
-
-  // The arugment encoder and buffer can be deleted once the command buffer completes.
-  [command_buffer addCompletedHandler:^(id<MTLCommandBuffer> cmdbuf) {
-    [argument_buffer release];   // -1
-    [argument_encoder release];  // -1
-  }];
-
-  [argument_encoder setArgumentBuffer:argument_buffer offset:0];
-  *out_encoder = argument_encoder;
-  *out_buffer = argument_buffer;
-  return iree_ok_status();
-}
-
 // Prepares kernels and argument buffers needed for kernel dispatches.
 static iree_status_t iree_hal_metal_command_segment_create_dispatch(
     iree_hal_command_buffer_t* base_command_buffer, iree_hal_executable_t* executable,
@@ -1033,17 +1009,23 @@
 
   // Record argument buffers for all descriptors and record buffer usages.
   iree_hal_metal_descriptor_t* descriptors = segment->descriptors;
-  iree_host_size_t i = 0;
-  while (i < segment->descriptor_count) {
+  for (iree_host_size_t i = 0; i < segment->descriptor_count;) {
     uint32_t current_set = descriptors[i].set;
 
     // Build argument encoder and argument buffer for the current descriptor set.
-    id<MTLArgumentEncoder> argument_encoder;
-    id<MTLBuffer> argument_buffer;
+    id<MTLBuffer> argument_buffer = command_buffer->staging_buffer->metal_buffer;
+    id<MTLArgumentEncoder> argument_encoder =
+        [segment->kernel_params.function newArgumentEncoderWithBufferIndex:current_set];  // +1
+    IREE_ASSERT(argument_encoder != nil);
+
+    // Reserve space for the argument buffer from shared staging buffer.
+    iree_byte_span_t reservation;
+    uint32_t argument_buffer_offset;
     IREE_RETURN_AND_END_ZONE_IF_ERROR(
-        z0, iree_hal_metal_create_argument_encoder(
-                command_buffer->command_buffer.device, command_buffer->command_buffer,
-                segment->kernel_params.function, current_set, &argument_encoder, &argument_buffer));
+        z0, iree_hal_metal_staging_buffer_reserve(
+                command_buffer->staging_buffer, argument_encoder.encodedLength,
+                argument_encoder.alignment, &reservation, &argument_buffer_offset));
+    [argument_encoder setArgumentBuffer:argument_buffer offset:argument_buffer_offset];
 
     // Now record all bound buffers belonging to the current set into the argument buffer.
     for (; i < segment->descriptor_count && descriptors[i].set == current_set; ++i) {
@@ -1058,7 +1040,9 @@
       [compute_encoder useResource:current_buffer usage:descriptors[i].usage];
     }
     // Record the argument buffer.
-    [compute_encoder setBuffer:argument_buffer offset:0 atIndex:current_set];
+    [compute_encoder setBuffer:argument_buffer offset:argument_buffer_offset atIndex:current_set];
+
+    [argument_encoder release];  // -1
   }
 
   // Record the dispatch, either direct or indirect.
diff --git a/experimental/metal/metal_device.m b/experimental/metal/metal_device.m
index 3ce0f33..70a775f 100644
--- a/experimental/metal/metal_device.m
+++ b/experimental/metal/metal_device.m
@@ -13,6 +13,7 @@
 #include "experimental/metal/metal_shared_event.h"
 #include "experimental/metal/nop_executable_cache.h"
 #include "experimental/metal/pipeline_layout.h"
+#include "experimental/metal/staging_buffer.h"
 #include "iree/base/api.h"
 #include "iree/base/tracing.h"
 #include "iree/hal/api.h"
@@ -29,8 +30,10 @@
   // contain inlined data uploads).
   iree_arena_block_pool_t block_pool;
 
-  // Original driver that owns this device.
-  iree_hal_driver_t* driver;
+  // Per-queue staging buffer for parameter uploads.
+  iree_hal_metal_staging_buffer_t staging_buffer;
+
+  iree_hal_metal_device_params_t params;
 
   iree_allocator_t host_allocator;
   iree_hal_allocator_t* device_allocator;
@@ -69,6 +72,7 @@
 void iree_hal_metal_device_params_initialize(iree_hal_metal_device_params_t* out_params) {
   memset(out_params, 0, sizeof(*out_params));
   out_params->arena_block_size = 32 * 1024;
+  out_params->queue_uniform_buffer_size = IREE_HAL_METAL_STAGING_BUFFER_DEFAULT_CAPACITY;
   out_params->command_dispatch_type = IREE_HAL_METAL_COMMAND_DISPATCH_TYPE_CONCURRENT;
   out_params->command_buffer_resource_reference_mode =
       IREE_HAL_METAL_COMMAND_BUFFER_RESOURCE_REFERENCE_MODE_UNRETAINED;
@@ -98,6 +102,7 @@
   iree_status_t status = iree_hal_metal_allocator_create((iree_hal_device_t*)device, metal_device,
                                                          params->resource_hazard_tracking_mode,
                                                          host_allocator, &device->device_allocator);
+
   iree_hal_metal_builtin_executable_t* builtin_executable = NULL;
   if (iree_status_is_ok(status)) {
     status =
@@ -105,13 +110,20 @@
   } else {
     iree_hal_device_release((iree_hal_device_t*)device);
   }
+
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_metal_staging_buffer_initialize(
+        metal_device, params->queue_uniform_buffer_size, &device->staging_buffer);
+  } else {
+    iree_hal_device_release((iree_hal_device_t*)device);
+  }
+
   if (iree_status_is_ok(status)) {
     iree_hal_resource_initialize(&iree_hal_metal_device_vtable, &device->resource);
     iree_string_view_append_to_buffer(identifier, &device->identifier,
                                       (char*)device + iree_sizeof_struct(*device));
     iree_arena_block_pool_initialize(params->arena_block_size, host_allocator, &device->block_pool);
-    device->driver = driver;
-    iree_hal_driver_retain(device->driver);
+    device->params = *params;
     device->host_allocator = host_allocator;
     device->device = [metal_device retain];          // +1
     device->queue = [metal_device newCommandQueue];  // +1
@@ -124,7 +136,6 @@
     device->event_listener = [[MTLSharedEventListener alloc]
         initWithDispatchQueue:device->semaphore_notification_queue];  // +1
     device->capture_manager = NULL;
-
     *out_device = (iree_hal_device_t*)device;
   }
   return status;
@@ -158,6 +169,7 @@
   [device->queue release];   // -1
   [device->device release];  // -1
 
+  iree_hal_metal_staging_buffer_deinitialize(&device->staging_buffer);
   iree_arena_block_pool_deinitialize(&device->block_pool);
 
   iree_allocator_free(host_allocator, device);
@@ -221,14 +233,21 @@
     iree_hal_command_category_t command_categories, iree_hal_queue_affinity_t queue_affinity,
     iree_host_size_t binding_capacity, iree_hal_command_buffer_t** out_command_buffer) {
   iree_hal_metal_device_t* device = iree_hal_metal_device_cast(base_device);
+
   if (iree_any_bit_set(mode, IREE_HAL_COMMAND_BUFFER_MODE_NESTED))
     return iree_make_status(IREE_STATUS_UNIMPLEMENTED, "nested command buffer not yet supported");
   if (!iree_all_bits_set(mode, IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT))
-    return iree_make_status(IREE_STATUS_UNIMPLEMENTED, "unimplmented multi-shot command buffer");
-  return iree_hal_metal_direct_command_buffer_create(
+    return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
+                            "multi-shot command buffer not yet supported");
+
+  iree_status_t status = iree_hal_metal_direct_command_buffer_create(
       base_device, mode, command_categories, binding_capacity,
       device->command_buffer_resource_reference_mode, device->queue, device->host_allocator,
-      &device->block_pool, device->builtin_executable, out_command_buffer);
+      &device->block_pool, &device->staging_buffer, device->builtin_executable, out_command_buffer);
+  if (iree_status_is_ok(status)) {
+    iree_hal_metal_staging_buffer_increase_refcount(&device->staging_buffer);
+  }
+  return status;
 }
 
 static iree_status_t iree_hal_metal_device_create_descriptor_set_layout(
@@ -354,6 +373,9 @@
       id<MTLCommandBuffer> handle = iree_hal_metal_direct_command_buffer_handle(command_buffer);
       [handle addCompletedHandler:^(id<MTLCommandBuffer> cb) {
         iree_hal_command_buffer_release(command_buffer);  // -1
+        // Decrease command buffer refcount in the shared staging buffer, and potentially reclaim
+        // resources. This is fine right now given we only support one-shot command buffers.
+        iree_hal_metal_staging_buffer_decrease_refcount(&device->staging_buffer);
       }];
       [handle commit];
     }
diff --git a/experimental/metal/staging_buffer.h b/experimental/metal/staging_buffer.h
new file mode 100644
index 0000000..c46df94
--- /dev/null
+++ b/experimental/metal/staging_buffer.h
@@ -0,0 +1,102 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_EXPERIMENTAL_METAL_STAGING_BUFFER_H_
+#define IREE_EXPERIMENTAL_METAL_STAGING_BUFFER_H_
+
+#import <Metal/Metal.h>
+
+#include "iree/base/api.h"
+#include "iree/hal/api.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif  // __cplusplus
+
+// Size, in bytes, of the shared storage mode staging buffer.
+// The given amount of system memory will be allocated and is accessible to both
+// the CPU and the GPU.
+//
+// Larger values here will use more memory but allow more concurrent/complex
+// command buffers. As most models that run in these environments are only a few
+// hundred dispatches per command buffer we can approximate an average
+// consumption of 500 dispatches x worst-case 256B per dispatch of parameters
+// and get 128KB.
+#define IREE_HAL_METAL_STAGING_BUFFER_DEFAULT_CAPACITY (128 * 1024)
+
+// A staging uniform buffer used for uploading parameters to the device.
+// This allows for high-frequency writes of parameters at appropriate alignment.
+//
+// Intended usage is to retain one of these per device queue and use them during
+// command buffer recording targeting that particular queue. This avoids
+// allocating a lot of small buffers. The underlying buffer has shared storage
+// mode; so it resides in system memory and is accessible to both the CPU and
+// the GPU.
+//
+// Parameters handled by this buffer include:
+// * Argument buffers for descriptor sets
+// * Source buffer for buffer update commands
+//
+// TODO(#14049): Use proper atomics/mutexes for concurrent command buffer
+// recording and execution.
+typedef struct iree_hal_metal_staging_buffer_t {
+  // Maximum number of bytes in the buffer.
+  uint32_t capacity;
+
+  // Device handle to the buffer.
+  id<MTLBuffer> metal_buffer;
+  // Host pointer to the buffer.
+  uint8_t* host_buffer;
+
+  // Current write offset of the device buffer.
+  uint32_t offset;
+
+  // The number of command buffers that are being recorded or executed on
+  // device. If this reaches zero, we know that there are no users of the
+  // staging buffer so we can discard the contents and reset the offset to
+  // zero.
+  uint32_t pending_command_buffers;
+} iree_hal_metal_staging_buffer_t;
+
+// Initializes |out_staging_buffer| with the given |buffer_capacity|.
+iree_status_t iree_hal_metal_staging_buffer_initialize(
+    id<MTLDevice> device, iree_host_size_t buffer_capacity,
+    iree_hal_metal_staging_buffer_t* out_staging_buffer);
+
+void iree_hal_metal_staging_buffer_deinitialize(
+    iree_hal_metal_staging_buffer_t* staging_buffer);
+
+// Reserves |length| bytes from the staging buffer and returns a pointer to it
+// in |out_reservation|.
+iree_status_t iree_hal_metal_staging_buffer_reserve(
+    iree_hal_metal_staging_buffer_t* staging_buffer, iree_host_size_t length,
+    iree_host_size_t alignment, iree_byte_span_t* out_reservation,
+    uint32_t* out_offset);
+
+// Appends |data| of |length| bytes to the staging buffer.
+iree_status_t iree_hal_metal_staging_buffer_append(
+    iree_hal_metal_staging_buffer_t* staging_buffer,
+    iree_const_byte_span_t source, iree_host_size_t alignment,
+    uint32_t* out_offset);
+
+// Resets the staging buffer to discard all its contents.
+void iree_hal_metal_staging_buffer_reset(
+    iree_hal_metal_staging_buffer_t* staging_buffer);
+
+// Increases the command buffer using this staging buffer by one.
+void iree_hal_metal_staging_buffer_increase_refcount(
+    iree_hal_metal_staging_buffer_t* staging_buffer);
+
+// Decreases the command buffer using this staging buffer by one, which may
+// trigger reclaiming of resources.
+void iree_hal_metal_staging_buffer_decrease_refcount(
+    iree_hal_metal_staging_buffer_t* staging_buffer);
+
+#ifdef __cplusplus
+}  // extern "C"
+#endif  // __cplusplus
+
+#endif  // IREE_EXPERIMENTAL_METAL_STAGING_BUFFER_H_
diff --git a/experimental/metal/staging_buffer.m b/experimental/metal/staging_buffer.m
new file mode 100644
index 0000000..08791b4
--- /dev/null
+++ b/experimental/metal/staging_buffer.m
@@ -0,0 +1,98 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "experimental/metal/staging_buffer.h"
+
+#include <stdint.h>
+
+#include "iree/base/api.h"
+#include "iree/base/tracing.h"
+
+iree_status_t iree_hal_metal_staging_buffer_initialize(
+    id<MTLDevice> device, iree_host_size_t buffer_capacity,
+    iree_hal_metal_staging_buffer_t* out_staging_buffer) {
+  IREE_ASSERT_ARGUMENT(device);
+  IREE_ASSERT_ARGUMENT(out_staging_buffer);
+  IREE_TRACE_ZONE_BEGIN(z0);
+  memset(out_staging_buffer, 0, sizeof(*out_staging_buffer));
+
+  // From Metal Best Practices Guide:
+  // "For small-sized data that changes frequently, choose the Shared mode. The overhead of copying
+  // data to video memory may be more expensive than the overhead of the GPU accessing system memory
+  // directly."
+  MTLResourceOptions options = MTLResourceStorageModeShared | MTLResourceCPUCacheModeWriteCombined;
+  id<MTLBuffer> metal_buffer = [device newBufferWithLength:buffer_capacity options:options];  // +1
+  if (!metal_buffer) {
+    IREE_TRACE_ZONE_END(z0);
+    return iree_make_status(IREE_STATUS_RESOURCE_EXHAUSTED,
+                            "failed to allocate staging buffer with size = %ld bytes",
+                            buffer_capacity);
+  }
+
+  out_staging_buffer->capacity = (uint32_t)buffer_capacity;
+  out_staging_buffer->metal_buffer = metal_buffer;
+  out_staging_buffer->host_buffer = metal_buffer.contents;
+  out_staging_buffer->offset = 0;
+  out_staging_buffer->pending_command_buffers = 0;
+
+  IREE_TRACE_ZONE_END(z0);
+  return iree_ok_status();
+}
+
+void iree_hal_metal_staging_buffer_deinitialize(iree_hal_metal_staging_buffer_t* staging_buffer) {
+  [staging_buffer->metal_buffer release];  // -1
+}
+
+iree_status_t iree_hal_metal_staging_buffer_reserve(iree_hal_metal_staging_buffer_t* staging_buffer,
+                                                    iree_host_size_t length,
+                                                    iree_host_size_t alignment,
+                                                    iree_byte_span_t* out_reservation,
+                                                    uint32_t* out_offset) {
+  iree_host_size_t aligned_length = iree_host_align(length, alignment);
+  if (aligned_length > staging_buffer->capacity) {
+    // This will never fit in the staging buffer.
+    return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+                            "reservation (%" PRIhsz " bytes) exceeds the maximum capacity of "
+                            "the staging buffer (%" PRIu32 " bytes)",
+                            length, staging_buffer->capacity);
+  } else if (staging_buffer->offset + aligned_length > staging_buffer->capacity) {
+    return iree_make_status(IREE_STATUS_RESOURCE_EXHAUSTED,
+                            "failed to reserve %" PRIhsz " bytes in staging buffer", length);
+  }
+  *out_reservation =
+      iree_make_byte_span(staging_buffer->host_buffer + staging_buffer->offset, aligned_length);
+  *out_offset = staging_buffer->offset;
+  staging_buffer->offset += aligned_length;
+  return iree_ok_status();
+}
+
+iree_status_t iree_hal_metal_staging_buffer_append(iree_hal_metal_staging_buffer_t* staging_buffer,
+                                                   iree_const_byte_span_t source,
+                                                   iree_host_size_t alignment,
+                                                   uint32_t* out_offset) {
+  iree_byte_span_t reservation;
+  IREE_RETURN_IF_ERROR(iree_hal_metal_staging_buffer_reserve(staging_buffer, source.data_length,
+                                                             alignment, &reservation, out_offset));
+  memcpy(reservation.data, source.data, source.data_length);
+  return iree_ok_status();
+}
+
+void iree_hal_metal_staging_buffer_reset(iree_hal_metal_staging_buffer_t* staging_buffer) {
+  staging_buffer->offset = 0;
+}
+
+void iree_hal_metal_staging_buffer_increase_refcount(
+    iree_hal_metal_staging_buffer_t* staging_buffer) {
+  ++staging_buffer->pending_command_buffers;
+}
+
+void iree_hal_metal_staging_buffer_decrease_refcount(
+    iree_hal_metal_staging_buffer_t* staging_buffer) {
+  IREE_ASSERT(staging_buffer->pending_command_buffers > 0);
+  if (--staging_buffer->pending_command_buffers == 0) {
+    iree_hal_metal_staging_buffer_reset(staging_buffer);
+  }
+}