[metal] Add option for strong resource reference in command buffers
diff --git a/experimental/metal/api.h b/experimental/metal/api.h
index 24ca999..0a5112b 100644
--- a/experimental/metal/api.h
+++ b/experimental/metal/api.h
@@ -27,6 +27,13 @@
IREE_HAL_METAL_COMMAND_DISPATCH_TYPE_SERIAL = 1,
} iree_hal_metal_command_dispatch_type_t;
+typedef enum iree_hal_metal_command_buffer_resource_reference_mode_e {
+ // Do not maintain strong references to resources used in command buffers.
+ IREE_HAL_METAL_COMMAND_BUFFER_RESOURCE_REFERENCE_MODE_UNRETAINED = 0,
+ // Maintain strong references to resources used in command buffers.
+ IREE_HAL_METAL_COMMAND_BUFFER_RESOURCE_REFERENCE_MODE_RETAINED = 1,
+} iree_hal_metal_command_buffer_resource_reference_mode_t;
+
typedef enum iree_hal_metal_resource_hazard_tracking_mode_e {
// Do not track resource hazards. Hosting applications are responsible for
// ensuring that resources are not modified while in use.
@@ -51,6 +58,12 @@
// debugging in certain cases.
iree_hal_metal_command_dispatch_type_t command_dispatch_type;
+ // Resource reference mode in command buffers.
+ // Normally we track resource lifetime in IREE explicitly, so we don't need to
+ // incur Metal runtime overhead to do that. But good for debugging purposes.
+ iree_hal_metal_command_buffer_resource_reference_mode_t
+ command_buffer_resource_reference_mode;
+
// Resource hazard tracking mode.
// IREE is following explicit GPU API model and tracks resource dependency by
// itself. So normally we don't need to let Metal runtime to track resource
diff --git a/experimental/metal/direct_command_buffer.h b/experimental/metal/direct_command_buffer.h
index 01abd08..09c1478 100644
--- a/experimental/metal/direct_command_buffer.h
+++ b/experimental/metal/direct_command_buffer.h
@@ -9,6 +9,7 @@
#import <Metal/Metal.h>
+#include "experimental/metal/api.h"
#include "experimental/metal/builtin_executables.h"
#include "iree/base/internal/arena.h"
#include "iree/hal/api.h"
@@ -37,8 +38,11 @@
iree_status_t iree_hal_metal_direct_command_buffer_create(
iree_hal_device_t* device, iree_hal_command_buffer_mode_t mode,
iree_hal_command_category_t command_categories,
- iree_host_size_t binding_capacity, id<MTLCommandQueue> queue,
- iree_allocator_t host_allocator, iree_arena_block_pool_t* block_pool,
+ iree_host_size_t binding_capacity,
+ iree_hal_metal_command_buffer_resource_reference_mode_t
+ resource_reference_mode,
+ id<MTLCommandQueue> queue, iree_allocator_t host_allocator,
+ iree_arena_block_pool_t* block_pool,
iree_hal_metal_builtin_executable_t* builtin_executable,
iree_hal_command_buffer_t** out_command_buffer);
diff --git a/experimental/metal/direct_command_buffer.m b/experimental/metal/direct_command_buffer.m
index 1ab3583..bfb6081 100644
--- a/experimental/metal/direct_command_buffer.m
+++ b/experimental/metal/direct_command_buffer.m
@@ -135,6 +135,7 @@
iree_status_t iree_hal_metal_direct_command_buffer_create(
iree_hal_device_t* device, iree_hal_command_buffer_mode_t mode,
iree_hal_command_category_t command_categories, iree_host_size_t binding_capacity,
+ iree_hal_metal_command_buffer_resource_reference_mode_t resource_reference_mode,
id<MTLCommandQueue> queue, iree_allocator_t host_allocator, iree_arena_block_pool_t* block_pool,
iree_hal_metal_builtin_executable_t* builtin_executable,
iree_hal_command_buffer_t** out_command_buffer) {
@@ -163,8 +164,13 @@
@autoreleasepool { // Use @autoreleasepool to trigger the autorelease within encoder creation.
// We track resource lifetime by ourselves in IREE; so just do unretained references to
// resources in Metal command buffer, which avoids overhead and gives better performance.
+ MTLCommandBufferDescriptor* descriptor = [MTLCommandBufferDescriptor new]; // +1
+ descriptor.retainedReferences =
+ resource_reference_mode == IREE_HAL_METAL_COMMAND_BUFFER_RESOURCE_REFERENCE_MODE_RETAINED;
+ descriptor.errorOptions = MTLCommandBufferErrorOptionNone;
command_buffer->command_buffer =
- [[queue commandBufferWithUnretainedReferences] retain]; // +1
+ [[queue commandBufferWithDescriptor:descriptor] retain]; // +1
+ [descriptor release]; // -1
}
const iree_hal_metal_device_params_t* params = iree_hal_metal_device_params(device);
command_buffer->dispatch_type =
diff --git a/experimental/metal/metal_device.m b/experimental/metal/metal_device.m
index 39fc78b..6fb90b7 100644
--- a/experimental/metal/metal_device.m
+++ b/experimental/metal/metal_device.m
@@ -40,6 +40,9 @@
// We can relax this to support multiple queues when needed later.
id<MTLCommandQueue> queue;
+ iree_hal_metal_command_buffer_resource_reference_mode_t command_buffer_resource_reference_mode;
+
+ // For polyfilling fill/copy/update buffers that are not directly supported by Metal APIs.
iree_hal_metal_builtin_executable_t* builtin_executable;
// A dispatch queue and associated event listener for running Objective-C blocks to signal
@@ -67,6 +70,8 @@
memset(out_params, 0, sizeof(*out_params));
out_params->arena_block_size = 32 * 1024;
out_params->command_dispatch_type = IREE_HAL_METAL_COMMAND_DISPATCH_TYPE_CONCURRENT;
+ out_params->command_buffer_resource_reference_mode =
+ IREE_HAL_METAL_COMMAND_BUFFER_RESOURCE_REFERENCE_MODE_UNRETAINED;
out_params->resource_hazard_tracking_mode =
IREE_HAL_METAL_RESOURCE_HAZARD_TRACKING_MODE_UNTRACKED;
}
@@ -105,6 +110,7 @@
device->host_allocator = host_allocator;
device->device = [metal_device retain]; // +1
device->queue = [metal_device newCommandQueue]; // +1
+ device->command_buffer_resource_reference_mode = params->command_buffer_resource_reference_mode;
device->builtin_executable = builtin_executable;
dispatch_queue_attr_t queue_attr = dispatch_queue_attr_make_with_qos_class(
DISPATCH_QUEUE_SERIAL, QOS_CLASS_USER_INITIATED, /*relative_priority=*/0);
@@ -215,8 +221,9 @@
if (!iree_all_bits_set(mode, IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT))
return iree_make_status(IREE_STATUS_UNIMPLEMENTED, "unimplmented multi-shot command buffer");
return iree_hal_metal_direct_command_buffer_create(
- base_device, mode, command_categories, binding_capacity, device->queue,
- device->host_allocator, &device->block_pool, device->builtin_executable, out_command_buffer);
+ base_device, mode, command_categories, binding_capacity,
+ device->command_buffer_resource_reference_mode, device->queue, device->host_allocator,
+ &device->block_pool, device->builtin_executable, out_command_buffer);
}
static iree_status_t iree_hal_metal_device_create_descriptor_set_layout(
@@ -307,14 +314,20 @@
@autoreleasepool {
// First create a new command buffer and encode wait commands for all wait semaphores.
if (wait_semaphore_list.count > 0) {
+ MTLCommandBufferDescriptor* descriptor = [MTLCommandBufferDescriptor new]; // +1
+ descriptor.retainedReferences =
+ device->command_buffer_resource_reference_mode ==
+ IREE_HAL_METAL_COMMAND_BUFFER_RESOURCE_REFERENCE_MODE_RETAINED;
+ descriptor.errorOptions = MTLCommandBufferErrorOptionNone;
id<MTLCommandBuffer> wait_command_buffer =
- [device->queue commandBufferWithUnretainedReferences]; // autoreleased
+ [device->queue commandBufferWithDescriptor:descriptor]; // autoreleased
for (iree_host_size_t i = 0; i < wait_semaphore_list.count; ++i) {
[wait_command_buffer
encodeWaitForEvent:iree_hal_metal_shared_event_handle(wait_semaphore_list.semaphores[i])
value:wait_semaphore_list.payload_values[i]];
}
[wait_command_buffer commit];
+ [descriptor release]; // -1
}
// Then commit all recorded compute command buffers.
@@ -324,14 +337,20 @@
// Finally create a new command buffer and encode signal commands for all signal semaphores.
if (signal_semaphore_list.count > 0) {
+ MTLCommandBufferDescriptor* descriptor = [MTLCommandBufferDescriptor new]; // +1
+ descriptor.retainedReferences =
+ device->command_buffer_resource_reference_mode ==
+ IREE_HAL_METAL_COMMAND_BUFFER_RESOURCE_REFERENCE_MODE_RETAINED;
+ descriptor.errorOptions = MTLCommandBufferErrorOptionNone;
id<MTLCommandBuffer> signal_command_buffer =
- [device->queue commandBufferWithUnretainedReferences]; // autoreleased
+ [device->queue commandBufferWithDescriptor:descriptor]; // autoreleased
for (iree_host_size_t i = 0; i < signal_semaphore_list.count; ++i) {
[signal_command_buffer encodeSignalEvent:iree_hal_metal_shared_event_handle(
signal_semaphore_list.semaphores[i])
value:signal_semaphore_list.payload_values[i]];
}
[signal_command_buffer commit];
+ [descriptor release]; // -1
}
}
diff --git a/experimental/metal/registration/driver_module.c b/experimental/metal/registration/driver_module.c
index 52b0e2b..da35698 100644
--- a/experimental/metal/registration/driver_module.c
+++ b/experimental/metal/registration/driver_module.c
@@ -16,7 +16,11 @@
#include "iree/base/tracing.h"
IREE_FLAG(bool, metal_serial_command_dispatch, false,
- "Run all commands in command encoder sequentially");
+ "Serializes all commands within command buffers as if there were "
+ "barriers between each");
+IREE_FLAG(bool, metal_command_buffer_retain_resources, false,
+ "Enables automatic Metal resource reference counting for diagnosing "
+ "resource lifetime issues");
IREE_FLAG(bool, metal_resource_hazard_tracking, false,
"Enables automatic Metal hazard tracking for diagnosing concurrency "
"issues");
@@ -58,6 +62,10 @@
FLAG_metal_serial_command_dispatch
? IREE_HAL_METAL_COMMAND_DISPATCH_TYPE_SERIAL
: IREE_HAL_METAL_COMMAND_DISPATCH_TYPE_CONCURRENT;
+ device_params.command_buffer_resource_reference_mode =
+ FLAG_metal_command_buffer_retain_resources
+ ? IREE_HAL_METAL_COMMAND_BUFFER_RESOURCE_REFERENCE_MODE_RETAINED
+ : IREE_HAL_METAL_COMMAND_BUFFER_RESOURCE_REFERENCE_MODE_UNRETAINED;
device_params.resource_hazard_tracking_mode =
FLAG_metal_resource_hazard_tracking
? IREE_HAL_METAL_RESOURCE_HAZARD_TRACKING_MODE_TRACKED