[metal] Enable real async execution on GPU This commit drops `legacy_sync` path from Metal HAL driver, where we were forcing waiting all semaphores before all queue execution. This requires us to track lifetimes of resources involved in queue execution better, particularly we need to make sure semaphores and command buffers aren't released until they are not needed by the GPU.
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/MetalSPIRV/MetalSPIRVTarget.cpp b/compiler/src/iree/compiler/Dialect/HAL/Target/MetalSPIRV/MetalSPIRVTarget.cpp index 97e8d75..4696078 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Target/MetalSPIRV/MetalSPIRVTarget.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Target/MetalSPIRV/MetalSPIRVTarget.cpp
@@ -117,10 +117,6 @@ Builder b(context); SmallVector<NamedAttribute> configItems; - // Indicates that the runtime HAL driver operates only in the legacy - // synchronous mode. - configItems.emplace_back(b.getStringAttr("legacy_sync"), b.getUnitAttr()); - configItems.emplace_back(b.getStringAttr("executable_targets"), getExecutableTargets(context));
diff --git a/experimental/metal/metal_device.m b/experimental/metal/metal_device.m index 71cd7c4..67a07fd 100644 --- a/experimental/metal/metal_device.m +++ b/experimental/metal/metal_device.m
@@ -314,18 +314,19 @@ @autoreleasepool { // First create a new command buffer and encode wait commands for all wait semaphores. if (wait_semaphore_list.count > 0) { - // Extract all MTLSharedEvents behind into a heap-allocated array--we will need to access them - // in command buffer completion callback. - id<MTLSharedEvent>* shared_events; + // Copy the full semaphore list to heap--we will need to access them in command buffer + // completion callback. + iree_hal_semaphore_t** saved_semaphores; + iree_host_size_t size = sizeof(iree_hal_semaphore_t*) * wait_semaphore_list.count; IREE_RETURN_AND_END_ZONE_IF_ERROR( - z0, iree_allocator_malloc(device->host_allocator, - sizeof(id<MTLSharedEvent>) * wait_semaphore_list.count, - (void**)&shared_events)); + z0, iree_allocator_malloc(device->host_allocator, size, (void**)&saved_semaphores)); + memcpy(saved_semaphores, wait_semaphore_list.semaphores, size); + + // IREE will free resources once their refcounts become zero on host. However, there are work + // happening on the GPU async still needing access. So make sure we retain all semaphores + // until command buffer completion. for (iree_host_size_t i = 0; i < wait_semaphore_list.count; ++i) { - shared_events[i] = iree_hal_metal_shared_event_handle(wait_semaphore_list.semaphores[i]); - // Make sure we retain the shared event until command buffer completion. IREE might free the - // wrapping semaphore after refcount become zero. - [shared_events[i] retain]; // +1 + iree_hal_semaphore_retain(saved_semaphores[i]); // +1 } MTLCommandBufferDescriptor* descriptor = [MTLCommandBufferDescriptor new]; // +1 @@ -336,14 +337,15 @@ id<MTLCommandBuffer> wait_command_buffer = [device->queue commandBufferWithDescriptor:descriptor]; // autoreleased for (iree_host_size_t i = 0; i < wait_semaphore_list.count; ++i) { - [wait_command_buffer encodeWaitForEvent:shared_events[i] - value:wait_semaphore_list.payload_values[i]]; + id<MTLSharedEvent> handle = + iree_hal_metal_shared_event_handle(wait_semaphore_list.semaphores[i]); + [wait_command_buffer encodeWaitForEvent:handle value:wait_semaphore_list.payload_values[i]]; } [wait_command_buffer addCompletedHandler:^(id<MTLCommandBuffer> cb) { for (iree_host_size_t i = 0; i < wait_semaphore_list.count; ++i) { - [shared_events[i] release]; // -1 + iree_hal_semaphore_release(saved_semaphores[i]); // -1 } - iree_allocator_free(device->host_allocator, shared_events); + iree_allocator_free(device->host_allocator, saved_semaphores); }]; [wait_command_buffer commit]; [descriptor release]; // -1 @@ -351,23 +353,29 @@ // Then commit all recorded compute command buffers. for (iree_host_size_t i = 0; i < command_buffer_count; ++i) { - [iree_hal_metal_direct_command_buffer_handle(command_buffers[i]) commit]; + iree_hal_command_buffer_t* command_buffer = command_buffers[i]; + iree_hal_command_buffer_retain(command_buffer); // +1 + id<MTLCommandBuffer> handle = iree_hal_metal_direct_command_buffer_handle(command_buffer); + [handle addCompletedHandler:^(id<MTLCommandBuffer> cb) { + iree_hal_command_buffer_release(command_buffer); // -1 + }]; + [handle commit]; } // Finally create a new command buffer and encode signal commands for all signal semaphores. if (signal_semaphore_list.count > 0) { - // Extract all MTLSharedEvents behind into a heap-allocated array--we will need to access them - // in command buffer completion callback. - id<MTLSharedEvent>* shared_events; + // Copy the full semaphore list to heap--we will need to access them in command buffer + // completion callback. + iree_hal_semaphore_t** saved_semaphores; + iree_host_size_t size = sizeof(iree_hal_semaphore_t*) * signal_semaphore_list.count; IREE_RETURN_AND_END_ZONE_IF_ERROR( - z0, iree_allocator_malloc(device->host_allocator, - sizeof(id<MTLSharedEvent>) * signal_semaphore_list.count, - (void**)&shared_events)); + z0, iree_allocator_malloc(device->host_allocator, size, (void**)&saved_semaphores)); + + // IREE will free resources once their refcounts become zero on host. However, there are work + // happening on the GPU async still needing access. So make sure we retain all semaphores + // until command buffer completion. for (iree_host_size_t i = 0; i < signal_semaphore_list.count; ++i) { - shared_events[i] = iree_hal_metal_shared_event_handle(signal_semaphore_list.semaphores[i]); - // Make sure we retain the shared event until command buffer completion. IREE might free the - // wrapping semaphore after refcount become zero. - [shared_events[i] retain]; // +1 + iree_hal_semaphore_retain(saved_semaphores[i]); // +1 } MTLCommandBufferDescriptor* descriptor = [MTLCommandBufferDescriptor new]; // +1 @@ -378,15 +386,16 @@ id<MTLCommandBuffer> signal_command_buffer = [device->queue commandBufferWithDescriptor:descriptor]; // autoreleased for (iree_host_size_t i = 0; i < signal_semaphore_list.count; ++i) { - [signal_command_buffer encodeSignalEvent:iree_hal_metal_shared_event_handle( - signal_semaphore_list.semaphores[i]) + id<MTLSharedEvent> handle = + iree_hal_metal_shared_event_handle(signal_semaphore_list.semaphores[i]); + [signal_command_buffer encodeSignalEvent:handle value:signal_semaphore_list.payload_values[i]]; } [signal_command_buffer addCompletedHandler:^(id<MTLCommandBuffer> cb) { for (iree_host_size_t i = 0; i < signal_semaphore_list.count; ++i) { - [shared_events[i] release]; // -1 + iree_hal_semaphore_release(saved_semaphores[i]); // -1 } - iree_allocator_free(device->host_allocator, shared_events); + iree_allocator_free(device->host_allocator, saved_semaphores); }]; [signal_command_buffer commit]; [descriptor release]; // -1