[metal] Use Metal shared event to implement IREE semaphore APIs
Metal shared events can be used to synchronize changes to resources
across multiple Metal device objects, across processes, or between
a device object and CPU access to resources. It's a direct match
for IREE semaphores.
diff --git a/experimental/metal/CMakeLists.txt b/experimental/metal/CMakeLists.txt
index 3248c51..ec60332 100644
--- a/experimental/metal/CMakeLists.txt
+++ b/experimental/metal/CMakeLists.txt
@@ -23,6 +23,8 @@
"metal_buffer.m"
"metal_device.m"
"metal_driver.m"
+ "metal_shared_event.h"
+ "metal_shared_event.m"
DEPS
iree::base
iree::base::core_headers
diff --git a/experimental/metal/cts/CMakeLists.txt b/experimental/metal/cts/CMakeLists.txt
index 1fa66b7..6f06e17 100644
--- a/experimental/metal/cts/CMakeLists.txt
+++ b/experimental/metal/cts/CMakeLists.txt
@@ -21,5 +21,6 @@
"allocator"
"buffer_mapping"
"driver"
+ "semaphore"
)
diff --git a/experimental/metal/metal_device.m b/experimental/metal/metal_device.m
index dd19ffe..8099dbf 100644
--- a/experimental/metal/metal_device.m
+++ b/experimental/metal/metal_device.m
@@ -7,6 +7,7 @@
#include "experimental/metal/metal_device.h"
#include "experimental/metal/direct_allocator.h"
+#include "experimental/metal/metal_shared_event.h"
#include "iree/base/api.h"
#include "iree/base/tracing.h"
#include "iree/hal/api.h"
@@ -25,6 +26,11 @@
iree_hal_allocator_t* device_allocator;
id<MTLDevice> device;
+
+ // A dispatch queue and associated event listener for running Objective-C blocks to singal
+ // semaphores and wake up threads.
+ dispatch_queue_t semaphore_notification_queue;
+ MTLSharedEventListener* event_listener;
} iree_hal_metal_device_t;
static const iree_hal_device_vtable_t iree_hal_metal_device_vtable;
@@ -54,6 +60,9 @@
iree_hal_driver_retain(device->driver);
device->host_allocator = host_allocator;
device->device = [metal_device retain]; // +1
+ device->semaphore_notification_queue = dispatch_queue_create("dev.iree.queue.metal", NULL);
+ device->event_listener = [[MTLSharedEventListener alloc]
+ initWithDispatchQueue:device->semaphore_notification_queue]; // +1
*out_device = (iree_hal_device_t*)device;
} else {
@@ -80,6 +89,9 @@
iree_allocator_t host_allocator = iree_hal_device_host_allocator(base_device);
IREE_TRACE_ZONE_BEGIN(z0);
+ [device->event_listener release]; // -1
+ dispatch_release(device->semaphore_notification_queue);
+
iree_hal_allocator_release(device->device_allocator);
[device->device release]; // -1
iree_hal_driver_release(device->driver);
@@ -165,12 +177,22 @@
static iree_status_t iree_hal_metal_device_create_semaphore(iree_hal_device_t* base_device,
uint64_t initial_value,
iree_hal_semaphore_t** out_semaphore) {
- return iree_make_status(IREE_STATUS_UNIMPLEMENTED, "unimplmented semaphore create");
+ iree_hal_metal_device_t* device = iree_hal_metal_device_cast(base_device);
+ return iree_hal_metal_shared_event_create(device->device, initial_value, device->event_listener,
+ device->host_allocator, out_semaphore);
}
static iree_hal_semaphore_compatibility_t iree_hal_metal_device_query_semaphore_compatibility(
iree_hal_device_t* base_device, iree_hal_semaphore_t* semaphore) {
- return IREE_HAL_SEMAPHORE_COMPATIBILITY_NONE;
+ if (iree_hal_metal_shared_event_isa(semaphore)) {
+ // Fast-path for semaphores related to this device.
+ // TODO(benvanik): ensure the creating devices are compatible in cases where
+ // multiple devices are used.
+ return IREE_HAL_SEMAPHORE_COMPATIBILITY_ALL;
+ }
+ // TODO(benvanik): semaphore APIs for querying allowed export formats. We
+ // can check device caps to see what external semaphore types are supported.
+ return IREE_HAL_SEMAPHORE_COMPATIBILITY_HOST_ONLY;
}
static iree_status_t iree_hal_metal_device_queue_alloca(
@@ -205,7 +227,7 @@
static iree_status_t iree_hal_metal_device_wait_semaphores(
iree_hal_device_t* base_device, iree_hal_wait_mode_t wait_mode,
const iree_hal_semaphore_list_t semaphore_list, iree_timeout_t timeout) {
- return iree_make_status(IREE_STATUS_UNIMPLEMENTED, "unimplmented semaphore wait");
+ return iree_hal_metal_shared_event_multi_wait(wait_mode, &semaphore_list, timeout);
}
static iree_status_t iree_hal_metal_device_profiling_begin(
diff --git a/experimental/metal/metal_shared_event.h b/experimental/metal/metal_shared_event.h
new file mode 100644
index 0000000..55f627d
--- /dev/null
+++ b/experimental/metal/metal_shared_event.h
@@ -0,0 +1,48 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_EXPERIMENTAL_METAL_METAL_SHARED_EVENT_H_
+#define IREE_EXPERIMENTAL_METAL_METAL_SHARED_EVENT_H_
+
+#import <Metal/Metal.h>
+
+#include "iree/base/api.h"
+#include "iree/hal/api.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif // __cplusplus
+
+// Creates a Metal shared event with the given |initial_value| to implement an
+// IREE semaphore.
+//
+// |listener| is used for dispatching notifications for async execution.
+//
+// |out_semaphore| must be released by the caller (see
+// iree_hal_semaphore_release).
+iree_status_t iree_hal_metal_shared_event_create(
+ id<MTLDevice> device, uint64_t initial_value,
+ MTLSharedEventListener* listener, iree_allocator_t host_allocator,
+ iree_hal_semaphore_t** out_semaphore);
+
+// Returns true if |semaphore| is a Metal shared event.
+bool iree_hal_metal_shared_event_isa(iree_hal_semaphore_t* semaphore);
+
+// Returns the underlying Metal shared event handle for the given |semaphore|.
+id<MTLSharedEvent> iree_hal_metal_shared_event_handle(
+ iree_hal_semaphore_t* base_semaphore);
+
+// Waits on the shared events in the given |semaphore_list| according to the
+// |wait_mode| before |timeout|.
+iree_status_t iree_hal_metal_shared_event_multi_wait(
+ iree_hal_wait_mode_t wait_mode,
+ const iree_hal_semaphore_list_t* semaphore_list, iree_timeout_t timeout);
+
+#ifdef __cplusplus
+} // extern "C"
+#endif // __cplusplus
+
+#endif // IREE_EXPERIMENTAL_METAL_METAL_SHARED_EVENT_H_
diff --git a/experimental/metal/metal_shared_event.m b/experimental/metal/metal_shared_event.m
new file mode 100644
index 0000000..ce1f688
--- /dev/null
+++ b/experimental/metal/metal_shared_event.m
@@ -0,0 +1,276 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "experimental/metal/metal_shared_event.h"
+
+#import <Metal/Metal.h>
+
+#include <stdint.h>
+
+#include "iree/base/api.h"
+#include "iree/base/internal/synchronization.h"
+#include "iree/base/tracing.h"
+#include "iree/hal/api.h"
+
+typedef struct iree_hal_metal_shared_event_t {
+ // Abstract resource used for injecting reference counting and vtable; must be at offset 0.
+ iree_hal_resource_t resource;
+
+ id<MTLSharedEvent> shared_event;
+ // A listener object used for dispatching notifications; owned by the device.
+ MTLSharedEventListener* event_listener;
+
+ iree_allocator_t host_allocator;
+
+ // Permanently failure state of the current semaphore, if failed.
+ iree_status_t failure_state;
+ // Mutex guarding access to the failure state.
+ iree_slim_mutex_t state_mutex;
+} iree_hal_metal_shared_event_t;
+
+static const iree_hal_semaphore_vtable_t iree_hal_metal_shared_event_vtable;
+
+static iree_hal_metal_shared_event_t* iree_hal_metal_shared_event_cast(
+ iree_hal_semaphore_t* base_value) {
+ IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_metal_shared_event_vtable);
+ return (iree_hal_metal_shared_event_t*)base_value;
+}
+
+bool iree_hal_metal_shared_event_isa(iree_hal_semaphore_t* semaphore) {
+ return iree_hal_resource_is(semaphore, &iree_hal_metal_shared_event_vtable);
+}
+
+id<MTLSharedEvent> iree_hal_metal_shared_event_handle(iree_hal_semaphore_t* base_semaphore) {
+ iree_hal_metal_shared_event_t* semaphore = iree_hal_metal_shared_event_cast(base_semaphore);
+ return semaphore->shared_event;
+}
+
+iree_status_t iree_hal_metal_shared_event_create(id<MTLDevice> device, uint64_t initial_value,
+ MTLSharedEventListener* listener,
+ iree_allocator_t host_allocator,
+ iree_hal_semaphore_t** out_semaphore) {
+ IREE_ASSERT_ARGUMENT(out_semaphore);
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ iree_hal_metal_shared_event_t* semaphore = NULL;
+ iree_status_t status =
+ iree_allocator_malloc(host_allocator, sizeof(*semaphore), (void**)&semaphore);
+ if (iree_status_is_ok(status)) {
+ iree_hal_resource_initialize(&iree_hal_metal_shared_event_vtable, &semaphore->resource);
+ semaphore->shared_event = [device newSharedEvent]; // +1
+ semaphore->shared_event.signaledValue = initial_value;
+ semaphore->event_listener = listener;
+ semaphore->host_allocator = host_allocator;
+ iree_slim_mutex_initialize(&semaphore->state_mutex);
+ semaphore->failure_state = iree_ok_status();
+ *out_semaphore = (iree_hal_semaphore_t*)semaphore;
+ }
+
+ IREE_TRACE_ZONE_END(z0);
+ return status;
+}
+
+static void iree_hal_metal_shared_event_destroy(iree_hal_semaphore_t* base_semaphore) {
+ iree_hal_metal_shared_event_t* semaphore = iree_hal_metal_shared_event_cast(base_semaphore);
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ [semaphore->shared_event release]; // -1
+ iree_slim_mutex_deinitialize(&semaphore->state_mutex);
+ iree_allocator_free(semaphore->host_allocator, semaphore);
+
+ IREE_TRACE_ZONE_END(z0);
+}
+
+static iree_status_t iree_hal_metal_shared_event_query(iree_hal_semaphore_t* base_semaphore,
+ uint64_t* out_value) {
+ iree_hal_metal_shared_event_t* semaphore = iree_hal_metal_shared_event_cast(base_semaphore);
+ uint64_t value = semaphore->shared_event.signaledValue;
+ if (IREE_UNLIKELY(value == UINT64_MAX)) {
+ iree_status_t status = iree_ok_status();
+ iree_slim_mutex_lock(&semaphore->state_mutex);
+ status = semaphore->failure_state;
+ iree_slim_mutex_unlock(&semaphore->state_mutex);
+ return status;
+ }
+ *out_value = value;
+ return iree_ok_status();
+}
+
+static iree_status_t iree_hal_metal_shared_event_signal(iree_hal_semaphore_t* base_semaphore,
+ uint64_t new_value) {
+ iree_hal_metal_shared_event_t* semaphore = iree_hal_metal_shared_event_cast(base_semaphore);
+ uint64_t value = semaphore->shared_event.signaledValue;
+ if (IREE_UNLIKELY(value == UINT64_MAX)) {
+ iree_status_t status = iree_ok_status();
+ iree_slim_mutex_lock(&semaphore->state_mutex);
+ status = semaphore->failure_state;
+ iree_slim_mutex_unlock(&semaphore->state_mutex);
+ return status;
+ }
+ semaphore->shared_event.signaledValue = new_value;
+ return iree_ok_status();
+}
+
+static void iree_hal_metal_shared_event_fail(iree_hal_semaphore_t* base_semaphore,
+ iree_status_t status) {
+ iree_hal_metal_shared_event_t* semaphore = iree_hal_metal_shared_event_cast(base_semaphore);
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ iree_slim_mutex_lock(&semaphore->state_mutex);
+ semaphore->failure_state = status;
+ semaphore->shared_event.signaledValue = UINT64_MAX;
+ iree_slim_mutex_unlock(&semaphore->state_mutex);
+
+ IREE_TRACE_ZONE_END(z0);
+}
+
+static iree_status_t iree_hal_metal_shared_event_wait(iree_hal_semaphore_t* base_semaphore,
+ uint64_t value, iree_timeout_t timeout) {
+ iree_hal_metal_shared_event_t* semaphore = iree_hal_metal_shared_event_cast(base_semaphore);
+
+ iree_time_t deadline_ns = iree_timeout_as_deadline_ns(timeout);
+ uint64_t timeout_ns;
+ dispatch_time_t apple_timeout_ns;
+ if (deadline_ns == IREE_TIME_INFINITE_FUTURE) {
+ timeout_ns = UINT64_MAX;
+ apple_timeout_ns = DISPATCH_TIME_FOREVER;
+ } else if (deadline_ns == IREE_TIME_INFINITE_PAST) {
+ timeout_ns = 0;
+ apple_timeout_ns = DISPATCH_TIME_NOW;
+ } else {
+ iree_time_t now_ns = iree_time_now();
+ if (deadline_ns < now_ns) {
+ return iree_status_from_code(IREE_STATUS_DEADLINE_EXCEEDED);
+ }
+ timeout_ns = (uint64_t)(deadline_ns - now_ns);
+ apple_timeout_ns = dispatch_time(DISPATCH_TIME_NOW, timeout_ns);
+ }
+
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ // Quick path for impatient waiting to avoid all the overhead of dispatch queues and semaphores.
+ if (timeout_ns == 0) {
+ uint64_t current_value = 0;
+ iree_status_t status = iree_hal_metal_shared_event_query(base_semaphore, ¤t_value);
+ if (iree_status_is_ok(status) && current_value < value) {
+ status = iree_status_from_code(IREE_STATUS_DEADLINE_EXCEEDED);
+ }
+ IREE_TRACE_ZONE_END(z0);
+ return status;
+ }
+
+ // Theoretically we don't really need to mark the semaphore handle as __block given that the
+ // handle itself is not modified and there is only one block and it will copy the handle.
+ // But marking it as __block serves as good documentation purpose.
+ __block dispatch_semaphore_t work_done = dispatch_semaphore_create(0);
+
+ __block bool did_fail = false;
+
+ // Use a listener to the MTLSharedEvent to notify us when the work is done on GPU by signaling a
+ // semaphore. The signaling will happen in a new dispatch queue; the current thread will wait on
+ // the semaphore.
+ [semaphore->shared_event notifyListener:semaphore->event_listener
+ atValue:value
+ block:^(id<MTLSharedEvent> se, uint64_t v) {
+ if (v == UINT64_MAX) did_fail = true;
+
+ dispatch_semaphore_signal(work_done);
+ }];
+
+ // If the work is not done immediately, dispatch_semaphore_wait decreases the semaphore value to
+ // less than zero first and then puts the current thread into wait state.
+ intptr_t timed_out = dispatch_semaphore_wait(work_done, apple_timeout_ns);
+ dispatch_release(work_done);
+
+ IREE_TRACE_ZONE_END(z0);
+ if (IREE_UNLIKELY(did_fail)) return iree_status_from_code(IREE_STATUS_ABORTED);
+ if (timed_out) return iree_status_from_code(IREE_STATUS_DEADLINE_EXCEEDED);
+ return iree_ok_status();
+}
+
+iree_status_t iree_hal_metal_shared_event_multi_wait(
+ iree_hal_wait_mode_t wait_mode, const iree_hal_semaphore_list_t* semaphore_list,
+ iree_timeout_t timeout) {
+ if (semaphore_list->count == 0) return iree_ok_status();
+ // If there is only one semaphore, just wait on it.
+ if (semaphore_list->count == 1) {
+ return iree_hal_metal_shared_event_wait(semaphore_list->semaphores[0],
+ semaphore_list->payload_values[0], timeout);
+ }
+
+ iree_time_t deadline_ns = iree_timeout_as_deadline_ns(timeout);
+ uint64_t timeout_ns;
+ dispatch_time_t apple_timeout_ns;
+ if (deadline_ns == IREE_TIME_INFINITE_FUTURE) {
+ timeout_ns = UINT64_MAX;
+ apple_timeout_ns = DISPATCH_TIME_FOREVER;
+ } else if (deadline_ns == IREE_TIME_INFINITE_PAST) {
+ timeout_ns = 0;
+ apple_timeout_ns = DISPATCH_TIME_NOW;
+ } else {
+ iree_time_t now_ns = iree_time_now();
+ if (deadline_ns < now_ns) {
+ return iree_status_from_code(IREE_STATUS_DEADLINE_EXCEEDED);
+ }
+ timeout_ns = (uint64_t)(deadline_ns - now_ns);
+ apple_timeout_ns = dispatch_time(DISPATCH_TIME_NOW, timeout_ns);
+ }
+
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ // Create an atomic to count how many semaphores have signaled. Mark it as `__block` so different
+ // threads are sharing the same data via reference.
+ __block iree_atomic_int32_t wait_count;
+ iree_atomic_store_int32(&wait_count, 0, iree_memory_order_release);
+ // The total count we are expecting to see.
+ iree_host_size_t total_count = (wait_mode == IREE_HAL_WAIT_MODE_ALL) ? semaphore_list->count : 1;
+ // Theoretically we don't really need to mark the semaphore handle as __block given that the
+ // handle itself is not modified and there is only one block and it will copy the handle.
+ // But marking it as __block serves as good documentation purpose.
+ __block dispatch_semaphore_t work_done = dispatch_semaphore_create(0);
+
+ __block bool did_fail = false;
+
+ for (iree_host_size_t i = 0; i < semaphore_list->count; ++i) {
+ // Use a listener to the MTLSharedEvent to notify us when the work is done on GPU by signaling a
+ // semaphore. The signaling will happen in a new dispatch queue; the current thread will wait on
+ // the semaphore.
+ iree_hal_metal_shared_event_t* semaphore =
+ iree_hal_metal_shared_event_cast(semaphore_list->semaphores[i]);
+ [semaphore->shared_event notifyListener:semaphore->event_listener
+ atValue:semaphore_list->payload_values[i]
+ block:^(id<MTLSharedEvent> se, uint64_t v) {
+ // Fail as a whole if any participating semaphore failed.
+ if (v == UINT64_MAX) did_fail = true;
+
+ int32_t old_value = iree_atomic_fetch_add_int32(
+ &wait_count, 1, iree_memory_order_release);
+ // The last signaled semaphore send out the notification.
+ // Atomic fetch add returns the old value, so need to +1.
+ if (old_value + 1 == total_count) {
+ dispatch_semaphore_signal(work_done);
+ }
+ }];
+ }
+
+ // If the work is not done immediately, dispatch_semaphore_wait decreases the semaphore value by
+ // one first and then puts the current thread into wait state.
+ intptr_t timed_out = dispatch_semaphore_wait(work_done, apple_timeout_ns);
+ dispatch_release(work_done);
+
+ IREE_TRACE_ZONE_END(z0);
+ if (IREE_UNLIKELY(did_fail)) return iree_status_from_code(IREE_STATUS_ABORTED);
+ if (timed_out) return iree_status_from_code(IREE_STATUS_DEADLINE_EXCEEDED);
+ return iree_ok_status();
+}
+
+static const iree_hal_semaphore_vtable_t iree_hal_metal_shared_event_vtable = {
+ .destroy = iree_hal_metal_shared_event_destroy,
+ .query = iree_hal_metal_shared_event_query,
+ .signal = iree_hal_metal_shared_event_signal,
+ .fail = iree_hal_metal_shared_event_fail,
+ .wait = iree_hal_metal_shared_event_wait,
+};