Adding indirect command buffer emulation for Metal. (#17849)

Until we switch to using MTLIndirectCommandBuffer any reusable command
buffers or ones with indirect bindings will need to be recorded into
deferred command buffers and replayed upon submission.
diff --git a/runtime/src/iree/hal/drivers/metal/CMakeLists.txt b/runtime/src/iree/hal/drivers/metal/CMakeLists.txt
index 0d60ae2..c3186a1 100644
--- a/runtime/src/iree/hal/drivers/metal/CMakeLists.txt
+++ b/runtime/src/iree/hal/drivers/metal/CMakeLists.txt
@@ -41,6 +41,7 @@
     iree::base::internal::flatcc::parsing
     iree::hal
     iree::hal::drivers::metal::builtin
+    iree::hal::utils::deferred_command_buffer
     iree::hal::utils::file_transfer
     iree::hal::utils::memory_file
     iree::hal::utils::resource_set
diff --git a/runtime/src/iree/hal/drivers/metal/metal_device.m b/runtime/src/iree/hal/drivers/metal/metal_device.m
index 620ff57..21386f8 100644
--- a/runtime/src/iree/hal/drivers/metal/metal_device.m
+++ b/runtime/src/iree/hal/drivers/metal/metal_device.m
@@ -17,6 +17,7 @@
 #include "iree/hal/drivers/metal/pipeline_layout.h"
 #include "iree/hal/drivers/metal/shared_event.h"
 #include "iree/hal/drivers/metal/staging_buffer.h"
+#include "iree/hal/utils/deferred_command_buffer.h"
 #include "iree/hal/utils/file_transfer.h"
 #include "iree/hal/utils/memory_file.h"
 #include "iree/hal/utils/resource_set.h"
@@ -247,12 +248,17 @@
     iree_host_size_t binding_capacity, iree_hal_command_buffer_t** out_command_buffer) {
   iree_hal_metal_device_t* device = iree_hal_metal_device_cast(base_device);
 
-  if (!iree_all_bits_set(mode, IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT)) {
-    return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
-                            "multi-shot command buffer not yet supported");
-  } else if (binding_capacity > 0) {
-    return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
-                            "indirect command buffers not yet supported");
+  // Native Metal command buffers are not reusable so we emulate by recording into our own reusable
+  // instance. This will be replayed against a Metal command buffer upon submission.
+  //
+  // TODO(indirect-cmd): natively support indirect command buffers in Metal via
+  // MTLIndirectCommandBuffer. We could switch to exclusively using that for all modes to keep the
+  // number of code paths down. MTLIndirectCommandBuffer is both reusable and has what we require
+  // for argument buffer updates to pass in binding tables.
+  if (!iree_all_bits_set(mode, IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT) || binding_capacity > 0) {
+    return iree_hal_deferred_command_buffer_create(
+        device->device_allocator, mode, command_categories, binding_capacity, &device->block_pool,
+        device->host_allocator, out_command_buffer);
   }
 
   return iree_hal_metal_direct_command_buffer_create(
@@ -390,6 +396,38 @@
   return loop_status;
 }
 
+static iree_status_t iree_hal_metal_replay_command_buffer(
+    iree_hal_metal_device_t* device, iree_hal_command_buffer_t* deferred_command_buffer,
+    iree_hal_buffer_binding_table_t binding_table,
+    iree_hal_command_buffer_t** out_direct_command_buffer) {
+  IREE_TRACE_ZONE_BEGIN(z0);
+
+  // Create the transient command buffer. Note that it is one-shot and has no indirect bindings as
+  // we will be replaying it once with all the bindings resolved.
+  iree_hal_command_buffer_t* direct_command_buffer = NULL;
+  IREE_RETURN_AND_END_ZONE_IF_ERROR(
+      z0, iree_hal_metal_direct_command_buffer_create(
+              (iree_hal_device_t*)device, IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT,
+              iree_hal_command_buffer_allowed_categories(deferred_command_buffer),
+              /*binding_capacity=*/0, device->command_buffer_resource_reference_mode, device->queue,
+              &device->block_pool, &device->staging_buffer, device->builtin_executable,
+              device->host_allocator, &direct_command_buffer));
+
+  // Attempt to replay all commands against the transient command buffer. Note that this will fail
+  // if any binding does not meet the requirements - having succeeded when recording initially is
+  // not a guarantee that this will succeed.
+  iree_status_t status = iree_hal_deferred_command_buffer_apply(
+      deferred_command_buffer, direct_command_buffer, binding_table);
+
+  if (iree_status_is_ok(status)) {
+    *out_direct_command_buffer = direct_command_buffer;
+  } else {
+    iree_hal_command_buffer_release(direct_command_buffer);
+  }
+  IREE_TRACE_ZONE_END(z0);
+  return status;
+}
+
 static iree_status_t iree_hal_metal_device_queue_execute(
     iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity,
     const iree_hal_semaphore_list_t wait_semaphore_list,
@@ -403,20 +441,46 @@
   IREE_RETURN_AND_END_ZONE_IF_ERROR(
       z0, iree_hal_resource_set_allocate(&device->block_pool, &resource_set));
 
-  iree_status_t status =
-      iree_hal_resource_set_insert(resource_set, command_buffer_count, command_buffers);
-
   // Put the full semaphore list into a resource set, which retains them--we will need to access
   // them until the command buffer completes.
-  if (iree_status_is_ok(status)) {
-    status = iree_hal_resource_set_insert(resource_set, wait_semaphore_list.count,
-                                          wait_semaphore_list.semaphores);
-  }
+  iree_status_t status = iree_hal_resource_set_insert(resource_set, wait_semaphore_list.count,
+                                                      wait_semaphore_list.semaphores);
   if (iree_status_is_ok(status)) {
     status = iree_hal_resource_set_insert(resource_set, signal_semaphore_list.count,
                                           signal_semaphore_list.semaphores);
   }
 
+  // Translate any deferred command buffers into real Metal command buffers.
+  // We do this prior to beginning execution so that if we fail we don't leave the system in an
+  // inconsistent state.
+  iree_hal_command_buffer_t** direct_command_buffers = (iree_hal_command_buffer_t**)iree_alloca(
+      command_buffer_count * sizeof(iree_hal_command_buffer_t*));
+  if (iree_status_is_ok(status)) {
+    for (iree_host_size_t i = 0; i < command_buffer_count; ++i) {
+      iree_hal_command_buffer_t* command_buffer = command_buffers[i];
+      iree_hal_command_buffer_t* direct_command_buffer = NULL;
+      if (iree_hal_deferred_command_buffer_isa(command_buffer)) {
+        // Create a temporary command buffer and replay the deferred command buffer with the
+        // binding table provided. Note that any resources used will be retained by the command
+        // buffer so we only need to retain the command buffer itself instead of the binding
+        // tables provided.
+        iree_hal_buffer_binding_table_t binding_table =
+            binding_tables ? binding_tables[i] : iree_hal_buffer_binding_table_empty();
+        @autoreleasepool {
+          status = iree_hal_metal_replay_command_buffer(device, command_buffer, binding_table,
+                                                        &direct_command_buffer);
+        }
+      } else {
+        // Retain the command buffer until the submission has completed.
+        direct_command_buffer = command_buffer;
+      }
+      if (!iree_status_is_ok(status)) break;
+      status = iree_hal_resource_set_insert(resource_set, 1, &direct_command_buffer);
+      if (!iree_status_is_ok(status)) break;
+      direct_command_buffers[i] = direct_command_buffer;
+    }
+  }
+
   if (iree_status_is_ok(status)) {
     @autoreleasepool {
       // First create a new command buffer and encode wait commands for all wait semaphores.
@@ -436,8 +500,14 @@
       // up with semaphore signaling.
       id<MTLCommandBuffer> signal_command_buffer = nil;
       for (iree_host_size_t i = 0; i < command_buffer_count; ++i) {
-        iree_hal_command_buffer_t* command_buffer = command_buffers[i];
-        id<MTLCommandBuffer> handle = iree_hal_metal_direct_command_buffer_handle(command_buffer);
+        // NOTE: translation happens above such that we always know these are direct command
+        // buffers.
+        //
+        // TODO(indirect-cmd): support indirect command buffers and switch here, or only use
+        // indirect command buffers and assume that instead.
+        iree_hal_command_buffer_t* direct_command_buffer = direct_command_buffers[i];
+        id<MTLCommandBuffer> handle =
+            iree_hal_metal_direct_command_buffer_handle(direct_command_buffer);
         if (i + 1 != command_buffer_count) [handle commit];
         signal_command_buffer = handle;
       }