[metal] Manage staging buffer refcount in command buffer lifetime

This avoids the potential refcount mismatches due to that we can
have command buffers created but never submitted and such.
diff --git a/experimental/metal/direct_command_buffer.m b/experimental/metal/direct_command_buffer.m
index 1835259..45f1448 100644
--- a/experimental/metal/direct_command_buffer.m
+++ b/experimental/metal/direct_command_buffer.m
@@ -381,6 +381,10 @@
 
   *out_command_buffer = &command_buffer->base;
 
+  // Increase command buffer refcount in the shared staging buffer. We tie this to the command
+  // buffer's lifetime to avoid resource leak.
+  iree_hal_metal_staging_buffer_increase_refcount(staging_buffer);
+
   IREE_TRACE_ZONE_END(z0);
   return status;
 }
@@ -392,6 +396,10 @@
 
   iree_hal_metal_command_buffer_reset(command_buffer);
 
+  // Decrease command buffer refcount in the shared staging buffer, and potentially reclaim
+  // resources. We tie this to the command buffer's lifetime to avoid resource leak.
+  iree_hal_metal_staging_buffer_decrease_refcount(command_buffer->staging_buffer);
+
   [command_buffer->state.encoder_event release];  // -1
   IREE_ASSERT_EQ(command_buffer->state.compute_encoder, nil);
   IREE_ASSERT_EQ(command_buffer->state.blit_encoder, nil);
diff --git a/experimental/metal/metal_device.m b/experimental/metal/metal_device.m
index 16aa574..089aaea 100644
--- a/experimental/metal/metal_device.m
+++ b/experimental/metal/metal_device.m
@@ -246,15 +246,11 @@
     return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
                             "multi-shot command buffer not yet supported");
 
-  iree_status_t status = iree_hal_metal_direct_command_buffer_create(
+  return iree_hal_metal_direct_command_buffer_create(
       base_device, mode, command_categories, binding_capacity,
       device->command_buffer_resource_reference_mode, device->queue, &device->block_pool,
       &device->staging_buffer, device->builtin_executable, device->host_allocator,
       out_command_buffer);
-  if (iree_status_is_ok(status)) {
-    iree_hal_metal_staging_buffer_increase_refcount(&device->staging_buffer);
-  }
-  return status;
 }
 
 static iree_status_t iree_hal_metal_device_create_descriptor_set_layout(
@@ -374,9 +370,6 @@
       id<MTLCommandBuffer> handle = iree_hal_metal_direct_command_buffer_handle(command_buffer);
       [handle addCompletedHandler:^(id<MTLCommandBuffer> cb) {
         iree_hal_command_buffer_release(command_buffer);  // -1
-        // Decrease command buffer refcount in the shared staging buffer, and potentially reclaim
-        // resources. This is fine right now given we only support one-shot command buffers.
-        iree_hal_metal_staging_buffer_decrease_refcount(&device->staging_buffer);
       }];
       [handle commit];
     }