blob: a65adbd6d7f49cb299b4b6bb981841d62907e254 [file] [log] [blame]
// Copyright 2022 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree/hal/fence.h"
#include <stddef.h>
//===----------------------------------------------------------------------===//
// iree_hal_fence_t
//===----------------------------------------------------------------------===//
IREE_API_EXPORT iree_status_t iree_hal_fence_create(
iree_host_size_t capacity, iree_allocator_t host_allocator,
iree_hal_fence_t** out_fence) {
IREE_ASSERT_ARGUMENT(out_fence);
*out_fence = NULL;
if (IREE_UNLIKELY(capacity >= UINT16_MAX)) {
return iree_make_status(
IREE_STATUS_RESOURCE_EXHAUSTED,
"capacity %" PRIhsz " is too large for fence storage", capacity);
}
IREE_TRACE_ZONE_BEGIN(z0);
const iree_host_size_t semaphore_base = iree_host_align(
sizeof(iree_hal_fence_t), iree_alignof(iree_hal_semaphore_t*));
const iree_host_size_t value_base =
iree_host_align(semaphore_base + capacity * sizeof(iree_hal_semaphore_t*),
iree_alignof(uint64_t));
const iree_host_size_t total_size = value_base + capacity * sizeof(uint64_t);
iree_hal_fence_t* fence = NULL;
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_allocator_malloc(host_allocator, total_size, (void**)&fence));
iree_atomic_ref_count_init(&fence->ref_count);
fence->host_allocator = host_allocator;
fence->capacity = (uint16_t)capacity;
fence->count = 0;
*out_fence = fence;
IREE_TRACE_ZONE_END(z0);
return iree_ok_status();
}
IREE_API_EXPORT iree_status_t iree_hal_fence_create_at(
iree_hal_semaphore_t* semaphore, uint64_t value,
iree_allocator_t host_allocator, iree_hal_fence_t** out_fence) {
IREE_ASSERT_ARGUMENT(semaphore);
IREE_ASSERT_ARGUMENT(out_fence);
*out_fence = NULL;
IREE_TRACE_ZONE_BEGIN(z0);
iree_hal_fence_t* fence = NULL;
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_hal_fence_create(1, host_allocator, &fence));
iree_status_t status = iree_hal_fence_insert(fence, semaphore, value);
if (iree_status_is_ok(status)) {
*out_fence = fence;
} else {
iree_hal_fence_release(fence);
}
IREE_TRACE_ZONE_END(z0);
return status;
}
// TODO(benvanik): actually join efficiently. Today we just create a fence that
// can hold the worst-case sum of all fence timepoints and then insert but it
// could be made much better. In most cases the joined fences have a near
// perfect overlap of semaphores and we are wasting memory.
IREE_API_EXPORT iree_status_t iree_hal_fence_join(
iree_host_size_t fence_count, iree_hal_fence_t** fences,
iree_allocator_t host_allocator, iree_hal_fence_t** out_fence) {
IREE_ASSERT_ARGUMENT(out_fence);
*out_fence = NULL;
IREE_TRACE_ZONE_BEGIN(z0);
// Find the maximum required timepoint capacity.
iree_host_size_t total_count = 0;
for (iree_host_size_t i = 0; i < fence_count; ++i) {
if (fences[i]) total_count += fences[i]->count;
}
// Empty list -> NULL.
if (!total_count) {
IREE_TRACE_ZONE_END(z0);
return NULL;
}
// Create the fence with the maximum capacity.
iree_hal_fence_t* fence = NULL;
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_hal_fence_create(total_count, host_allocator, &fence));
// Insert all timepoints from all fences.
iree_status_t status = iree_ok_status();
for (iree_host_size_t i = 0; i < fence_count; ++i) {
iree_hal_semaphore_list_t source_list =
iree_hal_fence_semaphore_list(fences[i]);
for (iree_host_size_t j = 0; j < source_list.count; ++j) {
status = iree_hal_fence_insert(fence, source_list.semaphores[j],
source_list.payload_values[j]);
if (!iree_status_is_ok(status)) break;
}
if (!iree_status_is_ok(status)) break;
}
if (iree_status_is_ok(status)) {
*out_fence = fence;
} else {
iree_hal_fence_release(fence);
}
IREE_TRACE_ZONE_END(z0);
return status;
}
IREE_API_EXPORT void iree_hal_fence_destroy(iree_hal_fence_t* fence) {
IREE_ASSERT_ARGUMENT(fence);
IREE_ASSERT_REF_COUNT_ZERO(&fence->ref_count);
IREE_TRACE_ZONE_BEGIN(z0);
iree_allocator_t host_allocator = fence->host_allocator;
iree_hal_semaphore_list_t list = iree_hal_fence_semaphore_list(fence);
for (iree_host_size_t i = 0; i < list.count; ++i) {
iree_hal_semaphore_release(list.semaphores[i]);
}
iree_allocator_free(host_allocator, fence);
IREE_TRACE_ZONE_END(z0);
}
IREE_API_EXPORT void iree_hal_fence_retain(iree_hal_fence_t* fence) {
if (IREE_LIKELY(fence)) {
iree_atomic_ref_count_inc(&fence->ref_count);
}
}
IREE_API_EXPORT void iree_hal_fence_release(iree_hal_fence_t* fence) {
if (IREE_LIKELY(fence) && iree_atomic_ref_count_dec(&fence->ref_count) == 1) {
iree_hal_fence_destroy(fence);
}
}
IREE_API_EXPORT iree_hal_semaphore_list_t
iree_hal_fence_semaphore_list(iree_hal_fence_t* fence) {
if (!fence) {
return (iree_hal_semaphore_list_t){
.count = 0,
.semaphores = NULL,
.payload_values = NULL,
};
}
uint8_t* p = (uint8_t*)fence;
const iree_host_size_t semaphore_base = iree_host_align(
sizeof(iree_hal_fence_t), iree_alignof(iree_hal_semaphore_t*));
const iree_host_size_t value_base = iree_host_align(
semaphore_base + fence->capacity * sizeof(iree_hal_semaphore_t*),
iree_alignof(uint64_t));
return (iree_hal_semaphore_list_t){
.count = fence->count,
.semaphores = (iree_hal_semaphore_t**)(p + semaphore_base),
.payload_values = (uint64_t*)(p + value_base),
};
}
IREE_API_EXPORT iree_host_size_t
iree_hal_fence_timepoint_count(const iree_hal_fence_t* fence) {
if (!fence) return 0;
return fence->count;
}
IREE_API_EXPORT iree_status_t iree_hal_fence_insert(
iree_hal_fence_t* fence, iree_hal_semaphore_t* semaphore, uint64_t value) {
IREE_ASSERT_ARGUMENT(fence);
IREE_ASSERT_ARGUMENT(semaphore);
iree_hal_semaphore_list_t list = iree_hal_fence_semaphore_list(fence);
// Try to find an existing entry with the same semaphore.
for (iree_host_size_t i = 0; i < list.count; ++i) {
if (list.semaphores[i] == semaphore) {
// Found existing; use max of both.
list.payload_values[i] = iree_max(list.payload_values[i], value);
return iree_ok_status();
}
}
// Append to list if capacity remaining.
if (list.count >= fence->capacity) {
return iree_make_status(IREE_STATUS_RESOURCE_EXHAUSTED,
"fence unique semaphore capacity %u reached",
fence->capacity);
}
list.semaphores[list.count] = semaphore;
iree_hal_semaphore_retain(semaphore);
list.payload_values[list.count] = value;
++fence->count;
return iree_ok_status();
}
IREE_API_EXPORT iree_status_t iree_hal_fence_extend(
iree_hal_fence_t* into_fence, iree_hal_fence_t* from_fence) {
IREE_ASSERT_ARGUMENT(into_fence);
IREE_ASSERT_ARGUMENT(from_fence);
iree_hal_semaphore_list_t list = iree_hal_fence_semaphore_list(from_fence);
for (iree_host_size_t i = 0; i < list.count; ++i) {
IREE_RETURN_IF_ERROR(iree_hal_fence_insert(into_fence, list.semaphores[i],
list.payload_values[i]));
}
return iree_ok_status();
}
IREE_API_EXPORT iree_status_t iree_hal_fence_query(iree_hal_fence_t* fence) {
if (!fence) return iree_ok_status();
iree_hal_semaphore_list_t semaphore_list =
iree_hal_fence_semaphore_list(fence);
for (iree_host_size_t i = 0; i < semaphore_list.count; ++i) {
uint64_t current_value = 0;
IREE_RETURN_IF_ERROR(
iree_hal_semaphore_query(semaphore_list.semaphores[i], &current_value));
if (current_value < semaphore_list.payload_values[i]) {
return iree_status_from_code(IREE_STATUS_DEFERRED);
}
}
return iree_ok_status();
}
IREE_API_EXPORT iree_status_t iree_hal_fence_signal(iree_hal_fence_t* fence) {
IREE_TRACE_ZONE_BEGIN(z0);
iree_status_t status =
iree_hal_semaphore_list_signal(iree_hal_fence_semaphore_list(fence));
IREE_TRACE_ZONE_END(z0);
return status;
}
IREE_API_EXPORT void iree_hal_fence_fail(iree_hal_fence_t* fence,
iree_status_t signal_status) {
IREE_TRACE_ZONE_BEGIN(z0);
iree_hal_semaphore_list_fail(iree_hal_fence_semaphore_list(fence),
signal_status);
IREE_TRACE_ZONE_END(z0);
}
IREE_API_EXPORT iree_status_t iree_hal_fence_wait(iree_hal_fence_t* fence,
iree_timeout_t timeout) {
if (!fence || !fence->count) return iree_ok_status();
IREE_TRACE_ZONE_BEGIN(z0);
iree_status_t status = iree_hal_semaphore_list_wait(
iree_hal_fence_semaphore_list(fence), timeout);
IREE_TRACE_ZONE_END(z0);
return status;
}
iree_status_t iree_hal_fence_wait_source_ctl(iree_wait_source_t wait_source,
iree_wait_source_command_t command,
const void* params,
void** inout_ptr) {
iree_hal_fence_t* fence = (iree_hal_fence_t*)wait_source.self;
switch (command) {
case IREE_WAIT_SOURCE_COMMAND_QUERY: {
iree_status_code_t* out_wait_status_code = (iree_status_code_t*)inout_ptr;
iree_status_t status = iree_hal_fence_query(fence);
if (!iree_status_is_ok(status)) {
*out_wait_status_code = iree_status_code(status);
iree_status_ignore(status);
} else {
*out_wait_status_code = IREE_STATUS_OK;
}
return iree_ok_status();
}
case IREE_WAIT_SOURCE_COMMAND_WAIT_ONE: {
const iree_timeout_t timeout =
((const iree_wait_source_wait_params_t*)params)->timeout;
return iree_hal_fence_wait(fence, timeout);
}
case IREE_WAIT_SOURCE_COMMAND_EXPORT: {
const iree_wait_primitive_type_t target_type =
((const iree_wait_source_export_params_t*)params)->target_type;
// TODO(benvanik): support exporting fences to real wait handles.
iree_wait_primitive_t* out_wait_primitive =
(iree_wait_primitive_t*)inout_ptr;
memset(out_wait_primitive, 0, sizeof(*out_wait_primitive));
(void)target_type;
return iree_make_status(IREE_STATUS_UNAVAILABLE,
"requested wait primitive type %d is unavailable",
(int)target_type);
}
default:
return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
"unimplemented wait_source command");
}
}
IREE_API_EXPORT iree_wait_source_t
iree_hal_fence_await(iree_hal_fence_t* fence) {
if (!fence) return iree_wait_source_immediate();
return (iree_wait_source_t){
.self = fence,
.data = 0,
.ctl = iree_hal_fence_wait_source_ctl,
};
}