| // Copyright 2020 Google LLC |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // https://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| |
| #include "iree/hal/local/task_semaphore.h" |
| |
| #include <inttypes.h> |
| |
| #include "iree/base/synchronization.h" |
| #include "iree/base/tracing.h" |
| #include "iree/base/wait_handle.h" |
| |
| // Sentinel used the semaphore has failed and an error status is set. |
| #define IREE_HAL_TASK_SEMAPHORE_FAILURE_VALUE UINT64_MAX |
| |
| //===----------------------------------------------------------------------===// |
| // iree_hal_task_timepoint_t |
| //===----------------------------------------------------------------------===// |
| |
| // Represents a point in the timeline that someone is waiting to be reached. |
| // When the semaphore is signaled to at least the specified value then the |
| // given event will be signaled and the timepoint discarded. |
| // |
| // Instances are owned and retained by the caller that requested them - usually |
| // in the arena associated with the submission, but could be on the stack of a |
| // synchronously waiting thread. |
| typedef struct iree_hal_task_timepoint_s { |
| struct iree_hal_task_timepoint_s* next; |
| struct iree_hal_task_timepoint_s* prev; |
| uint64_t payload_value; |
| iree_event_t event; |
| } iree_hal_task_timepoint_t; |
| |
| // A doubly-linked FIFO list of timepoints. |
| // The order of the timepoints does *not* match increasing payload values but |
| // instead the order they were added to the list. |
| // |
| // Note that the timepoints are not owned by the list - this just nicely |
| // stitches together timepoints for the semaphore. |
| typedef struct { |
| iree_hal_task_timepoint_t* head; |
| iree_hal_task_timepoint_t* tail; |
| } iree_hal_task_timepoint_list_t; |
| |
| static void iree_hal_task_timepoint_list_initialize( |
| iree_hal_task_timepoint_list_t* out_list) { |
| memset(out_list, 0, sizeof(*out_list)); |
| } |
| |
| // Moves |source_list| into |out_target_list|. |
| // |source_list| will be reset and the prior contents of |out_target_list| will |
| // be discarded. |
| static void iree_hal_task_timepoint_list_move( |
| iree_hal_task_timepoint_list_t* source_list, |
| iree_hal_task_timepoint_list_t* out_target_list) { |
| memcpy(out_target_list, source_list, sizeof(*out_target_list)); |
| memset(source_list, 0, sizeof(*source_list)); |
| } |
| |
| // Appends a timepoint to the end of the timepoint list. |
| static void iree_hal_task_timepoint_list_append( |
| iree_hal_task_timepoint_list_t* list, |
| iree_hal_task_timepoint_t* timepoint) { |
| timepoint->next = NULL; |
| timepoint->prev = list->tail; |
| if (list->tail != NULL) { |
| list->tail->next = timepoint; |
| list->tail = timepoint; |
| } else { |
| list->head = timepoint; |
| list->tail = timepoint; |
| } |
| } |
| |
| // Erases a timepoint from the list. |
| static void iree_hal_task_timepoint_list_erase( |
| iree_hal_task_timepoint_list_t* list, |
| iree_hal_task_timepoint_t* timepoint) { |
| if (timepoint->prev != NULL) timepoint->prev->next = timepoint->next; |
| if (timepoint == list->head) list->head = timepoint->next; |
| if (timepoint == list->tail) list->tail = timepoint->prev; |
| timepoint->prev = NULL; |
| timepoint->next = NULL; |
| } |
| |
| // Scans the |pending_list| for all timepoints that are satisfied by the |
| // timeline having reached |payload_value|. Each satisfied timepoint will be |
| // moved to |out_ready_list|. |
| static void iree_hal_task_timepoint_list_take_ready( |
| iree_hal_task_timepoint_list_t* pending_list, uint64_t payload_value, |
| iree_hal_task_timepoint_list_t* out_ready_list) { |
| iree_hal_task_timepoint_list_initialize(out_ready_list); |
| iree_hal_task_timepoint_t* next = pending_list->head; |
| while (next != NULL) { |
| iree_hal_task_timepoint_t* timepoint = next; |
| next = timepoint->next; |
| bool is_satisfied = timepoint->payload_value <= payload_value; |
| if (!is_satisfied) continue; |
| |
| // Remove from pending list. |
| iree_hal_task_timepoint_list_erase(pending_list, timepoint); |
| |
| // Add to ready list. |
| iree_hal_task_timepoint_list_append(out_ready_list, timepoint); |
| } |
| } |
| |
| // Notifies all of the timepoints in the |ready_list| that their condition has |
| // been satisfied. |ready_list| will be reset as ownership of the events is |
| // held by the originator. |
| static void iree_hal_task_timepoint_list_notify_ready( |
| iree_hal_task_timepoint_list_t* ready_list) { |
| iree_hal_task_timepoint_t* next = ready_list->head; |
| while (next != NULL) { |
| iree_hal_task_timepoint_t* timepoint = next; |
| next = timepoint->next; |
| timepoint->next = NULL; |
| timepoint->prev = NULL; |
| iree_event_set(&timepoint->event); |
| } |
| iree_hal_task_timepoint_list_initialize(ready_list); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // iree_hal_task_semaphore_t |
| //===----------------------------------------------------------------------===// |
| |
| typedef struct { |
| iree_hal_resource_t resource; |
| iree_allocator_t host_allocator; |
| iree_hal_local_event_pool_t* event_pool; |
| |
| // Guards all mutable fields. We expect low contention on semaphores and since |
| // iree_slim_mutex_t is (effectively) just a CAS this keeps things simpler |
| // than trying to make the entire structure lock-free. |
| iree_slim_mutex_t mutex; |
| |
| // Current signaled value. May be IREE_HAL_TASK_SEMAPHORE_FAILURE_VALUE to |
| // indicate that the semaphore has been signaled for failure and |
| // |failure_status| contains the error. |
| uint64_t current_value; |
| |
| // OK or the status passed to iree_hal_semaphore_fail. Owned by the semaphore. |
| iree_status_t failure_status; |
| |
| // In-process notification signaled when the semaphore value changes. This is |
| // used exclusively for wait-ones to avoid going to the kernel for a full wait |
| // handle operation. |
| iree_notification_t notification; |
| |
| // A list of all reserved timepoints waiting for the semaphore to reach a |
| // certain payload value. |
| iree_hal_task_timepoint_list_t timepoint_list; |
| } iree_hal_task_semaphore_t; |
| |
| static const iree_hal_semaphore_vtable_t iree_hal_task_semaphore_vtable; |
| |
| static iree_hal_task_semaphore_t* iree_hal_task_semaphore_cast( |
| iree_hal_semaphore_t* base_value) { |
| IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_task_semaphore_vtable); |
| return (iree_hal_task_semaphore_t*)base_value; |
| } |
| |
| iree_status_t iree_hal_task_semaphore_create( |
| iree_hal_local_event_pool_t* event_pool, uint64_t initial_value, |
| iree_allocator_t host_allocator, iree_hal_semaphore_t** out_semaphore) { |
| IREE_ASSERT_ARGUMENT(event_pool); |
| IREE_ASSERT_ARGUMENT(out_semaphore); |
| *out_semaphore = NULL; |
| IREE_TRACE_ZONE_BEGIN(z0); |
| |
| iree_hal_task_semaphore_t* semaphore = NULL; |
| iree_status_t status = iree_allocator_malloc( |
| host_allocator, sizeof(*semaphore), (void**)&semaphore); |
| if (iree_status_is_ok(status)) { |
| iree_hal_resource_initialize(&iree_hal_task_semaphore_vtable, |
| &semaphore->resource); |
| semaphore->host_allocator = host_allocator; |
| semaphore->event_pool = event_pool; |
| |
| iree_slim_mutex_initialize(&semaphore->mutex); |
| semaphore->current_value = initial_value; |
| semaphore->failure_status = iree_ok_status(); |
| iree_notification_initialize(&semaphore->notification); |
| iree_hal_task_timepoint_list_initialize(&semaphore->timepoint_list); |
| |
| *out_semaphore = (iree_hal_semaphore_t*)semaphore; |
| } |
| |
| IREE_TRACE_ZONE_END(z0); |
| return status; |
| } |
| |
| static void iree_hal_task_semaphore_destroy( |
| iree_hal_semaphore_t* base_semaphore) { |
| iree_hal_task_semaphore_t* semaphore = |
| iree_hal_task_semaphore_cast(base_semaphore); |
| iree_allocator_t host_allocator = semaphore->host_allocator; |
| IREE_TRACE_ZONE_BEGIN(z0); |
| |
| iree_status_free(semaphore->failure_status); |
| iree_notification_deinitialize(&semaphore->notification); |
| iree_allocator_free(host_allocator, semaphore); |
| |
| IREE_TRACE_ZONE_END(z0); |
| } |
| |
| static iree_status_t iree_hal_task_semaphore_query( |
| iree_hal_semaphore_t* base_semaphore, uint64_t* out_value) { |
| iree_hal_task_semaphore_t* semaphore = |
| iree_hal_task_semaphore_cast(base_semaphore); |
| |
| iree_slim_mutex_lock(&semaphore->mutex); |
| |
| *out_value = semaphore->current_value; |
| |
| iree_status_t status = iree_ok_status(); |
| if (*out_value >= IREE_HAL_TASK_SEMAPHORE_FAILURE_VALUE) { |
| status = iree_status_clone(semaphore->failure_status); |
| } |
| |
| iree_slim_mutex_unlock(&semaphore->mutex); |
| |
| return status; |
| } |
| |
| static iree_status_t iree_hal_task_semaphore_signal( |
| iree_hal_semaphore_t* base_semaphore, uint64_t new_value) { |
| iree_hal_task_semaphore_t* semaphore = |
| iree_hal_task_semaphore_cast(base_semaphore); |
| |
| iree_slim_mutex_lock(&semaphore->mutex); |
| |
| if (new_value <= semaphore->current_value) { |
| uint64_t current_value = semaphore->current_value; |
| iree_slim_mutex_unlock(&semaphore->mutex); |
| return iree_make_status(IREE_STATUS_OUT_OF_RANGE, |
| "semaphore values must be monotonically " |
| "increasing; current_value=%" PRIu64 |
| ", new_value=%" PRIu64, |
| current_value, new_value); |
| } |
| |
| semaphore->current_value = new_value; |
| |
| // Scan for all timepoints that are now satisfied and move them to our local |
| // ready list. This way we can notify them without needing to continue holding |
| // the semaphore lock. |
| iree_hal_task_timepoint_list_t ready_list; |
| iree_hal_task_timepoint_list_take_ready(&semaphore->timepoint_list, new_value, |
| &ready_list); |
| |
| iree_notification_post(&semaphore->notification, IREE_ALL_WAITERS); |
| iree_slim_mutex_unlock(&semaphore->mutex); |
| |
| // Notify all waiters - note that this must happen outside the lock. |
| iree_hal_task_timepoint_list_notify_ready(&ready_list); |
| |
| return iree_ok_status(); |
| } |
| |
| static void iree_hal_task_semaphore_fail(iree_hal_semaphore_t* base_semaphore, |
| iree_status_t status) { |
| iree_hal_task_semaphore_t* semaphore = |
| iree_hal_task_semaphore_cast(base_semaphore); |
| |
| iree_slim_mutex_lock(&semaphore->mutex); |
| |
| // Try to set our local status - we only preserve the first failure so only |
| // do this if we are going from a valid semaphore to a failed one. |
| if (!iree_status_is_ok(semaphore->failure_status)) { |
| // Previous status was not OK; drop our new status. |
| IREE_IGNORE_ERROR(status); |
| iree_slim_mutex_unlock(&semaphore->mutex); |
| return; |
| } |
| |
| // Signal to our failure sentinel value. |
| semaphore->current_value = IREE_HAL_TASK_SEMAPHORE_FAILURE_VALUE; |
| semaphore->failure_status = status; |
| |
| // Take the whole timepoint list as we'll be signaling all of them. Since |
| // we hold the lock no other timepoints can be created while we are cleaning |
| // up. |
| iree_hal_task_timepoint_list_t ready_list; |
| iree_hal_task_timepoint_list_move(&semaphore->timepoint_list, &ready_list); |
| |
| iree_notification_post(&semaphore->notification, IREE_ALL_WAITERS); |
| iree_slim_mutex_unlock(&semaphore->mutex); |
| |
| // Notify all waiters - note that this must happen outside the lock. |
| iree_hal_task_timepoint_list_notify_ready(&ready_list); |
| } |
| |
| // Acquires a timepoint waiting for the given value. |
| // |out_timepoint| is owned by the caller and must be kept live until the |
| // timepoint has been reached (or it is cancelled by the caller). |
| static iree_status_t iree_hal_task_semaphore_acquire_timepoint( |
| iree_hal_task_semaphore_t* semaphore, uint64_t minimum_value, |
| iree_hal_task_timepoint_t* out_timepoint) { |
| memset(out_timepoint, 0, sizeof(*out_timepoint)); |
| out_timepoint->payload_value = minimum_value; |
| IREE_RETURN_IF_ERROR(iree_hal_local_event_pool_acquire( |
| semaphore->event_pool, 1, &out_timepoint->event)); |
| iree_hal_task_timepoint_list_append(&semaphore->timepoint_list, |
| out_timepoint); |
| return iree_ok_status(); |
| } |
| |
| typedef struct { |
| iree_task_wait_t task; |
| iree_hal_task_semaphore_t* semaphore; |
| iree_hal_task_timepoint_t timepoint; |
| } iree_hal_task_semaphore_wait_cmd_t; |
| |
| // Cleans up a wait task by returning the event used to the pool and - if the |
| // task failed - ensuring we scrub it from the timepoint list. |
| static void iree_hal_task_semaphore_wait_cmd_cleanup(iree_task_t* task, |
| iree_status_t status) { |
| iree_hal_task_semaphore_wait_cmd_t* cmd = |
| (iree_hal_task_semaphore_wait_cmd_t*)task; |
| iree_hal_local_event_pool_release(cmd->semaphore->event_pool, 1, |
| &cmd->timepoint.event); |
| if (IREE_UNLIKELY(!iree_status_is_ok(status))) { |
| // Abort the timepoint. Note that this is not designed to be fast as |
| // semaphore failure is an exceptional case. |
| iree_slim_mutex_lock(&cmd->semaphore->mutex); |
| iree_hal_task_timepoint_list_erase(&cmd->semaphore->timepoint_list, |
| &cmd->timepoint); |
| iree_slim_mutex_unlock(&cmd->semaphore->mutex); |
| } |
| } |
| |
| iree_status_t iree_hal_task_semaphore_enqueue_timepoint( |
| iree_hal_semaphore_t* base_semaphore, uint64_t minimum_value, |
| iree_task_t* issue_task, iree_arena_allocator_t* arena, |
| iree_task_submission_t* submission) { |
| iree_hal_task_semaphore_t* semaphore = |
| iree_hal_task_semaphore_cast(base_semaphore); |
| |
| iree_slim_mutex_lock(&semaphore->mutex); |
| |
| iree_status_t status = iree_ok_status(); |
| if (semaphore->current_value >= minimum_value) { |
| // Fast path: already satisfied. |
| } else { |
| // Slow path: acquire a system wait handle and perform a full wait. |
| iree_hal_task_semaphore_wait_cmd_t* cmd = NULL; |
| status = iree_arena_allocate(arena, sizeof(*cmd), (void**)&cmd); |
| if (iree_status_is_ok(status)) { |
| status = iree_hal_task_semaphore_acquire_timepoint( |
| semaphore, minimum_value, &cmd->timepoint); |
| } |
| if (iree_status_is_ok(status)) { |
| iree_task_wait_initialize(issue_task->scope, cmd->timepoint.event, |
| &cmd->task); |
| iree_task_set_cleanup_fn(&cmd->task.header, |
| iree_hal_task_semaphore_wait_cmd_cleanup); |
| iree_task_set_completion_task(&cmd->task.header, issue_task); |
| cmd->semaphore = semaphore; |
| iree_task_submission_enqueue(submission, &cmd->task.header); |
| } |
| } |
| |
| iree_slim_mutex_unlock(&semaphore->mutex); |
| return status; |
| } |
| |
| static iree_status_t iree_hal_task_semaphore_wait_with_deadline( |
| iree_hal_semaphore_t* base_semaphore, uint64_t value, |
| iree_time_t deadline_ns) { |
| iree_hal_task_semaphore_t* semaphore = |
| iree_hal_task_semaphore_cast(base_semaphore); |
| |
| iree_slim_mutex_lock(&semaphore->mutex); |
| |
| if (semaphore->current_value >= value) { |
| // Fast path: already satisfied. |
| iree_slim_mutex_unlock(&semaphore->mutex); |
| return iree_ok_status(); |
| } else if (deadline_ns == IREE_TIME_INFINITE_PAST) { |
| // Not satisfied but a poll, so can avoid the expensive wait handle work. |
| iree_slim_mutex_unlock(&semaphore->mutex); |
| return iree_status_from_code(IREE_STATUS_DEADLINE_EXCEEDED); |
| } |
| |
| // Slow path: acquire a timepoint while we hold the lock. |
| iree_hal_task_timepoint_t timepoint; |
| iree_status_t status = |
| iree_hal_task_semaphore_acquire_timepoint(semaphore, value, &timepoint); |
| |
| iree_slim_mutex_unlock(&semaphore->mutex); |
| if (IREE_UNLIKELY(!iree_status_is_ok(status))) return status; |
| |
| // Wait until the timepoint resolves. |
| // If satisfied the timepoint is automatically cleaned up and we are done. If |
| // the deadline is reached before satisfied then we have to clean it up. |
| status = iree_wait_one(&timepoint.event, deadline_ns); |
| if (!iree_status_is_ok(status)) { |
| iree_slim_mutex_lock(&semaphore->mutex); |
| iree_hal_task_timepoint_list_erase(&semaphore->timepoint_list, &timepoint); |
| iree_slim_mutex_unlock(&semaphore->mutex); |
| } |
| iree_hal_local_event_pool_release(semaphore->event_pool, 1, &timepoint.event); |
| return status; |
| } |
| |
| static iree_status_t iree_hal_task_semaphore_wait_with_timeout( |
| iree_hal_semaphore_t* base_semaphore, uint64_t value, |
| iree_duration_t timeout_ns) { |
| return iree_hal_task_semaphore_wait_with_deadline( |
| base_semaphore, value, iree_relative_timeout_to_deadline_ns(timeout_ns)); |
| } |
| |
| iree_status_t iree_hal_task_semaphore_multi_wait( |
| iree_hal_wait_mode_t wait_mode, |
| const iree_hal_semaphore_list_t* semaphore_list, iree_time_t deadline_ns, |
| iree_hal_local_event_pool_t* event_pool, |
| iree_arena_block_pool_t* block_pool) { |
| IREE_ASSERT_ARGUMENT(semaphore_list); |
| if (semaphore_list->count == 0) { |
| return iree_ok_status(); |
| } else if (semaphore_list->count == 1) { |
| // Fast-path for a single semaphore. |
| return iree_hal_semaphore_wait_with_deadline( |
| semaphore_list->semaphores[0], semaphore_list->payload_values[0], |
| deadline_ns); |
| } |
| |
| IREE_TRACE_ZONE_BEGIN(z0); |
| |
| // Avoid heap allocations by using the device block pool for the wait set. |
| iree_arena_allocator_t arena; |
| iree_arena_initialize(block_pool, &arena); |
| iree_wait_set_t* wait_set = NULL; |
| iree_status_t status = iree_wait_set_allocate( |
| semaphore_list->count, iree_arena_allocator(&arena), &wait_set); |
| |
| // Acquire a wait handle for each semaphore timepoint we are to wait on. |
| // TODO(benvanik): flip this API around so we can batch request events from |
| // the event pool. We should be acquiring all required time points in one |
| // call. |
| iree_host_size_t timepoint_count = 0; |
| iree_hal_task_timepoint_t* timepoints = NULL; |
| iree_host_size_t total_timepoint_size = |
| semaphore_list->count * sizeof(timepoints[0]); |
| status = |
| iree_arena_allocate(&arena, total_timepoint_size, (void**)&timepoints); |
| if (iree_status_is_ok(status)) { |
| memset(timepoints, 0, total_timepoint_size); |
| for (iree_host_size_t i = 0; i < semaphore_list->count; ++i) { |
| iree_hal_task_semaphore_t* semaphore = |
| iree_hal_task_semaphore_cast(semaphore_list->semaphores[i]); |
| iree_slim_mutex_lock(&semaphore->mutex); |
| if (semaphore->current_value >= semaphore_list->payload_values[i]) { |
| // Fast path: already satisfied. |
| } else { |
| // Slow path: get a native wait handle for the timepoint. |
| iree_hal_task_timepoint_t* timepoint = &timepoints[timepoint_count++]; |
| status = iree_hal_task_semaphore_acquire_timepoint( |
| semaphore, semaphore_list->payload_values[i], timepoint); |
| if (iree_status_is_ok(status)) { |
| status = iree_wait_set_insert(wait_set, timepoint->event); |
| } |
| } |
| iree_slim_mutex_unlock(&semaphore->mutex); |
| if (!iree_status_is_ok(status)) break; |
| } |
| } |
| |
| // Perform the wait. |
| if (iree_status_is_ok(status)) { |
| if (wait_mode == IREE_HAL_WAIT_MODE_ANY) { |
| status = iree_wait_any(wait_set, deadline_ns, /*out_wake_handle=*/NULL); |
| } else { |
| status = iree_wait_all(wait_set, deadline_ns); |
| } |
| } |
| |
| if (timepoints != NULL) { |
| // TODO(benvanik): if we flip the API to multi-acquire events from the pool |
| // above then we can multi-release here too. |
| for (iree_host_size_t i = 0; i < timepoint_count; ++i) { |
| iree_hal_local_event_pool_release(event_pool, 1, &timepoints[i].event); |
| } |
| } |
| iree_wait_set_free(wait_set); |
| iree_arena_deinitialize(&arena); |
| |
| IREE_TRACE_ZONE_END(z0); |
| return status; |
| } |
| |
| static const iree_hal_semaphore_vtable_t iree_hal_task_semaphore_vtable = { |
| .destroy = iree_hal_task_semaphore_destroy, |
| .query = iree_hal_task_semaphore_query, |
| .signal = iree_hal_task_semaphore_signal, |
| .fail = iree_hal_task_semaphore_fail, |
| .wait_with_deadline = iree_hal_task_semaphore_wait_with_deadline, |
| .wait_with_timeout = iree_hal_task_semaphore_wait_with_timeout, |
| }; |