Implementing coroutine semaphore/fence awaits.

Progress on #8093.
diff --git a/runtime/src/iree/modules/hal/module.c b/runtime/src/iree/modules/hal/module.c
index eada4af..939a38a 100644
--- a/runtime/src/iree/modules/hal/module.c
+++ b/runtime/src/iree/modules/hal/module.c
@@ -1375,43 +1375,202 @@
   return iree_ok_status();
 }
 
+// Removes entries in |fences| if they have been reached.
+// Returns failure if one or more fences have failed.
+static iree_status_t iree_hal_module_fence_elide_reached(
+    iree_host_size_t* fence_count, iree_hal_fence_t** fences) {
+  iree_host_size_t new_count = *fence_count;
+  for (iree_host_size_t i = 0; i < new_count;) {
+    iree_status_t status = iree_hal_fence_query(fences[i]);
+    if (iree_status_is_ok(status)) {
+      // Has been reached; shift the list down.
+      memmove(&fences[i], &fences[i + 1],
+              (new_count - i - 1) * sizeof(iree_hal_fence_t*));
+      fences[new_count - 1] = NULL;
+      --new_count;
+    } else if (iree_status_is_deferred(status)) {
+      // Still waiting.
+      iree_status_ignore(status);
+      ++i;  // next
+    } else {
+      // Failed; propagate failure.
+      *fence_count = new_count;
+      return status;
+    }
+  }
+  *fence_count = new_count;
+  return iree_ok_status();
+}
+
+// Enters a wait frame for all timepoints in all |fences|.
+// Returns an |out_wait_status| of OK if all fences have been reached or
+// IREE_STATUS_DEFERRED if one or more fences are still pending and a wait
+// frame was entered.
+static iree_status_t iree_hal_module_fence_await_begin(
+    iree_vm_stack_t* stack, iree_host_size_t fence_count,
+    iree_hal_fence_t** fences, iree_timeout_t timeout, iree_zone_id_t zone_id,
+    iree_status_t* out_wait_status) {
+  // To avoid additional allocations when waiting on multiple fences we enter
+  // the wait frame with the maximum required wait source capacity and perform
+  // a simple deduplication when building the list. Ideally this helps get us on
+  // fast paths of single semaphore waits. The common case is a single fence in
+  // which case this is all exceptional.
+  iree_host_size_t total_timepoint_capacity = 0;
+  for (iree_host_size_t i = 0; i < fence_count; ++i) {
+    total_timepoint_capacity += iree_hal_fence_timepoint_count(fences[i]);
+  }
+
+  // Fast-path for no semaphores (empty/immediate fences).
+  if (total_timepoint_capacity == 0) {
+    *out_wait_status = iree_ok_status();
+    IREE_TRACE_ZONE_END(zone_id);
+    return iree_ok_status();
+  }
+
+  // Reserve storage as if all timepoints from all fences were unique.
+  iree_vm_wait_frame_t* wait_frame = NULL;
+  IREE_RETURN_IF_ERROR(iree_vm_stack_wait_enter(stack, IREE_VM_WAIT_ALL,
+                                                total_timepoint_capacity,
+                                                timeout, zone_id, &wait_frame));
+
+  // Insert the first set of timepoints - they're already deduplicated.
+  iree_host_size_t unique_timepoint_count = 0;
+  if (fence_count >= 1) {
+    iree_hal_semaphore_list_t semaphore_list =
+        iree_hal_fence_semaphore_list(fences[0]);
+    for (iree_host_size_t i = 0; i < semaphore_list.count; ++i) {
+      iree_wait_source_t wait_source = iree_hal_semaphore_await(
+          semaphore_list.semaphores[i], semaphore_list.payload_values[i]);
+      wait_frame->wait_sources[unique_timepoint_count++] = wait_source;
+    }
+  }
+
+  // TODO(benvanik): simplify this; it may not be worth the complexity. We'll
+  // need more real workloads using multi-fence joins to see how useful this is.
+
+  // Insert remaining fence timepoints by performing merging as we go.
+  for (iree_host_size_t i = 1; i < fence_count; ++i) {
+    iree_hal_semaphore_list_t semaphore_list =
+        iree_hal_fence_semaphore_list(fences[i]);
+    for (iree_host_size_t j = 0; j < semaphore_list.count; ++j) {
+      // O(n^2) set insertion - relying on this being rare and the total count
+      // being low. The savings of a small linear scan here relative to an
+      // additional syscall are always worth it but we may want to go further.
+      iree_wait_source_t wait_source = iree_hal_semaphore_await(
+          semaphore_list.semaphores[j], semaphore_list.payload_values[j]);
+      bool found_existing = false;
+      for (iree_host_size_t k = 0; k < unique_timepoint_count; ++k) {
+        if (wait_frame->wait_sources[k].ctl == wait_source.ctl &&
+            wait_frame->wait_sources[k].self == wait_source.self) {
+          // Found existing; use max of both.
+          wait_frame->wait_sources[k].data =
+              iree_max(wait_frame->wait_sources[k].data, wait_source.data);
+          found_existing = true;
+          break;
+        }
+      }
+      if (!found_existing) {
+        wait_frame->wait_sources[unique_timepoint_count++] = wait_source;
+      }
+    }
+  }
+
+  // Update frame with the actual number of timepoints in the wait operation.
+  wait_frame->count = unique_timepoint_count;
+
+  *out_wait_status = iree_status_from_code(IREE_STATUS_DEFERRED);
+  return iree_ok_status();
+}
+
+// PC for iree_hal_module_fence_await.
+enum iree_hal_module_fence_await_pc_e {
+  // Initial entry point that will try to either wait inline or yield to the
+  // scheduler with a wait-all operation.
+  IREE_HAL_MODULE_FENCE_AWAIT_PC_BEGIN = 0,
+  // Resume entry point after the scheduler wait has resolved (successfully or
+  // otherwise).
+  IREE_HAL_MODULE_FENCE_AWAIT_PC_RESUME,
+};
+
 IREE_VM_ABI_EXPORT(iree_hal_module_fence_await,  //
                    iree_hal_module_state_t,      //
                    iCrD, i) {
-  uint32_t timeout_millis = (uint32_t)args->i0;
-  iree_host_size_t fence_count = 0;
-  iree_hal_fence_t** fences = NULL;
-  IREE_VM_ABI_VLA_STACK_DEREF(args, a1_count, a1, iree_hal_fence, 32,
-                              &fence_count, &fences);
+  // On entry we either perform the wait or begin a coroutine yield operation.
+  // After resuming we check to see if the fence has been reached and propagate
+  // the result.
+  iree_vm_stack_frame_t* current_frame = iree_vm_stack_top(stack);
+  iree_zone_id_t zone_id = 0;
+  iree_status_t wait_status = iree_ok_status();
+  if (current_frame->pc == IREE_HAL_MODULE_FENCE_AWAIT_PC_BEGIN) {
+    uint32_t timeout_millis = (uint32_t)args->i0;
+    iree_host_size_t fence_count = 0;
+    iree_hal_fence_t** fences = NULL;
+    IREE_VM_ABI_VLA_STACK_DEREF(args, a1_count, a1, iree_hal_fence, 32,
+                                &fence_count, &fences);
 
-  // Capture absolute timeout so that regardless of how long it takes us to wait
-  // the user-perceived wait time remains the same.
-  iree_timeout_t timeout = iree_make_timeout_ms(timeout_millis);
-  iree_convert_timeout_to_absolute(&timeout);
+    IREE_TRACE_ZONE_BEGIN(z0);
+    zone_id = z0;
 
-  // Wait on each fence in-turn.
-  // TODO(benvanik): use a stack wait frame and expand all fences into their
-  // individual timepoint wait sources. This will allow the loop to perform a
-  // multi-wait without needing to materialize intermediate wait primitives
-  // which may not be possible across devices.
-  iree_status_t status = iree_ok_status();
-  for (iree_host_size_t i = 0; i < fence_count; ++i) {
-    status = iree_hal_fence_wait(fences[i], timeout);
-    if (!iree_status_is_ok(status)) break;
+    // Capture absolute timeout so that regardless of how long it takes us to
+    // wait the user-perceived wait time remains the same.
+    iree_timeout_t timeout = timeout_millis == UINT32_MAX
+                                 ? iree_infinite_timeout()
+                                 : iree_make_timeout_ms(timeout_millis);
+    iree_convert_timeout_to_absolute(&timeout);
+
+    // Remove any fences that have been reached and check for failure.
+    IREE_RETURN_AND_END_ZONE_IF_ERROR(
+        zone_id, iree_hal_module_fence_elide_reached(&fence_count, fences));
+
+    // If all fences have been reached we can exit early as if we waited
+    // successfully.
+    if (fence_count > 0) {
+      if (iree_all_bits_set(state->flags, IREE_HAL_MODULE_FLAG_SYNCHRONOUS)) {
+        // Block the native thread until the fence is reached or the deadline is
+        // exceeded.
+        for (iree_host_size_t i = 0; i < fence_count; ++i) {
+          wait_status = iree_hal_fence_wait(fences[i], timeout);
+          if (!iree_status_is_ok(wait_status)) break;
+        }
+      } else {
+        IREE_RETURN_AND_END_ZONE_IF_ERROR(
+            zone_id,
+            iree_hal_module_fence_await_begin(stack, fence_count, fences,
+                                              timeout, zone_id, &wait_status));
+        current_frame->pc = IREE_HAL_MODULE_FENCE_AWAIT_PC_RESUME;
+        if (iree_status_is_deferred(wait_status)) {
+          zone_id = 0;  // ownership transferred to wait frame
+        }
+      }
+    }
+  } else {
+    // Resume by leaving the wait frame and storing the result.
+    iree_vm_wait_result_t wait_result;
+    IREE_RETURN_IF_ERROR(iree_vm_stack_wait_leave(stack, &wait_result));
+    wait_status = wait_result.status;
+    IREE_TRACE(zone_id = wait_result.trace_zone);
   }
 
-  if (iree_status_is_ok(status)) {
+  iree_status_t status = iree_ok_status();
+  if (iree_status_is_ok(wait_status)) {
     // Successful wait.
     rets->i0 = 0;
-    return iree_ok_status();
-  } else if (iree_status_is_deadline_exceeded(status)) {
+  } else if (iree_status_is_deferred(wait_status)) {
+    // Yielding; resume required.
+    // NOTE: zone not ended as it's reserved on the stack.
+    status = wait_status;
+  } else if (iree_status_is_deadline_exceeded(wait_status)) {
     // Propagate deadline exceeded back to the VM.
-    rets->i0 = (int32_t)iree_status_consume_code(status);
-    iree_status_ignore(status);
-    return iree_ok_status();
+    rets->i0 = (int32_t)iree_status_consume_code(wait_status);
+    iree_status_ignore(wait_status);
+  } else {
+    // Fail the invocation.
+    status = wait_status;
   }
 
-  // Fail the invocation.
+  IREE_TRACE({
+    if (zone_id) IREE_TRACE_ZONE_END(zone_id);
+  });
   return status;
 }
 
@@ -1425,7 +1584,6 @@
   iree_hal_device_t* device = NULL;
   IREE_RETURN_IF_ERROR(iree_hal_device_check_deref(args->r0, &device));
   uint64_t initial_value = (uint64_t)args->i1;
-
   iree_hal_semaphore_t* semaphore = NULL;
   IREE_RETURN_IF_ERROR(
       iree_hal_semaphore_create(device, initial_value, &semaphore));
@@ -1438,11 +1596,12 @@
                    r, iI) {
   iree_hal_semaphore_t* semaphore = NULL;
   IREE_RETURN_IF_ERROR(iree_hal_semaphore_check_deref(args->r0, &semaphore));
-
-  uint64_t value = 0;
-  iree_status_t query_status = iree_hal_semaphore_query(semaphore, &value);
+  uint64_t current_value = 0;
+  iree_status_t query_status =
+      iree_hal_semaphore_query(semaphore, &current_value);
   rets->i0 = iree_status_consume_code(query_status);
-  rets->i1 = value;
+  rets->i1 = current_value;
+  iree_status_ignore(query_status);
   return iree_ok_status();
 }
 
@@ -1452,7 +1611,6 @@
   iree_hal_semaphore_t* semaphore = NULL;
   IREE_RETURN_IF_ERROR(iree_hal_semaphore_check_deref(args->r0, &semaphore));
   uint64_t new_value = (uint64_t)args->i1;
-
   return iree_hal_semaphore_signal(semaphore, new_value);
 }
 
@@ -1463,34 +1621,102 @@
   IREE_RETURN_IF_ERROR(iree_hal_semaphore_check_deref(args->r0, &semaphore));
   iree_status_code_t status_code =
       (iree_status_code_t)(args->i1 & IREE_STATUS_CODE_MASK);
-
   iree_hal_semaphore_fail(semaphore, iree_make_status(status_code));
   return iree_ok_status();
 }
 
+// PC for iree_hal_module_semaphore_await.
+enum iree_hal_module_semaphore_await_pc_e {
+  // Initial entry point that will try to either wait inline or yield to the
+  // scheduler with a wait-all operation.
+  IREE_HAL_MODULE_SEMAPHORE_AWAIT_PC_BEGIN = 0,
+  // Resume entry point after the scheduler wait has resolved (successfully or
+  // otherwise).
+  IREE_HAL_MODULE_SEMAPHORE_AWAIT_PC_RESUME,
+};
+
 IREE_VM_ABI_EXPORT(iree_hal_module_semaphore_await,  //
                    iree_hal_module_state_t,          //
                    rI, i) {
-  iree_hal_semaphore_t* semaphore = NULL;
-  IREE_RETURN_IF_ERROR(iree_hal_semaphore_check_deref(args->r0, &semaphore));
-  uint64_t new_value = (uint64_t)args->i1;
+  // On entry we either perform the wait or begin a coroutine yield operation.
+  // After resuming we check to see if the timepoint has been reached and
+  // propagate the result.
+  iree_vm_stack_frame_t* current_frame = iree_vm_stack_top(stack);
+  iree_zone_id_t zone_id = 0;
+  iree_status_t wait_status = iree_ok_status();
+  if (current_frame->pc == IREE_HAL_MODULE_SEMAPHORE_AWAIT_PC_BEGIN) {
+    iree_hal_semaphore_t* semaphore = NULL;
+    IREE_RETURN_IF_ERROR(iree_hal_semaphore_check_deref(args->r0, &semaphore));
+    uint64_t new_value = (uint64_t)args->i1;
 
-  // TODO(benvanik): coroutine magic.
-  iree_status_t status =
-      iree_hal_semaphore_wait(semaphore, new_value, iree_infinite_timeout());
+    IREE_TRACE_ZONE_BEGIN(z0);
+    zone_id = z0;
 
-  if (iree_status_is_ok(status)) {
-    // Successful wait.
-    rets->i0 = 0;
-    return iree_ok_status();
-  } else if (iree_status_is_deadline_exceeded(status)) {
-    // Propagate deadline exceeded back to the VM.
-    rets->i0 = (int32_t)iree_status_consume_code(status);
-    iree_status_ignore(status);
-    return iree_ok_status();
+    // TODO(benvanik): take timeout as an argument.
+    // Capture absolute timeout so that regardless of how long it takes us to
+    // wait the user-perceived wait time remains the same.
+    iree_timeout_t timeout = iree_infinite_timeout();
+    iree_convert_timeout_to_absolute(&timeout);
+
+    if (iree_all_bits_set(state->flags, IREE_HAL_MODULE_FLAG_SYNCHRONOUS)) {
+      // Block the native thread until the fence is reached or the deadline is
+      // exceeded.
+      wait_status = iree_hal_semaphore_wait(semaphore, new_value, timeout);
+    } else {
+      // Quick check inline before yielding to the scheduler. This avoids a
+      // round-trip through the scheduling stack for cases where we complete
+      // synchronously.
+      //
+      // The query may fail to indicate that the semaphore is in a failure
+      // state and we propagate the failure status to the waiter.
+      //
+      // It's possible to race here if we get back an older value and then
+      // before we wait the target is reached but that's ok: the wait will
+      // always be correctly ordered.
+      uint64_t current_value = 0ull;
+      wait_status = iree_hal_semaphore_query(semaphore, &current_value);
+      if (iree_status_is_ok(wait_status) && current_value < new_value) {
+        // Enter a wait frame and yield execution back to the scheduler.
+        // When the wait handle resolves we'll resume at the RESUME PC.
+        iree_vm_wait_frame_t* wait_frame = NULL;
+        IREE_RETURN_AND_END_ZONE_IF_ERROR(
+            zone_id, iree_vm_stack_wait_enter(stack, IREE_VM_WAIT_ALL, 1,
+                                              timeout, zone_id, &wait_frame));
+        wait_frame->wait_sources[0] =
+            iree_hal_semaphore_await(semaphore, new_value);
+        current_frame->pc = IREE_HAL_MODULE_SEMAPHORE_AWAIT_PC_RESUME;
+        wait_status = iree_status_from_code(IREE_STATUS_DEFERRED);
+        zone_id = 0;  // ownership transferred to wait frame
+      }
+    }
+  } else {
+    // Resume by leaving the wait frame and storing the result.
+    iree_vm_wait_result_t wait_result;
+    IREE_RETURN_IF_ERROR(iree_vm_stack_wait_leave(stack, &wait_result));
+    wait_status = wait_result.status;
+    IREE_TRACE(zone_id = wait_result.trace_zone);
   }
 
-  // Fail the invocation.
+  iree_status_t status = iree_ok_status();
+  if (iree_status_is_ok(wait_status)) {
+    // Successful wait.
+    rets->i0 = 0;
+  } else if (iree_status_is_deferred(wait_status)) {
+    // Yielding; resume required.
+    // NOTE: zone not ended as it's reserved on the stack.
+    status = wait_status;
+  } else if (iree_status_is_deadline_exceeded(wait_status)) {
+    // Propagate deadline exceeded back to the VM.
+    rets->i0 = (int32_t)iree_status_consume_code(wait_status);
+    iree_status_ignore(wait_status);
+  } else {
+    // Fail the invocation.
+    status = wait_status;
+  }
+
+  IREE_TRACE({
+    if (zone_id) IREE_TRACE_ZONE_END(zone_id);
+  });
   return status;
 }