[metal] Avoid resource set leak in queue execution
We need to check the status and free the resource set if any
error happens.
diff --git a/experimental/metal/metal_device.m b/experimental/metal/metal_device.m
index ba11534..1b7122b 100644
--- a/experimental/metal/metal_device.m
+++ b/experimental/metal/metal_device.m
@@ -341,61 +341,69 @@
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_hal_resource_set_allocate(&device->block_pool, &resource_set));
- IREE_RETURN_AND_END_ZONE_IF_ERROR(
- z0, iree_hal_resource_set_insert(resource_set, command_buffer_count, command_buffers));
+ iree_status_t status =
+ iree_hal_resource_set_insert(resource_set, command_buffer_count, command_buffers);
+
// Put the full semaphore list into a resource set, which retains them--we will need to access
// them until the command buffer completes.
- IREE_RETURN_AND_END_ZONE_IF_ERROR(
- z0, iree_hal_resource_set_insert(resource_set, wait_semaphore_list.count,
- wait_semaphore_list.semaphores));
- IREE_RETURN_AND_END_ZONE_IF_ERROR(
- z0, iree_hal_resource_set_insert(resource_set, signal_semaphore_list.count,
- signal_semaphore_list.semaphores));
+ if (iree_status_is_ok(status)) {
+ status = iree_hal_resource_set_insert(resource_set, wait_semaphore_list.count,
+ wait_semaphore_list.semaphores);
+ }
+ if (iree_status_is_ok(status)) {
+ status = iree_hal_resource_set_insert(resource_set, signal_semaphore_list.count,
+ signal_semaphore_list.semaphores);
+ }
- @autoreleasepool {
- // First create a new command buffer and encode wait commands for all wait semaphores.
- if (wait_semaphore_list.count > 0) {
- id<MTLCommandBuffer> wait_command_buffer = [device->queue
- commandBufferWithDescriptor:device->command_buffer_descriptor]; // autoreleased
- for (iree_host_size_t i = 0; i < wait_semaphore_list.count; ++i) {
- id<MTLSharedEvent> handle =
- iree_hal_metal_shared_event_handle(wait_semaphore_list.semaphores[i]);
- [wait_command_buffer encodeWaitForEvent:handle value:wait_semaphore_list.payload_values[i]];
+ if (iree_status_is_ok(status)) {
+ @autoreleasepool {
+ // First create a new command buffer and encode wait commands for all wait semaphores.
+ if (wait_semaphore_list.count > 0) {
+ id<MTLCommandBuffer> wait_command_buffer = [device->queue
+ commandBufferWithDescriptor:device->command_buffer_descriptor]; // autoreleased
+ for (iree_host_size_t i = 0; i < wait_semaphore_list.count; ++i) {
+ id<MTLSharedEvent> handle =
+ iree_hal_metal_shared_event_handle(wait_semaphore_list.semaphores[i]);
+ [wait_command_buffer encodeWaitForEvent:handle
+ value:wait_semaphore_list.payload_values[i]];
+ }
+ [wait_command_buffer commit];
}
- [wait_command_buffer commit];
- }
- // Then commit all recorded compute command buffers, except the last one, which we will patch up
- // with semaphore signaling.
- id<MTLCommandBuffer> signal_command_buffer = nil;
- for (iree_host_size_t i = 0; i < command_buffer_count; ++i) {
- iree_hal_command_buffer_t* command_buffer = command_buffers[i];
- id<MTLCommandBuffer> handle = iree_hal_metal_direct_command_buffer_handle(command_buffer);
- if (i + 1 != command_buffer_count) [handle commit];
- signal_command_buffer = handle;
- }
- if (signal_command_buffer == nil) {
- signal_command_buffer = [device->queue
- commandBufferWithDescriptor:device->command_buffer_descriptor]; // autoreleased
- }
+ // Then commit all recorded compute command buffers, except the last one, which we will patch
+ // up with semaphore signaling.
+ id<MTLCommandBuffer> signal_command_buffer = nil;
+ for (iree_host_size_t i = 0; i < command_buffer_count; ++i) {
+ iree_hal_command_buffer_t* command_buffer = command_buffers[i];
+ id<MTLCommandBuffer> handle = iree_hal_metal_direct_command_buffer_handle(command_buffer);
+ if (i + 1 != command_buffer_count) [handle commit];
+ signal_command_buffer = handle;
+ }
+ if (signal_command_buffer == nil) {
+ signal_command_buffer = [device->queue
+ commandBufferWithDescriptor:device->command_buffer_descriptor]; // autoreleased
+ }
- // Finally encode signal commands for all signal semaphores.
- for (iree_host_size_t i = 0; i < signal_semaphore_list.count; ++i) {
- id<MTLSharedEvent> handle =
- iree_hal_metal_shared_event_handle(signal_semaphore_list.semaphores[i]);
- [signal_command_buffer encodeSignalEvent:handle
- value:signal_semaphore_list.payload_values[i]];
- }
+ // Finally encode signal commands for all signal semaphores.
+ for (iree_host_size_t i = 0; i < signal_semaphore_list.count; ++i) {
+ id<MTLSharedEvent> handle =
+ iree_hal_metal_shared_event_handle(signal_semaphore_list.semaphores[i]);
+ [signal_command_buffer encodeSignalEvent:handle
+ value:signal_semaphore_list.payload_values[i]];
+ }
- [signal_command_buffer addCompletedHandler:^(id<MTLCommandBuffer> cb) {
- // Now we can release all retained resources.
- iree_hal_resource_set_free(resource_set);
- }];
- [signal_command_buffer commit];
+ [signal_command_buffer addCompletedHandler:^(id<MTLCommandBuffer> cb) {
+ // Now we can release all retained resources.
+ iree_hal_resource_set_free(resource_set);
+ }];
+ [signal_command_buffer commit];
+ }
+ } else {
+ iree_hal_resource_set_free(resource_set);
}
IREE_TRACE_ZONE_END(z0);
- return iree_ok_status();
+ return status;
}
static iree_status_t iree_hal_metal_device_queue_flush(iree_hal_device_t* base_device,