blob: a5e878857e804353cbf3dac40669551fb969e223 [file] [log] [blame]
// Copyright 2023 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "experimental/cuda2/cuda_device.h"
#include <stddef.h>
#include <stdint.h>
#include <string.h>
#include "experimental/cuda2/cuda_allocator.h"
#include "experimental/cuda2/cuda_buffer.h"
#include "experimental/cuda2/cuda_dynamic_symbols.h"
#include "experimental/cuda2/cuda_status_util.h"
#include "experimental/cuda2/event_pool.h"
#include "experimental/cuda2/event_semaphore.h"
#include "experimental/cuda2/graph_command_buffer.h"
#include "experimental/cuda2/memory_pools.h"
#include "experimental/cuda2/nccl_channel.h"
#include "experimental/cuda2/nccl_dynamic_symbols.h"
#include "experimental/cuda2/nop_executable_cache.h"
#include "experimental/cuda2/pending_queue_actions.h"
#include "experimental/cuda2/pipeline_layout.h"
#include "experimental/cuda2/timepoint_pool.h"
#include "experimental/cuda2/tracing.h"
#include "iree/base/internal/arena.h"
#include "iree/base/internal/event_pool.h"
#include "iree/base/internal/math.h"
#include "iree/hal/utils/buffer_transfer.h"
#include "iree/hal/utils/deferred_command_buffer.h"
#include "iree/hal/utils/file_transfer.h"
#include "iree/hal/utils/memory_file.h"
//===----------------------------------------------------------------------===//
// iree_hal_cuda2_device_t
//===----------------------------------------------------------------------===//
typedef struct iree_hal_cuda2_device_t {
// Abstract resource used for injecting reference counting and vtable;
// must be at offset 0.
iree_hal_resource_t resource;
iree_string_view_t identifier;
// Block pool used for command buffers with a larger block size (as command
// buffers can contain inlined data uploads).
iree_arena_block_pool_t block_pool;
// Optional driver that owns the CUDA symbols. We retain it for our lifetime
// to ensure the symbols remains valid.
iree_hal_driver_t* driver;
const iree_hal_cuda2_dynamic_symbols_t* cuda_symbols;
const iree_hal_cuda2_nccl_dynamic_symbols_t* nccl_symbols;
// Parameters used to control device behavior.
iree_hal_cuda2_device_params_t params;
CUcontext cu_context;
CUdevice cu_device;
// TODO: Support multiple device streams.
// The CUstream used to issue device kernels and allocations.
CUstream dispatch_cu_stream;
// The CUstream used to issue host callback functions.
CUstream callback_cu_stream;
iree_hal_cuda2_tracing_context_t* tracing_context;
iree_allocator_t host_allocator;
// Host/device event pools, used for backing semaphore timepoints.
iree_event_pool_t* host_event_pool;
iree_hal_cuda2_event_pool_t* device_event_pool;
// Timepoint pools, shared by various semaphores.
iree_hal_cuda2_timepoint_pool_t* timepoint_pool;
// A queue to order device workloads and relase to the GPU when constraints
// are met. It buffers submissions and allocations internally before they
// are ready. This queue couples with HAL semaphores backed by iree_event_t
// and CUevent objects.
iree_hal_cuda2_pending_queue_actions_t* pending_queue_actions;
// Device memory pools and allocators.
bool supports_memory_pools;
iree_hal_cuda2_memory_pools_t memory_pools;
iree_hal_allocator_t* device_allocator;
// Optional provider used for creating/configuring collective channels.
iree_hal_channel_provider_t* channel_provider;
} iree_hal_cuda2_device_t;
static const iree_hal_device_vtable_t iree_hal_cuda2_device_vtable;
static iree_hal_cuda2_device_t* iree_hal_cuda2_device_cast(
iree_hal_device_t* base_value) {
IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_cuda2_device_vtable);
return (iree_hal_cuda2_device_t*)base_value;
}
static iree_hal_cuda2_device_t* iree_hal_cuda2_device_cast_unsafe(
iree_hal_device_t* base_value) {
return (iree_hal_cuda2_device_t*)base_value;
}
IREE_API_EXPORT void iree_hal_cuda2_device_params_initialize(
iree_hal_cuda2_device_params_t* out_params) {
memset(out_params, 0, sizeof(*out_params));
out_params->arena_block_size = 32 * 1024;
out_params->event_pool_capacity = 32;
out_params->queue_count = 1;
out_params->stream_tracing = false;
out_params->async_allocations = true;
}
static iree_status_t iree_hal_cuda2_device_check_params(
const iree_hal_cuda2_device_params_t* params) {
if (params->arena_block_size < 4096) {
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"arena block size too small (< 4096 bytes)");
}
if (params->queue_count == 0) {
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"at least one queue is required");
}
return iree_ok_status();
}
static iree_status_t iree_hal_cuda2_device_create_internal(
iree_hal_driver_t* driver, iree_string_view_t identifier,
const iree_hal_cuda2_device_params_t* params, CUdevice cu_device,
CUstream dispatch_stream, CUstream callback_stream, CUcontext context,
const iree_hal_cuda2_dynamic_symbols_t* cuda_symbols,
const iree_hal_cuda2_nccl_dynamic_symbols_t* nccl_symbols,
iree_allocator_t host_allocator, iree_hal_device_t** out_device) {
iree_hal_cuda2_device_t* device = NULL;
iree_host_size_t total_size = iree_sizeof_struct(*device) + identifier.size;
IREE_RETURN_IF_ERROR(
iree_allocator_malloc(host_allocator, total_size, (void**)&device));
memset(device, 0, total_size);
iree_hal_resource_initialize(&iree_hal_cuda2_device_vtable,
&device->resource);
iree_string_view_append_to_buffer(
identifier, &device->identifier,
(char*)device + iree_sizeof_struct(*device));
iree_arena_block_pool_initialize(params->arena_block_size, host_allocator,
&device->block_pool);
device->driver = driver;
iree_hal_driver_retain(device->driver);
device->cuda_symbols = cuda_symbols;
device->nccl_symbols = nccl_symbols;
device->params = *params;
device->cu_context = context;
device->cu_device = cu_device;
device->dispatch_cu_stream = dispatch_stream;
device->callback_cu_stream = callback_stream;
device->host_allocator = host_allocator;
iree_status_t status = iree_hal_cuda2_pending_queue_actions_create(
cuda_symbols, &device->block_pool, host_allocator,
&device->pending_queue_actions);
// Enable tracing for the (currently only) stream - no-op if disabled.
if (iree_status_is_ok(status) && device->params.stream_tracing) {
status = iree_hal_cuda2_tracing_context_allocate(
device->cuda_symbols, device->identifier, dispatch_stream,
&device->block_pool, host_allocator, &device->tracing_context);
}
// Memory pool support is conditional.
if (iree_status_is_ok(status) && params->async_allocations) {
int supports_memory_pools = 0;
status = IREE_CURESULT_TO_STATUS(
cuda_symbols,
cuDeviceGetAttribute(&supports_memory_pools,
CU_DEVICE_ATTRIBUTE_MEMORY_POOLS_SUPPORTED,
cu_device),
"cuDeviceGetAttribute");
device->supports_memory_pools = supports_memory_pools != 0;
}
// Create memory pools first so that we can share them with the allocator.
if (iree_status_is_ok(status) && device->supports_memory_pools) {
status = iree_hal_cuda2_memory_pools_initialize(
cuda_symbols, cu_device, &params->memory_pools, host_allocator,
&device->memory_pools);
}
if (iree_status_is_ok(status)) {
status = iree_hal_cuda2_allocator_create(
cuda_symbols, cu_device, dispatch_stream,
device->supports_memory_pools ? &device->memory_pools : NULL,
host_allocator, &device->device_allocator);
}
if (iree_status_is_ok(status)) {
*out_device = (iree_hal_device_t*)device;
} else {
iree_hal_device_release((iree_hal_device_t*)device);
}
return status;
}
iree_status_t iree_hal_cuda2_device_create(
iree_hal_driver_t* driver, iree_string_view_t identifier,
const iree_hal_cuda2_device_params_t* params,
const iree_hal_cuda2_dynamic_symbols_t* cuda_symbols,
const iree_hal_cuda2_nccl_dynamic_symbols_t* nccl_symbols, CUdevice device,
iree_allocator_t host_allocator, iree_hal_device_t** out_device) {
IREE_ASSERT_ARGUMENT(driver);
IREE_ASSERT_ARGUMENT(params);
IREE_ASSERT_ARGUMENT(cuda_symbols);
IREE_ASSERT_ARGUMENT(out_device);
IREE_TRACE_ZONE_BEGIN(z0);
iree_status_t status = iree_hal_cuda2_device_check_params(params);
// Get the main context for the device.
CUcontext context = NULL;
if (iree_status_is_ok(status)) {
status = IREE_CURESULT_TO_STATUS(
cuda_symbols, cuDevicePrimaryCtxRetain(&context, device));
}
if (iree_status_is_ok(status)) {
status = IREE_CURESULT_TO_STATUS(cuda_symbols, cuCtxSetCurrent(context));
}
// Create the default dispatch stream for the device.
CUstream dispatch_stream = NULL;
if (iree_status_is_ok(status)) {
status = IREE_CURESULT_TO_STATUS(
cuda_symbols, cuStreamCreate(&dispatch_stream, CU_STREAM_NON_BLOCKING));
}
// Create the default callback stream for the device.
CUstream callback_stream = NULL;
if (iree_status_is_ok(status)) {
status = IREE_CURESULT_TO_STATUS(
cuda_symbols, cuStreamCreate(&callback_stream, CU_STREAM_NON_BLOCKING));
}
if (iree_status_is_ok(status)) {
status = iree_hal_cuda2_device_create_internal(
driver, identifier, params, device, dispatch_stream, callback_stream,
context, cuda_symbols, nccl_symbols, host_allocator, out_device);
} else {
// Release resources we have accquired thus far.
if (callback_stream) cuda_symbols->cuStreamDestroy(callback_stream);
if (dispatch_stream) cuda_symbols->cuStreamDestroy(dispatch_stream);
if (context) cuda_symbols->cuDevicePrimaryCtxRelease(device);
}
iree_event_pool_t* host_event_pool = NULL;
if (iree_status_is_ok(status)) {
status = iree_event_pool_allocate(params->event_pool_capacity,
host_allocator, &host_event_pool);
}
iree_hal_cuda2_event_pool_t* device_event_pool = NULL;
if (iree_status_is_ok(status)) {
status = iree_hal_cuda2_event_pool_allocate(
*out_device, cuda_symbols, params->event_pool_capacity, host_allocator,
&device_event_pool);
}
iree_hal_cuda2_timepoint_pool_t* timepoint_pool = NULL;
if (iree_status_is_ok(status)) {
status = iree_hal_cuda2_timepoint_pool_allocate(
host_event_pool, device_event_pool, params->event_pool_capacity,
host_allocator, &timepoint_pool);
}
if (iree_status_is_ok(status)) {
iree_hal_cuda2_device_t* cuda_device =
iree_hal_cuda2_device_cast(*out_device);
cuda_device->host_event_pool = host_event_pool;
cuda_device->device_event_pool = device_event_pool;
cuda_device->timepoint_pool = timepoint_pool;
} else {
// Release resources we have accquired after HAL device creation.
if (timepoint_pool) iree_hal_cuda2_timepoint_pool_free(timepoint_pool);
if (device_event_pool) iree_hal_cuda2_event_pool_release(device_event_pool);
if (host_event_pool) iree_event_pool_free(host_event_pool);
// Release other resources via the HAL device.
iree_hal_device_release(*out_device);
}
IREE_TRACE_ZONE_END(z0);
return status;
}
CUcontext iree_hal_cuda2_device_context(iree_hal_device_t* base_device) {
iree_hal_cuda2_device_t* device =
iree_hal_cuda2_device_cast_unsafe(base_device);
return device->cu_context;
}
const iree_hal_cuda2_dynamic_symbols_t* iree_hal_cuda2_device_dynamic_symbols(
iree_hal_device_t* base_device) {
iree_hal_cuda2_device_t* device =
iree_hal_cuda2_device_cast_unsafe(base_device);
return device->cuda_symbols;
}
static void iree_hal_cuda2_device_destroy(iree_hal_device_t* base_device) {
iree_hal_cuda2_device_t* device = iree_hal_cuda2_device_cast(base_device);
iree_allocator_t host_allocator = iree_hal_device_host_allocator(base_device);
const iree_hal_cuda2_dynamic_symbols_t* symbols = device->cuda_symbols;
IREE_TRACE_ZONE_BEGIN(z0);
// Destroy the pending workload queue.
iree_hal_cuda2_pending_queue_actions_destroy(
(iree_hal_resource_t*)device->pending_queue_actions);
// There should be no more buffers live that use the allocator.
iree_hal_allocator_release(device->device_allocator);
// Buffers may have been retaining collective resources.
iree_hal_channel_provider_release(device->channel_provider);
// Destroy memory pools that hold on to reserved memory.
iree_hal_cuda2_memory_pools_deinitialize(&device->memory_pools);
iree_hal_cuda2_tracing_context_free(device->tracing_context);
// Destroy various pools for synchronization.
if (device->timepoint_pool) {
iree_hal_cuda2_timepoint_pool_free(device->timepoint_pool);
}
if (device->device_event_pool) {
iree_hal_cuda2_event_pool_release(device->device_event_pool);
}
if (device->host_event_pool) iree_event_pool_free(device->host_event_pool);
IREE_CUDA_IGNORE_ERROR(symbols, cuStreamDestroy(device->dispatch_cu_stream));
IREE_CUDA_IGNORE_ERROR(symbols, cuStreamDestroy(device->callback_cu_stream));
IREE_CUDA_IGNORE_ERROR(symbols, cuDevicePrimaryCtxRelease(device->cu_device));
iree_arena_block_pool_deinitialize(&device->block_pool);
// Finally, destroy the device.
iree_hal_driver_release(device->driver);
iree_allocator_free(host_allocator, device);
IREE_TRACE_ZONE_END(z0);
}
static iree_string_view_t iree_hal_cuda2_device_id(
iree_hal_device_t* base_device) {
iree_hal_cuda2_device_t* device = iree_hal_cuda2_device_cast(base_device);
return device->identifier;
}
static iree_allocator_t iree_hal_cuda2_device_host_allocator(
iree_hal_device_t* base_device) {
iree_hal_cuda2_device_t* device = iree_hal_cuda2_device_cast(base_device);
return device->host_allocator;
}
static iree_hal_allocator_t* iree_hal_cuda2_device_allocator(
iree_hal_device_t* base_device) {
iree_hal_cuda2_device_t* device = iree_hal_cuda2_device_cast(base_device);
return device->device_allocator;
}
static void iree_hal_cuda2_replace_device_allocator(
iree_hal_device_t* base_device, iree_hal_allocator_t* new_allocator) {
iree_hal_cuda2_device_t* device = iree_hal_cuda2_device_cast(base_device);
iree_hal_allocator_retain(new_allocator);
iree_hal_allocator_release(device->device_allocator);
device->device_allocator = new_allocator;
}
static void iree_hal_cuda2_replace_channel_provider(
iree_hal_device_t* base_device, iree_hal_channel_provider_t* new_provider) {
iree_hal_cuda2_device_t* device = iree_hal_cuda2_device_cast(base_device);
iree_hal_channel_provider_retain(new_provider);
iree_hal_channel_provider_release(device->channel_provider);
device->channel_provider = new_provider;
}
static iree_status_t iree_hal_cuda2_device_trim(
iree_hal_device_t* base_device) {
iree_hal_cuda2_device_t* device = iree_hal_cuda2_device_cast(base_device);
iree_arena_block_pool_trim(&device->block_pool);
IREE_RETURN_IF_ERROR(iree_hal_allocator_trim(device->device_allocator));
if (device->supports_memory_pools) {
IREE_RETURN_IF_ERROR(iree_hal_cuda2_memory_pools_trim(
&device->memory_pools, &device->params.memory_pools));
}
return iree_ok_status();
}
static iree_status_t iree_hal_cuda2_device_query_attribute(
iree_hal_cuda2_device_t* device, CUdevice_attribute attribute,
int64_t* out_value) {
int value = 0;
IREE_CUDA_RETURN_IF_ERROR(
device->cuda_symbols,
cuDeviceGetAttribute(&value, attribute, device->cu_device),
"cuDeviceGetAttribute");
*out_value = value;
return iree_ok_status();
}
static iree_status_t iree_hal_cuda2_device_query_i64(
iree_hal_device_t* base_device, iree_string_view_t category,
iree_string_view_t key, int64_t* out_value) {
iree_hal_cuda2_device_t* device = iree_hal_cuda2_device_cast(base_device);
*out_value = 0;
if (iree_string_view_equal(category, IREE_SV("hal.executable.format"))) {
*out_value = iree_string_view_equal(key, IREE_SV("cuda-nvptx-fb")) ? 1 : 0;
return iree_ok_status();
}
if (iree_string_view_equal(category, IREE_SV("cuda.device"))) {
if (iree_string_view_equal(key, IREE_SV("compute_capability_major"))) {
return iree_hal_cuda2_device_query_attribute(
device, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, out_value);
} else if (iree_string_view_equal(key,
IREE_SV("compute_capability_minor"))) {
return iree_hal_cuda2_device_query_attribute(
device, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, out_value);
}
}
return iree_make_status(
IREE_STATUS_NOT_FOUND,
"unknown device configuration key value '%.*s :: %.*s'",
(int)category.size, category.data, (int)key.size, key.data);
}
static iree_status_t iree_hal_cuda2_device_create_channel(
iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity,
iree_hal_channel_params_t params, iree_hal_channel_t** out_channel) {
iree_hal_cuda2_device_t* device = iree_hal_cuda2_device_cast(base_device);
if (!device->nccl_symbols || !device->nccl_symbols->dylib) {
return iree_make_status(
IREE_STATUS_UNAVAILABLE,
"NCCL runtime library (%d.%d.%d) not available; ensure installed and "
"the shared library is on your PATH/LD_LIBRARY_PATH "
"(nccl.dll/libnccl.so)",
NCCL_MAJOR, NCCL_MINOR, NCCL_PATCH);
}
// Today we only allow a single logical device per channel.
// We could multiplex channels but it'd be better to surface that to the
// compiler so that it can emit the right rank math.
int requested_count = iree_math_count_ones_u64(queue_affinity);
// TODO(#12206): properly assign affinity in the compiler.
if (requested_count != 64 && requested_count != 1) {
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"exactly one participant is allowed in a "
"channel but %d were specified",
requested_count);
}
// Ask the channel provider (if configured) for the default rank and count
// if the user did not set them.
if (device->channel_provider &&
(params.rank == IREE_HAL_CHANNEL_RANK_DEFAULT ||
params.count == IREE_HAL_CHANNEL_COUNT_DEFAULT)) {
IREE_RETURN_IF_ERROR(
iree_hal_channel_provider_query_default_rank_and_count(
device->channel_provider, &params.rank, &params.count),
"querying default collective group rank and count");
}
// An ID is required to initialize NCCL. On the root it'll be the local ID and
// on all other participants it'll be the root ID.
iree_hal_cuda2_nccl_id_t id;
memset(&id, 0, sizeof(id));
if (iree_const_byte_span_is_empty(params.id)) {
// User wants the default ID.
if (!device->channel_provider) {
return iree_make_status(
IREE_STATUS_INVALID_ARGUMENT,
"default collective channel ID requested but no channel provider has "
"been set on the device to provide it");
}
if (params.rank == 0) {
// Bootstrap NCCL to get the root ID.
IREE_RETURN_IF_ERROR(
iree_hal_cuda2_nccl_get_unique_id(device->nccl_symbols, &id),
"bootstrapping NCCL root");
}
// Exchange NCCL ID with all participants.
IREE_RETURN_IF_ERROR(iree_hal_channel_provider_exchange_default_id(
device->channel_provider,
iree_make_byte_span((void*)&id, sizeof(id))),
"exchanging NCCL ID with other participants");
} else if (params.id.data_length != IREE_ARRAYSIZE(id.data)) {
// User provided something but it's not what we expect.
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"NCCL ID must be %zu bytes matching the "
"ncclUniqueId struct but caller provided %zu bytes",
IREE_ARRAYSIZE(id.data), sizeof(id));
} else {
// User provided the ID - we treat it as opaque here and let NCCL validate.
memcpy(id.data, params.id.data, IREE_ARRAYSIZE(id.data));
}
if (iree_hal_cuda2_nccl_id_is_empty(&id)) {
// TODO: maybe this is ok? a localhost alias or something?
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"no default NCCL ID specified (all zeros)");
}
// TODO: when we support multiple logical devices we'll want to pass in the
// context of the device mapped to the queue_affinity. For now since this
// implementation only supports one device we pass in the only one we have.
return iree_hal_cuda2_nccl_channel_create(
device->cuda_symbols, device->nccl_symbols, &id, params.rank,
params.count, device->host_allocator, out_channel);
}
static iree_status_t iree_hal_cuda2_device_create_command_buffer(
iree_hal_device_t* base_device, iree_hal_command_buffer_mode_t mode,
iree_hal_command_category_t command_categories,
iree_hal_queue_affinity_t queue_affinity, iree_host_size_t binding_capacity,
iree_hal_command_buffer_t** out_command_buffer) {
iree_hal_cuda2_device_t* device = iree_hal_cuda2_device_cast(base_device);
return iree_hal_cuda2_graph_command_buffer_create(
base_device, device->cuda_symbols, device->cu_context, mode,
command_categories, queue_affinity, binding_capacity, &device->block_pool,
device->host_allocator, out_command_buffer);
}
static iree_status_t iree_hal_cuda2_device_create_descriptor_set_layout(
iree_hal_device_t* base_device,
iree_hal_descriptor_set_layout_flags_t flags,
iree_host_size_t binding_count,
const iree_hal_descriptor_set_layout_binding_t* bindings,
iree_hal_descriptor_set_layout_t** out_descriptor_set_layout) {
iree_hal_cuda2_device_t* device = iree_hal_cuda2_device_cast(base_device);
return iree_hal_cuda2_descriptor_set_layout_create(
flags, binding_count, bindings, device->host_allocator,
out_descriptor_set_layout);
}
static iree_status_t iree_hal_cuda2_device_create_event(
iree_hal_device_t* base_device, iree_hal_event_t** out_event) {
return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
"event not yet implmeneted");
}
static iree_status_t iree_hal_cuda2_device_create_executable_cache(
iree_hal_device_t* base_device, iree_string_view_t identifier,
iree_loop_t loop, iree_hal_executable_cache_t** out_executable_cache) {
iree_hal_cuda2_device_t* device = iree_hal_cuda2_device_cast(base_device);
return iree_hal_cuda2_nop_executable_cache_create(
identifier, device->cuda_symbols, device->cu_device,
device->host_allocator, out_executable_cache);
}
static iree_status_t iree_hal_cuda2_device_import_file(
iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity,
iree_hal_memory_access_t access, iree_io_file_handle_t* handle,
iree_hal_external_file_flags_t flags, iree_hal_file_t** out_file) {
if (iree_io_file_handle_type(handle) !=
IREE_IO_FILE_HANDLE_TYPE_HOST_ALLOCATION) {
return iree_make_status(
IREE_STATUS_UNAVAILABLE,
"implementation does not support the external file type");
}
return iree_hal_memory_file_wrap(
queue_affinity, access, handle, iree_hal_device_allocator(base_device),
iree_hal_device_host_allocator(base_device), out_file);
}
static iree_status_t iree_hal_cuda2_device_create_pipeline_layout(
iree_hal_device_t* base_device, iree_host_size_t push_constants,
iree_host_size_t set_layout_count,
iree_hal_descriptor_set_layout_t* const* set_layouts,
iree_hal_pipeline_layout_t** out_pipeline_layout) {
iree_hal_cuda2_device_t* device = iree_hal_cuda2_device_cast(base_device);
return iree_hal_cuda2_pipeline_layout_create(
set_layout_count, set_layouts, push_constants, device->host_allocator,
out_pipeline_layout);
}
static iree_status_t iree_hal_cuda2_device_create_semaphore(
iree_hal_device_t* base_device, uint64_t initial_value,
iree_hal_semaphore_t** out_semaphore) {
iree_hal_cuda2_device_t* device = iree_hal_cuda2_device_cast(base_device);
return iree_hal_cuda2_event_semaphore_create(
initial_value, device->cuda_symbols, device->timepoint_pool,
device->pending_queue_actions, device->host_allocator, out_semaphore);
}
static iree_hal_semaphore_compatibility_t
iree_hal_cuda2_device_query_semaphore_compatibility(
iree_hal_device_t* base_device, iree_hal_semaphore_t* semaphore) {
// TODO: implement CUDA semaphores.
return IREE_HAL_SEMAPHORE_COMPATIBILITY_HOST_ONLY;
}
// TODO: implement multiple streams; today we only have one and queue_affinity
// is ignored.
// TODO: implement proper semaphores in CUDA to ensure ordering and avoid
// the barrier here.
static iree_status_t iree_hal_cuda2_device_queue_alloca(
iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity,
const iree_hal_semaphore_list_t wait_semaphore_list,
const iree_hal_semaphore_list_t signal_semaphore_list,
iree_hal_allocator_pool_t pool, iree_hal_buffer_params_t params,
iree_device_size_t allocation_size,
iree_hal_buffer_t** IREE_RESTRICT out_buffer) {
iree_hal_cuda2_device_t* device = iree_hal_cuda2_device_cast(base_device);
// NOTE: block on the semaphores here; we could avoid this by properly
// sequencing device work with semaphores. The CUDA HAL is not currently
// asynchronous.
IREE_RETURN_IF_ERROR(iree_hal_semaphore_list_wait(wait_semaphore_list,
iree_infinite_timeout()));
// Allocate from the pool; likely to fail in cases of virtual memory
// exhaustion but the error may be deferred until a later synchronization.
// If pools are not supported we allocate a buffer as normal from whatever
// allocator is set on the device.
iree_status_t status = iree_ok_status();
if (device->supports_memory_pools &&
!iree_all_bits_set(params.type, IREE_HAL_MEMORY_TYPE_HOST_VISIBLE)) {
status = iree_hal_cuda2_memory_pools_alloca(
&device->memory_pools, device->dispatch_cu_stream, pool, params,
allocation_size, out_buffer);
} else {
status = iree_hal_allocator_allocate_buffer(
iree_hal_device_allocator(base_device), params, allocation_size,
out_buffer);
}
// Only signal if not returning a synchronous error - synchronous failure
// indicates that the stream is unchanged (it's not really since we waited
// above, but we at least won't deadlock like this).
if (iree_status_is_ok(status)) {
status = iree_hal_semaphore_list_signal(signal_semaphore_list);
}
return status;
}
// TODO: implement multiple streams; today we only have one and queue_affinity
// is ignored.
// TODO: implement proper semaphores in CUDA to ensure ordering and avoid
// the barrier here.
static iree_status_t iree_hal_cuda2_device_queue_dealloca(
iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity,
const iree_hal_semaphore_list_t wait_semaphore_list,
const iree_hal_semaphore_list_t signal_semaphore_list,
iree_hal_buffer_t* buffer) {
iree_hal_cuda2_device_t* device = iree_hal_cuda2_device_cast(base_device);
// NOTE: block on the semaphores here; we could avoid this by properly
// sequencing device work with semaphores. The CUDA HAL is not currently
// asynchronous.
IREE_RETURN_IF_ERROR(iree_hal_semaphore_list_wait(wait_semaphore_list,
iree_infinite_timeout()));
// Schedule the buffer deallocation if we got it from a pool and otherwise
// drop it on the floor and let it be freed when the buffer is released.
iree_status_t status = iree_ok_status();
if (device->supports_memory_pools) {
status = iree_hal_cuda2_memory_pools_dealloca(
&device->memory_pools, device->dispatch_cu_stream, buffer);
}
// Only signal if not returning a synchronous error - synchronous failure
// indicates that the stream is unchanged (it's not really since we waited
// above, but we at least won't deadlock like this).
if (iree_status_is_ok(status)) {
status = iree_hal_semaphore_list_signal(signal_semaphore_list);
}
return status;
}
static iree_status_t iree_hal_cuda2_device_queue_read(
iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity,
const iree_hal_semaphore_list_t wait_semaphore_list,
const iree_hal_semaphore_list_t signal_semaphore_list,
iree_hal_file_t* source_file, uint64_t source_offset,
iree_hal_buffer_t* target_buffer, iree_device_size_t target_offset,
iree_device_size_t length, uint32_t flags) {
// TODO: expose streaming chunk count/size options.
iree_status_t loop_status = iree_ok_status();
iree_hal_file_transfer_options_t options = {
.loop = iree_loop_inline(&loop_status),
.chunk_count = IREE_HAL_FILE_TRANSFER_CHUNK_COUNT_DEFAULT,
.chunk_size = IREE_HAL_FILE_TRANSFER_CHUNK_SIZE_DEFAULT,
};
IREE_RETURN_IF_ERROR(iree_hal_device_queue_read_streaming(
base_device, queue_affinity, wait_semaphore_list, signal_semaphore_list,
source_file, source_offset, target_buffer, target_offset, length, flags,
options));
return loop_status;
}
static iree_status_t iree_hal_cuda2_device_queue_write(
iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity,
const iree_hal_semaphore_list_t wait_semaphore_list,
const iree_hal_semaphore_list_t signal_semaphore_list,
iree_hal_buffer_t* source_buffer, iree_device_size_t source_offset,
iree_hal_file_t* target_file, uint64_t target_offset,
iree_device_size_t length, uint32_t flags) {
// TODO: expose streaming chunk count/size options.
iree_status_t loop_status = iree_ok_status();
iree_hal_file_transfer_options_t options = {
.loop = iree_loop_inline(&loop_status),
.chunk_count = IREE_HAL_FILE_TRANSFER_CHUNK_COUNT_DEFAULT,
.chunk_size = IREE_HAL_FILE_TRANSFER_CHUNK_SIZE_DEFAULT,
};
IREE_RETURN_IF_ERROR(iree_hal_device_queue_write_streaming(
base_device, queue_affinity, wait_semaphore_list, signal_semaphore_list,
source_buffer, source_offset, target_file, target_offset, length, flags,
options));
return loop_status;
}
static iree_status_t iree_hal_cuda2_device_queue_execute(
iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity,
const iree_hal_semaphore_list_t wait_semaphore_list,
const iree_hal_semaphore_list_t signal_semaphore_list,
iree_host_size_t command_buffer_count,
iree_hal_command_buffer_t* const* command_buffers) {
iree_hal_cuda2_device_t* device = iree_hal_cuda2_device_cast(base_device);
IREE_TRACE_ZONE_BEGIN(z0);
iree_status_t status = iree_hal_cuda2_pending_queue_actions_enqueue_execution(
device->dispatch_cu_stream, device->callback_cu_stream,
device->pending_queue_actions, wait_semaphore_list, signal_semaphore_list,
command_buffer_count, command_buffers);
if (iree_status_is_ok(status)) {
// Try to advance the pending workload queue.
status = iree_hal_cuda2_pending_queue_actions_issue(
device->pending_queue_actions);
}
iree_hal_cuda2_tracing_context_collect(device->tracing_context);
IREE_TRACE_ZONE_END(z0);
return status;
}
static iree_status_t iree_hal_cuda2_device_queue_flush(
iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity) {
iree_hal_cuda2_device_t* device = iree_hal_cuda2_device_cast(base_device);
IREE_TRACE_ZONE_BEGIN(z0);
// Try to advance the pending workload queue.
iree_status_t status =
iree_hal_cuda2_pending_queue_actions_issue(device->pending_queue_actions);
IREE_TRACE_ZONE_END(z0);
return status;
}
static iree_status_t iree_hal_cuda2_device_wait_semaphores(
iree_hal_device_t* base_device, iree_hal_wait_mode_t wait_mode,
const iree_hal_semaphore_list_t semaphore_list, iree_timeout_t timeout) {
return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
"waiting multiple semaphores not yet implemented");
}
static iree_status_t iree_hal_cuda2_device_profiling_begin(
iree_hal_device_t* base_device,
const iree_hal_device_profiling_options_t* options) {
// Unimplemented (and that's ok).
// We could hook in to CUPTI here or use the much simpler cuProfilerStart API.
return iree_ok_status();
}
static iree_status_t iree_hal_cuda2_device_profiling_flush(
iree_hal_device_t* base_device) {
// Unimplemented (and that's ok).
return iree_ok_status();
}
static iree_status_t iree_hal_cuda2_device_profiling_end(
iree_hal_device_t* base_device) {
// Unimplemented (and that's ok).
return iree_ok_status();
}
static const iree_hal_device_vtable_t iree_hal_cuda2_device_vtable = {
.destroy = iree_hal_cuda2_device_destroy,
.id = iree_hal_cuda2_device_id,
.host_allocator = iree_hal_cuda2_device_host_allocator,
.device_allocator = iree_hal_cuda2_device_allocator,
.replace_device_allocator = iree_hal_cuda2_replace_device_allocator,
.replace_channel_provider = iree_hal_cuda2_replace_channel_provider,
.trim = iree_hal_cuda2_device_trim,
.query_i64 = iree_hal_cuda2_device_query_i64,
.create_channel = iree_hal_cuda2_device_create_channel,
.create_command_buffer = iree_hal_cuda2_device_create_command_buffer,
.create_descriptor_set_layout =
iree_hal_cuda2_device_create_descriptor_set_layout,
.create_event = iree_hal_cuda2_device_create_event,
.create_executable_cache = iree_hal_cuda2_device_create_executable_cache,
.import_file = iree_hal_cuda2_device_import_file,
.create_pipeline_layout = iree_hal_cuda2_device_create_pipeline_layout,
.create_semaphore = iree_hal_cuda2_device_create_semaphore,
.query_semaphore_compatibility =
iree_hal_cuda2_device_query_semaphore_compatibility,
.transfer_range = iree_hal_device_submit_transfer_range_and_wait,
.queue_alloca = iree_hal_cuda2_device_queue_alloca,
.queue_dealloca = iree_hal_cuda2_device_queue_dealloca,
.queue_read = iree_hal_cuda2_device_queue_read,
.queue_write = iree_hal_cuda2_device_queue_write,
.queue_execute = iree_hal_cuda2_device_queue_execute,
.queue_flush = iree_hal_cuda2_device_queue_flush,
.wait_semaphores = iree_hal_cuda2_device_wait_semaphores,
.profiling_begin = iree_hal_cuda2_device_profiling_begin,
.profiling_flush = iree_hal_cuda2_device_profiling_flush,
.profiling_end = iree_hal_cuda2_device_profiling_end,
};