[metal] Use Metal shared event to implement IREE semaphore APIs

Metal shared events can be used to synchronize changes to resources
across multiple Metal device objects, across processes, or between
a device object and CPU access to resources. It's a direct match
for IREE semaphores.
diff --git a/experimental/metal/CMakeLists.txt b/experimental/metal/CMakeLists.txt
index 3248c51..ec60332 100644
--- a/experimental/metal/CMakeLists.txt
+++ b/experimental/metal/CMakeLists.txt
@@ -23,6 +23,8 @@
     "metal_buffer.m"
     "metal_device.m"
     "metal_driver.m"
+    "metal_shared_event.h"
+    "metal_shared_event.m"
   DEPS
     iree::base
     iree::base::core_headers
diff --git a/experimental/metal/cts/CMakeLists.txt b/experimental/metal/cts/CMakeLists.txt
index 1fa66b7..6f06e17 100644
--- a/experimental/metal/cts/CMakeLists.txt
+++ b/experimental/metal/cts/CMakeLists.txt
@@ -21,5 +21,6 @@
     "allocator"
     "buffer_mapping"
     "driver"
+    "semaphore"
 )
 
diff --git a/experimental/metal/metal_device.m b/experimental/metal/metal_device.m
index dd19ffe..8099dbf 100644
--- a/experimental/metal/metal_device.m
+++ b/experimental/metal/metal_device.m
@@ -7,6 +7,7 @@
 #include "experimental/metal/metal_device.h"
 
 #include "experimental/metal/direct_allocator.h"
+#include "experimental/metal/metal_shared_event.h"
 #include "iree/base/api.h"
 #include "iree/base/tracing.h"
 #include "iree/hal/api.h"
@@ -25,6 +26,11 @@
   iree_hal_allocator_t* device_allocator;
 
   id<MTLDevice> device;
+
+  // A dispatch queue and associated event listener for running Objective-C blocks to singal
+  // semaphores and wake up threads.
+  dispatch_queue_t semaphore_notification_queue;
+  MTLSharedEventListener* event_listener;
 } iree_hal_metal_device_t;
 
 static const iree_hal_device_vtable_t iree_hal_metal_device_vtable;
@@ -54,6 +60,9 @@
     iree_hal_driver_retain(device->driver);
     device->host_allocator = host_allocator;
     device->device = [metal_device retain];  // +1
+    device->semaphore_notification_queue = dispatch_queue_create("dev.iree.queue.metal", NULL);
+    device->event_listener = [[MTLSharedEventListener alloc]
+        initWithDispatchQueue:device->semaphore_notification_queue];  // +1
 
     *out_device = (iree_hal_device_t*)device;
   } else {
@@ -80,6 +89,9 @@
   iree_allocator_t host_allocator = iree_hal_device_host_allocator(base_device);
   IREE_TRACE_ZONE_BEGIN(z0);
 
+  [device->event_listener release];  // -1
+  dispatch_release(device->semaphore_notification_queue);
+
   iree_hal_allocator_release(device->device_allocator);
   [device->device release];  // -1
   iree_hal_driver_release(device->driver);
@@ -165,12 +177,22 @@
 static iree_status_t iree_hal_metal_device_create_semaphore(iree_hal_device_t* base_device,
                                                             uint64_t initial_value,
                                                             iree_hal_semaphore_t** out_semaphore) {
-  return iree_make_status(IREE_STATUS_UNIMPLEMENTED, "unimplmented semaphore create");
+  iree_hal_metal_device_t* device = iree_hal_metal_device_cast(base_device);
+  return iree_hal_metal_shared_event_create(device->device, initial_value, device->event_listener,
+                                            device->host_allocator, out_semaphore);
 }
 
 static iree_hal_semaphore_compatibility_t iree_hal_metal_device_query_semaphore_compatibility(
     iree_hal_device_t* base_device, iree_hal_semaphore_t* semaphore) {
-  return IREE_HAL_SEMAPHORE_COMPATIBILITY_NONE;
+  if (iree_hal_metal_shared_event_isa(semaphore)) {
+    // Fast-path for semaphores related to this device.
+    // TODO(benvanik): ensure the creating devices are compatible in cases where
+    // multiple devices are used.
+    return IREE_HAL_SEMAPHORE_COMPATIBILITY_ALL;
+  }
+  // TODO(benvanik): semaphore APIs for querying allowed export formats. We
+  // can check device caps to see what external semaphore types are supported.
+  return IREE_HAL_SEMAPHORE_COMPATIBILITY_HOST_ONLY;
 }
 
 static iree_status_t iree_hal_metal_device_queue_alloca(
@@ -205,7 +227,7 @@
 static iree_status_t iree_hal_metal_device_wait_semaphores(
     iree_hal_device_t* base_device, iree_hal_wait_mode_t wait_mode,
     const iree_hal_semaphore_list_t semaphore_list, iree_timeout_t timeout) {
-  return iree_make_status(IREE_STATUS_UNIMPLEMENTED, "unimplmented semaphore wait");
+  return iree_hal_metal_shared_event_multi_wait(wait_mode, &semaphore_list, timeout);
 }
 
 static iree_status_t iree_hal_metal_device_profiling_begin(
diff --git a/experimental/metal/metal_shared_event.h b/experimental/metal/metal_shared_event.h
new file mode 100644
index 0000000..55f627d
--- /dev/null
+++ b/experimental/metal/metal_shared_event.h
@@ -0,0 +1,48 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_EXPERIMENTAL_METAL_METAL_SHARED_EVENT_H_
+#define IREE_EXPERIMENTAL_METAL_METAL_SHARED_EVENT_H_
+
+#import <Metal/Metal.h>
+
+#include "iree/base/api.h"
+#include "iree/hal/api.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif  // __cplusplus
+
+// Creates a Metal shared event with the given |initial_value| to implement an
+// IREE semaphore.
+//
+// |listener| is used for dispatching notifications for async execution.
+//
+// |out_semaphore| must be released by the caller (see
+// iree_hal_semaphore_release).
+iree_status_t iree_hal_metal_shared_event_create(
+    id<MTLDevice> device, uint64_t initial_value,
+    MTLSharedEventListener* listener, iree_allocator_t host_allocator,
+    iree_hal_semaphore_t** out_semaphore);
+
+// Returns true if |semaphore| is a Metal shared event.
+bool iree_hal_metal_shared_event_isa(iree_hal_semaphore_t* semaphore);
+
+// Returns the underlying Metal shared event handle for the given |semaphore|.
+id<MTLSharedEvent> iree_hal_metal_shared_event_handle(
+    iree_hal_semaphore_t* base_semaphore);
+
+// Waits on the shared events in the given |semaphore_list| according to the
+// |wait_mode| before |timeout|.
+iree_status_t iree_hal_metal_shared_event_multi_wait(
+    iree_hal_wait_mode_t wait_mode,
+    const iree_hal_semaphore_list_t* semaphore_list, iree_timeout_t timeout);
+
+#ifdef __cplusplus
+}  // extern "C"
+#endif  // __cplusplus
+
+#endif  // IREE_EXPERIMENTAL_METAL_METAL_SHARED_EVENT_H_
diff --git a/experimental/metal/metal_shared_event.m b/experimental/metal/metal_shared_event.m
new file mode 100644
index 0000000..ce1f688
--- /dev/null
+++ b/experimental/metal/metal_shared_event.m
@@ -0,0 +1,276 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "experimental/metal/metal_shared_event.h"
+
+#import <Metal/Metal.h>
+
+#include <stdint.h>
+
+#include "iree/base/api.h"
+#include "iree/base/internal/synchronization.h"
+#include "iree/base/tracing.h"
+#include "iree/hal/api.h"
+
+typedef struct iree_hal_metal_shared_event_t {
+  // Abstract resource used for injecting reference counting and vtable; must be at offset 0.
+  iree_hal_resource_t resource;
+
+  id<MTLSharedEvent> shared_event;
+  // A listener object used for dispatching notifications; owned by the device.
+  MTLSharedEventListener* event_listener;
+
+  iree_allocator_t host_allocator;
+
+  // Permanently failure state of the current semaphore, if failed.
+  iree_status_t failure_state;
+  // Mutex guarding access to the failure state.
+  iree_slim_mutex_t state_mutex;
+} iree_hal_metal_shared_event_t;
+
+static const iree_hal_semaphore_vtable_t iree_hal_metal_shared_event_vtable;
+
+static iree_hal_metal_shared_event_t* iree_hal_metal_shared_event_cast(
+    iree_hal_semaphore_t* base_value) {
+  IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_metal_shared_event_vtable);
+  return (iree_hal_metal_shared_event_t*)base_value;
+}
+
+bool iree_hal_metal_shared_event_isa(iree_hal_semaphore_t* semaphore) {
+  return iree_hal_resource_is(semaphore, &iree_hal_metal_shared_event_vtable);
+}
+
+id<MTLSharedEvent> iree_hal_metal_shared_event_handle(iree_hal_semaphore_t* base_semaphore) {
+  iree_hal_metal_shared_event_t* semaphore = iree_hal_metal_shared_event_cast(base_semaphore);
+  return semaphore->shared_event;
+}
+
+iree_status_t iree_hal_metal_shared_event_create(id<MTLDevice> device, uint64_t initial_value,
+                                                 MTLSharedEventListener* listener,
+                                                 iree_allocator_t host_allocator,
+                                                 iree_hal_semaphore_t** out_semaphore) {
+  IREE_ASSERT_ARGUMENT(out_semaphore);
+  IREE_TRACE_ZONE_BEGIN(z0);
+
+  iree_hal_metal_shared_event_t* semaphore = NULL;
+  iree_status_t status =
+      iree_allocator_malloc(host_allocator, sizeof(*semaphore), (void**)&semaphore);
+  if (iree_status_is_ok(status)) {
+    iree_hal_resource_initialize(&iree_hal_metal_shared_event_vtable, &semaphore->resource);
+    semaphore->shared_event = [device newSharedEvent];  // +1
+    semaphore->shared_event.signaledValue = initial_value;
+    semaphore->event_listener = listener;
+    semaphore->host_allocator = host_allocator;
+    iree_slim_mutex_initialize(&semaphore->state_mutex);
+    semaphore->failure_state = iree_ok_status();
+    *out_semaphore = (iree_hal_semaphore_t*)semaphore;
+  }
+
+  IREE_TRACE_ZONE_END(z0);
+  return status;
+}
+
+static void iree_hal_metal_shared_event_destroy(iree_hal_semaphore_t* base_semaphore) {
+  iree_hal_metal_shared_event_t* semaphore = iree_hal_metal_shared_event_cast(base_semaphore);
+  IREE_TRACE_ZONE_BEGIN(z0);
+
+  [semaphore->shared_event release];  // -1
+  iree_slim_mutex_deinitialize(&semaphore->state_mutex);
+  iree_allocator_free(semaphore->host_allocator, semaphore);
+
+  IREE_TRACE_ZONE_END(z0);
+}
+
+static iree_status_t iree_hal_metal_shared_event_query(iree_hal_semaphore_t* base_semaphore,
+                                                       uint64_t* out_value) {
+  iree_hal_metal_shared_event_t* semaphore = iree_hal_metal_shared_event_cast(base_semaphore);
+  uint64_t value = semaphore->shared_event.signaledValue;
+  if (IREE_UNLIKELY(value == UINT64_MAX)) {
+    iree_status_t status = iree_ok_status();
+    iree_slim_mutex_lock(&semaphore->state_mutex);
+    status = semaphore->failure_state;
+    iree_slim_mutex_unlock(&semaphore->state_mutex);
+    return status;
+  }
+  *out_value = value;
+  return iree_ok_status();
+}
+
+static iree_status_t iree_hal_metal_shared_event_signal(iree_hal_semaphore_t* base_semaphore,
+                                                        uint64_t new_value) {
+  iree_hal_metal_shared_event_t* semaphore = iree_hal_metal_shared_event_cast(base_semaphore);
+  uint64_t value = semaphore->shared_event.signaledValue;
+  if (IREE_UNLIKELY(value == UINT64_MAX)) {
+    iree_status_t status = iree_ok_status();
+    iree_slim_mutex_lock(&semaphore->state_mutex);
+    status = semaphore->failure_state;
+    iree_slim_mutex_unlock(&semaphore->state_mutex);
+    return status;
+  }
+  semaphore->shared_event.signaledValue = new_value;
+  return iree_ok_status();
+}
+
+static void iree_hal_metal_shared_event_fail(iree_hal_semaphore_t* base_semaphore,
+                                             iree_status_t status) {
+  iree_hal_metal_shared_event_t* semaphore = iree_hal_metal_shared_event_cast(base_semaphore);
+  IREE_TRACE_ZONE_BEGIN(z0);
+
+  iree_slim_mutex_lock(&semaphore->state_mutex);
+  semaphore->failure_state = status;
+  semaphore->shared_event.signaledValue = UINT64_MAX;
+  iree_slim_mutex_unlock(&semaphore->state_mutex);
+
+  IREE_TRACE_ZONE_END(z0);
+}
+
+static iree_status_t iree_hal_metal_shared_event_wait(iree_hal_semaphore_t* base_semaphore,
+                                                      uint64_t value, iree_timeout_t timeout) {
+  iree_hal_metal_shared_event_t* semaphore = iree_hal_metal_shared_event_cast(base_semaphore);
+
+  iree_time_t deadline_ns = iree_timeout_as_deadline_ns(timeout);
+  uint64_t timeout_ns;
+  dispatch_time_t apple_timeout_ns;
+  if (deadline_ns == IREE_TIME_INFINITE_FUTURE) {
+    timeout_ns = UINT64_MAX;
+    apple_timeout_ns = DISPATCH_TIME_FOREVER;
+  } else if (deadline_ns == IREE_TIME_INFINITE_PAST) {
+    timeout_ns = 0;
+    apple_timeout_ns = DISPATCH_TIME_NOW;
+  } else {
+    iree_time_t now_ns = iree_time_now();
+    if (deadline_ns < now_ns) {
+      return iree_status_from_code(IREE_STATUS_DEADLINE_EXCEEDED);
+    }
+    timeout_ns = (uint64_t)(deadline_ns - now_ns);
+    apple_timeout_ns = dispatch_time(DISPATCH_TIME_NOW, timeout_ns);
+  }
+
+  IREE_TRACE_ZONE_BEGIN(z0);
+
+  // Quick path for impatient waiting to avoid all the overhead of dispatch queues and semaphores.
+  if (timeout_ns == 0) {
+    uint64_t current_value = 0;
+    iree_status_t status = iree_hal_metal_shared_event_query(base_semaphore, &current_value);
+    if (iree_status_is_ok(status) && current_value < value) {
+      status = iree_status_from_code(IREE_STATUS_DEADLINE_EXCEEDED);
+    }
+    IREE_TRACE_ZONE_END(z0);
+    return status;
+  }
+
+  // Theoretically we don't really need to mark the semaphore handle as __block given that the
+  // handle itself is not modified and there is only one block and it will copy the handle.
+  // But marking it as __block serves as good documentation purpose.
+  __block dispatch_semaphore_t work_done = dispatch_semaphore_create(0);
+
+  __block bool did_fail = false;
+
+  // Use a listener to the MTLSharedEvent to notify us when the work is done on GPU by signaling a
+  // semaphore. The signaling will happen in a new dispatch queue; the current thread will wait on
+  // the semaphore.
+  [semaphore->shared_event notifyListener:semaphore->event_listener
+                                  atValue:value
+                                    block:^(id<MTLSharedEvent> se, uint64_t v) {
+                                      if (v == UINT64_MAX) did_fail = true;
+
+                                      dispatch_semaphore_signal(work_done);
+                                    }];
+
+  // If the work is not done immediately, dispatch_semaphore_wait decreases the semaphore value to
+  // less than zero first and then puts the current thread into wait state.
+  intptr_t timed_out = dispatch_semaphore_wait(work_done, apple_timeout_ns);
+  dispatch_release(work_done);
+
+  IREE_TRACE_ZONE_END(z0);
+  if (IREE_UNLIKELY(did_fail)) return iree_status_from_code(IREE_STATUS_ABORTED);
+  if (timed_out) return iree_status_from_code(IREE_STATUS_DEADLINE_EXCEEDED);
+  return iree_ok_status();
+}
+
+iree_status_t iree_hal_metal_shared_event_multi_wait(
+    iree_hal_wait_mode_t wait_mode, const iree_hal_semaphore_list_t* semaphore_list,
+    iree_timeout_t timeout) {
+  if (semaphore_list->count == 0) return iree_ok_status();
+  // If there is only one semaphore, just wait on it.
+  if (semaphore_list->count == 1) {
+    return iree_hal_metal_shared_event_wait(semaphore_list->semaphores[0],
+                                            semaphore_list->payload_values[0], timeout);
+  }
+
+  iree_time_t deadline_ns = iree_timeout_as_deadline_ns(timeout);
+  uint64_t timeout_ns;
+  dispatch_time_t apple_timeout_ns;
+  if (deadline_ns == IREE_TIME_INFINITE_FUTURE) {
+    timeout_ns = UINT64_MAX;
+    apple_timeout_ns = DISPATCH_TIME_FOREVER;
+  } else if (deadline_ns == IREE_TIME_INFINITE_PAST) {
+    timeout_ns = 0;
+    apple_timeout_ns = DISPATCH_TIME_NOW;
+  } else {
+    iree_time_t now_ns = iree_time_now();
+    if (deadline_ns < now_ns) {
+      return iree_status_from_code(IREE_STATUS_DEADLINE_EXCEEDED);
+    }
+    timeout_ns = (uint64_t)(deadline_ns - now_ns);
+    apple_timeout_ns = dispatch_time(DISPATCH_TIME_NOW, timeout_ns);
+  }
+
+  IREE_TRACE_ZONE_BEGIN(z0);
+
+  // Create an atomic to count how many semaphores have signaled. Mark it as `__block` so different
+  // threads are sharing the same data via reference.
+  __block iree_atomic_int32_t wait_count;
+  iree_atomic_store_int32(&wait_count, 0, iree_memory_order_release);
+  // The total count we are expecting to see.
+  iree_host_size_t total_count = (wait_mode == IREE_HAL_WAIT_MODE_ALL) ? semaphore_list->count : 1;
+  // Theoretically we don't really need to mark the semaphore handle as __block given that the
+  // handle itself is not modified and there is only one block and it will copy the handle.
+  // But marking it as __block serves as good documentation purpose.
+  __block dispatch_semaphore_t work_done = dispatch_semaphore_create(0);
+
+  __block bool did_fail = false;
+
+  for (iree_host_size_t i = 0; i < semaphore_list->count; ++i) {
+    // Use a listener to the MTLSharedEvent to notify us when the work is done on GPU by signaling a
+    // semaphore. The signaling will happen in a new dispatch queue; the current thread will wait on
+    // the semaphore.
+    iree_hal_metal_shared_event_t* semaphore =
+        iree_hal_metal_shared_event_cast(semaphore_list->semaphores[i]);
+    [semaphore->shared_event notifyListener:semaphore->event_listener
+                                    atValue:semaphore_list->payload_values[i]
+                                      block:^(id<MTLSharedEvent> se, uint64_t v) {
+                                        // Fail as a whole if any participating semaphore failed.
+                                        if (v == UINT64_MAX) did_fail = true;
+
+                                        int32_t old_value = iree_atomic_fetch_add_int32(
+                                            &wait_count, 1, iree_memory_order_release);
+                                        // The last signaled semaphore send out the notification.
+                                        // Atomic fetch add returns the old value, so need to +1.
+                                        if (old_value + 1 == total_count) {
+                                          dispatch_semaphore_signal(work_done);
+                                        }
+                                      }];
+  }
+
+  // If the work is not done immediately, dispatch_semaphore_wait decreases the semaphore value by
+  // one first and then puts the current thread into wait state.
+  intptr_t timed_out = dispatch_semaphore_wait(work_done, apple_timeout_ns);
+  dispatch_release(work_done);
+
+  IREE_TRACE_ZONE_END(z0);
+  if (IREE_UNLIKELY(did_fail)) return iree_status_from_code(IREE_STATUS_ABORTED);
+  if (timed_out) return iree_status_from_code(IREE_STATUS_DEADLINE_EXCEEDED);
+  return iree_ok_status();
+}
+
+static const iree_hal_semaphore_vtable_t iree_hal_metal_shared_event_vtable = {
+    .destroy = iree_hal_metal_shared_event_destroy,
+    .query = iree_hal_metal_shared_event_query,
+    .signal = iree_hal_metal_shared_event_signal,
+    .fail = iree_hal_metal_shared_event_fail,
+    .wait = iree_hal_metal_shared_event_wait,
+};