Integrate NCCL (#11481)

This integration enables basic NCCL features in the CUDA runtime. This
enables a minimum test to run. Many more things should be done on top of
this.

To enable a build with NCCL, use `-DIREE_HAL_DRIVER_CUDA_NCCL=ON` for
your cmake command. The same string is used for the macro to guard the C
source code.

Two environmental variables are introduced to set the number of
processes and process ID.
1. `IREE_SPMD_NPROCS`
2. `IREE_SPMD_PROCID`

The GPU ID can be set using `--device=cuda://<index>` or
`--device=cuda://GPU-<uuid>` for `iree-run-module`.

The NCCL dynamic library is loaded when users set `IREE_SPMD_NPROCS` to
>= 1.

There are many things to be done based on this work. We need

1. a full set of E2E tests from stream async ops to the runtime,
2. to decide where to host the modified NCCL source code,
3. to setup a CI test flow and more.

Here is a sample allgather test:

```mlir
func.func @main() -> !hal.buffer_view {
  %c0 = arith.constant 0 : index
  %c2 = arith.constant 2 : index
  %c8 = arith.constant 8 : index
  %c16 = arith.constant 16 : index
  %input_cst = stream.tensor.constant : tensor<2xi32> in !stream.resource<constant> =
    dense<[101, 102]> : tensor<2xi32>
  %input = stream.async.transfer %input_cst : !stream.resource<constant>{%c8} -> !stream.resource<*>{%c8}
  %fill_val = arith.constant -1 : i32
  %output = stream.tensor.splat %fill_val :
    i32 -> tensor<2x2xi32> in !stream.resource<*>{%c16}
  %channel = stream.channel.default on(#hal.affinity.queue<[0]>) : !stream.channel

  %0 = stream.async.collective<all_gather : si32>[%c2]
      on(#hal.affinity.queue<[0]>) channel(%channel)
      %input[%c0 to %c8 for %c8],
      %output[%c0 to %c16 for %c16] :
      !stream.resource<*>{%c8} -> %output as !stream.resource<*>{%c16}
  %1 = stream.async.transfer %0 : !stream.resource<*>{%c16} -> !stream.resource<external>{%c16}
  %result = stream.tensor.export %1 :
    tensor<2x2xi32> in !stream.resource<external>{%c16} -> !hal.buffer_view
  return %result : !hal.buffer_view
}
```

A sample command to build is:
```zsh
iree-compile --iree-hal-cuda-llvm-target-arch=sm_86 --iree-hal-target-backends=cuda -o allgather.vmfb allgather.mlir
```

Here is a sample command line for a host with two CUDA devices and the
result.
```zsh
IREE_SPMD_NPROCS=2 NCCL_COMM_ID=127.0.0.1:8000 IREE_SPMD_PROCID=0 iree-run-module --device=cuda://0 --module_file=allgather.vmfb --entry_function=main & \
IREE_SPMD_NPROCS=2 NCCL_COMM_ID=127.0.0.1:8000 IREE_SPMD_PROCID=1 iree-run-module --device=cuda://1 --module_file=allgather.vmfb --entry_function=main
EXEC @main
EXEC @main
result[0]: hal.buffer_view
2x2xi32=[101 102][101 102]
result[0]: hal.buffer_view
2x2xi32=[101 102][101 102]
```
diff --git a/CMakeLists.txt b/CMakeLists.txt
index bc71b6c..e6cfa37 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -109,7 +109,7 @@
 option(IREE_HAL_DRIVER_DEFAULTS "Sets the default value for all runtime HAL drivers" ON)
 # CUDA support must be explicitly enabled.
 set(IREE_HAL_DRIVER_CUDA_DEFAULT OFF)
-
+set(IREE_HAL_DRVIER_CUDA_NCCL_DEFAULT OFF)
 # Vulkan is not natively supported on Apple platforms.
 # Metal should generally be used instead, though MoltenVK may also work.
 if(APPLE)
@@ -119,6 +119,7 @@
 endif()
 
 option(IREE_HAL_DRIVER_CUDA "Enables the 'cuda' runtime HAL driver" ${IREE_HAL_DRIVER_CUDA_DEFAULT})
+option(IREE_HAL_DRIVER_CUDA_NCCL "Enables the 'nccl' runtime with CUDA" ${IREE_HAL_DRIVER_CUDA_NCCL_DEFAULT})
 option(IREE_HAL_DRIVER_LOCAL_SYNC "Enables the 'local-sync' runtime HAL driver" ${IREE_HAL_DRIVER_DEFAULTS})
 option(IREE_HAL_DRIVER_LOCAL_TASK "Enables the 'local-task' runtime HAL driver" ${IREE_HAL_DRIVER_DEFAULTS})
 option(IREE_HAL_DRIVER_VULKAN "Enables the 'vulkan' runtime HAL driver" ${IREE_HAL_DRIVER_VULKAN_DEFAULT})
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp
index 573c089..7891fe2 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp
@@ -140,7 +140,8 @@
   auto bits = IREE::HAL::CommandCategoryBitfield::None;
   for (auto &block : region) {
     for (auto &op : block) {
-      if (isa<IREE::Stream::CmdDispatchOp>(op)) {
+      if (isa<IREE::Stream::CmdDispatchOp>(op) ||
+          isa<IREE::Stream::CmdCollectiveOp>(op)) {
         bits = bits | IREE::HAL::CommandCategoryBitfield::Dispatch;
       } else {
         bits = bits | IREE::HAL::CommandCategoryBitfield::Transfer;
diff --git a/runtime/src/iree/hal/drivers/cuda/CMakeLists.txt b/runtime/src/iree/hal/drivers/cuda/CMakeLists.txt
index 85b2d2b..1422075 100644
--- a/runtime/src/iree/hal/drivers/cuda/CMakeLists.txt
+++ b/runtime/src/iree/hal/drivers/cuda/CMakeLists.txt
@@ -10,6 +10,12 @@
 
 iree_add_all_subdirs()
 
+if(IREE_HAL_DRIVER_CUDA_NCCL)
+  set(IREE_HAL_DRIVER_CUDA_NCCL_VAL 1)
+else()
+  set(IREE_HAL_DRIVER_CUDA_NCCL_VAL 0)
+endif()
+
 iree_cc_library(
   NAME
     cuda
@@ -59,6 +65,8 @@
     iree::hal::utils::resource_set
     iree::hal::utils::semaphore_base
     iree::schemas::cuda_executable_def_c_fbs
+  DEFINES
+    "IREE_HAL_DRIVER_CUDA_NCCL=${IREE_HAL_DRIVER_CUDA_NCCL_VAL}"
   PUBLIC
 )
 
@@ -78,6 +86,8 @@
     iree::base::core_headers
     iree::base::internal::dynamic_library
     iree::base::tracing
+  DEFINES
+    "IREE_HAL_DRIVER_CUDA_NCCL=${IREE_HAL_DRIVER_CUDA_NCCL_VAL}"
   PUBLIC
 )
 
@@ -94,3 +104,19 @@
   LABELS
     "driver=cuda"
 )
+
+if(IREE_HAL_DRIVER_CUDA_NCCL)
+iree_cc_test(
+  NAME
+    dynamic_symbols_test_for_nccl
+  SRCS
+    "dynamic_symbols_test_for_nccl.cc"
+  DEPS
+    ::dynamic_symbols
+    iree::base
+    iree::testing::gtest
+    iree::testing::gtest_main
+  LABELS
+    "driver=cuda"
+)
+endif()
diff --git a/runtime/src/iree/hal/drivers/cuda/cuda_allocator.c b/runtime/src/iree/hal/drivers/cuda/cuda_allocator.c
index 2bd5506..c4d44b3 100644
--- a/runtime/src/iree/hal/drivers/cuda/cuda_allocator.c
+++ b/runtime/src/iree/hal/drivers/cuda/cuda_allocator.c
@@ -146,8 +146,8 @@
 
   // Buffers can only be used on the queue if they are device visible.
   if (iree_all_bits_set(params->type, IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE)) {
-    if (iree_all_bits_set(params->usage,
-                          IREE_HAL_BUFFER_USAGE_DISPATCH_STORAGE)) {
+    if (iree_any_bit_set(params->usage,
+                         IREE_HAL_BUFFER_USAGE_DISPATCH_STORAGE)) {
       compatibility |= IREE_HAL_BUFFER_COMPATIBILITY_QUEUE_DISPATCH;
     }
   }
diff --git a/runtime/src/iree/hal/drivers/cuda/cuda_device.c b/runtime/src/iree/hal/drivers/cuda/cuda_device.c
index 19a9fc8..98a9839 100644
--- a/runtime/src/iree/hal/drivers/cuda/cuda_device.c
+++ b/runtime/src/iree/hal/drivers/cuda/cuda_device.c
@@ -250,16 +250,6 @@
     iree_hal_channel_params_t params, iree_hal_channel_t** out_channel) {
   iree_hal_cuda_device_t* device = iree_hal_cuda_device_cast(base_device);
 
-  // TODO(#9580): check if nccl symbols are available - if not then we fail
-  // here and have the error propagated up to users. If we wanted to delay load
-  // NCCL we'd want to take a lock here, load it, and merge the symbols into the
-  // dynamic symbol table.
-  if (true) {
-    return iree_make_status(
-        IREE_STATUS_UNIMPLEMENTED,
-        "NCCL unavailable and collective operations cannot be performed");
-  }
-
   // Try to use the ID specified in the parameters and fall back to the default.
   iree_hal_cuda_nccl_id_t id;
   if (iree_const_byte_span_is_empty(params.id)) {
diff --git a/runtime/src/iree/hal/drivers/cuda/cuda_driver.c b/runtime/src/iree/hal/drivers/cuda/cuda_driver.c
index 4d113d6..21ab104 100644
--- a/runtime/src/iree/hal/drivers/cuda/cuda_driver.c
+++ b/runtime/src/iree/hal/drivers/cuda/cuda_driver.c
@@ -4,6 +4,10 @@
 // See https://llvm.org/LICENSE.txt for license information.
 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
+#include <iree/base/status.h>
+#if IREE_HAL_DRIVER_CUDA_NCCL
+#include <nccl.h>
+#endif
 #include <stdint.h>
 #include <string.h>
 
@@ -48,6 +52,30 @@
   out_options->default_device_index = 0;
 }
 
+#if IREE_HAL_DRIVER_CUDA_NCCL
+
+static iree_status_t iree_hal_nccl_get_unique_id_from_env(
+    iree_hal_cuda_driver_t* driver) {
+  IREE_TRACE_ZONE_BEGIN(z0);
+
+  char* nccl_comm_id_str = getenv("NCCL_COMM_ID");
+  if (!nccl_comm_id_str) {
+    IREE_TRACE_ZONE_END(z0);
+    return iree_make_status(IREE_STATUS_INTERNAL);
+  }
+
+  IREE_RETURN_AND_END_ZONE_IF_ERROR(
+      z0, NCCL_RESULT_TO_STATUS(
+              &driver->syms,
+              ncclGetUniqueId(
+                  (ncclUniqueId*)&driver->default_params.nccl_default_id),
+              "ncclGetUniqueId"));
+  IREE_TRACE_ZONE_END(z0);
+  return iree_ok_status();
+}
+
+#endif  // IREE_HAL_DRIVER_CUDA_NCCL
+
 static iree_status_t iree_hal_cuda_driver_create_internal(
     iree_string_view_t identifier,
     const iree_hal_cuda_device_params_t* default_params,
@@ -69,11 +97,23 @@
 
   iree_status_t status =
       iree_hal_cuda_dynamic_symbols_initialize(host_allocator, &driver->syms);
-  if (iree_status_is_ok(status)) {
-    *out_driver = (iree_hal_driver_t*)driver;
-  } else {
+  if (!iree_status_is_ok(status)) {
     iree_hal_driver_release((iree_hal_driver_t*)driver);
+    return status;
   }
+
+#if IREE_HAL_DRIVER_CUDA_NCCL
+  // Initialize NCCL if NPROCS is set.
+  if (driver->default_params.nccl_default_count > 0) {
+    // get a unique ID from the environmental variable
+    status = iree_hal_nccl_get_unique_id_from_env(driver);
+    if (!iree_status_is_ok(status)) {
+      iree_hal_driver_release((iree_hal_driver_t*)driver);
+      return status;
+    }
+  }
+#endif
+  *out_driver = (iree_hal_driver_t*)driver;
   return status;
 }
 
@@ -162,7 +202,7 @@
   return iree_ok_status();
 }
 
-// Return true if the device support all the extension required.
+// Return true if the device supports all the extension required.
 static bool iree_hal_cuda_is_valid_device(iree_hal_cuda_driver_t* driver,
                                           CUdevice device) {
   return true;
diff --git a/runtime/src/iree/hal/drivers/cuda/cuda_headers.h b/runtime/src/iree/hal/drivers/cuda/cuda_headers.h
index e2fd2b9..30633a3 100644
--- a/runtime/src/iree/hal/drivers/cuda/cuda_headers.h
+++ b/runtime/src/iree/hal/drivers/cuda/cuda_headers.h
@@ -8,5 +8,7 @@
 #define IREE_HAL_DRIVERS_CUDA_CUDA_HEADERS_H_
 
 #include "cuda.h"  // IWYU pragma: export
-
-#endif  // IREE_HAL_DRIVERS_CUDA_CUDA_HEADERS_H_
+#if IREE_HAL_DRIVER_CUDA_NCCL
+#include "nccl.h"  // IWYU pragma: export
+#endif             // IREE_HAL_DRIVER_CUDA_NCCL
+#endif             // IREE_HAL_DRIVERS_CUDA_CUDA_HEADERS_H_
diff --git a/runtime/src/iree/hal/drivers/cuda/dynamic_symbol_tables.h b/runtime/src/iree/hal/drivers/cuda/dynamic_symbol_tables.h
index 11c4f8b..2004d2f 100644
--- a/runtime/src/iree/hal/drivers/cuda/dynamic_symbol_tables.h
+++ b/runtime/src/iree/hal/drivers/cuda/dynamic_symbol_tables.h
@@ -57,3 +57,42 @@
 CU_PFN_DECL(cuLaunchKernel, CUfunction, unsigned int, unsigned int,
             unsigned int, unsigned int, unsigned int, unsigned int,
             unsigned int, CUstream, void**, void**)
+
+// NCCL
+
+NCCL_PFN_DECL(ncclGetVersion, int *)
+NCCL_PFN_DECL(ncclGetUniqueId, ncclUniqueId *)
+NCCL_PFN_DECL(ncclCommInitRankConfig, ncclComm_t *, int, ncclUniqueId, int,
+              ncclConfig_t *)
+NCCL_PFN_DECL(ncclCommInitRank, ncclComm_t *, int, ncclUniqueId, int)
+NCCL_PFN_DECL(ncclCommInitAll, ncclComm_t *, int, const int *)
+NCCL_PFN_DECL(ncclCommFinalize, ncclComm_t)
+NCCL_PFN_DECL(ncclCommDestroy, ncclComm_t)
+NCCL_PFN_DECL(ncclCommAbort, ncclComm_t)
+NCCL_PFN_DECL_STR_RETURN(ncclGetErrorString, ncclResult_t)
+NCCL_PFN_DECL_STR_RETURN(ncclGetLastError, ncclComm_t)
+NCCL_PFN_DECL(ncclCommGetAsyncError, ncclComm_t, ncclResult_t *)
+NCCL_PFN_DECL(ncclCommCount, const ncclComm_t, int *)
+NCCL_PFN_DECL(ncclCommCuDevice, const ncclComm_t, int *)
+NCCL_PFN_DECL(ncclCommUserRank, const ncclComm_t, int *)
+NCCL_PFN_DECL(ncclRedOpCreatePreMulSum, ncclRedOp_t *, void *, ncclDataType_t,
+              ncclScalarResidence_t, ncclComm_t)
+NCCL_PFN_DECL(ncclRedOpDestroy, ncclRedOp_t, ncclComm_t)
+NCCL_PFN_DECL(ncclReduce, const void *, void *, size_t, ncclDataType_t,
+              ncclRedOp_t, int, ncclComm_t, cudaStream_t)
+NCCL_PFN_DECL(ncclBcast, void *, size_t, ncclDataType_t, int, ncclComm_t,
+              cudaStream_t)
+NCCL_PFN_DECL(ncclBroadcast, const void *, void *, size_t, ncclDataType_t, int,
+              ncclComm_t, cudaStream_t)
+NCCL_PFN_DECL(ncclAllReduce, const void *, void *, size_t, ncclDataType_t,
+              ncclRedOp_t, ncclComm_t, cudaStream_t)
+NCCL_PFN_DECL(ncclReduceScatter, const void *, void *, size_t, ncclDataType_t,
+              ncclRedOp_t, ncclComm_t, cudaStream_t)
+NCCL_PFN_DECL(ncclAllGather, const void *, void *, size_t, ncclDataType_t,
+              ncclComm_t, cudaStream_t)
+NCCL_PFN_DECL(ncclSend, const void *, size_t, ncclDataType_t, int, ncclComm_t,
+              cudaStream_t)
+NCCL_PFN_DECL(ncclRecv, void *, size_t, ncclDataType_t, int, ncclComm_t,
+              cudaStream_t)
+NCCL_PFN_DECL(ncclGroupStart)
+NCCL_PFN_DECL(ncclGroupEnd)
diff --git a/runtime/src/iree/hal/drivers/cuda/dynamic_symbols.c b/runtime/src/iree/hal/drivers/cuda/dynamic_symbols.c
index 9715a46..f199dcf 100644
--- a/runtime/src/iree/hal/drivers/cuda/dynamic_symbols.c
+++ b/runtime/src/iree/hal/drivers/cuda/dynamic_symbols.c
@@ -20,23 +20,52 @@
 #endif
 };
 
+#if IREE_HAL_DRIVER_CUDA_NCCL
+static const char* kNCCLLoaderSearchNames[] = {
+#if defined(IREE_PLATFORM_WINDOWS)
+    "nccl.dll",
+#else
+    "libnccl.so",
+#endif
+};
+#endif  // IREE_HAL_DRIVER_CUDA_NCCL
+
 #define concat(A, B) A B
 
 // Load CUDA entry points, prefer _v2 version if it exists.
 static iree_status_t iree_hal_cuda_dynamic_symbols_resolve_all(
     iree_hal_cuda_dynamic_symbols_t* syms) {
-#define CU_PFN_DECL(cudaSymbolName, ...)                                       \
-  {                                                                            \
-    static const char* kName = #cudaSymbolName;                                \
-    IREE_RETURN_IF_ERROR(iree_dynamic_library_lookup_symbol(                   \
-        syms->loader_library, kName, (void**)&syms->cudaSymbolName));          \
-    static const char* kNameV2 = concat(#cudaSymbolName, "_v2");               \
-    void* funV2;                                                               \
-    iree_dynamic_library_lookup_symbol(syms->loader_library, kNameV2, &funV2); \
-    if (funV2) syms->cudaSymbolName = funV2;                                   \
+#define CU_PFN_DECL(cudaSymbolName, ...)                                     \
+  {                                                                          \
+    static const char* kName = #cudaSymbolName;                              \
+    IREE_RETURN_IF_ERROR(iree_dynamic_library_lookup_symbol(                 \
+        syms->cuda_library, kName, (void**)&syms->cudaSymbolName));          \
+    static const char* kNameV2 = concat(#cudaSymbolName, "_v2");             \
+    void* funV2;                                                             \
+    iree_dynamic_library_lookup_symbol(syms->cuda_library, kNameV2, &funV2); \
+    if (funV2) syms->cudaSymbolName = funV2;                                 \
   }
+#if IREE_HAL_DRIVER_CUDA_NCCL
+#define NCCL_PFN_DECL(ncclSymbolName, ...)                          \
+  {                                                                 \
+    static const char* kName = #ncclSymbolName;                     \
+    IREE_RETURN_IF_ERROR(iree_dynamic_library_lookup_symbol(        \
+        syms->nccl_library, kName, (void**)&syms->ncclSymbolName)); \
+  }
+#define NCCL_PFN_DECL_STR_RETURN(ncclSymbolName, ...)               \
+  {                                                                 \
+    static const char* kName = #ncclSymbolName;                     \
+    IREE_RETURN_IF_ERROR(iree_dynamic_library_lookup_symbol(        \
+        syms->nccl_library, kName, (void**)&syms->ncclSymbolName)); \
+  }
+#else
+#define NCCL_PFN_DECL(ncclSymbolName, ...)
+#define NCCL_PFN_DECL_STR_RETURN(ncclSymbolName, ...)
+#endif
 #include "iree/hal/drivers/cuda/dynamic_symbol_tables.h"  // IWYU pragma: keep
 #undef CU_PFN_DECL
+#undef NCCL_PFN_DECL
+#undef NCCL_PFN_DECL_STR_RETURN
   return iree_ok_status();
 }
 
@@ -47,14 +76,24 @@
   memset(out_syms, 0, sizeof(*out_syms));
   iree_status_t status = iree_dynamic_library_load_from_files(
       IREE_ARRAYSIZE(kCUDALoaderSearchNames), kCUDALoaderSearchNames,
-      IREE_DYNAMIC_LIBRARY_FLAG_NONE, host_allocator,
-      &out_syms->loader_library);
+      IREE_DYNAMIC_LIBRARY_FLAG_NONE, host_allocator, &out_syms->cuda_library);
   if (iree_status_is_not_found(status)) {
     iree_status_ignore(status);
     return iree_make_status(
         IREE_STATUS_UNAVAILABLE,
         "CUDA runtime library not available; ensure installed and on path");
   }
+#if IREE_HAL_DRIVER_CUDA_NCCL
+  status = iree_dynamic_library_load_from_files(
+      IREE_ARRAYSIZE(kNCCLLoaderSearchNames), kNCCLLoaderSearchNames,
+      IREE_DYNAMIC_LIBRARY_FLAG_NONE, host_allocator, &out_syms->nccl_library);
+  if (iree_status_is_not_found(status)) {
+    iree_status_ignore(status);
+    return iree_make_status(
+        IREE_STATUS_UNAVAILABLE,
+        "NCCL runtime library not available; ensure installed and on path");
+  }
+#endif  // IREE_HAL_DRIVER_CUDA_NCCL
   if (iree_status_is_ok(status)) {
     status = iree_hal_cuda_dynamic_symbols_resolve_all(out_syms);
   }
@@ -68,7 +107,10 @@
 void iree_hal_cuda_dynamic_symbols_deinitialize(
     iree_hal_cuda_dynamic_symbols_t* syms) {
   IREE_TRACE_ZONE_BEGIN(z0);
-  iree_dynamic_library_release(syms->loader_library);
+  iree_dynamic_library_release(syms->cuda_library);
+#if IREE_HAL_DRIVER_CUDA_NCCL
+  iree_dynamic_library_release(syms->nccl_library);
+#endif
   memset(syms, 0, sizeof(*syms));
   IREE_TRACE_ZONE_END(z0);
 }
diff --git a/runtime/src/iree/hal/drivers/cuda/dynamic_symbols.h b/runtime/src/iree/hal/drivers/cuda/dynamic_symbols.h
index 44750f7..ef524ff 100644
--- a/runtime/src/iree/hal/drivers/cuda/dynamic_symbols.h
+++ b/runtime/src/iree/hal/drivers/cuda/dynamic_symbols.h
@@ -7,6 +7,9 @@
 #ifndef IREE_HAL_DRIVERS_CUDA_DYNAMIC_SYMBOLS_H_
 #define IREE_HAL_DRIVERS_CUDA_DYNAMIC_SYMBOLS_H_
 
+#if IREE_HAL_DRIVER_CUDA_NCCL
+#include <nccl.h>
+#endif
 #include "iree/base/api.h"
 #include "iree/base/internal/dynamic_library.h"
 #include "iree/hal/drivers/cuda/cuda_headers.h"
@@ -15,17 +18,29 @@
 extern "C" {
 #endif  // __cplusplus
 
-// DynamicSymbols allow loading dynamically a subset of CUDA driver API. It
-// loads all the function declared in `dynamic_symbol_tables.def` and fail if
-// any of the symbol is not available. The functions signatures are matching
-// the declarations in `cuda.h`.
+// DynamicSymbols allow loading dynamically a subset of CUDA driver and NCCL
+// API. It loads all the function declared in `dynamic_symbol_tables.h` and fail
+// if any of the symbol is not available. The functions signatures are matching
+// the declarations in `cuda.h` and `nccl.h`.
 typedef struct iree_hal_cuda_dynamic_symbols_t {
-  iree_dynamic_library_t* loader_library;
+  iree_dynamic_library_t* cuda_library;
+  iree_dynamic_library_t* nccl_library;
 
 #define CU_PFN_DECL(cudaSymbolName, ...) \
   CUresult (*cudaSymbolName)(__VA_ARGS__);
+#if IREE_HAL_DRIVER_CUDA_NCCL
+#define NCCL_PFN_DECL(ncclSymbolName, ...) \
+  ncclResult_t (*ncclSymbolName)(__VA_ARGS__);
+#define NCCL_PFN_DECL_STR_RETURN(ncclSymbolName, ...) \
+  const char* (*ncclSymbolName)(__VA_ARGS__);
+#else
+#define NCCL_PFN_DECL(ncclSymbolName, ...)
+#define NCCL_PFN_DECL_STR_RETURN(ncclSymbolName, ...)
+#endif
 #include "iree/hal/drivers/cuda/dynamic_symbol_tables.h"  // IWYU pragma: export
 #undef CU_PFN_DECL
+#undef NCCL_PFN_DECL
+#undef NCCL_PFN_DECL_STR_RETURN
 } iree_hal_cuda_dynamic_symbols_t;
 
 // Initializes |out_syms| in-place with dynamically loaded CUDA symbols.
diff --git a/runtime/src/iree/hal/drivers/cuda/dynamic_symbols_test.cc b/runtime/src/iree/hal/drivers/cuda/dynamic_symbols_test.cc
index 9c001de..e3b366d 100644
--- a/runtime/src/iree/hal/drivers/cuda/dynamic_symbols_test.cc
+++ b/runtime/src/iree/hal/drivers/cuda/dynamic_symbols_test.cc
@@ -16,7 +16,7 @@
 namespace cuda {
 namespace {
 
-#define CUDE_CHECK_ERRORS(expr)      \
+#define CUDA_CHECK_ERRORS(expr)      \
   {                                  \
     CUresult status = expr;          \
     ASSERT_EQ(CUDA_SUCCESS, status); \
@@ -34,11 +34,11 @@
   }
 
   int device_count = 0;
-  CUDE_CHECK_ERRORS(symbols.cuInit(0));
-  CUDE_CHECK_ERRORS(symbols.cuDeviceGetCount(&device_count));
+  CUDA_CHECK_ERRORS(symbols.cuInit(0));
+  CUDA_CHECK_ERRORS(symbols.cuDeviceGetCount(&device_count));
   if (device_count > 0) {
     CUdevice device;
-    CUDE_CHECK_ERRORS(symbols.cuDeviceGet(&device, /*ordinal=*/0));
+    CUDA_CHECK_ERRORS(symbols.cuDeviceGet(&device, /*ordinal=*/0));
   }
 
   iree_hal_cuda_dynamic_symbols_deinitialize(&symbols);
diff --git a/runtime/src/iree/hal/drivers/cuda/dynamic_symbols_test_for_nccl.cc b/runtime/src/iree/hal/drivers/cuda/dynamic_symbols_test_for_nccl.cc
new file mode 100644
index 0000000..56be1ab
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/cuda/dynamic_symbols_test_for_nccl.cc
@@ -0,0 +1,46 @@
+// Copyright 2022 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include <nccl.h>
+
+#include <iostream>
+
+#include "iree/base/api.h"
+#include "iree/hal/drivers/cuda/dynamic_symbols.h"
+#include "iree/testing/gtest.h"
+
+namespace iree {
+namespace hal {
+namespace cuda {
+namespace {
+
+#define NCCL_CHECK_ERRORS(expr)     \
+  {                                 \
+    ncclResult_t status = expr;     \
+    ASSERT_EQ(ncclSuccess, status); \
+  }
+
+TEST(DynamicSymbolsTest, CreateFromSystemLoader) {
+  iree_hal_cuda_dynamic_symbols_t symbols;
+  iree_status_t status = iree_hal_cuda_dynamic_symbols_initialize(
+      iree_allocator_system(), &symbols);
+  if (!iree_status_is_ok(status)) {
+    iree_status_fprint(stderr, status);
+    iree_status_ignore(status);
+    std::cerr << "Symbols cannot be loaded, skipping test.";
+    GTEST_FAIL();
+  }
+
+  int nccl_version = 0;
+  NCCL_CHECK_ERRORS(symbols.ncclGetVersion(&nccl_version));
+  ASSERT_EQ(NCCL_VERSION_CODE, nccl_version);
+  iree_hal_cuda_dynamic_symbols_deinitialize(&symbols);
+}
+
+}  // namespace
+}  // namespace cuda
+}  // namespace hal
+}  // namespace iree
diff --git a/runtime/src/iree/hal/drivers/cuda/nccl_channel.c b/runtime/src/iree/hal/drivers/cuda/nccl_channel.c
index 186064e..3255f32 100644
--- a/runtime/src/iree/hal/drivers/cuda/nccl_channel.c
+++ b/runtime/src/iree/hal/drivers/cuda/nccl_channel.c
@@ -6,10 +6,21 @@
 
 #include "iree/hal/drivers/cuda/nccl_channel.h"
 
+#include <iree/base/config.h>
+#include <iree/base/status.h>
+#include <iree/hal/command_buffer.h>
+#include <iree/hal/utils/collective_batch.h>
+#if IREE_HAL_DRIVER_CUDA_NCCL
+#include <nccl.h>
+#endif
 #include <stddef.h>
 
 #include "iree/base/api.h"
 #include "iree/base/tracing.h"
+#include "iree/hal/drivers/cuda/cuda_buffer.h"
+#include "iree/hal/drivers/cuda/status_util.h"
+
+#if IREE_HAL_DRIVER_CUDA_NCCL
 
 // Returns the same value as NCCL's init.cc hashUniqueId.
 // These magic constants were chosen by their implementation and unlikely to
@@ -67,24 +78,21 @@
   IREE_TRACE_ZONE_APPEND_VALUE(z0, rank);
   IREE_TRACE_ZONE_APPEND_VALUE(z0, count);
 
-  // TODO(#9580): actually use nccl to create a communicator.
-  // Something like:
-  //  ncclConfig_t config = NCCL_CONFIG_INITIALIZER;
-  //  config.blocking = 0;
-  //  syms->ncclCommInitRankConfig(&comm, count, *id, rank, &config);
-  // NOTE: CHECK ERRORS! we can safely return here as we haven't allocated the
-  // channel wrapper yet.
   ncclComm_t comm = NULL;
-  if (!comm) {
+  ncclConfig_t config = NCCL_CONFIG_INITIALIZER;
+  config.blocking = 1;  // FIXME: use async to check a timeout
+  iree_status_t status = NCCL_RESULT_TO_STATUS(
+      context_wrapper->syms,
+      ncclCommInitRankConfig(&comm, count, *((const ncclUniqueId*)id), rank,
+                             &config));
+  if (!iree_status_is_ok(status)) {
     IREE_TRACE_ZONE_END(z0);
-    return iree_make_status(
-        IREE_STATUS_INTERNAL,
-        "failed to create NCCL communicator for rank=%d count=%d", rank, count);
+    return status;
   }
 
   iree_hal_cuda_nccl_channel_t* channel = NULL;
-  iree_status_t status = iree_allocator_malloc(
-      context_wrapper->host_allocator, sizeof(*channel), (void**)&channel);
+  status = iree_allocator_malloc(context_wrapper->host_allocator,
+                                 sizeof(*channel), (void**)&channel);
   if (iree_status_is_ok(status)) {
     iree_hal_resource_initialize(&iree_hal_cuda_nccl_channel_vtable,
                                  &channel->resource);
@@ -110,7 +118,7 @@
   IREE_TRACE_ZONE_APPEND_VALUE(z0, channel->rank);
   IREE_TRACE_ZONE_APPEND_VALUE(z0, channel->count);
 
-  // TODO(#9580): tear down nccl - blocking if needed.
+  // TODO(#9580): support async tear down
   // We could be smarter about starting finalization of all channels async and
   // then waiting for them to complete but we aren't currently optimizing for
   // lifetime performance. To do that we'd probably want to track each open
@@ -122,9 +130,11 @@
   //  syms->ncclCommDestroy(channel->comm)
   // Should work the same (as we are doing a blocking teardown):
   //  syms->ncclCommDestroy(channel->comm)
-
+  NCCL_IGNORE_ERROR(channel->context_wrapper->syms,
+                    ncclCommFinalize(channel->comm));
+  NCCL_IGNORE_ERROR(channel->context_wrapper->syms,
+                    ncclCommDestroy(channel->comm));
   iree_allocator_free(host_allocator, channel);
-
   IREE_TRACE_ZONE_END(z0);
 }
 
@@ -140,36 +150,252 @@
   *out_count = channel->count;
 }
 
-ncclComm_t iree_hal_cuda_nccl_channel_comm(iree_hal_channel_t* base_channel) {
+// Returns the NCCL communicator for the given |channel|, if available.
+static ncclComm_t iree_hal_cuda_nccl_channel_comm(
+    iree_hal_channel_t* base_channel) {
   IREE_ASSERT_ARGUMENT(base_channel);
   iree_hal_cuda_nccl_channel_t* channel =
       iree_hal_cuda_nccl_channel_cast(base_channel);
   return channel->comm;
 }
 
+static iree_status_t get_nccl_data_type(iree_hal_collective_element_type_t in,
+                                        ncclDataType_t* out) {
+  switch (in) {
+    case IREE_HAL_COLLECTIVE_ELEMENT_TYPE_SINT_8:
+      *out = ncclInt8;
+      break;
+    case IREE_HAL_COLLECTIVE_ELEMENT_TYPE_UINT_8:
+      *out = ncclUint8;
+      break;
+    case IREE_HAL_COLLECTIVE_ELEMENT_TYPE_SINT_16:
+      return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
+                              "SINT16 is not supported for collective op");
+    case IREE_HAL_COLLECTIVE_ELEMENT_TYPE_UINT_16:
+      return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
+                              "UINT16 is not supported for collective op");
+    case IREE_HAL_COLLECTIVE_ELEMENT_TYPE_SINT_32:
+      *out = ncclInt32;
+      break;
+    case IREE_HAL_COLLECTIVE_ELEMENT_TYPE_UINT_32:
+      *out = ncclUint32;
+      break;
+    case IREE_HAL_COLLECTIVE_ELEMENT_TYPE_SINT_64:
+      *out = ncclInt64;
+      break;
+    case IREE_HAL_COLLECTIVE_ELEMENT_TYPE_UINT_64:
+      *out = ncclUint64;
+      break;
+    case IREE_HAL_COLLECTIVE_ELEMENT_TYPE_FLOAT_16:
+      *out = ncclFloat16;
+      break;
+    case IREE_HAL_COLLECTIVE_ELEMENT_TYPE_FLOAT_32:
+      *out = ncclFloat32;
+      break;
+    case IREE_HAL_COLLECTIVE_ELEMENT_TYPE_FLOAT_64:
+      *out = ncclFloat64;
+      break;
+    case IREE_HAL_COLLECTIVE_ELEMENT_TYPE_BFLOAT_16:
+      *out = ncclFloat64;
+      break;
+  }
+  return iree_ok_status();
+}
+
+static iree_status_t get_nccl_red_type(iree_hal_collective_reduction_t in,
+                                       ncclRedOp_t* out) {
+  switch (in) {
+    case IREE_HAL_COLLECTIVE_REDUCTION_SUM:
+      *out = ncclSum;
+      break;
+    case IREE_HAL_COLLECTIVE_REDUCTION_PRODUCT:
+      *out = ncclProd;
+      break;
+    case IREE_HAL_COLLECTIVE_REDUCTION_MINIMUM:
+      *out = ncclMin;
+      break;
+    case IREE_HAL_COLLECTIVE_REDUCTION_MAXIMUM:
+      *out = ncclMax;
+      break;
+    case IREE_HAL_COLLECTIVE_REDUCTION_AVERAGE:
+      *out = ncclAvg;
+      break;
+  }
+
+  return iree_ok_status();
+}
+
+static iree_status_t iree_hal_cuda_nccl_submit_batch_entry(
+    const iree_hal_collective_batch_entry_t* entry, CUstream stream) {
+  IREE_ASSERT_ARGUMENT(entry);
+  IREE_ASSERT_ARGUMENT(stream);
+
+  iree_hal_cuda_nccl_channel_t* channel =
+      iree_hal_cuda_nccl_channel_cast(entry->channel);
+  iree_hal_cuda_dynamic_symbols_t* syms = channel->context_wrapper->syms;
+  ncclComm_t comm = iree_hal_cuda_nccl_channel_comm(entry->channel);
+  ncclDataType_t datatype;
+  IREE_RETURN_IF_ERROR(get_nccl_data_type(entry->op.element_type, &datatype));
+
+  switch (entry->op.kind) {
+    case IREE_HAL_COLLECTIVE_KIND_ALL_GATHER: {
+      CUdeviceptr sendbuff =
+          iree_hal_cuda_buffer_device_pointer(
+              iree_hal_buffer_allocated_buffer(entry->send_binding.buffer)) +
+          iree_hal_buffer_byte_offset(entry->send_binding.buffer) +
+          entry->send_binding.offset;
+      CUdeviceptr recvbuff =
+          iree_hal_cuda_buffer_device_pointer(
+              iree_hal_buffer_allocated_buffer(entry->recv_binding.buffer)) +
+          iree_hal_buffer_byte_offset(entry->recv_binding.buffer) +
+          entry->recv_binding.offset;
+      NCCL_RETURN_IF_ERROR(
+          syms,
+          ncclAllGather((const void*)sendbuff, (void*)recvbuff,
+                        entry->element_count, datatype, comm, stream),
+          "ncclAllGather");
+      break;
+    }
+    case IREE_HAL_COLLECTIVE_KIND_ALL_REDUCE: {
+      CUdeviceptr sendbuff =
+          iree_hal_cuda_buffer_device_pointer(
+              iree_hal_buffer_allocated_buffer(entry->send_binding.buffer)) +
+          iree_hal_buffer_byte_offset(entry->send_binding.buffer) +
+          entry->send_binding.offset;
+      CUdeviceptr recvbuff =
+          iree_hal_cuda_buffer_device_pointer(
+              iree_hal_buffer_allocated_buffer(entry->recv_binding.buffer)) +
+          iree_hal_buffer_byte_offset(entry->recv_binding.buffer) +
+          entry->recv_binding.offset;
+      ncclRedOp_t redop;
+      IREE_RETURN_IF_ERROR(get_nccl_red_type(entry->op.reduction, &redop));
+      NCCL_RETURN_IF_ERROR(
+          syms,
+          ncclAllReduce((const void*)sendbuff, (void*)recvbuff,
+                        entry->element_count, datatype, redop, comm, stream),
+          "ncclAllReduce");
+      break;
+    }
+    case IREE_HAL_COLLECTIVE_KIND_BROADCAST: {
+      CUdeviceptr sendbuff =
+          iree_hal_cuda_buffer_device_pointer(
+              iree_hal_buffer_allocated_buffer(entry->send_binding.buffer)) +
+          iree_hal_buffer_byte_offset(entry->send_binding.buffer) +
+          entry->send_binding.offset;
+      CUdeviceptr recvbuff =
+          iree_hal_cuda_buffer_device_pointer(
+              iree_hal_buffer_allocated_buffer(entry->recv_binding.buffer)) +
+          iree_hal_buffer_byte_offset(entry->recv_binding.buffer) +
+          entry->recv_binding.offset;
+      NCCL_RETURN_IF_ERROR(syms,
+                           ncclBroadcast((const void*)sendbuff, (void*)recvbuff,
+                                         entry->element_count, datatype,
+                                         entry->param, comm, stream),
+                           "ncclBroadcast");
+      break;
+    }
+    case IREE_HAL_COLLECTIVE_KIND_REDUCE: {
+      CUdeviceptr sendbuff =
+          iree_hal_cuda_buffer_device_pointer(
+              iree_hal_buffer_allocated_buffer(entry->send_binding.buffer)) +
+          iree_hal_buffer_byte_offset(entry->send_binding.buffer) +
+          entry->send_binding.offset;
+      CUdeviceptr recvbuff =
+          iree_hal_cuda_buffer_device_pointer(
+              iree_hal_buffer_allocated_buffer(entry->recv_binding.buffer)) +
+          iree_hal_buffer_byte_offset(entry->recv_binding.buffer) +
+          entry->recv_binding.offset;
+      ncclRedOp_t redop;
+      IREE_RETURN_IF_ERROR(get_nccl_red_type(entry->op.reduction, &redop));
+      NCCL_RETURN_IF_ERROR(syms,
+                           ncclReduce((const void*)sendbuff, (void*)recvbuff,
+                                      entry->element_count, datatype, redop,
+                                      entry->param, comm, stream),
+                           "ncclReduce");
+      break;
+    }
+    case IREE_HAL_COLLECTIVE_KIND_REDUCE_SCATTER: {
+      CUdeviceptr sendbuff =
+          iree_hal_cuda_buffer_device_pointer(
+              iree_hal_buffer_allocated_buffer(entry->send_binding.buffer)) +
+          iree_hal_buffer_byte_offset(entry->send_binding.buffer) +
+          entry->send_binding.offset;
+      CUdeviceptr recvbuff =
+          iree_hal_cuda_buffer_device_pointer(
+              iree_hal_buffer_allocated_buffer(entry->recv_binding.buffer)) +
+          iree_hal_buffer_byte_offset(entry->recv_binding.buffer) +
+          entry->recv_binding.offset;
+      ncclRedOp_t redop;
+      IREE_RETURN_IF_ERROR(get_nccl_red_type(entry->op.reduction, &redop));
+      NCCL_RETURN_IF_ERROR(
+          syms,
+          ncclReduceScatter((const void*)sendbuff, (void*)recvbuff,
+                            entry->element_count, datatype, redop, comm,
+                            stream),
+          "ncclReduceScatter");
+      break;
+    }
+    case IREE_HAL_COLLECTIVE_KIND_SEND: {
+      CUdeviceptr sendbuff =
+          iree_hal_cuda_buffer_device_pointer(
+              iree_hal_buffer_allocated_buffer(entry->send_binding.buffer)) +
+          iree_hal_buffer_byte_offset(entry->send_binding.buffer) +
+          entry->send_binding.offset;
+      NCCL_RETURN_IF_ERROR(syms,
+                           ncclSend((const void*)sendbuff, entry->element_count,
+                                    datatype, entry->param, comm, stream),
+                           "ncclSend");
+      break;
+    }
+    case IREE_HAL_COLLECTIVE_KIND_RECV: {
+      CUdeviceptr recvbuff =
+          iree_hal_cuda_buffer_device_pointer(
+              iree_hal_buffer_allocated_buffer(entry->recv_binding.buffer)) +
+          iree_hal_buffer_byte_offset(entry->recv_binding.buffer) +
+          entry->recv_binding.offset;
+      NCCL_RETURN_IF_ERROR(syms,
+                           ncclRecv((void*)recvbuff, entry->element_count,
+                                    datatype, entry->param, comm, stream),
+                           "ncclRecv");
+      break;
+    }
+  }  // switch
+  return iree_ok_status();
+}
+
 iree_status_t iree_hal_cuda_nccl_submit_batch(
     iree_hal_cuda_context_wrapper_t* context,
     const iree_hal_collective_batch_t* batch, CUstream stream) {
   IREE_ASSERT_ARGUMENT(context);
   IREE_ASSERT_ARGUMENT(batch);
   IREE_ASSERT_ARGUMENT(stream);
-
-  // TODO(#9580): issue the operations in the batch. Note that the channel may
-  // change between ops and the communicator should be retrieved from each.
-  //
-  // Something like:
-  //  make context->cu_context active (for when using multiple devices)
-  //  syms->ncclGroupStart();
-  //  for each entry in batch:
-  //    ncclComm_t comm = iree_hal_cuda_nccl_channel_comm(entry->channel);
-  //    syms->nccl*(comm, ...);
-  //  syms->ncclGroupEnd();
-
-  return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
-                          "NCCL submission not yet implemented");
+  NCCL_RETURN_IF_ERROR(context->syms, ncclGroupStart(), "ncclGroupStart");
+  for (IREE_HOST_SIZE_T i = 0; i < batch->count; ++i) {
+    iree_hal_cuda_nccl_submit_batch_entry(&batch->entries[i], stream);
+  }
+  return NCCL_RESULT_TO_STATUS(context->syms, ncclGroupEnd(), "ncclGroupEnd");
 }
 
 static const iree_hal_channel_vtable_t iree_hal_cuda_nccl_channel_vtable = {
     .destroy = iree_hal_cuda_nccl_channel_destroy,
     .query_rank_and_count = iree_hal_cuda_nccl_channel_query_rank_and_count,
 };
+
+#else  // IREE_HAL_DRIVER_CUDA_NCCL
+
+iree_status_t iree_hal_cuda_nccl_channel_create(
+    iree_hal_cuda_context_wrapper_t* context_wrapper,
+    const iree_hal_cuda_nccl_id_t* id, int rank, int count,
+    iree_hal_channel_t** out_channel) {
+  return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
+                          "iree_hal_cuda_nccl_channel_create()");
+}
+
+iree_status_t iree_hal_cuda_nccl_submit_batch(
+    iree_hal_cuda_context_wrapper_t* context,
+    const iree_hal_collective_batch_t* batch, CUstream stream) {
+  return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
+                          "iree_hal_cuda_nccl_submit_batch()");
+}
+
+#endif  // IREE_HAL_DRIVER_CUDA_NCCL
diff --git a/runtime/src/iree/hal/drivers/cuda/nccl_channel.h b/runtime/src/iree/hal/drivers/cuda/nccl_channel.h
index 51ffd54..8e65e59 100644
--- a/runtime/src/iree/hal/drivers/cuda/nccl_channel.h
+++ b/runtime/src/iree/hal/drivers/cuda/nccl_channel.h
@@ -18,17 +18,11 @@
 extern "C" {
 #endif  // __cplusplus
 
-// Creates a new NCCL communicator channel.
-typedef struct ncclComm* ncclComm_t;
-
 iree_status_t iree_hal_cuda_nccl_channel_create(
     iree_hal_cuda_context_wrapper_t* context_wrapper,
     const iree_hal_cuda_nccl_id_t* id, int rank, int count,
     iree_hal_channel_t** out_channel);
 
-// Returns the NCCL communicator for the given |channel|, if available.
-ncclComm_t iree_hal_cuda_nccl_channel_comm(iree_hal_channel_t* channel);
-
 // Performs a non-blocking submission of |batch| to |stream|.
 // The backing storage of |batch| is dropped immediately but all resources
 // referenced will be retained by the parent command buffer for its lifetime.
diff --git a/runtime/src/iree/hal/drivers/cuda/registration/driver_module.c b/runtime/src/iree/hal/drivers/cuda/registration/driver_module.c
index f609b6d..ab4164f 100644
--- a/runtime/src/iree/hal/drivers/cuda/registration/driver_module.c
+++ b/runtime/src/iree/hal/drivers/cuda/registration/driver_module.c
@@ -39,6 +39,34 @@
   return iree_ok_status();
 }
 
+#if IREE_HAL_DRIVER_CUDA_NCCL
+static iree_status_t iree_hal_cuda_init_nccl_rank_and_count(
+    iree_hal_cuda_device_params_t* params) {
+  char* nprocs_str = getenv("IREE_SPMD_NPROCS");
+  if (!nprocs_str) {
+    params->nccl_default_count = 0;
+    params->nccl_default_rank = 0;
+    return iree_status_from_code(IREE_STATUS_OK);
+  }
+
+  int nprocs = atoi(nprocs_str);
+  if (nprocs <= 0) return iree_status_from_code(IREE_STATUS_OUT_OF_RANGE);
+  params->nccl_default_count = nprocs;
+
+  char* procid_str = getenv("IREE_SPMD_PROCID");
+  if (!procid_str) {
+    // Expected PROCID when NPROCS is set.
+    return iree_status_from_code(IREE_STATUS_INVALID_ARGUMENT);
+  }
+  int procid = atoi(procid_str);
+  if (procid < 0 || procid >= nprocs) {
+    return iree_status_from_code(IREE_STATUS_OUT_OF_RANGE);
+  }
+  params->nccl_default_rank = procid;
+  return iree_status_from_code(IREE_STATUS_OK);
+}
+#endif  // IREE_HAL_DRIVER_CUDA_NCCL
+
 static iree_status_t iree_hal_cuda_driver_factory_try_create(
     void* self, iree_string_view_t driver_name, iree_allocator_t host_allocator,
     iree_hal_driver_t** out_driver) {
@@ -58,6 +86,12 @@
         IREE_HAL_CUDA_COMMAND_BUFFER_MODE_STREAM;
   }
   default_params.allow_inline_execution = FLAG_cuda_allow_inline_execution;
+#if IREE_HAL_DRIVER_CUDA_NCCL
+  iree_hal_cuda_init_nccl_rank_and_count(&default_params);
+
+  // Note that nccl_default_id can't be initalized until the driver imports
+  // the NCCL symbols from the dynamic library.
+#endif
 
   iree_hal_cuda_driver_options_t driver_options;
   iree_hal_cuda_driver_options_initialize(&driver_options);
diff --git a/runtime/src/iree/hal/drivers/cuda/status_util.c b/runtime/src/iree/hal/drivers/cuda/status_util.c
index 7881d2c..14790c1 100644
--- a/runtime/src/iree/hal/drivers/cuda/status_util.c
+++ b/runtime/src/iree/hal/drivers/cuda/status_util.c
@@ -30,3 +30,18 @@
                                         "CUDA driver error '%s' (%d): %s",
                                         error_name, result, error_string);
 }
+
+#if IREE_HAL_DRIVER_CUDA_NCCL
+iree_status_t iree_hal_nccl_result_to_status(
+    iree_hal_cuda_dynamic_symbols_t* syms, ncclResult_t result,
+    const char* file, uint32_t line) {
+  if (IREE_LIKELY(result == ncclSuccess)) {
+    return iree_ok_status();
+  }
+
+  const char* error_string = syms->ncclGetErrorString(result);
+  return iree_make_status_with_location(file, line, IREE_STATUS_INTERNAL,
+                                        "NCCL error %d: %s", result,
+                                        error_string);
+}
+#endif  // IREE_HAL_DRIVER_CUDA_NCCL
diff --git a/runtime/src/iree/hal/drivers/cuda/status_util.h b/runtime/src/iree/hal/drivers/cuda/status_util.h
index 9cb9d74..0d8409a 100644
--- a/runtime/src/iree/hal/drivers/cuda/status_util.h
+++ b/runtime/src/iree/hal/drivers/cuda/status_util.h
@@ -47,6 +47,40 @@
     iree_hal_cuda_dynamic_symbols_t* syms, CUresult result, const char* file,
     uint32_t line);
 
+#if IREE_HAL_DRIVER_CUDA_NCCL
+// Converts a ncclResult_t to an iree_status_t.
+//
+// Usage:
+//   iree_status_t status = NCCL_RESULT_TO_STATUS(ncclDoThing(...));
+#define NCCL_RESULT_TO_STATUS(syms, expr, ...) \
+  iree_hal_nccl_result_to_status((syms), ((syms)->expr), __FILE__, __LINE__)
+
+// Converts a ncclResult_t to a Status object.
+iree_status_t iree_hal_nccl_result_to_status(
+    iree_hal_cuda_dynamic_symbols_t* syms, ncclResult_t result,
+    const char* file, uint32_t line);
+
+// IREE_RETURN_IF_ERROR but implicitly converts the ncclResult_t return value to
+// a Status.
+//
+// Usage:
+//   NCCL_RETURN_IF_ERROR(ncclDoThing(...), "message");
+#define NCCL_RETURN_IF_ERROR(syms, expr, ...)                                 \
+  IREE_RETURN_IF_ERROR(iree_hal_nccl_result_to_status((syms), ((syms)->expr), \
+                                                      __FILE__, __LINE__),    \
+                       __VA_ARGS__)
+
+// IREE_IGNORE_ERROR but implicitly converts the ncclResult_t return value to a
+// Status.
+//
+// Usage:
+//   NCCL_IGNORE_ERROR(ncclDoThing(...));
+#define NCCL_IGNORE_ERROR(syms, expr)                                      \
+  IREE_IGNORE_ERROR(iree_hal_nccl_result_to_status((syms), ((syms)->expr), \
+                                                   __FILE__, __LINE__))
+
+#endif  // IREE_HAL_DRIVER_CUDA_NCCL
+
 #ifdef __cplusplus
 }  // extern "C"
 #endif  // __cplusplus
diff --git a/runtime/src/iree/modules/hal/module.c b/runtime/src/iree/modules/hal/module.c
index 8386b18..ec8d411 100644
--- a/runtime/src/iree/modules/hal/module.c
+++ b/runtime/src/iree/modules/hal/module.c
@@ -736,7 +736,7 @@
       .length = iree_hal_cast_device_size(args->i9),
   };
   IREE_RETURN_IF_ERROR(
-      iree_hal_buffer_check_deref_or_null(args->r7, &send_binding.buffer));
+      iree_hal_buffer_check_deref_or_null(args->r7, &recv_binding.buffer));
   iree_device_size_t element_count = iree_hal_cast_device_size(args->i10);
   return iree_hal_command_buffer_collective(command_buffer, channel, op, param,
                                             send_binding, recv_binding,