[cuda] Port over tracing utilities and use in NCCL channel (#14063)
The main change is removing the context wrapper and including CUDA
dynamic symbols directly.
Progress towards https://github.com/openxla/iree/issues/13245
diff --git a/experimental/cuda2/cuda_device.c b/experimental/cuda2/cuda_device.c
index 36bf8ab..e51d326 100644
--- a/experimental/cuda2/cuda_device.c
+++ b/experimental/cuda2/cuda_device.c
@@ -19,6 +19,7 @@
#include "experimental/cuda2/nccl_dynamic_symbols.h"
#include "experimental/cuda2/nop_executable_cache.h"
#include "experimental/cuda2/pipeline_layout.h"
+#include "experimental/cuda2/tracing.h"
#include "iree/base/internal/arena.h"
#include "iree/base/internal/math.h"
#include "iree/hal/utils/buffer_transfer.h"
@@ -53,6 +54,8 @@
// TODO: support multiple streams.
CUstream cu_stream;
+ iree_hal_cuda2_tracing_context_t* tracing_context;
+
iree_allocator_t host_allocator;
// Device memory pools and allocators.
@@ -82,6 +85,7 @@
memset(out_params, 0, sizeof(*out_params));
out_params->arena_block_size = 32 * 1024;
out_params->queue_count = 1;
+ out_params->stream_tracing = false;
out_params->async_allocations = true;
}
@@ -128,7 +132,13 @@
device->cu_stream = stream;
device->host_allocator = host_allocator;
+ // Enable tracing for the (currently only) stream - no-op if disabled.
iree_status_t status = iree_ok_status();
+ if (device->params.stream_tracing) {
+ status = iree_hal_cuda2_tracing_context_allocate(
+ device->cuda_symbols, device->identifier, stream, &device->block_pool,
+ host_allocator, &device->tracing_context);
+ }
// Memory pool support is conditional.
if (iree_status_is_ok(status) && params->async_allocations) {
@@ -237,6 +247,7 @@
iree_hal_cuda2_memory_pools_deinitialize(&device->memory_pools);
// TODO: support multiple streams.
+ iree_hal_cuda2_tracing_context_free(device->tracing_context);
IREE_CUDA_IGNORE_ERROR(device->cuda_symbols,
cuStreamDestroy(device->cu_stream));