Implementing coroutine semaphore/fence awaits. Progress on #8093.
diff --git a/runtime/src/iree/modules/hal/module.c b/runtime/src/iree/modules/hal/module.c index eada4af..939a38a 100644 --- a/runtime/src/iree/modules/hal/module.c +++ b/runtime/src/iree/modules/hal/module.c
@@ -1375,43 +1375,202 @@ return iree_ok_status(); } +// Removes entries in |fences| if they have been reached. +// Returns failure if one or more fences have failed. +static iree_status_t iree_hal_module_fence_elide_reached( + iree_host_size_t* fence_count, iree_hal_fence_t** fences) { + iree_host_size_t new_count = *fence_count; + for (iree_host_size_t i = 0; i < new_count;) { + iree_status_t status = iree_hal_fence_query(fences[i]); + if (iree_status_is_ok(status)) { + // Has been reached; shift the list down. + memmove(&fences[i], &fences[i + 1], + (new_count - i - 1) * sizeof(iree_hal_fence_t*)); + fences[new_count - 1] = NULL; + --new_count; + } else if (iree_status_is_deferred(status)) { + // Still waiting. + iree_status_ignore(status); + ++i; // next + } else { + // Failed; propagate failure. + *fence_count = new_count; + return status; + } + } + *fence_count = new_count; + return iree_ok_status(); +} + +// Enters a wait frame for all timepoints in all |fences|. +// Returns an |out_wait_status| of OK if all fences have been reached or +// IREE_STATUS_DEFERRED if one or more fences are still pending and a wait +// frame was entered. +static iree_status_t iree_hal_module_fence_await_begin( + iree_vm_stack_t* stack, iree_host_size_t fence_count, + iree_hal_fence_t** fences, iree_timeout_t timeout, iree_zone_id_t zone_id, + iree_status_t* out_wait_status) { + // To avoid additional allocations when waiting on multiple fences we enter + // the wait frame with the maximum required wait source capacity and perform + // a simple deduplication when building the list. Ideally this helps get us on + // fast paths of single semaphore waits. The common case is a single fence in + // which case this is all exceptional. + iree_host_size_t total_timepoint_capacity = 0; + for (iree_host_size_t i = 0; i < fence_count; ++i) { + total_timepoint_capacity += iree_hal_fence_timepoint_count(fences[i]); + } + + // Fast-path for no semaphores (empty/immediate fences). + if (total_timepoint_capacity == 0) { + *out_wait_status = iree_ok_status(); + IREE_TRACE_ZONE_END(zone_id); + return iree_ok_status(); + } + + // Reserve storage as if all timepoints from all fences were unique. + iree_vm_wait_frame_t* wait_frame = NULL; + IREE_RETURN_IF_ERROR(iree_vm_stack_wait_enter(stack, IREE_VM_WAIT_ALL, + total_timepoint_capacity, + timeout, zone_id, &wait_frame)); + + // Insert the first set of timepoints - they're already deduplicated. + iree_host_size_t unique_timepoint_count = 0; + if (fence_count >= 1) { + iree_hal_semaphore_list_t semaphore_list = + iree_hal_fence_semaphore_list(fences[0]); + for (iree_host_size_t i = 0; i < semaphore_list.count; ++i) { + iree_wait_source_t wait_source = iree_hal_semaphore_await( + semaphore_list.semaphores[i], semaphore_list.payload_values[i]); + wait_frame->wait_sources[unique_timepoint_count++] = wait_source; + } + } + + // TODO(benvanik): simplify this; it may not be worth the complexity. We'll + // need more real workloads using multi-fence joins to see how useful this is. + + // Insert remaining fence timepoints by performing merging as we go. + for (iree_host_size_t i = 1; i < fence_count; ++i) { + iree_hal_semaphore_list_t semaphore_list = + iree_hal_fence_semaphore_list(fences[i]); + for (iree_host_size_t j = 0; j < semaphore_list.count; ++j) { + // O(n^2) set insertion - relying on this being rare and the total count + // being low. The savings of a small linear scan here relative to an + // additional syscall are always worth it but we may want to go further. + iree_wait_source_t wait_source = iree_hal_semaphore_await( + semaphore_list.semaphores[j], semaphore_list.payload_values[j]); + bool found_existing = false; + for (iree_host_size_t k = 0; k < unique_timepoint_count; ++k) { + if (wait_frame->wait_sources[k].ctl == wait_source.ctl && + wait_frame->wait_sources[k].self == wait_source.self) { + // Found existing; use max of both. + wait_frame->wait_sources[k].data = + iree_max(wait_frame->wait_sources[k].data, wait_source.data); + found_existing = true; + break; + } + } + if (!found_existing) { + wait_frame->wait_sources[unique_timepoint_count++] = wait_source; + } + } + } + + // Update frame with the actual number of timepoints in the wait operation. + wait_frame->count = unique_timepoint_count; + + *out_wait_status = iree_status_from_code(IREE_STATUS_DEFERRED); + return iree_ok_status(); +} + +// PC for iree_hal_module_fence_await. +enum iree_hal_module_fence_await_pc_e { + // Initial entry point that will try to either wait inline or yield to the + // scheduler with a wait-all operation. + IREE_HAL_MODULE_FENCE_AWAIT_PC_BEGIN = 0, + // Resume entry point after the scheduler wait has resolved (successfully or + // otherwise). + IREE_HAL_MODULE_FENCE_AWAIT_PC_RESUME, +}; + IREE_VM_ABI_EXPORT(iree_hal_module_fence_await, // iree_hal_module_state_t, // iCrD, i) { - uint32_t timeout_millis = (uint32_t)args->i0; - iree_host_size_t fence_count = 0; - iree_hal_fence_t** fences = NULL; - IREE_VM_ABI_VLA_STACK_DEREF(args, a1_count, a1, iree_hal_fence, 32, - &fence_count, &fences); + // On entry we either perform the wait or begin a coroutine yield operation. + // After resuming we check to see if the fence has been reached and propagate + // the result. + iree_vm_stack_frame_t* current_frame = iree_vm_stack_top(stack); + iree_zone_id_t zone_id = 0; + iree_status_t wait_status = iree_ok_status(); + if (current_frame->pc == IREE_HAL_MODULE_FENCE_AWAIT_PC_BEGIN) { + uint32_t timeout_millis = (uint32_t)args->i0; + iree_host_size_t fence_count = 0; + iree_hal_fence_t** fences = NULL; + IREE_VM_ABI_VLA_STACK_DEREF(args, a1_count, a1, iree_hal_fence, 32, + &fence_count, &fences); - // Capture absolute timeout so that regardless of how long it takes us to wait - // the user-perceived wait time remains the same. - iree_timeout_t timeout = iree_make_timeout_ms(timeout_millis); - iree_convert_timeout_to_absolute(&timeout); + IREE_TRACE_ZONE_BEGIN(z0); + zone_id = z0; - // Wait on each fence in-turn. - // TODO(benvanik): use a stack wait frame and expand all fences into their - // individual timepoint wait sources. This will allow the loop to perform a - // multi-wait without needing to materialize intermediate wait primitives - // which may not be possible across devices. - iree_status_t status = iree_ok_status(); - for (iree_host_size_t i = 0; i < fence_count; ++i) { - status = iree_hal_fence_wait(fences[i], timeout); - if (!iree_status_is_ok(status)) break; + // Capture absolute timeout so that regardless of how long it takes us to + // wait the user-perceived wait time remains the same. + iree_timeout_t timeout = timeout_millis == UINT32_MAX + ? iree_infinite_timeout() + : iree_make_timeout_ms(timeout_millis); + iree_convert_timeout_to_absolute(&timeout); + + // Remove any fences that have been reached and check for failure. + IREE_RETURN_AND_END_ZONE_IF_ERROR( + zone_id, iree_hal_module_fence_elide_reached(&fence_count, fences)); + + // If all fences have been reached we can exit early as if we waited + // successfully. + if (fence_count > 0) { + if (iree_all_bits_set(state->flags, IREE_HAL_MODULE_FLAG_SYNCHRONOUS)) { + // Block the native thread until the fence is reached or the deadline is + // exceeded. + for (iree_host_size_t i = 0; i < fence_count; ++i) { + wait_status = iree_hal_fence_wait(fences[i], timeout); + if (!iree_status_is_ok(wait_status)) break; + } + } else { + IREE_RETURN_AND_END_ZONE_IF_ERROR( + zone_id, + iree_hal_module_fence_await_begin(stack, fence_count, fences, + timeout, zone_id, &wait_status)); + current_frame->pc = IREE_HAL_MODULE_FENCE_AWAIT_PC_RESUME; + if (iree_status_is_deferred(wait_status)) { + zone_id = 0; // ownership transferred to wait frame + } + } + } + } else { + // Resume by leaving the wait frame and storing the result. + iree_vm_wait_result_t wait_result; + IREE_RETURN_IF_ERROR(iree_vm_stack_wait_leave(stack, &wait_result)); + wait_status = wait_result.status; + IREE_TRACE(zone_id = wait_result.trace_zone); } - if (iree_status_is_ok(status)) { + iree_status_t status = iree_ok_status(); + if (iree_status_is_ok(wait_status)) { // Successful wait. rets->i0 = 0; - return iree_ok_status(); - } else if (iree_status_is_deadline_exceeded(status)) { + } else if (iree_status_is_deferred(wait_status)) { + // Yielding; resume required. + // NOTE: zone not ended as it's reserved on the stack. + status = wait_status; + } else if (iree_status_is_deadline_exceeded(wait_status)) { // Propagate deadline exceeded back to the VM. - rets->i0 = (int32_t)iree_status_consume_code(status); - iree_status_ignore(status); - return iree_ok_status(); + rets->i0 = (int32_t)iree_status_consume_code(wait_status); + iree_status_ignore(wait_status); + } else { + // Fail the invocation. + status = wait_status; } - // Fail the invocation. + IREE_TRACE({ + if (zone_id) IREE_TRACE_ZONE_END(zone_id); + }); return status; } @@ -1425,7 +1584,6 @@ iree_hal_device_t* device = NULL; IREE_RETURN_IF_ERROR(iree_hal_device_check_deref(args->r0, &device)); uint64_t initial_value = (uint64_t)args->i1; - iree_hal_semaphore_t* semaphore = NULL; IREE_RETURN_IF_ERROR( iree_hal_semaphore_create(device, initial_value, &semaphore)); @@ -1438,11 +1596,12 @@ r, iI) { iree_hal_semaphore_t* semaphore = NULL; IREE_RETURN_IF_ERROR(iree_hal_semaphore_check_deref(args->r0, &semaphore)); - - uint64_t value = 0; - iree_status_t query_status = iree_hal_semaphore_query(semaphore, &value); + uint64_t current_value = 0; + iree_status_t query_status = + iree_hal_semaphore_query(semaphore, ¤t_value); rets->i0 = iree_status_consume_code(query_status); - rets->i1 = value; + rets->i1 = current_value; + iree_status_ignore(query_status); return iree_ok_status(); } @@ -1452,7 +1611,6 @@ iree_hal_semaphore_t* semaphore = NULL; IREE_RETURN_IF_ERROR(iree_hal_semaphore_check_deref(args->r0, &semaphore)); uint64_t new_value = (uint64_t)args->i1; - return iree_hal_semaphore_signal(semaphore, new_value); } @@ -1463,34 +1621,102 @@ IREE_RETURN_IF_ERROR(iree_hal_semaphore_check_deref(args->r0, &semaphore)); iree_status_code_t status_code = (iree_status_code_t)(args->i1 & IREE_STATUS_CODE_MASK); - iree_hal_semaphore_fail(semaphore, iree_make_status(status_code)); return iree_ok_status(); } +// PC for iree_hal_module_semaphore_await. +enum iree_hal_module_semaphore_await_pc_e { + // Initial entry point that will try to either wait inline or yield to the + // scheduler with a wait-all operation. + IREE_HAL_MODULE_SEMAPHORE_AWAIT_PC_BEGIN = 0, + // Resume entry point after the scheduler wait has resolved (successfully or + // otherwise). + IREE_HAL_MODULE_SEMAPHORE_AWAIT_PC_RESUME, +}; + IREE_VM_ABI_EXPORT(iree_hal_module_semaphore_await, // iree_hal_module_state_t, // rI, i) { - iree_hal_semaphore_t* semaphore = NULL; - IREE_RETURN_IF_ERROR(iree_hal_semaphore_check_deref(args->r0, &semaphore)); - uint64_t new_value = (uint64_t)args->i1; + // On entry we either perform the wait or begin a coroutine yield operation. + // After resuming we check to see if the timepoint has been reached and + // propagate the result. + iree_vm_stack_frame_t* current_frame = iree_vm_stack_top(stack); + iree_zone_id_t zone_id = 0; + iree_status_t wait_status = iree_ok_status(); + if (current_frame->pc == IREE_HAL_MODULE_SEMAPHORE_AWAIT_PC_BEGIN) { + iree_hal_semaphore_t* semaphore = NULL; + IREE_RETURN_IF_ERROR(iree_hal_semaphore_check_deref(args->r0, &semaphore)); + uint64_t new_value = (uint64_t)args->i1; - // TODO(benvanik): coroutine magic. - iree_status_t status = - iree_hal_semaphore_wait(semaphore, new_value, iree_infinite_timeout()); + IREE_TRACE_ZONE_BEGIN(z0); + zone_id = z0; - if (iree_status_is_ok(status)) { - // Successful wait. - rets->i0 = 0; - return iree_ok_status(); - } else if (iree_status_is_deadline_exceeded(status)) { - // Propagate deadline exceeded back to the VM. - rets->i0 = (int32_t)iree_status_consume_code(status); - iree_status_ignore(status); - return iree_ok_status(); + // TODO(benvanik): take timeout as an argument. + // Capture absolute timeout so that regardless of how long it takes us to + // wait the user-perceived wait time remains the same. + iree_timeout_t timeout = iree_infinite_timeout(); + iree_convert_timeout_to_absolute(&timeout); + + if (iree_all_bits_set(state->flags, IREE_HAL_MODULE_FLAG_SYNCHRONOUS)) { + // Block the native thread until the fence is reached or the deadline is + // exceeded. + wait_status = iree_hal_semaphore_wait(semaphore, new_value, timeout); + } else { + // Quick check inline before yielding to the scheduler. This avoids a + // round-trip through the scheduling stack for cases where we complete + // synchronously. + // + // The query may fail to indicate that the semaphore is in a failure + // state and we propagate the failure status to the waiter. + // + // It's possible to race here if we get back an older value and then + // before we wait the target is reached but that's ok: the wait will + // always be correctly ordered. + uint64_t current_value = 0ull; + wait_status = iree_hal_semaphore_query(semaphore, ¤t_value); + if (iree_status_is_ok(wait_status) && current_value < new_value) { + // Enter a wait frame and yield execution back to the scheduler. + // When the wait handle resolves we'll resume at the RESUME PC. + iree_vm_wait_frame_t* wait_frame = NULL; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + zone_id, iree_vm_stack_wait_enter(stack, IREE_VM_WAIT_ALL, 1, + timeout, zone_id, &wait_frame)); + wait_frame->wait_sources[0] = + iree_hal_semaphore_await(semaphore, new_value); + current_frame->pc = IREE_HAL_MODULE_SEMAPHORE_AWAIT_PC_RESUME; + wait_status = iree_status_from_code(IREE_STATUS_DEFERRED); + zone_id = 0; // ownership transferred to wait frame + } + } + } else { + // Resume by leaving the wait frame and storing the result. + iree_vm_wait_result_t wait_result; + IREE_RETURN_IF_ERROR(iree_vm_stack_wait_leave(stack, &wait_result)); + wait_status = wait_result.status; + IREE_TRACE(zone_id = wait_result.trace_zone); } - // Fail the invocation. + iree_status_t status = iree_ok_status(); + if (iree_status_is_ok(wait_status)) { + // Successful wait. + rets->i0 = 0; + } else if (iree_status_is_deferred(wait_status)) { + // Yielding; resume required. + // NOTE: zone not ended as it's reserved on the stack. + status = wait_status; + } else if (iree_status_is_deadline_exceeded(wait_status)) { + // Propagate deadline exceeded back to the VM. + rets->i0 = (int32_t)iree_status_consume_code(wait_status); + iree_status_ignore(wait_status); + } else { + // Fail the invocation. + status = wait_status; + } + + IREE_TRACE({ + if (zone_id) IREE_TRACE_ZONE_END(zone_id); + }); return status; }