blob: 9582821fa5cc8cd5eb6f5a80dc12240b3a6626a4 [file] [log] [blame]
// Copyright 2023 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree/io/parameter_provider.h"
IREE_API_EXPORT void iree_io_parameter_provider_retain(
iree_io_parameter_provider_t* provider) {
if (IREE_LIKELY(provider)) {
iree_atomic_ref_count_inc(&provider->ref_count);
}
}
IREE_API_EXPORT void iree_io_parameter_provider_release(
iree_io_parameter_provider_t* provider) {
if (IREE_LIKELY(provider) &&
iree_atomic_ref_count_dec(&provider->ref_count) == 1) {
provider->vtable->destroy(provider);
}
}
IREE_API_EXPORT iree_status_t
iree_io_parameter_provider_notify(iree_io_parameter_provider_t* provider,
iree_io_parameter_provider_signal_t signal) {
IREE_ASSERT_ARGUMENT(provider);
IREE_TRACE_ZONE_BEGIN(z0);
IREE_TRACE({
switch (signal) {
case IREE_IO_PARAMETER_PROVIDER_SIGNAL_RESUME:
IREE_TRACE_ZONE_APPEND_TEXT(z0, "RESUME");
break;
case IREE_IO_PARAMETER_PROVIDER_SIGNAL_SUSPEND:
IREE_TRACE_ZONE_APPEND_TEXT(z0, "SUSPEND");
break;
case IREE_IO_PARAMETER_PROVIDER_SIGNAL_LOW_MEMORY:
IREE_TRACE_ZONE_APPEND_TEXT(z0, "LOW_MEMORY");
break;
default:
IREE_TRACE_ZONE_APPEND_TEXT(z0, "(unknown)");
break;
}
});
iree_status_t status = provider->vtable->notify(provider, signal);
IREE_TRACE_ZONE_END(z0);
return status;
}
IREE_API_EXPORT bool iree_io_parameter_provider_query_support(
iree_io_parameter_provider_t* provider, iree_string_view_t scope) {
IREE_ASSERT_ARGUMENT(provider);
return provider->vtable->query_support(provider, scope);
}
IREE_API_EXPORT iree_status_t iree_io_parameter_provider_load(
iree_io_parameter_provider_t* provider, iree_hal_device_t* device,
iree_hal_queue_affinity_t queue_affinity,
const iree_hal_semaphore_list_t wait_semaphore_list,
const iree_hal_semaphore_list_t signal_semaphore_list,
iree_string_view_t source_scope, iree_hal_buffer_params_t target_params,
iree_host_size_t count, iree_io_parameter_enumerator_t enumerator,
iree_io_parameter_emitter_t emitter) {
IREE_ASSERT_ARGUMENT(provider);
IREE_TRACE_ZONE_BEGIN(z0);
iree_status_t status = provider->vtable->load(
provider, device, queue_affinity, wait_semaphore_list,
signal_semaphore_list, source_scope, target_params, count, enumerator,
emitter);
IREE_TRACE_ZONE_END(z0);
return status;
}
typedef struct {
iree_string_view_t key;
iree_io_parameter_span_t span;
} iree_io_parameter_provider_single_enumerator_state_t;
static iree_status_t iree_io_parameter_provider_single_enumerator(
void* user_data, iree_host_size_t i, iree_string_view_t* out_key,
iree_io_parameter_span_t* out_span) {
IREE_ASSERT_EQ(i, 0);
iree_io_parameter_provider_single_enumerator_state_t* state =
(iree_io_parameter_provider_single_enumerator_state_t*)user_data;
*out_key = state->key;
*out_span = state->span;
return iree_ok_status();
}
IREE_API_EXPORT iree_status_t iree_io_parameter_provider_read(
iree_io_parameter_provider_t* provider, iree_hal_device_t* device,
iree_hal_queue_affinity_t queue_affinity,
const iree_hal_semaphore_list_t wait_semaphore_list,
const iree_hal_semaphore_list_t signal_semaphore_list,
iree_string_view_t source_scope, iree_string_view_t source_key,
uint64_t source_offset, iree_hal_buffer_t* target_buffer,
iree_device_size_t target_offset, iree_device_size_t length) {
IREE_ASSERT_ARGUMENT(provider);
IREE_ASSERT_ARGUMENT(target_buffer);
IREE_TRACE_ZONE_BEGIN(z0);
iree_io_parameter_provider_single_enumerator_state_t enumerator_state = {
.key = source_key,
.span =
{
.parameter_offset = source_offset,
.buffer_offset = target_offset,
.length = length,
},
};
iree_io_parameter_enumerator_t enumerator = {
.fn = iree_io_parameter_provider_single_enumerator,
.user_data = &enumerator_state,
};
iree_status_t status = provider->vtable->gather(
provider, device, queue_affinity, wait_semaphore_list,
signal_semaphore_list, source_scope, target_buffer, 1, enumerator);
IREE_TRACE_ZONE_END(z0);
return status;
}
IREE_API_EXPORT iree_status_t iree_io_parameter_provider_write(
iree_io_parameter_provider_t* provider, iree_hal_device_t* device,
iree_hal_queue_affinity_t queue_affinity,
const iree_hal_semaphore_list_t wait_semaphore_list,
const iree_hal_semaphore_list_t signal_semaphore_list,
iree_hal_buffer_t* source_buffer, iree_device_size_t source_offset,
iree_string_view_t target_scope, iree_string_view_t target_key,
uint64_t target_offset, iree_device_size_t length) {
IREE_ASSERT_ARGUMENT(provider);
IREE_ASSERT_ARGUMENT(source_buffer);
IREE_TRACE_ZONE_BEGIN(z0);
iree_io_parameter_provider_single_enumerator_state_t enumerator_state = {
.key = target_key,
.span =
{
.parameter_offset = target_offset,
.buffer_offset = source_offset,
.length = length,
},
};
iree_io_parameter_enumerator_t enumerator = {
.fn = iree_io_parameter_provider_single_enumerator,
.user_data = &enumerator_state,
};
iree_status_t status = provider->vtable->scatter(
provider, device, queue_affinity, wait_semaphore_list,
signal_semaphore_list, source_buffer, target_scope, 1, enumerator);
IREE_TRACE_ZONE_END(z0);
return status;
}
IREE_API_EXPORT iree_status_t iree_io_parameter_provider_gather(
iree_io_parameter_provider_t* provider, iree_hal_device_t* device,
iree_hal_queue_affinity_t queue_affinity,
const iree_hal_semaphore_list_t wait_semaphore_list,
const iree_hal_semaphore_list_t signal_semaphore_list,
iree_string_view_t source_scope, iree_hal_buffer_t* target_buffer,
iree_host_size_t count, iree_io_parameter_enumerator_t enumerator) {
IREE_ASSERT_ARGUMENT(provider);
IREE_ASSERT_ARGUMENT(target_buffer);
IREE_TRACE_ZONE_BEGIN(z0);
IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, count);
if (count == 0) {
// Preserve the timeline when there's no work to do.
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_hal_device_queue_barrier(
device, queue_affinity, wait_semaphore_list,
signal_semaphore_list, IREE_HAL_EXECUTE_FLAG_NONE));
} else {
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, provider->vtable->gather(provider, device, queue_affinity,
wait_semaphore_list, signal_semaphore_list,
source_scope, target_buffer, count,
enumerator));
}
IREE_TRACE_ZONE_END(z0);
return iree_ok_status();
}
IREE_API_EXPORT iree_status_t iree_io_parameter_provider_scatter(
iree_io_parameter_provider_t* provider, iree_hal_device_t* device,
iree_hal_queue_affinity_t queue_affinity,
const iree_hal_semaphore_list_t wait_semaphore_list,
const iree_hal_semaphore_list_t signal_semaphore_list,
iree_hal_buffer_t* source_buffer, iree_string_view_t target_scope,
iree_host_size_t count, iree_io_parameter_enumerator_t enumerator) {
IREE_ASSERT_ARGUMENT(provider);
IREE_ASSERT_ARGUMENT(source_buffer);
IREE_TRACE_ZONE_BEGIN(z0);
IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, count);
if (count == 0) {
// Preserve the timeline when there's no work to do.
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_hal_device_queue_barrier(
device, queue_affinity, wait_semaphore_list,
signal_semaphore_list, IREE_HAL_EXECUTE_FLAG_NONE));
} else {
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, provider->vtable->scatter(provider, device, queue_affinity,
wait_semaphore_list,
signal_semaphore_list, source_buffer,
target_scope, count, enumerator));
}
IREE_TRACE_ZONE_END(z0);
return iree_ok_status();
}