Adding indirect command buffer emulation for Metal. (#17849)
Until we switch to using MTLIndirectCommandBuffer any reusable command
buffers or ones with indirect bindings will need to be recorded into
deferred command buffers and replayed upon submission.
diff --git a/runtime/src/iree/hal/drivers/metal/CMakeLists.txt b/runtime/src/iree/hal/drivers/metal/CMakeLists.txt
index 0d60ae2..c3186a1 100644
--- a/runtime/src/iree/hal/drivers/metal/CMakeLists.txt
+++ b/runtime/src/iree/hal/drivers/metal/CMakeLists.txt
@@ -41,6 +41,7 @@
iree::base::internal::flatcc::parsing
iree::hal
iree::hal::drivers::metal::builtin
+ iree::hal::utils::deferred_command_buffer
iree::hal::utils::file_transfer
iree::hal::utils::memory_file
iree::hal::utils::resource_set
diff --git a/runtime/src/iree/hal/drivers/metal/metal_device.m b/runtime/src/iree/hal/drivers/metal/metal_device.m
index 620ff57..21386f8 100644
--- a/runtime/src/iree/hal/drivers/metal/metal_device.m
+++ b/runtime/src/iree/hal/drivers/metal/metal_device.m
@@ -17,6 +17,7 @@
#include "iree/hal/drivers/metal/pipeline_layout.h"
#include "iree/hal/drivers/metal/shared_event.h"
#include "iree/hal/drivers/metal/staging_buffer.h"
+#include "iree/hal/utils/deferred_command_buffer.h"
#include "iree/hal/utils/file_transfer.h"
#include "iree/hal/utils/memory_file.h"
#include "iree/hal/utils/resource_set.h"
@@ -247,12 +248,17 @@
iree_host_size_t binding_capacity, iree_hal_command_buffer_t** out_command_buffer) {
iree_hal_metal_device_t* device = iree_hal_metal_device_cast(base_device);
- if (!iree_all_bits_set(mode, IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT)) {
- return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
- "multi-shot command buffer not yet supported");
- } else if (binding_capacity > 0) {
- return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
- "indirect command buffers not yet supported");
+ // Native Metal command buffers are not reusable so we emulate by recording into our own reusable
+ // instance. This will be replayed against a Metal command buffer upon submission.
+ //
+ // TODO(indirect-cmd): natively support indirect command buffers in Metal via
+ // MTLIndirectCommandBuffer. We could switch to exclusively using that for all modes to keep the
+ // number of code paths down. MTLIndirectCommandBuffer is both reusable and has what we require
+ // for argument buffer updates to pass in binding tables.
+ if (!iree_all_bits_set(mode, IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT) || binding_capacity > 0) {
+ return iree_hal_deferred_command_buffer_create(
+ device->device_allocator, mode, command_categories, binding_capacity, &device->block_pool,
+ device->host_allocator, out_command_buffer);
}
return iree_hal_metal_direct_command_buffer_create(
@@ -390,6 +396,38 @@
return loop_status;
}
+static iree_status_t iree_hal_metal_replay_command_buffer(
+ iree_hal_metal_device_t* device, iree_hal_command_buffer_t* deferred_command_buffer,
+ iree_hal_buffer_binding_table_t binding_table,
+ iree_hal_command_buffer_t** out_direct_command_buffer) {
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ // Create the transient command buffer. Note that it is one-shot and has no indirect bindings as
+ // we will be replaying it once with all the bindings resolved.
+ iree_hal_command_buffer_t* direct_command_buffer = NULL;
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_hal_metal_direct_command_buffer_create(
+ (iree_hal_device_t*)device, IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT,
+ iree_hal_command_buffer_allowed_categories(deferred_command_buffer),
+ /*binding_capacity=*/0, device->command_buffer_resource_reference_mode, device->queue,
+ &device->block_pool, &device->staging_buffer, device->builtin_executable,
+ device->host_allocator, &direct_command_buffer));
+
+ // Attempt to replay all commands against the transient command buffer. Note that this will fail
+ // if any binding does not meet the requirements - having succeeded when recording initially is
+ // not a guarantee that this will succeed.
+ iree_status_t status = iree_hal_deferred_command_buffer_apply(
+ deferred_command_buffer, direct_command_buffer, binding_table);
+
+ if (iree_status_is_ok(status)) {
+ *out_direct_command_buffer = direct_command_buffer;
+ } else {
+ iree_hal_command_buffer_release(direct_command_buffer);
+ }
+ IREE_TRACE_ZONE_END(z0);
+ return status;
+}
+
static iree_status_t iree_hal_metal_device_queue_execute(
iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity,
const iree_hal_semaphore_list_t wait_semaphore_list,
@@ -403,20 +441,46 @@
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_hal_resource_set_allocate(&device->block_pool, &resource_set));
- iree_status_t status =
- iree_hal_resource_set_insert(resource_set, command_buffer_count, command_buffers);
-
// Put the full semaphore list into a resource set, which retains them--we will need to access
// them until the command buffer completes.
- if (iree_status_is_ok(status)) {
- status = iree_hal_resource_set_insert(resource_set, wait_semaphore_list.count,
- wait_semaphore_list.semaphores);
- }
+ iree_status_t status = iree_hal_resource_set_insert(resource_set, wait_semaphore_list.count,
+ wait_semaphore_list.semaphores);
if (iree_status_is_ok(status)) {
status = iree_hal_resource_set_insert(resource_set, signal_semaphore_list.count,
signal_semaphore_list.semaphores);
}
+ // Translate any deferred command buffers into real Metal command buffers.
+ // We do this prior to beginning execution so that if we fail we don't leave the system in an
+ // inconsistent state.
+ iree_hal_command_buffer_t** direct_command_buffers = (iree_hal_command_buffer_t**)iree_alloca(
+ command_buffer_count * sizeof(iree_hal_command_buffer_t*));
+ if (iree_status_is_ok(status)) {
+ for (iree_host_size_t i = 0; i < command_buffer_count; ++i) {
+ iree_hal_command_buffer_t* command_buffer = command_buffers[i];
+ iree_hal_command_buffer_t* direct_command_buffer = NULL;
+ if (iree_hal_deferred_command_buffer_isa(command_buffer)) {
+ // Create a temporary command buffer and replay the deferred command buffer with the
+ // binding table provided. Note that any resources used will be retained by the command
+ // buffer so we only need to retain the command buffer itself instead of the binding
+ // tables provided.
+ iree_hal_buffer_binding_table_t binding_table =
+ binding_tables ? binding_tables[i] : iree_hal_buffer_binding_table_empty();
+ @autoreleasepool {
+ status = iree_hal_metal_replay_command_buffer(device, command_buffer, binding_table,
+ &direct_command_buffer);
+ }
+ } else {
+ // Retain the command buffer until the submission has completed.
+ direct_command_buffer = command_buffer;
+ }
+ if (!iree_status_is_ok(status)) break;
+ status = iree_hal_resource_set_insert(resource_set, 1, &direct_command_buffer);
+ if (!iree_status_is_ok(status)) break;
+ direct_command_buffers[i] = direct_command_buffer;
+ }
+ }
+
if (iree_status_is_ok(status)) {
@autoreleasepool {
// First create a new command buffer and encode wait commands for all wait semaphores.
@@ -436,8 +500,14 @@
// up with semaphore signaling.
id<MTLCommandBuffer> signal_command_buffer = nil;
for (iree_host_size_t i = 0; i < command_buffer_count; ++i) {
- iree_hal_command_buffer_t* command_buffer = command_buffers[i];
- id<MTLCommandBuffer> handle = iree_hal_metal_direct_command_buffer_handle(command_buffer);
+ // NOTE: translation happens above such that we always know these are direct command
+ // buffers.
+ //
+ // TODO(indirect-cmd): support indirect command buffers and switch here, or only use
+ // indirect command buffers and assume that instead.
+ iree_hal_command_buffer_t* direct_command_buffer = direct_command_buffers[i];
+ id<MTLCommandBuffer> handle =
+ iree_hal_metal_direct_command_buffer_handle(direct_command_buffer);
if (i + 1 != command_buffer_count) [handle commit];
signal_command_buffer = handle;
}