[metal] Avoid resource set leak in queue execution We need to check the status and free the resource set if any error happens.
diff --git a/experimental/metal/metal_device.m b/experimental/metal/metal_device.m index ba11534..1b7122b 100644 --- a/experimental/metal/metal_device.m +++ b/experimental/metal/metal_device.m
@@ -341,61 +341,69 @@ IREE_RETURN_AND_END_ZONE_IF_ERROR( z0, iree_hal_resource_set_allocate(&device->block_pool, &resource_set)); - IREE_RETURN_AND_END_ZONE_IF_ERROR( - z0, iree_hal_resource_set_insert(resource_set, command_buffer_count, command_buffers)); + iree_status_t status = + iree_hal_resource_set_insert(resource_set, command_buffer_count, command_buffers); + // Put the full semaphore list into a resource set, which retains them--we will need to access // them until the command buffer completes. - IREE_RETURN_AND_END_ZONE_IF_ERROR( - z0, iree_hal_resource_set_insert(resource_set, wait_semaphore_list.count, - wait_semaphore_list.semaphores)); - IREE_RETURN_AND_END_ZONE_IF_ERROR( - z0, iree_hal_resource_set_insert(resource_set, signal_semaphore_list.count, - signal_semaphore_list.semaphores)); + if (iree_status_is_ok(status)) { + status = iree_hal_resource_set_insert(resource_set, wait_semaphore_list.count, + wait_semaphore_list.semaphores); + } + if (iree_status_is_ok(status)) { + status = iree_hal_resource_set_insert(resource_set, signal_semaphore_list.count, + signal_semaphore_list.semaphores); + } - @autoreleasepool { - // First create a new command buffer and encode wait commands for all wait semaphores. - if (wait_semaphore_list.count > 0) { - id<MTLCommandBuffer> wait_command_buffer = [device->queue - commandBufferWithDescriptor:device->command_buffer_descriptor]; // autoreleased - for (iree_host_size_t i = 0; i < wait_semaphore_list.count; ++i) { - id<MTLSharedEvent> handle = - iree_hal_metal_shared_event_handle(wait_semaphore_list.semaphores[i]); - [wait_command_buffer encodeWaitForEvent:handle value:wait_semaphore_list.payload_values[i]]; + if (iree_status_is_ok(status)) { + @autoreleasepool { + // First create a new command buffer and encode wait commands for all wait semaphores. + if (wait_semaphore_list.count > 0) { + id<MTLCommandBuffer> wait_command_buffer = [device->queue + commandBufferWithDescriptor:device->command_buffer_descriptor]; // autoreleased + for (iree_host_size_t i = 0; i < wait_semaphore_list.count; ++i) { + id<MTLSharedEvent> handle = + iree_hal_metal_shared_event_handle(wait_semaphore_list.semaphores[i]); + [wait_command_buffer encodeWaitForEvent:handle + value:wait_semaphore_list.payload_values[i]]; + } + [wait_command_buffer commit]; } - [wait_command_buffer commit]; - } - // Then commit all recorded compute command buffers, except the last one, which we will patch up - // with semaphore signaling. - id<MTLCommandBuffer> signal_command_buffer = nil; - for (iree_host_size_t i = 0; i < command_buffer_count; ++i) { - iree_hal_command_buffer_t* command_buffer = command_buffers[i]; - id<MTLCommandBuffer> handle = iree_hal_metal_direct_command_buffer_handle(command_buffer); - if (i + 1 != command_buffer_count) [handle commit]; - signal_command_buffer = handle; - } - if (signal_command_buffer == nil) { - signal_command_buffer = [device->queue - commandBufferWithDescriptor:device->command_buffer_descriptor]; // autoreleased - } + // Then commit all recorded compute command buffers, except the last one, which we will patch + // up with semaphore signaling. + id<MTLCommandBuffer> signal_command_buffer = nil; + for (iree_host_size_t i = 0; i < command_buffer_count; ++i) { + iree_hal_command_buffer_t* command_buffer = command_buffers[i]; + id<MTLCommandBuffer> handle = iree_hal_metal_direct_command_buffer_handle(command_buffer); + if (i + 1 != command_buffer_count) [handle commit]; + signal_command_buffer = handle; + } + if (signal_command_buffer == nil) { + signal_command_buffer = [device->queue + commandBufferWithDescriptor:device->command_buffer_descriptor]; // autoreleased + } - // Finally encode signal commands for all signal semaphores. - for (iree_host_size_t i = 0; i < signal_semaphore_list.count; ++i) { - id<MTLSharedEvent> handle = - iree_hal_metal_shared_event_handle(signal_semaphore_list.semaphores[i]); - [signal_command_buffer encodeSignalEvent:handle - value:signal_semaphore_list.payload_values[i]]; - } + // Finally encode signal commands for all signal semaphores. + for (iree_host_size_t i = 0; i < signal_semaphore_list.count; ++i) { + id<MTLSharedEvent> handle = + iree_hal_metal_shared_event_handle(signal_semaphore_list.semaphores[i]); + [signal_command_buffer encodeSignalEvent:handle + value:signal_semaphore_list.payload_values[i]]; + } - [signal_command_buffer addCompletedHandler:^(id<MTLCommandBuffer> cb) { - // Now we can release all retained resources. - iree_hal_resource_set_free(resource_set); - }]; - [signal_command_buffer commit]; + [signal_command_buffer addCompletedHandler:^(id<MTLCommandBuffer> cb) { + // Now we can release all retained resources. + iree_hal_resource_set_free(resource_set); + }]; + [signal_command_buffer commit]; + } + } else { + iree_hal_resource_set_free(resource_set); } IREE_TRACE_ZONE_END(z0); - return iree_ok_status(); + return status; } static iree_status_t iree_hal_metal_device_queue_flush(iree_hal_device_t* base_device,