Integrate NCCL (#11481)
This integration enables basic NCCL features in the CUDA runtime. This
enables a minimum test to run. Many more things should be done on top of
this.
To enable a build with NCCL, use `-DIREE_HAL_DRIVER_CUDA_NCCL=ON` for
your cmake command. The same string is used for the macro to guard the C
source code.
Two environmental variables are introduced to set the number of
processes and process ID.
1. `IREE_SPMD_NPROCS`
2. `IREE_SPMD_PROCID`
The GPU ID can be set using `--device=cuda://<index>` or
`--device=cuda://GPU-<uuid>` for `iree-run-module`.
The NCCL dynamic library is loaded when users set `IREE_SPMD_NPROCS` to
>= 1.
There are many things to be done based on this work. We need
1. a full set of E2E tests from stream async ops to the runtime,
2. to decide where to host the modified NCCL source code,
3. to setup a CI test flow and more.
Here is a sample allgather test:
```mlir
func.func @main() -> !hal.buffer_view {
%c0 = arith.constant 0 : index
%c2 = arith.constant 2 : index
%c8 = arith.constant 8 : index
%c16 = arith.constant 16 : index
%input_cst = stream.tensor.constant : tensor<2xi32> in !stream.resource<constant> =
dense<[101, 102]> : tensor<2xi32>
%input = stream.async.transfer %input_cst : !stream.resource<constant>{%c8} -> !stream.resource<*>{%c8}
%fill_val = arith.constant -1 : i32
%output = stream.tensor.splat %fill_val :
i32 -> tensor<2x2xi32> in !stream.resource<*>{%c16}
%channel = stream.channel.default on(#hal.affinity.queue<[0]>) : !stream.channel
%0 = stream.async.collective<all_gather : si32>[%c2]
on(#hal.affinity.queue<[0]>) channel(%channel)
%input[%c0 to %c8 for %c8],
%output[%c0 to %c16 for %c16] :
!stream.resource<*>{%c8} -> %output as !stream.resource<*>{%c16}
%1 = stream.async.transfer %0 : !stream.resource<*>{%c16} -> !stream.resource<external>{%c16}
%result = stream.tensor.export %1 :
tensor<2x2xi32> in !stream.resource<external>{%c16} -> !hal.buffer_view
return %result : !hal.buffer_view
}
```
A sample command to build is:
```zsh
iree-compile --iree-hal-cuda-llvm-target-arch=sm_86 --iree-hal-target-backends=cuda -o allgather.vmfb allgather.mlir
```
Here is a sample command line for a host with two CUDA devices and the
result.
```zsh
IREE_SPMD_NPROCS=2 NCCL_COMM_ID=127.0.0.1:8000 IREE_SPMD_PROCID=0 iree-run-module --device=cuda://0 --module_file=allgather.vmfb --entry_function=main & \
IREE_SPMD_NPROCS=2 NCCL_COMM_ID=127.0.0.1:8000 IREE_SPMD_PROCID=1 iree-run-module --device=cuda://1 --module_file=allgather.vmfb --entry_function=main
EXEC @main
EXEC @main
result[0]: hal.buffer_view
2x2xi32=[101 102][101 102]
result[0]: hal.buffer_view
2x2xi32=[101 102][101 102]
```
diff --git a/CMakeLists.txt b/CMakeLists.txt
index bc71b6c..e6cfa37 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -109,7 +109,7 @@
option(IREE_HAL_DRIVER_DEFAULTS "Sets the default value for all runtime HAL drivers" ON)
# CUDA support must be explicitly enabled.
set(IREE_HAL_DRIVER_CUDA_DEFAULT OFF)
-
+set(IREE_HAL_DRVIER_CUDA_NCCL_DEFAULT OFF)
# Vulkan is not natively supported on Apple platforms.
# Metal should generally be used instead, though MoltenVK may also work.
if(APPLE)
@@ -119,6 +119,7 @@
endif()
option(IREE_HAL_DRIVER_CUDA "Enables the 'cuda' runtime HAL driver" ${IREE_HAL_DRIVER_CUDA_DEFAULT})
+option(IREE_HAL_DRIVER_CUDA_NCCL "Enables the 'nccl' runtime with CUDA" ${IREE_HAL_DRIVER_CUDA_NCCL_DEFAULT})
option(IREE_HAL_DRIVER_LOCAL_SYNC "Enables the 'local-sync' runtime HAL driver" ${IREE_HAL_DRIVER_DEFAULTS})
option(IREE_HAL_DRIVER_LOCAL_TASK "Enables the 'local-task' runtime HAL driver" ${IREE_HAL_DRIVER_DEFAULTS})
option(IREE_HAL_DRIVER_VULKAN "Enables the 'vulkan' runtime HAL driver" ${IREE_HAL_DRIVER_VULKAN_DEFAULT})
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp
index 573c089..7891fe2 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp
@@ -140,7 +140,8 @@
auto bits = IREE::HAL::CommandCategoryBitfield::None;
for (auto &block : region) {
for (auto &op : block) {
- if (isa<IREE::Stream::CmdDispatchOp>(op)) {
+ if (isa<IREE::Stream::CmdDispatchOp>(op) ||
+ isa<IREE::Stream::CmdCollectiveOp>(op)) {
bits = bits | IREE::HAL::CommandCategoryBitfield::Dispatch;
} else {
bits = bits | IREE::HAL::CommandCategoryBitfield::Transfer;
diff --git a/runtime/src/iree/hal/drivers/cuda/CMakeLists.txt b/runtime/src/iree/hal/drivers/cuda/CMakeLists.txt
index 85b2d2b..1422075 100644
--- a/runtime/src/iree/hal/drivers/cuda/CMakeLists.txt
+++ b/runtime/src/iree/hal/drivers/cuda/CMakeLists.txt
@@ -10,6 +10,12 @@
iree_add_all_subdirs()
+if(IREE_HAL_DRIVER_CUDA_NCCL)
+ set(IREE_HAL_DRIVER_CUDA_NCCL_VAL 1)
+else()
+ set(IREE_HAL_DRIVER_CUDA_NCCL_VAL 0)
+endif()
+
iree_cc_library(
NAME
cuda
@@ -59,6 +65,8 @@
iree::hal::utils::resource_set
iree::hal::utils::semaphore_base
iree::schemas::cuda_executable_def_c_fbs
+ DEFINES
+ "IREE_HAL_DRIVER_CUDA_NCCL=${IREE_HAL_DRIVER_CUDA_NCCL_VAL}"
PUBLIC
)
@@ -78,6 +86,8 @@
iree::base::core_headers
iree::base::internal::dynamic_library
iree::base::tracing
+ DEFINES
+ "IREE_HAL_DRIVER_CUDA_NCCL=${IREE_HAL_DRIVER_CUDA_NCCL_VAL}"
PUBLIC
)
@@ -94,3 +104,19 @@
LABELS
"driver=cuda"
)
+
+if(IREE_HAL_DRIVER_CUDA_NCCL)
+iree_cc_test(
+ NAME
+ dynamic_symbols_test_for_nccl
+ SRCS
+ "dynamic_symbols_test_for_nccl.cc"
+ DEPS
+ ::dynamic_symbols
+ iree::base
+ iree::testing::gtest
+ iree::testing::gtest_main
+ LABELS
+ "driver=cuda"
+)
+endif()
diff --git a/runtime/src/iree/hal/drivers/cuda/cuda_allocator.c b/runtime/src/iree/hal/drivers/cuda/cuda_allocator.c
index 2bd5506..c4d44b3 100644
--- a/runtime/src/iree/hal/drivers/cuda/cuda_allocator.c
+++ b/runtime/src/iree/hal/drivers/cuda/cuda_allocator.c
@@ -146,8 +146,8 @@
// Buffers can only be used on the queue if they are device visible.
if (iree_all_bits_set(params->type, IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE)) {
- if (iree_all_bits_set(params->usage,
- IREE_HAL_BUFFER_USAGE_DISPATCH_STORAGE)) {
+ if (iree_any_bit_set(params->usage,
+ IREE_HAL_BUFFER_USAGE_DISPATCH_STORAGE)) {
compatibility |= IREE_HAL_BUFFER_COMPATIBILITY_QUEUE_DISPATCH;
}
}
diff --git a/runtime/src/iree/hal/drivers/cuda/cuda_device.c b/runtime/src/iree/hal/drivers/cuda/cuda_device.c
index 19a9fc8..98a9839 100644
--- a/runtime/src/iree/hal/drivers/cuda/cuda_device.c
+++ b/runtime/src/iree/hal/drivers/cuda/cuda_device.c
@@ -250,16 +250,6 @@
iree_hal_channel_params_t params, iree_hal_channel_t** out_channel) {
iree_hal_cuda_device_t* device = iree_hal_cuda_device_cast(base_device);
- // TODO(#9580): check if nccl symbols are available - if not then we fail
- // here and have the error propagated up to users. If we wanted to delay load
- // NCCL we'd want to take a lock here, load it, and merge the symbols into the
- // dynamic symbol table.
- if (true) {
- return iree_make_status(
- IREE_STATUS_UNIMPLEMENTED,
- "NCCL unavailable and collective operations cannot be performed");
- }
-
// Try to use the ID specified in the parameters and fall back to the default.
iree_hal_cuda_nccl_id_t id;
if (iree_const_byte_span_is_empty(params.id)) {
diff --git a/runtime/src/iree/hal/drivers/cuda/cuda_driver.c b/runtime/src/iree/hal/drivers/cuda/cuda_driver.c
index 4d113d6..21ab104 100644
--- a/runtime/src/iree/hal/drivers/cuda/cuda_driver.c
+++ b/runtime/src/iree/hal/drivers/cuda/cuda_driver.c
@@ -4,6 +4,10 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+#include <iree/base/status.h>
+#if IREE_HAL_DRIVER_CUDA_NCCL
+#include <nccl.h>
+#endif
#include <stdint.h>
#include <string.h>
@@ -48,6 +52,30 @@
out_options->default_device_index = 0;
}
+#if IREE_HAL_DRIVER_CUDA_NCCL
+
+static iree_status_t iree_hal_nccl_get_unique_id_from_env(
+ iree_hal_cuda_driver_t* driver) {
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ char* nccl_comm_id_str = getenv("NCCL_COMM_ID");
+ if (!nccl_comm_id_str) {
+ IREE_TRACE_ZONE_END(z0);
+ return iree_make_status(IREE_STATUS_INTERNAL);
+ }
+
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, NCCL_RESULT_TO_STATUS(
+ &driver->syms,
+ ncclGetUniqueId(
+ (ncclUniqueId*)&driver->default_params.nccl_default_id),
+ "ncclGetUniqueId"));
+ IREE_TRACE_ZONE_END(z0);
+ return iree_ok_status();
+}
+
+#endif // IREE_HAL_DRIVER_CUDA_NCCL
+
static iree_status_t iree_hal_cuda_driver_create_internal(
iree_string_view_t identifier,
const iree_hal_cuda_device_params_t* default_params,
@@ -69,11 +97,23 @@
iree_status_t status =
iree_hal_cuda_dynamic_symbols_initialize(host_allocator, &driver->syms);
- if (iree_status_is_ok(status)) {
- *out_driver = (iree_hal_driver_t*)driver;
- } else {
+ if (!iree_status_is_ok(status)) {
iree_hal_driver_release((iree_hal_driver_t*)driver);
+ return status;
}
+
+#if IREE_HAL_DRIVER_CUDA_NCCL
+ // Initialize NCCL if NPROCS is set.
+ if (driver->default_params.nccl_default_count > 0) {
+ // get a unique ID from the environmental variable
+ status = iree_hal_nccl_get_unique_id_from_env(driver);
+ if (!iree_status_is_ok(status)) {
+ iree_hal_driver_release((iree_hal_driver_t*)driver);
+ return status;
+ }
+ }
+#endif
+ *out_driver = (iree_hal_driver_t*)driver;
return status;
}
@@ -162,7 +202,7 @@
return iree_ok_status();
}
-// Return true if the device support all the extension required.
+// Return true if the device supports all the extension required.
static bool iree_hal_cuda_is_valid_device(iree_hal_cuda_driver_t* driver,
CUdevice device) {
return true;
diff --git a/runtime/src/iree/hal/drivers/cuda/cuda_headers.h b/runtime/src/iree/hal/drivers/cuda/cuda_headers.h
index e2fd2b9..30633a3 100644
--- a/runtime/src/iree/hal/drivers/cuda/cuda_headers.h
+++ b/runtime/src/iree/hal/drivers/cuda/cuda_headers.h
@@ -8,5 +8,7 @@
#define IREE_HAL_DRIVERS_CUDA_CUDA_HEADERS_H_
#include "cuda.h" // IWYU pragma: export
-
-#endif // IREE_HAL_DRIVERS_CUDA_CUDA_HEADERS_H_
+#if IREE_HAL_DRIVER_CUDA_NCCL
+#include "nccl.h" // IWYU pragma: export
+#endif // IREE_HAL_DRIVER_CUDA_NCCL
+#endif // IREE_HAL_DRIVERS_CUDA_CUDA_HEADERS_H_
diff --git a/runtime/src/iree/hal/drivers/cuda/dynamic_symbol_tables.h b/runtime/src/iree/hal/drivers/cuda/dynamic_symbol_tables.h
index 11c4f8b..2004d2f 100644
--- a/runtime/src/iree/hal/drivers/cuda/dynamic_symbol_tables.h
+++ b/runtime/src/iree/hal/drivers/cuda/dynamic_symbol_tables.h
@@ -57,3 +57,42 @@
CU_PFN_DECL(cuLaunchKernel, CUfunction, unsigned int, unsigned int,
unsigned int, unsigned int, unsigned int, unsigned int,
unsigned int, CUstream, void**, void**)
+
+// NCCL
+
+NCCL_PFN_DECL(ncclGetVersion, int *)
+NCCL_PFN_DECL(ncclGetUniqueId, ncclUniqueId *)
+NCCL_PFN_DECL(ncclCommInitRankConfig, ncclComm_t *, int, ncclUniqueId, int,
+ ncclConfig_t *)
+NCCL_PFN_DECL(ncclCommInitRank, ncclComm_t *, int, ncclUniqueId, int)
+NCCL_PFN_DECL(ncclCommInitAll, ncclComm_t *, int, const int *)
+NCCL_PFN_DECL(ncclCommFinalize, ncclComm_t)
+NCCL_PFN_DECL(ncclCommDestroy, ncclComm_t)
+NCCL_PFN_DECL(ncclCommAbort, ncclComm_t)
+NCCL_PFN_DECL_STR_RETURN(ncclGetErrorString, ncclResult_t)
+NCCL_PFN_DECL_STR_RETURN(ncclGetLastError, ncclComm_t)
+NCCL_PFN_DECL(ncclCommGetAsyncError, ncclComm_t, ncclResult_t *)
+NCCL_PFN_DECL(ncclCommCount, const ncclComm_t, int *)
+NCCL_PFN_DECL(ncclCommCuDevice, const ncclComm_t, int *)
+NCCL_PFN_DECL(ncclCommUserRank, const ncclComm_t, int *)
+NCCL_PFN_DECL(ncclRedOpCreatePreMulSum, ncclRedOp_t *, void *, ncclDataType_t,
+ ncclScalarResidence_t, ncclComm_t)
+NCCL_PFN_DECL(ncclRedOpDestroy, ncclRedOp_t, ncclComm_t)
+NCCL_PFN_DECL(ncclReduce, const void *, void *, size_t, ncclDataType_t,
+ ncclRedOp_t, int, ncclComm_t, cudaStream_t)
+NCCL_PFN_DECL(ncclBcast, void *, size_t, ncclDataType_t, int, ncclComm_t,
+ cudaStream_t)
+NCCL_PFN_DECL(ncclBroadcast, const void *, void *, size_t, ncclDataType_t, int,
+ ncclComm_t, cudaStream_t)
+NCCL_PFN_DECL(ncclAllReduce, const void *, void *, size_t, ncclDataType_t,
+ ncclRedOp_t, ncclComm_t, cudaStream_t)
+NCCL_PFN_DECL(ncclReduceScatter, const void *, void *, size_t, ncclDataType_t,
+ ncclRedOp_t, ncclComm_t, cudaStream_t)
+NCCL_PFN_DECL(ncclAllGather, const void *, void *, size_t, ncclDataType_t,
+ ncclComm_t, cudaStream_t)
+NCCL_PFN_DECL(ncclSend, const void *, size_t, ncclDataType_t, int, ncclComm_t,
+ cudaStream_t)
+NCCL_PFN_DECL(ncclRecv, void *, size_t, ncclDataType_t, int, ncclComm_t,
+ cudaStream_t)
+NCCL_PFN_DECL(ncclGroupStart)
+NCCL_PFN_DECL(ncclGroupEnd)
diff --git a/runtime/src/iree/hal/drivers/cuda/dynamic_symbols.c b/runtime/src/iree/hal/drivers/cuda/dynamic_symbols.c
index 9715a46..f199dcf 100644
--- a/runtime/src/iree/hal/drivers/cuda/dynamic_symbols.c
+++ b/runtime/src/iree/hal/drivers/cuda/dynamic_symbols.c
@@ -20,23 +20,52 @@
#endif
};
+#if IREE_HAL_DRIVER_CUDA_NCCL
+static const char* kNCCLLoaderSearchNames[] = {
+#if defined(IREE_PLATFORM_WINDOWS)
+ "nccl.dll",
+#else
+ "libnccl.so",
+#endif
+};
+#endif // IREE_HAL_DRIVER_CUDA_NCCL
+
#define concat(A, B) A B
// Load CUDA entry points, prefer _v2 version if it exists.
static iree_status_t iree_hal_cuda_dynamic_symbols_resolve_all(
iree_hal_cuda_dynamic_symbols_t* syms) {
-#define CU_PFN_DECL(cudaSymbolName, ...) \
- { \
- static const char* kName = #cudaSymbolName; \
- IREE_RETURN_IF_ERROR(iree_dynamic_library_lookup_symbol( \
- syms->loader_library, kName, (void**)&syms->cudaSymbolName)); \
- static const char* kNameV2 = concat(#cudaSymbolName, "_v2"); \
- void* funV2; \
- iree_dynamic_library_lookup_symbol(syms->loader_library, kNameV2, &funV2); \
- if (funV2) syms->cudaSymbolName = funV2; \
+#define CU_PFN_DECL(cudaSymbolName, ...) \
+ { \
+ static const char* kName = #cudaSymbolName; \
+ IREE_RETURN_IF_ERROR(iree_dynamic_library_lookup_symbol( \
+ syms->cuda_library, kName, (void**)&syms->cudaSymbolName)); \
+ static const char* kNameV2 = concat(#cudaSymbolName, "_v2"); \
+ void* funV2; \
+ iree_dynamic_library_lookup_symbol(syms->cuda_library, kNameV2, &funV2); \
+ if (funV2) syms->cudaSymbolName = funV2; \
}
+#if IREE_HAL_DRIVER_CUDA_NCCL
+#define NCCL_PFN_DECL(ncclSymbolName, ...) \
+ { \
+ static const char* kName = #ncclSymbolName; \
+ IREE_RETURN_IF_ERROR(iree_dynamic_library_lookup_symbol( \
+ syms->nccl_library, kName, (void**)&syms->ncclSymbolName)); \
+ }
+#define NCCL_PFN_DECL_STR_RETURN(ncclSymbolName, ...) \
+ { \
+ static const char* kName = #ncclSymbolName; \
+ IREE_RETURN_IF_ERROR(iree_dynamic_library_lookup_symbol( \
+ syms->nccl_library, kName, (void**)&syms->ncclSymbolName)); \
+ }
+#else
+#define NCCL_PFN_DECL(ncclSymbolName, ...)
+#define NCCL_PFN_DECL_STR_RETURN(ncclSymbolName, ...)
+#endif
#include "iree/hal/drivers/cuda/dynamic_symbol_tables.h" // IWYU pragma: keep
#undef CU_PFN_DECL
+#undef NCCL_PFN_DECL
+#undef NCCL_PFN_DECL_STR_RETURN
return iree_ok_status();
}
@@ -47,14 +76,24 @@
memset(out_syms, 0, sizeof(*out_syms));
iree_status_t status = iree_dynamic_library_load_from_files(
IREE_ARRAYSIZE(kCUDALoaderSearchNames), kCUDALoaderSearchNames,
- IREE_DYNAMIC_LIBRARY_FLAG_NONE, host_allocator,
- &out_syms->loader_library);
+ IREE_DYNAMIC_LIBRARY_FLAG_NONE, host_allocator, &out_syms->cuda_library);
if (iree_status_is_not_found(status)) {
iree_status_ignore(status);
return iree_make_status(
IREE_STATUS_UNAVAILABLE,
"CUDA runtime library not available; ensure installed and on path");
}
+#if IREE_HAL_DRIVER_CUDA_NCCL
+ status = iree_dynamic_library_load_from_files(
+ IREE_ARRAYSIZE(kNCCLLoaderSearchNames), kNCCLLoaderSearchNames,
+ IREE_DYNAMIC_LIBRARY_FLAG_NONE, host_allocator, &out_syms->nccl_library);
+ if (iree_status_is_not_found(status)) {
+ iree_status_ignore(status);
+ return iree_make_status(
+ IREE_STATUS_UNAVAILABLE,
+ "NCCL runtime library not available; ensure installed and on path");
+ }
+#endif // IREE_HAL_DRIVER_CUDA_NCCL
if (iree_status_is_ok(status)) {
status = iree_hal_cuda_dynamic_symbols_resolve_all(out_syms);
}
@@ -68,7 +107,10 @@
void iree_hal_cuda_dynamic_symbols_deinitialize(
iree_hal_cuda_dynamic_symbols_t* syms) {
IREE_TRACE_ZONE_BEGIN(z0);
- iree_dynamic_library_release(syms->loader_library);
+ iree_dynamic_library_release(syms->cuda_library);
+#if IREE_HAL_DRIVER_CUDA_NCCL
+ iree_dynamic_library_release(syms->nccl_library);
+#endif
memset(syms, 0, sizeof(*syms));
IREE_TRACE_ZONE_END(z0);
}
diff --git a/runtime/src/iree/hal/drivers/cuda/dynamic_symbols.h b/runtime/src/iree/hal/drivers/cuda/dynamic_symbols.h
index 44750f7..ef524ff 100644
--- a/runtime/src/iree/hal/drivers/cuda/dynamic_symbols.h
+++ b/runtime/src/iree/hal/drivers/cuda/dynamic_symbols.h
@@ -7,6 +7,9 @@
#ifndef IREE_HAL_DRIVERS_CUDA_DYNAMIC_SYMBOLS_H_
#define IREE_HAL_DRIVERS_CUDA_DYNAMIC_SYMBOLS_H_
+#if IREE_HAL_DRIVER_CUDA_NCCL
+#include <nccl.h>
+#endif
#include "iree/base/api.h"
#include "iree/base/internal/dynamic_library.h"
#include "iree/hal/drivers/cuda/cuda_headers.h"
@@ -15,17 +18,29 @@
extern "C" {
#endif // __cplusplus
-// DynamicSymbols allow loading dynamically a subset of CUDA driver API. It
-// loads all the function declared in `dynamic_symbol_tables.def` and fail if
-// any of the symbol is not available. The functions signatures are matching
-// the declarations in `cuda.h`.
+// DynamicSymbols allow loading dynamically a subset of CUDA driver and NCCL
+// API. It loads all the function declared in `dynamic_symbol_tables.h` and fail
+// if any of the symbol is not available. The functions signatures are matching
+// the declarations in `cuda.h` and `nccl.h`.
typedef struct iree_hal_cuda_dynamic_symbols_t {
- iree_dynamic_library_t* loader_library;
+ iree_dynamic_library_t* cuda_library;
+ iree_dynamic_library_t* nccl_library;
#define CU_PFN_DECL(cudaSymbolName, ...) \
CUresult (*cudaSymbolName)(__VA_ARGS__);
+#if IREE_HAL_DRIVER_CUDA_NCCL
+#define NCCL_PFN_DECL(ncclSymbolName, ...) \
+ ncclResult_t (*ncclSymbolName)(__VA_ARGS__);
+#define NCCL_PFN_DECL_STR_RETURN(ncclSymbolName, ...) \
+ const char* (*ncclSymbolName)(__VA_ARGS__);
+#else
+#define NCCL_PFN_DECL(ncclSymbolName, ...)
+#define NCCL_PFN_DECL_STR_RETURN(ncclSymbolName, ...)
+#endif
#include "iree/hal/drivers/cuda/dynamic_symbol_tables.h" // IWYU pragma: export
#undef CU_PFN_DECL
+#undef NCCL_PFN_DECL
+#undef NCCL_PFN_DECL_STR_RETURN
} iree_hal_cuda_dynamic_symbols_t;
// Initializes |out_syms| in-place with dynamically loaded CUDA symbols.
diff --git a/runtime/src/iree/hal/drivers/cuda/dynamic_symbols_test.cc b/runtime/src/iree/hal/drivers/cuda/dynamic_symbols_test.cc
index 9c001de..e3b366d 100644
--- a/runtime/src/iree/hal/drivers/cuda/dynamic_symbols_test.cc
+++ b/runtime/src/iree/hal/drivers/cuda/dynamic_symbols_test.cc
@@ -16,7 +16,7 @@
namespace cuda {
namespace {
-#define CUDE_CHECK_ERRORS(expr) \
+#define CUDA_CHECK_ERRORS(expr) \
{ \
CUresult status = expr; \
ASSERT_EQ(CUDA_SUCCESS, status); \
@@ -34,11 +34,11 @@
}
int device_count = 0;
- CUDE_CHECK_ERRORS(symbols.cuInit(0));
- CUDE_CHECK_ERRORS(symbols.cuDeviceGetCount(&device_count));
+ CUDA_CHECK_ERRORS(symbols.cuInit(0));
+ CUDA_CHECK_ERRORS(symbols.cuDeviceGetCount(&device_count));
if (device_count > 0) {
CUdevice device;
- CUDE_CHECK_ERRORS(symbols.cuDeviceGet(&device, /*ordinal=*/0));
+ CUDA_CHECK_ERRORS(symbols.cuDeviceGet(&device, /*ordinal=*/0));
}
iree_hal_cuda_dynamic_symbols_deinitialize(&symbols);
diff --git a/runtime/src/iree/hal/drivers/cuda/dynamic_symbols_test_for_nccl.cc b/runtime/src/iree/hal/drivers/cuda/dynamic_symbols_test_for_nccl.cc
new file mode 100644
index 0000000..56be1ab
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/cuda/dynamic_symbols_test_for_nccl.cc
@@ -0,0 +1,46 @@
+// Copyright 2022 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include <nccl.h>
+
+#include <iostream>
+
+#include "iree/base/api.h"
+#include "iree/hal/drivers/cuda/dynamic_symbols.h"
+#include "iree/testing/gtest.h"
+
+namespace iree {
+namespace hal {
+namespace cuda {
+namespace {
+
+#define NCCL_CHECK_ERRORS(expr) \
+ { \
+ ncclResult_t status = expr; \
+ ASSERT_EQ(ncclSuccess, status); \
+ }
+
+TEST(DynamicSymbolsTest, CreateFromSystemLoader) {
+ iree_hal_cuda_dynamic_symbols_t symbols;
+ iree_status_t status = iree_hal_cuda_dynamic_symbols_initialize(
+ iree_allocator_system(), &symbols);
+ if (!iree_status_is_ok(status)) {
+ iree_status_fprint(stderr, status);
+ iree_status_ignore(status);
+ std::cerr << "Symbols cannot be loaded, skipping test.";
+ GTEST_FAIL();
+ }
+
+ int nccl_version = 0;
+ NCCL_CHECK_ERRORS(symbols.ncclGetVersion(&nccl_version));
+ ASSERT_EQ(NCCL_VERSION_CODE, nccl_version);
+ iree_hal_cuda_dynamic_symbols_deinitialize(&symbols);
+}
+
+} // namespace
+} // namespace cuda
+} // namespace hal
+} // namespace iree
diff --git a/runtime/src/iree/hal/drivers/cuda/nccl_channel.c b/runtime/src/iree/hal/drivers/cuda/nccl_channel.c
index 186064e..3255f32 100644
--- a/runtime/src/iree/hal/drivers/cuda/nccl_channel.c
+++ b/runtime/src/iree/hal/drivers/cuda/nccl_channel.c
@@ -6,10 +6,21 @@
#include "iree/hal/drivers/cuda/nccl_channel.h"
+#include <iree/base/config.h>
+#include <iree/base/status.h>
+#include <iree/hal/command_buffer.h>
+#include <iree/hal/utils/collective_batch.h>
+#if IREE_HAL_DRIVER_CUDA_NCCL
+#include <nccl.h>
+#endif
#include <stddef.h>
#include "iree/base/api.h"
#include "iree/base/tracing.h"
+#include "iree/hal/drivers/cuda/cuda_buffer.h"
+#include "iree/hal/drivers/cuda/status_util.h"
+
+#if IREE_HAL_DRIVER_CUDA_NCCL
// Returns the same value as NCCL's init.cc hashUniqueId.
// These magic constants were chosen by their implementation and unlikely to
@@ -67,24 +78,21 @@
IREE_TRACE_ZONE_APPEND_VALUE(z0, rank);
IREE_TRACE_ZONE_APPEND_VALUE(z0, count);
- // TODO(#9580): actually use nccl to create a communicator.
- // Something like:
- // ncclConfig_t config = NCCL_CONFIG_INITIALIZER;
- // config.blocking = 0;
- // syms->ncclCommInitRankConfig(&comm, count, *id, rank, &config);
- // NOTE: CHECK ERRORS! we can safely return here as we haven't allocated the
- // channel wrapper yet.
ncclComm_t comm = NULL;
- if (!comm) {
+ ncclConfig_t config = NCCL_CONFIG_INITIALIZER;
+ config.blocking = 1; // FIXME: use async to check a timeout
+ iree_status_t status = NCCL_RESULT_TO_STATUS(
+ context_wrapper->syms,
+ ncclCommInitRankConfig(&comm, count, *((const ncclUniqueId*)id), rank,
+ &config));
+ if (!iree_status_is_ok(status)) {
IREE_TRACE_ZONE_END(z0);
- return iree_make_status(
- IREE_STATUS_INTERNAL,
- "failed to create NCCL communicator for rank=%d count=%d", rank, count);
+ return status;
}
iree_hal_cuda_nccl_channel_t* channel = NULL;
- iree_status_t status = iree_allocator_malloc(
- context_wrapper->host_allocator, sizeof(*channel), (void**)&channel);
+ status = iree_allocator_malloc(context_wrapper->host_allocator,
+ sizeof(*channel), (void**)&channel);
if (iree_status_is_ok(status)) {
iree_hal_resource_initialize(&iree_hal_cuda_nccl_channel_vtable,
&channel->resource);
@@ -110,7 +118,7 @@
IREE_TRACE_ZONE_APPEND_VALUE(z0, channel->rank);
IREE_TRACE_ZONE_APPEND_VALUE(z0, channel->count);
- // TODO(#9580): tear down nccl - blocking if needed.
+ // TODO(#9580): support async tear down
// We could be smarter about starting finalization of all channels async and
// then waiting for them to complete but we aren't currently optimizing for
// lifetime performance. To do that we'd probably want to track each open
@@ -122,9 +130,11 @@
// syms->ncclCommDestroy(channel->comm)
// Should work the same (as we are doing a blocking teardown):
// syms->ncclCommDestroy(channel->comm)
-
+ NCCL_IGNORE_ERROR(channel->context_wrapper->syms,
+ ncclCommFinalize(channel->comm));
+ NCCL_IGNORE_ERROR(channel->context_wrapper->syms,
+ ncclCommDestroy(channel->comm));
iree_allocator_free(host_allocator, channel);
-
IREE_TRACE_ZONE_END(z0);
}
@@ -140,36 +150,252 @@
*out_count = channel->count;
}
-ncclComm_t iree_hal_cuda_nccl_channel_comm(iree_hal_channel_t* base_channel) {
+// Returns the NCCL communicator for the given |channel|, if available.
+static ncclComm_t iree_hal_cuda_nccl_channel_comm(
+ iree_hal_channel_t* base_channel) {
IREE_ASSERT_ARGUMENT(base_channel);
iree_hal_cuda_nccl_channel_t* channel =
iree_hal_cuda_nccl_channel_cast(base_channel);
return channel->comm;
}
+static iree_status_t get_nccl_data_type(iree_hal_collective_element_type_t in,
+ ncclDataType_t* out) {
+ switch (in) {
+ case IREE_HAL_COLLECTIVE_ELEMENT_TYPE_SINT_8:
+ *out = ncclInt8;
+ break;
+ case IREE_HAL_COLLECTIVE_ELEMENT_TYPE_UINT_8:
+ *out = ncclUint8;
+ break;
+ case IREE_HAL_COLLECTIVE_ELEMENT_TYPE_SINT_16:
+ return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
+ "SINT16 is not supported for collective op");
+ case IREE_HAL_COLLECTIVE_ELEMENT_TYPE_UINT_16:
+ return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
+ "UINT16 is not supported for collective op");
+ case IREE_HAL_COLLECTIVE_ELEMENT_TYPE_SINT_32:
+ *out = ncclInt32;
+ break;
+ case IREE_HAL_COLLECTIVE_ELEMENT_TYPE_UINT_32:
+ *out = ncclUint32;
+ break;
+ case IREE_HAL_COLLECTIVE_ELEMENT_TYPE_SINT_64:
+ *out = ncclInt64;
+ break;
+ case IREE_HAL_COLLECTIVE_ELEMENT_TYPE_UINT_64:
+ *out = ncclUint64;
+ break;
+ case IREE_HAL_COLLECTIVE_ELEMENT_TYPE_FLOAT_16:
+ *out = ncclFloat16;
+ break;
+ case IREE_HAL_COLLECTIVE_ELEMENT_TYPE_FLOAT_32:
+ *out = ncclFloat32;
+ break;
+ case IREE_HAL_COLLECTIVE_ELEMENT_TYPE_FLOAT_64:
+ *out = ncclFloat64;
+ break;
+ case IREE_HAL_COLLECTIVE_ELEMENT_TYPE_BFLOAT_16:
+ *out = ncclFloat64;
+ break;
+ }
+ return iree_ok_status();
+}
+
+static iree_status_t get_nccl_red_type(iree_hal_collective_reduction_t in,
+ ncclRedOp_t* out) {
+ switch (in) {
+ case IREE_HAL_COLLECTIVE_REDUCTION_SUM:
+ *out = ncclSum;
+ break;
+ case IREE_HAL_COLLECTIVE_REDUCTION_PRODUCT:
+ *out = ncclProd;
+ break;
+ case IREE_HAL_COLLECTIVE_REDUCTION_MINIMUM:
+ *out = ncclMin;
+ break;
+ case IREE_HAL_COLLECTIVE_REDUCTION_MAXIMUM:
+ *out = ncclMax;
+ break;
+ case IREE_HAL_COLLECTIVE_REDUCTION_AVERAGE:
+ *out = ncclAvg;
+ break;
+ }
+
+ return iree_ok_status();
+}
+
+static iree_status_t iree_hal_cuda_nccl_submit_batch_entry(
+ const iree_hal_collective_batch_entry_t* entry, CUstream stream) {
+ IREE_ASSERT_ARGUMENT(entry);
+ IREE_ASSERT_ARGUMENT(stream);
+
+ iree_hal_cuda_nccl_channel_t* channel =
+ iree_hal_cuda_nccl_channel_cast(entry->channel);
+ iree_hal_cuda_dynamic_symbols_t* syms = channel->context_wrapper->syms;
+ ncclComm_t comm = iree_hal_cuda_nccl_channel_comm(entry->channel);
+ ncclDataType_t datatype;
+ IREE_RETURN_IF_ERROR(get_nccl_data_type(entry->op.element_type, &datatype));
+
+ switch (entry->op.kind) {
+ case IREE_HAL_COLLECTIVE_KIND_ALL_GATHER: {
+ CUdeviceptr sendbuff =
+ iree_hal_cuda_buffer_device_pointer(
+ iree_hal_buffer_allocated_buffer(entry->send_binding.buffer)) +
+ iree_hal_buffer_byte_offset(entry->send_binding.buffer) +
+ entry->send_binding.offset;
+ CUdeviceptr recvbuff =
+ iree_hal_cuda_buffer_device_pointer(
+ iree_hal_buffer_allocated_buffer(entry->recv_binding.buffer)) +
+ iree_hal_buffer_byte_offset(entry->recv_binding.buffer) +
+ entry->recv_binding.offset;
+ NCCL_RETURN_IF_ERROR(
+ syms,
+ ncclAllGather((const void*)sendbuff, (void*)recvbuff,
+ entry->element_count, datatype, comm, stream),
+ "ncclAllGather");
+ break;
+ }
+ case IREE_HAL_COLLECTIVE_KIND_ALL_REDUCE: {
+ CUdeviceptr sendbuff =
+ iree_hal_cuda_buffer_device_pointer(
+ iree_hal_buffer_allocated_buffer(entry->send_binding.buffer)) +
+ iree_hal_buffer_byte_offset(entry->send_binding.buffer) +
+ entry->send_binding.offset;
+ CUdeviceptr recvbuff =
+ iree_hal_cuda_buffer_device_pointer(
+ iree_hal_buffer_allocated_buffer(entry->recv_binding.buffer)) +
+ iree_hal_buffer_byte_offset(entry->recv_binding.buffer) +
+ entry->recv_binding.offset;
+ ncclRedOp_t redop;
+ IREE_RETURN_IF_ERROR(get_nccl_red_type(entry->op.reduction, &redop));
+ NCCL_RETURN_IF_ERROR(
+ syms,
+ ncclAllReduce((const void*)sendbuff, (void*)recvbuff,
+ entry->element_count, datatype, redop, comm, stream),
+ "ncclAllReduce");
+ break;
+ }
+ case IREE_HAL_COLLECTIVE_KIND_BROADCAST: {
+ CUdeviceptr sendbuff =
+ iree_hal_cuda_buffer_device_pointer(
+ iree_hal_buffer_allocated_buffer(entry->send_binding.buffer)) +
+ iree_hal_buffer_byte_offset(entry->send_binding.buffer) +
+ entry->send_binding.offset;
+ CUdeviceptr recvbuff =
+ iree_hal_cuda_buffer_device_pointer(
+ iree_hal_buffer_allocated_buffer(entry->recv_binding.buffer)) +
+ iree_hal_buffer_byte_offset(entry->recv_binding.buffer) +
+ entry->recv_binding.offset;
+ NCCL_RETURN_IF_ERROR(syms,
+ ncclBroadcast((const void*)sendbuff, (void*)recvbuff,
+ entry->element_count, datatype,
+ entry->param, comm, stream),
+ "ncclBroadcast");
+ break;
+ }
+ case IREE_HAL_COLLECTIVE_KIND_REDUCE: {
+ CUdeviceptr sendbuff =
+ iree_hal_cuda_buffer_device_pointer(
+ iree_hal_buffer_allocated_buffer(entry->send_binding.buffer)) +
+ iree_hal_buffer_byte_offset(entry->send_binding.buffer) +
+ entry->send_binding.offset;
+ CUdeviceptr recvbuff =
+ iree_hal_cuda_buffer_device_pointer(
+ iree_hal_buffer_allocated_buffer(entry->recv_binding.buffer)) +
+ iree_hal_buffer_byte_offset(entry->recv_binding.buffer) +
+ entry->recv_binding.offset;
+ ncclRedOp_t redop;
+ IREE_RETURN_IF_ERROR(get_nccl_red_type(entry->op.reduction, &redop));
+ NCCL_RETURN_IF_ERROR(syms,
+ ncclReduce((const void*)sendbuff, (void*)recvbuff,
+ entry->element_count, datatype, redop,
+ entry->param, comm, stream),
+ "ncclReduce");
+ break;
+ }
+ case IREE_HAL_COLLECTIVE_KIND_REDUCE_SCATTER: {
+ CUdeviceptr sendbuff =
+ iree_hal_cuda_buffer_device_pointer(
+ iree_hal_buffer_allocated_buffer(entry->send_binding.buffer)) +
+ iree_hal_buffer_byte_offset(entry->send_binding.buffer) +
+ entry->send_binding.offset;
+ CUdeviceptr recvbuff =
+ iree_hal_cuda_buffer_device_pointer(
+ iree_hal_buffer_allocated_buffer(entry->recv_binding.buffer)) +
+ iree_hal_buffer_byte_offset(entry->recv_binding.buffer) +
+ entry->recv_binding.offset;
+ ncclRedOp_t redop;
+ IREE_RETURN_IF_ERROR(get_nccl_red_type(entry->op.reduction, &redop));
+ NCCL_RETURN_IF_ERROR(
+ syms,
+ ncclReduceScatter((const void*)sendbuff, (void*)recvbuff,
+ entry->element_count, datatype, redop, comm,
+ stream),
+ "ncclReduceScatter");
+ break;
+ }
+ case IREE_HAL_COLLECTIVE_KIND_SEND: {
+ CUdeviceptr sendbuff =
+ iree_hal_cuda_buffer_device_pointer(
+ iree_hal_buffer_allocated_buffer(entry->send_binding.buffer)) +
+ iree_hal_buffer_byte_offset(entry->send_binding.buffer) +
+ entry->send_binding.offset;
+ NCCL_RETURN_IF_ERROR(syms,
+ ncclSend((const void*)sendbuff, entry->element_count,
+ datatype, entry->param, comm, stream),
+ "ncclSend");
+ break;
+ }
+ case IREE_HAL_COLLECTIVE_KIND_RECV: {
+ CUdeviceptr recvbuff =
+ iree_hal_cuda_buffer_device_pointer(
+ iree_hal_buffer_allocated_buffer(entry->recv_binding.buffer)) +
+ iree_hal_buffer_byte_offset(entry->recv_binding.buffer) +
+ entry->recv_binding.offset;
+ NCCL_RETURN_IF_ERROR(syms,
+ ncclRecv((void*)recvbuff, entry->element_count,
+ datatype, entry->param, comm, stream),
+ "ncclRecv");
+ break;
+ }
+ } // switch
+ return iree_ok_status();
+}
+
iree_status_t iree_hal_cuda_nccl_submit_batch(
iree_hal_cuda_context_wrapper_t* context,
const iree_hal_collective_batch_t* batch, CUstream stream) {
IREE_ASSERT_ARGUMENT(context);
IREE_ASSERT_ARGUMENT(batch);
IREE_ASSERT_ARGUMENT(stream);
-
- // TODO(#9580): issue the operations in the batch. Note that the channel may
- // change between ops and the communicator should be retrieved from each.
- //
- // Something like:
- // make context->cu_context active (for when using multiple devices)
- // syms->ncclGroupStart();
- // for each entry in batch:
- // ncclComm_t comm = iree_hal_cuda_nccl_channel_comm(entry->channel);
- // syms->nccl*(comm, ...);
- // syms->ncclGroupEnd();
-
- return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
- "NCCL submission not yet implemented");
+ NCCL_RETURN_IF_ERROR(context->syms, ncclGroupStart(), "ncclGroupStart");
+ for (IREE_HOST_SIZE_T i = 0; i < batch->count; ++i) {
+ iree_hal_cuda_nccl_submit_batch_entry(&batch->entries[i], stream);
+ }
+ return NCCL_RESULT_TO_STATUS(context->syms, ncclGroupEnd(), "ncclGroupEnd");
}
static const iree_hal_channel_vtable_t iree_hal_cuda_nccl_channel_vtable = {
.destroy = iree_hal_cuda_nccl_channel_destroy,
.query_rank_and_count = iree_hal_cuda_nccl_channel_query_rank_and_count,
};
+
+#else // IREE_HAL_DRIVER_CUDA_NCCL
+
+iree_status_t iree_hal_cuda_nccl_channel_create(
+ iree_hal_cuda_context_wrapper_t* context_wrapper,
+ const iree_hal_cuda_nccl_id_t* id, int rank, int count,
+ iree_hal_channel_t** out_channel) {
+ return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
+ "iree_hal_cuda_nccl_channel_create()");
+}
+
+iree_status_t iree_hal_cuda_nccl_submit_batch(
+ iree_hal_cuda_context_wrapper_t* context,
+ const iree_hal_collective_batch_t* batch, CUstream stream) {
+ return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
+ "iree_hal_cuda_nccl_submit_batch()");
+}
+
+#endif // IREE_HAL_DRIVER_CUDA_NCCL
diff --git a/runtime/src/iree/hal/drivers/cuda/nccl_channel.h b/runtime/src/iree/hal/drivers/cuda/nccl_channel.h
index 51ffd54..8e65e59 100644
--- a/runtime/src/iree/hal/drivers/cuda/nccl_channel.h
+++ b/runtime/src/iree/hal/drivers/cuda/nccl_channel.h
@@ -18,17 +18,11 @@
extern "C" {
#endif // __cplusplus
-// Creates a new NCCL communicator channel.
-typedef struct ncclComm* ncclComm_t;
-
iree_status_t iree_hal_cuda_nccl_channel_create(
iree_hal_cuda_context_wrapper_t* context_wrapper,
const iree_hal_cuda_nccl_id_t* id, int rank, int count,
iree_hal_channel_t** out_channel);
-// Returns the NCCL communicator for the given |channel|, if available.
-ncclComm_t iree_hal_cuda_nccl_channel_comm(iree_hal_channel_t* channel);
-
// Performs a non-blocking submission of |batch| to |stream|.
// The backing storage of |batch| is dropped immediately but all resources
// referenced will be retained by the parent command buffer for its lifetime.
diff --git a/runtime/src/iree/hal/drivers/cuda/registration/driver_module.c b/runtime/src/iree/hal/drivers/cuda/registration/driver_module.c
index f609b6d..ab4164f 100644
--- a/runtime/src/iree/hal/drivers/cuda/registration/driver_module.c
+++ b/runtime/src/iree/hal/drivers/cuda/registration/driver_module.c
@@ -39,6 +39,34 @@
return iree_ok_status();
}
+#if IREE_HAL_DRIVER_CUDA_NCCL
+static iree_status_t iree_hal_cuda_init_nccl_rank_and_count(
+ iree_hal_cuda_device_params_t* params) {
+ char* nprocs_str = getenv("IREE_SPMD_NPROCS");
+ if (!nprocs_str) {
+ params->nccl_default_count = 0;
+ params->nccl_default_rank = 0;
+ return iree_status_from_code(IREE_STATUS_OK);
+ }
+
+ int nprocs = atoi(nprocs_str);
+ if (nprocs <= 0) return iree_status_from_code(IREE_STATUS_OUT_OF_RANGE);
+ params->nccl_default_count = nprocs;
+
+ char* procid_str = getenv("IREE_SPMD_PROCID");
+ if (!procid_str) {
+ // Expected PROCID when NPROCS is set.
+ return iree_status_from_code(IREE_STATUS_INVALID_ARGUMENT);
+ }
+ int procid = atoi(procid_str);
+ if (procid < 0 || procid >= nprocs) {
+ return iree_status_from_code(IREE_STATUS_OUT_OF_RANGE);
+ }
+ params->nccl_default_rank = procid;
+ return iree_status_from_code(IREE_STATUS_OK);
+}
+#endif // IREE_HAL_DRIVER_CUDA_NCCL
+
static iree_status_t iree_hal_cuda_driver_factory_try_create(
void* self, iree_string_view_t driver_name, iree_allocator_t host_allocator,
iree_hal_driver_t** out_driver) {
@@ -58,6 +86,12 @@
IREE_HAL_CUDA_COMMAND_BUFFER_MODE_STREAM;
}
default_params.allow_inline_execution = FLAG_cuda_allow_inline_execution;
+#if IREE_HAL_DRIVER_CUDA_NCCL
+ iree_hal_cuda_init_nccl_rank_and_count(&default_params);
+
+ // Note that nccl_default_id can't be initalized until the driver imports
+ // the NCCL symbols from the dynamic library.
+#endif
iree_hal_cuda_driver_options_t driver_options;
iree_hal_cuda_driver_options_initialize(&driver_options);
diff --git a/runtime/src/iree/hal/drivers/cuda/status_util.c b/runtime/src/iree/hal/drivers/cuda/status_util.c
index 7881d2c..14790c1 100644
--- a/runtime/src/iree/hal/drivers/cuda/status_util.c
+++ b/runtime/src/iree/hal/drivers/cuda/status_util.c
@@ -30,3 +30,18 @@
"CUDA driver error '%s' (%d): %s",
error_name, result, error_string);
}
+
+#if IREE_HAL_DRIVER_CUDA_NCCL
+iree_status_t iree_hal_nccl_result_to_status(
+ iree_hal_cuda_dynamic_symbols_t* syms, ncclResult_t result,
+ const char* file, uint32_t line) {
+ if (IREE_LIKELY(result == ncclSuccess)) {
+ return iree_ok_status();
+ }
+
+ const char* error_string = syms->ncclGetErrorString(result);
+ return iree_make_status_with_location(file, line, IREE_STATUS_INTERNAL,
+ "NCCL error %d: %s", result,
+ error_string);
+}
+#endif // IREE_HAL_DRIVER_CUDA_NCCL
diff --git a/runtime/src/iree/hal/drivers/cuda/status_util.h b/runtime/src/iree/hal/drivers/cuda/status_util.h
index 9cb9d74..0d8409a 100644
--- a/runtime/src/iree/hal/drivers/cuda/status_util.h
+++ b/runtime/src/iree/hal/drivers/cuda/status_util.h
@@ -47,6 +47,40 @@
iree_hal_cuda_dynamic_symbols_t* syms, CUresult result, const char* file,
uint32_t line);
+#if IREE_HAL_DRIVER_CUDA_NCCL
+// Converts a ncclResult_t to an iree_status_t.
+//
+// Usage:
+// iree_status_t status = NCCL_RESULT_TO_STATUS(ncclDoThing(...));
+#define NCCL_RESULT_TO_STATUS(syms, expr, ...) \
+ iree_hal_nccl_result_to_status((syms), ((syms)->expr), __FILE__, __LINE__)
+
+// Converts a ncclResult_t to a Status object.
+iree_status_t iree_hal_nccl_result_to_status(
+ iree_hal_cuda_dynamic_symbols_t* syms, ncclResult_t result,
+ const char* file, uint32_t line);
+
+// IREE_RETURN_IF_ERROR but implicitly converts the ncclResult_t return value to
+// a Status.
+//
+// Usage:
+// NCCL_RETURN_IF_ERROR(ncclDoThing(...), "message");
+#define NCCL_RETURN_IF_ERROR(syms, expr, ...) \
+ IREE_RETURN_IF_ERROR(iree_hal_nccl_result_to_status((syms), ((syms)->expr), \
+ __FILE__, __LINE__), \
+ __VA_ARGS__)
+
+// IREE_IGNORE_ERROR but implicitly converts the ncclResult_t return value to a
+// Status.
+//
+// Usage:
+// NCCL_IGNORE_ERROR(ncclDoThing(...));
+#define NCCL_IGNORE_ERROR(syms, expr) \
+ IREE_IGNORE_ERROR(iree_hal_nccl_result_to_status((syms), ((syms)->expr), \
+ __FILE__, __LINE__))
+
+#endif // IREE_HAL_DRIVER_CUDA_NCCL
+
#ifdef __cplusplus
} // extern "C"
#endif // __cplusplus
diff --git a/runtime/src/iree/modules/hal/module.c b/runtime/src/iree/modules/hal/module.c
index 8386b18..ec8d411 100644
--- a/runtime/src/iree/modules/hal/module.c
+++ b/runtime/src/iree/modules/hal/module.c
@@ -736,7 +736,7 @@
.length = iree_hal_cast_device_size(args->i9),
};
IREE_RETURN_IF_ERROR(
- iree_hal_buffer_check_deref_or_null(args->r7, &send_binding.buffer));
+ iree_hal_buffer_check_deref_or_null(args->r7, &recv_binding.buffer));
iree_device_size_t element_count = iree_hal_cast_device_size(args->i10);
return iree_hal_command_buffer_collective(command_buffer, channel, op, param,
send_binding, recv_binding,