[metal] Add option for strong resource reference in command buffers
diff --git a/experimental/metal/api.h b/experimental/metal/api.h
index 24ca999..0a5112b 100644
--- a/experimental/metal/api.h
+++ b/experimental/metal/api.h
@@ -27,6 +27,13 @@
   IREE_HAL_METAL_COMMAND_DISPATCH_TYPE_SERIAL = 1,
 } iree_hal_metal_command_dispatch_type_t;
 
+typedef enum iree_hal_metal_command_buffer_resource_reference_mode_e {
+  // Do not maintain strong references to resources used in command buffers.
+  IREE_HAL_METAL_COMMAND_BUFFER_RESOURCE_REFERENCE_MODE_UNRETAINED = 0,
+  // Maintain strong references to resources used in command buffers.
+  IREE_HAL_METAL_COMMAND_BUFFER_RESOURCE_REFERENCE_MODE_RETAINED = 1,
+} iree_hal_metal_command_buffer_resource_reference_mode_t;
+
 typedef enum iree_hal_metal_resource_hazard_tracking_mode_e {
   // Do not track resource hazards. Hosting applications are responsible for
   // ensuring that resources are not modified while in use.
@@ -51,6 +58,12 @@
   // debugging in certain cases.
   iree_hal_metal_command_dispatch_type_t command_dispatch_type;
 
+  // Resource reference mode in command buffers.
+  // Normally we track resource lifetime in IREE explicitly, so we don't need to
+  // incur Metal runtime overhead to do that. But good for debugging purposes.
+  iree_hal_metal_command_buffer_resource_reference_mode_t
+      command_buffer_resource_reference_mode;
+
   // Resource hazard tracking mode.
   // IREE is following explicit GPU API model and tracks resource dependency by
   // itself. So normally we don't need to let Metal runtime to track resource
diff --git a/experimental/metal/direct_command_buffer.h b/experimental/metal/direct_command_buffer.h
index 01abd08..09c1478 100644
--- a/experimental/metal/direct_command_buffer.h
+++ b/experimental/metal/direct_command_buffer.h
@@ -9,6 +9,7 @@
 
 #import <Metal/Metal.h>
 
+#include "experimental/metal/api.h"
 #include "experimental/metal/builtin_executables.h"
 #include "iree/base/internal/arena.h"
 #include "iree/hal/api.h"
@@ -37,8 +38,11 @@
 iree_status_t iree_hal_metal_direct_command_buffer_create(
     iree_hal_device_t* device, iree_hal_command_buffer_mode_t mode,
     iree_hal_command_category_t command_categories,
-    iree_host_size_t binding_capacity, id<MTLCommandQueue> queue,
-    iree_allocator_t host_allocator, iree_arena_block_pool_t* block_pool,
+    iree_host_size_t binding_capacity,
+    iree_hal_metal_command_buffer_resource_reference_mode_t
+        resource_reference_mode,
+    id<MTLCommandQueue> queue, iree_allocator_t host_allocator,
+    iree_arena_block_pool_t* block_pool,
     iree_hal_metal_builtin_executable_t* builtin_executable,
     iree_hal_command_buffer_t** out_command_buffer);
 
diff --git a/experimental/metal/direct_command_buffer.m b/experimental/metal/direct_command_buffer.m
index 1ab3583..bfb6081 100644
--- a/experimental/metal/direct_command_buffer.m
+++ b/experimental/metal/direct_command_buffer.m
@@ -135,6 +135,7 @@
 iree_status_t iree_hal_metal_direct_command_buffer_create(
     iree_hal_device_t* device, iree_hal_command_buffer_mode_t mode,
     iree_hal_command_category_t command_categories, iree_host_size_t binding_capacity,
+    iree_hal_metal_command_buffer_resource_reference_mode_t resource_reference_mode,
     id<MTLCommandQueue> queue, iree_allocator_t host_allocator, iree_arena_block_pool_t* block_pool,
     iree_hal_metal_builtin_executable_t* builtin_executable,
     iree_hal_command_buffer_t** out_command_buffer) {
@@ -163,8 +164,13 @@
     @autoreleasepool {  // Use @autoreleasepool to trigger the autorelease within encoder creation.
       // We track resource lifetime by ourselves in IREE; so just do unretained references to
       // resources in Metal command buffer, which avoids overhead and gives better performance.
+      MTLCommandBufferDescriptor* descriptor = [MTLCommandBufferDescriptor new];  // +1
+      descriptor.retainedReferences =
+          resource_reference_mode == IREE_HAL_METAL_COMMAND_BUFFER_RESOURCE_REFERENCE_MODE_RETAINED;
+      descriptor.errorOptions = MTLCommandBufferErrorOptionNone;
       command_buffer->command_buffer =
-          [[queue commandBufferWithUnretainedReferences] retain];  // +1
+          [[queue commandBufferWithDescriptor:descriptor] retain];  // +1
+      [descriptor release];                                         // -1
     }
     const iree_hal_metal_device_params_t* params = iree_hal_metal_device_params(device);
     command_buffer->dispatch_type =
diff --git a/experimental/metal/metal_device.m b/experimental/metal/metal_device.m
index 39fc78b..6fb90b7 100644
--- a/experimental/metal/metal_device.m
+++ b/experimental/metal/metal_device.m
@@ -40,6 +40,9 @@
   // We can relax this to support multiple queues when needed later.
   id<MTLCommandQueue> queue;
 
+  iree_hal_metal_command_buffer_resource_reference_mode_t command_buffer_resource_reference_mode;
+
+  // For polyfilling fill/copy/update buffers that are not directly supported by Metal APIs.
   iree_hal_metal_builtin_executable_t* builtin_executable;
 
   // A dispatch queue and associated event listener for running Objective-C blocks to signal
@@ -67,6 +70,8 @@
   memset(out_params, 0, sizeof(*out_params));
   out_params->arena_block_size = 32 * 1024;
   out_params->command_dispatch_type = IREE_HAL_METAL_COMMAND_DISPATCH_TYPE_CONCURRENT;
+  out_params->command_buffer_resource_reference_mode =
+      IREE_HAL_METAL_COMMAND_BUFFER_RESOURCE_REFERENCE_MODE_UNRETAINED;
   out_params->resource_hazard_tracking_mode =
       IREE_HAL_METAL_RESOURCE_HAZARD_TRACKING_MODE_UNTRACKED;
 }
@@ -105,6 +110,7 @@
     device->host_allocator = host_allocator;
     device->device = [metal_device retain];          // +1
     device->queue = [metal_device newCommandQueue];  // +1
+    device->command_buffer_resource_reference_mode = params->command_buffer_resource_reference_mode;
     device->builtin_executable = builtin_executable;
     dispatch_queue_attr_t queue_attr = dispatch_queue_attr_make_with_qos_class(
         DISPATCH_QUEUE_SERIAL, QOS_CLASS_USER_INITIATED, /*relative_priority=*/0);
@@ -215,8 +221,9 @@
   if (!iree_all_bits_set(mode, IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT))
     return iree_make_status(IREE_STATUS_UNIMPLEMENTED, "unimplmented multi-shot command buffer");
   return iree_hal_metal_direct_command_buffer_create(
-      base_device, mode, command_categories, binding_capacity, device->queue,
-      device->host_allocator, &device->block_pool, device->builtin_executable, out_command_buffer);
+      base_device, mode, command_categories, binding_capacity,
+      device->command_buffer_resource_reference_mode, device->queue, device->host_allocator,
+      &device->block_pool, device->builtin_executable, out_command_buffer);
 }
 
 static iree_status_t iree_hal_metal_device_create_descriptor_set_layout(
@@ -307,14 +314,20 @@
   @autoreleasepool {
     // First create a new command buffer and encode wait commands for all wait semaphores.
     if (wait_semaphore_list.count > 0) {
+      MTLCommandBufferDescriptor* descriptor = [MTLCommandBufferDescriptor new];  // +1
+      descriptor.retainedReferences =
+          device->command_buffer_resource_reference_mode ==
+          IREE_HAL_METAL_COMMAND_BUFFER_RESOURCE_REFERENCE_MODE_RETAINED;
+      descriptor.errorOptions = MTLCommandBufferErrorOptionNone;
       id<MTLCommandBuffer> wait_command_buffer =
-          [device->queue commandBufferWithUnretainedReferences];  // autoreleased
+          [device->queue commandBufferWithDescriptor:descriptor];  // autoreleased
       for (iree_host_size_t i = 0; i < wait_semaphore_list.count; ++i) {
         [wait_command_buffer
             encodeWaitForEvent:iree_hal_metal_shared_event_handle(wait_semaphore_list.semaphores[i])
                          value:wait_semaphore_list.payload_values[i]];
       }
       [wait_command_buffer commit];
+      [descriptor release];  // -1
     }
 
     // Then commit all recorded compute command buffers.
@@ -324,14 +337,20 @@
 
     // Finally create a new command buffer and encode signal commands for all signal semaphores.
     if (signal_semaphore_list.count > 0) {
+      MTLCommandBufferDescriptor* descriptor = [MTLCommandBufferDescriptor new];  // +1
+      descriptor.retainedReferences =
+          device->command_buffer_resource_reference_mode ==
+          IREE_HAL_METAL_COMMAND_BUFFER_RESOURCE_REFERENCE_MODE_RETAINED;
+      descriptor.errorOptions = MTLCommandBufferErrorOptionNone;
       id<MTLCommandBuffer> signal_command_buffer =
-          [device->queue commandBufferWithUnretainedReferences];  // autoreleased
+          [device->queue commandBufferWithDescriptor:descriptor];  // autoreleased
       for (iree_host_size_t i = 0; i < signal_semaphore_list.count; ++i) {
         [signal_command_buffer encodeSignalEvent:iree_hal_metal_shared_event_handle(
                                                      signal_semaphore_list.semaphores[i])
                                            value:signal_semaphore_list.payload_values[i]];
       }
       [signal_command_buffer commit];
+      [descriptor release];  // -1
     }
   }
 
diff --git a/experimental/metal/registration/driver_module.c b/experimental/metal/registration/driver_module.c
index 52b0e2b..da35698 100644
--- a/experimental/metal/registration/driver_module.c
+++ b/experimental/metal/registration/driver_module.c
@@ -16,7 +16,11 @@
 #include "iree/base/tracing.h"
 
 IREE_FLAG(bool, metal_serial_command_dispatch, false,
-          "Run all commands in command encoder sequentially");
+          "Serializes all commands within command buffers as if there were "
+          "barriers between each");
+IREE_FLAG(bool, metal_command_buffer_retain_resources, false,
+          "Enables automatic Metal resource reference counting for diagnosing "
+          "resource lifetime issues");
 IREE_FLAG(bool, metal_resource_hazard_tracking, false,
           "Enables automatic Metal hazard tracking for diagnosing concurrency "
           "issues");
@@ -58,6 +62,10 @@
       FLAG_metal_serial_command_dispatch
           ? IREE_HAL_METAL_COMMAND_DISPATCH_TYPE_SERIAL
           : IREE_HAL_METAL_COMMAND_DISPATCH_TYPE_CONCURRENT;
+  device_params.command_buffer_resource_reference_mode =
+      FLAG_metal_command_buffer_retain_resources
+          ? IREE_HAL_METAL_COMMAND_BUFFER_RESOURCE_REFERENCE_MODE_RETAINED
+          : IREE_HAL_METAL_COMMAND_BUFFER_RESOURCE_REFERENCE_MODE_UNRETAINED;
   device_params.resource_hazard_tracking_mode =
       FLAG_metal_resource_hazard_tracking
           ? IREE_HAL_METAL_RESOURCE_HAZARD_TRACKING_MODE_TRACKED