Use MPI for NCCL unique ID exchange by default (#12902)

Now the default implementation of the channel provider uses MPI to
exchange NCCL unique ID.

A few MPI functions are dynamically called just like how NCCL functions
are called. With the default implementation, OpenMPI must be available
in the system to run a module with collective operations. (Note that We
can't use MPICH-variants because of the difference in defining
MPI_Datatype and MPI_Comm.)

The device index is overridden by the MPI rank automatically.

Here is the command line before and after.

before
```
NCCL_COMM_ID=127.0.0.1:8000 iree-run-module --cuda_nccl_default_rank=0 --cuda_nccl_default_count=2 --device=cuda://0 --module=all_gather.vmfb --function=all_gather_dim_0 & \
NCCL_COMM_ID=127.0.0.1:8000 iree-run-module --cuda_nccl_default_rank=1 --cuda_nccl_default_count=2 --device=cuda://1 --module=all_gather.vmfb --function=all_gather_dim_0
```

after
```
mpirun -np 2 iree-run-module --device=cuda --module=all_gather.vmfb --function=all_gather_dim_0
```
diff --git a/runtime/src/iree/hal/drivers/cuda/api.h b/runtime/src/iree/hal/drivers/cuda/api.h
index ee8733d..c026eb8 100644
--- a/runtime/src/iree/hal/drivers/cuda/api.h
+++ b/runtime/src/iree/hal/drivers/cuda/api.h
@@ -79,6 +79,13 @@
 IREE_API_EXPORT iree_status_t iree_hal_cuda_nccl_get_unique_id(
     iree_hal_device_t* device, iree_hal_cuda_nccl_id_t* out_id);
 
+// Default implementation of the collective channel provider that uses MPI.
+// Hosting layers would want to use their own implementation to exchange IDs.
+IREE_API_EXPORT iree_status_t iree_hal_cuda_nccl_query_group_params(
+    void* self, iree_hal_device_t* device,
+    iree_hal_queue_affinity_t queue_affinity, iree_byte_span_t id_storage,
+    iree_hal_channel_params_t* out_params);
+
 //===----------------------------------------------------------------------===//
 // iree_hal_cuda_driver_t
 //===----------------------------------------------------------------------===//
@@ -88,6 +95,10 @@
   // Index of the default CUDA device to use within the list of available
   // devices.
   int default_device_index;
+  // The rank in the default collective channel.
+  int default_rank;
+  // The count of the default collective channel.
+  int default_count;
 } iree_hal_cuda_driver_options_t;
 
 IREE_API_EXPORT void iree_hal_cuda_driver_options_initialize(
diff --git a/runtime/src/iree/hal/drivers/cuda/cuda_device.c b/runtime/src/iree/hal/drivers/cuda/cuda_device.c
index bf3d4da..7c27c69 100644
--- a/runtime/src/iree/hal/drivers/cuda/cuda_device.c
+++ b/runtime/src/iree/hal/drivers/cuda/cuda_device.c
@@ -198,6 +198,13 @@
   return iree_ok_status();
 }
 
+iree_hal_cuda_dynamic_symbols_t* iree_hal_cuda_get_dynamic_symbols(
+    iree_hal_device_t* base_device) {
+  IREE_ASSERT_ARGUMENT(base_device);
+  iree_hal_cuda_device_t* device = iree_hal_cuda_device_cast(base_device);
+  return device->context_wrapper.syms;
+}
+
 static void iree_hal_cuda_device_destroy(iree_hal_device_t* base_device) {
   iree_hal_cuda_device_t* device = iree_hal_cuda_device_cast(base_device);
   iree_allocator_t host_allocator = iree_hal_device_host_allocator(base_device);
@@ -294,6 +301,49 @@
                                                        out_id);
 }
 
+IREE_API_EXPORT iree_status_t iree_hal_cuda_nccl_query_group_params(
+    void* self, iree_hal_device_t* device,
+    iree_hal_queue_affinity_t queue_affinity, iree_byte_span_t id_storage,
+    iree_hal_channel_params_t* out_params) {
+  IREE_ASSERT_EQ(id_storage.data_length, sizeof(iree_hal_cuda_nccl_id_t));
+
+  iree_hal_cuda_dynamic_symbols_t* syms =
+      iree_hal_cuda_get_dynamic_symbols(device);
+
+  if (!syms->mpi_library) {
+    return iree_make_status(
+        IREE_STATUS_UNIMPLEMENTED,
+        "MPI should be loaded to use NCCL collective operations.");
+  }
+
+  // Until we have multi channel support, we only create the default channel.
+  IREE_ASSERT_EQ(out_params->rank, IREE_HAL_CHANNEL_RANK_DEFAULT);
+  IREE_ASSERT_EQ(out_params->count, IREE_HAL_CHANNEL_COUNT_DEFAULT);
+
+  // Update the rank and count.
+  MPI_RETURN_IF_ERROR(
+      syms, MPI_Comm_rank(syms->ompi_mpi_comm_world, &out_params->rank),
+      "MPI_Comm_rank");
+  MPI_RETURN_IF_ERROR(
+      syms, MPI_Comm_size(syms->ompi_mpi_comm_world, &out_params->count),
+      "MPI_Comm_size");
+
+  iree_hal_cuda_nccl_id_t* id = (iree_hal_cuda_nccl_id_t*)id_storage.data;
+  if (out_params->rank == 0) {
+    // The root process of the group creates the unique ID and broadcasts it
+    // to the others.
+    IREE_RETURN_IF_ERROR(iree_hal_cuda_nccl_get_unique_id(device, id));
+  }
+
+  MPI_RETURN_IF_ERROR(syms,
+                      MPI_Bcast(id, id_storage.data_length, syms->ompi_mpi_byte,
+                                0, syms->ompi_mpi_comm_world),
+                      "MPI_Bcast");
+  out_params->id = iree_const_cast_byte_span(id_storage);
+
+  return iree_ok_status();
+}
+
 static iree_status_t iree_hal_cuda_device_create_channel(
     iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity,
     iree_hal_channel_params_t params, iree_hal_channel_t** out_channel) {
diff --git a/runtime/src/iree/hal/drivers/cuda/cuda_device.h b/runtime/src/iree/hal/drivers/cuda/cuda_device.h
index 7b8ded4..2ce444f 100644
--- a/runtime/src/iree/hal/drivers/cuda/cuda_device.h
+++ b/runtime/src/iree/hal/drivers/cuda/cuda_device.h
@@ -28,6 +28,10 @@
 iree_status_t iree_hal_cuda_device_get_context(iree_hal_device_t* base_device,
                                                CUcontext* out_context);
 
+// Returns the dynamic symbol table from the device's context.
+iree_hal_cuda_dynamic_symbols_t* iree_hal_cuda_get_dynamic_symbols(
+    iree_hal_device_t* base_device);
+
 #ifdef __cplusplus
 }  // extern "C"
 #endif  // __cplusplus
diff --git a/runtime/src/iree/hal/drivers/cuda/cuda_driver.c b/runtime/src/iree/hal/drivers/cuda/cuda_driver.c
index a41402c..1186bde 100644
--- a/runtime/src/iree/hal/drivers/cuda/cuda_driver.c
+++ b/runtime/src/iree/hal/drivers/cuda/cuda_driver.c
@@ -46,6 +46,27 @@
     iree_hal_cuda_driver_options_t* out_options) {
   memset(out_options, 0, sizeof(*out_options));
   out_options->default_device_index = 0;
+  out_options->default_rank = 0;
+  out_options->default_count = 0;
+}
+
+// Gets the MPI world size from the environmental variable.
+// Returns 0 if the variable is not set.
+static iree_status_t iree_hal_cuda_get_mpi_comm_world_size_from_env(
+    int32_t* size) {
+  *size = 0;
+
+  const char* comm_world_size_str = getenv("OMPI_COMM_WORLD_SIZE");
+  if (!comm_world_size_str) {
+    return iree_ok_status();
+  }
+
+  if (!iree_string_view_atoi_uint32(iree_make_cstring_view(comm_world_size_str),
+                                    size)) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "OMPI_COMM_WORLD_SIZE=%s", comm_world_size_str);
+  }
+  return iree_ok_status();
 }
 
 static iree_status_t iree_hal_cuda_driver_create_internal(
@@ -70,12 +91,31 @@
   iree_status_t status =
       iree_hal_cuda_dynamic_symbols_initialize(host_allocator, &driver->syms);
 
-  // Initialize NCCL too if a channel provider is defined or any default
-  // collective group values.
+  int comm_world_size = 0;
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_cuda_get_mpi_comm_world_size_from_env(&comm_world_size);
+  }
+
+  // Initialize NCCL too if MPI is used or default_count is set.
   if (iree_status_is_ok(status) &&
-      default_params->channel_provider.query_group_params) {
+      (comm_world_size > 0 || options->default_count > 0)) {
     status = iree_hal_cuda_nccl_dynamic_symbols_initialize(host_allocator,
                                                            &driver->syms);
+    if (iree_status_is_ok(status) && comm_world_size > 0) {
+      status = iree_hal_mpi_dynamic_symbols_initialize(host_allocator,
+                                                       &driver->syms);
+      if (iree_status_is_ok(status)) {
+        MPI_RETURN_IF_ERROR(&driver->syms, MPI_Init(NULL, NULL), "MPI_Init");
+
+        // Override the device index with the MPI rank.
+        int rank = 0;
+        MPI_RETURN_IF_ERROR(
+            &driver->syms,
+            MPI_Comm_rank(driver->syms.ompi_mpi_comm_world, &rank),
+            "MPI_Comm_rank");
+        driver->default_device_index = rank;
+      }
+    }
   }
 
   if (iree_status_is_ok(status)) {
@@ -90,7 +130,9 @@
   iree_hal_cuda_driver_t* driver = iree_hal_cuda_driver_cast(base_driver);
   iree_allocator_t host_allocator = driver->host_allocator;
   IREE_TRACE_ZONE_BEGIN(z0);
-
+  if (driver->syms.mpi_library) {
+    MPI_IGNORE_ERROR(&driver->syms, MPI_Finalize());
+  }
   iree_hal_cuda_dynamic_symbols_deinitialize(&driver->syms);
   iree_allocator_free(host_allocator, driver);
 
diff --git a/runtime/src/iree/hal/drivers/cuda/cuda_headers.h b/runtime/src/iree/hal/drivers/cuda/cuda_headers.h
index f2775a0..c2bfd54 100644
--- a/runtime/src/iree/hal/drivers/cuda/cuda_headers.h
+++ b/runtime/src/iree/hal/drivers/cuda/cuda_headers.h
@@ -7,6 +7,7 @@
 #ifndef IREE_HAL_DRIVERS_CUDA_CUDA_HEADERS_H_
 #define IREE_HAL_DRIVERS_CUDA_CUDA_HEADERS_H_
 
-#include "cuda.h"  // IWYU pragma: export
+#include "cuda.h"                   // IWYU pragma: export
 #include "third_party/nccl/nccl.h"  // IWYU pragma: export
+
 #endif  // IREE_HAL_DRIVERS_CUDA_CUDA_HEADERS_H_
diff --git a/runtime/src/iree/hal/drivers/cuda/dynamic_symbol_tables.h b/runtime/src/iree/hal/drivers/cuda/dynamic_symbol_tables.h
index e85ec7a..8c69823 100644
--- a/runtime/src/iree/hal/drivers/cuda/dynamic_symbol_tables.h
+++ b/runtime/src/iree/hal/drivers/cuda/dynamic_symbol_tables.h
@@ -106,3 +106,15 @@
               cudaStream_t)
 NCCL_PFN_DECL(ncclGroupStart)
 NCCL_PFN_DECL(ncclGroupEnd)
+
+// MPI
+
+MPI_PFN_DECL(MPI_Init, int*, char***)
+MPI_PFN_DECL(MPI_Finalize)
+MPI_PFN_DECL(MPI_Bcast, void* buffer, int count, void* datatype, int root,
+             void* comm)
+MPI_PFN_DECL(MPI_Comm_rank, void* comm, int* rank)
+MPI_PFN_DECL(MPI_Comm_size, void* comm, int* size)
+MPI_PFN_DECL(MPI_Comm_split, void* comm, int color, int key, void** newcomm)
+MPI_PFN_DECL(ompi_mpi_byte)
+MPI_PFN_DECL(ompi_mpi_comm_world)
diff --git a/runtime/src/iree/hal/drivers/cuda/dynamic_symbols.c b/runtime/src/iree/hal/drivers/cuda/dynamic_symbols.c
index 44d1244..88a74c3 100644
--- a/runtime/src/iree/hal/drivers/cuda/dynamic_symbols.c
+++ b/runtime/src/iree/hal/drivers/cuda/dynamic_symbols.c
@@ -29,6 +29,15 @@
 #endif  // IREE_PLATFORM_WINDOWS
 };
 
+// TODO(okkwon): move this to a place that can be used by other drivers.
+static const char* kMPILoaderSearchNames[] = {
+#if defined(IREE_PLATFORM_WINDOWS)
+    "mpi.dll",
+#else
+    "libmpi.so",
+#endif  // IREE_PLATFORM_WINDOWS
+};
+
 #define concat(A, B) A B
 
 // Load CUDA entry points, prefer _v2 version if it exists.
@@ -46,10 +55,12 @@
   }
 #define NCCL_PFN_DECL(ncclSymbolName, ...)
 #define NCCL_PFN_DECL_STR_RETURN(ncclSymbolName, ...)
+#define MPI_PFN_DECL(mpiSymbolName, ...)
 #include "iree/hal/drivers/cuda/dynamic_symbol_tables.h"  // IWYU pragma: keep
 #undef CU_PFN_DECL
 #undef NCCL_PFN_DECL
 #undef NCCL_PFN_DECL_STR_RETURN
+#undef MPI_PFN_DECL
   return iree_ok_status();
 }
 
@@ -69,10 +80,32 @@
     IREE_RETURN_IF_ERROR(iree_dynamic_library_lookup_symbol(        \
         syms->nccl_library, kName, (void**)&syms->ncclSymbolName)); \
   }
+#define MPI_PFN_DECL(mpiSymbolName, ...)
 #include "iree/hal/drivers/cuda/dynamic_symbol_tables.h"  // IWYU pragma: keep
 #undef CU_PFN_DECL
 #undef NCCL_PFN_DECL
 #undef NCCL_PFN_DECL_STR_RETURN
+#undef MPI_PFN_DECL
+  return iree_ok_status();
+}
+
+// Load MPI entry points.
+static iree_status_t iree_hal_mpi_dynamic_symbols_resolve_all(
+    iree_hal_cuda_dynamic_symbols_t* syms) {
+#define CU_PFN_DECL(cudaSymbolName, ...)
+#define NCCL_PFN_DECL(ncclSymbolName, ...)
+#define NCCL_PFN_DECL_STR_RETURN(ncclSymbolName, ...)
+#define MPI_PFN_DECL(mpiSymbolName, ...)                          \
+  {                                                               \
+    static const char* kName = #mpiSymbolName;                    \
+    IREE_RETURN_IF_ERROR(iree_dynamic_library_lookup_symbol(      \
+        syms->mpi_library, kName, (void**)&syms->mpiSymbolName)); \
+  }
+#include "iree/hal/drivers/cuda/dynamic_symbol_tables.h"  // IWYU pragma: keep
+#undef CU_PFN_DECL
+#undef NCCL_PFN_DECL
+#undef NCCL_PFN_DECL_STR_RETURN
+#undef MPI_PFN_DECL
   return iree_ok_status();
 }
 
@@ -163,11 +196,39 @@
   return status;
 }
 
+// FIXME(okkwon): it is unrelated to CUDA, but iree_hal_cuda_dynamic_symbols_t
+// is used.
+iree_status_t iree_hal_mpi_dynamic_symbols_initialize(
+    iree_allocator_t host_allocator,
+    iree_hal_cuda_dynamic_symbols_t* out_syms) {
+  IREE_TRACE_ZONE_BEGIN(z0);
+  out_syms->mpi_library = NULL;
+  iree_status_t status = iree_dynamic_library_load_from_files(
+      IREE_ARRAYSIZE(kMPILoaderSearchNames), kMPILoaderSearchNames,
+      IREE_DYNAMIC_LIBRARY_FLAG_NONE, host_allocator, &out_syms->mpi_library);
+  if (iree_status_is_not_found(status)) {
+    iree_status_ignore(status);
+    IREE_TRACE_ZONE_END(z0);
+    return iree_make_status(IREE_STATUS_UNAVAILABLE,
+                            "MPI runtime library not available; ensure "
+                            "installed and on path");
+  }
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_mpi_dynamic_symbols_resolve_all(out_syms);
+  }
+  if (!iree_status_is_ok(status)) {
+    iree_hal_cuda_dynamic_symbols_deinitialize(out_syms);
+  }
+  IREE_TRACE_ZONE_END(z0);
+  return status;
+}
+
 void iree_hal_cuda_dynamic_symbols_deinitialize(
     iree_hal_cuda_dynamic_symbols_t* syms) {
   IREE_TRACE_ZONE_BEGIN(z0);
   iree_dynamic_library_release(syms->cuda_library);
   iree_dynamic_library_release(syms->nccl_library);
+  iree_dynamic_library_release(syms->mpi_library);
   memset(syms, 0, sizeof(*syms));
   IREE_TRACE_ZONE_END(z0);
 }
diff --git a/runtime/src/iree/hal/drivers/cuda/dynamic_symbols.h b/runtime/src/iree/hal/drivers/cuda/dynamic_symbols.h
index 6fc8e98..8c95867 100644
--- a/runtime/src/iree/hal/drivers/cuda/dynamic_symbols.h
+++ b/runtime/src/iree/hal/drivers/cuda/dynamic_symbols.h
@@ -10,7 +10,6 @@
 #include "iree/base/api.h"
 #include "iree/base/internal/dynamic_library.h"
 #include "iree/hal/drivers/cuda/cuda_headers.h"
-#include "third_party/nccl/nccl.h"
 
 #ifdef __cplusplus
 extern "C" {
@@ -23,6 +22,7 @@
 typedef struct iree_hal_cuda_dynamic_symbols_t {
   iree_dynamic_library_t* cuda_library;
   iree_dynamic_library_t* nccl_library;
+  iree_dynamic_library_t* mpi_library;
 
 #define CU_PFN_DECL(cudaSymbolName, ...) \
   CUresult (*cudaSymbolName)(__VA_ARGS__);
@@ -30,10 +30,12 @@
   ncclResult_t (*ncclSymbolName)(__VA_ARGS__);
 #define NCCL_PFN_DECL_STR_RETURN(ncclSymbolName, ...) \
   const char* (*ncclSymbolName)(__VA_ARGS__);
+#define MPI_PFN_DECL(mpiSymbolName, ...) int (*mpiSymbolName)(__VA_ARGS__);
 #include "iree/hal/drivers/cuda/dynamic_symbol_tables.h"  // IWYU pragma: export
 #undef CU_PFN_DECL
 #undef NCCL_PFN_DECL
 #undef NCCL_PFN_DECL_STR_RETURN
+#undef MPI_PFN_DECL
 } iree_hal_cuda_dynamic_symbols_t;
 
 // Initializes |out_syms| in-place with dynamically loaded CUDA symbols.
@@ -48,6 +50,12 @@
 iree_status_t iree_hal_cuda_nccl_dynamic_symbols_initialize(
     iree_allocator_t host_allocator, iree_hal_cuda_dynamic_symbols_t* out_syms);
 
+// Initializes |out_syms| in-place with dynamically loaded MPI symbols.
+// iree_hal_cuda_dynamic_symbols_deinitialize must be used to release the
+// library resources.
+iree_status_t iree_hal_mpi_dynamic_symbols_initialize(
+    iree_allocator_t host_allocator, iree_hal_cuda_dynamic_symbols_t* out_syms);
+
 // Deinitializes |syms| by unloading the backing library. All function pointers
 // will be invalidated. They _may_ still work if there are other reasons the
 // library remains loaded so be careful.
diff --git a/runtime/src/iree/hal/drivers/cuda/registration/driver_module.c b/runtime/src/iree/hal/drivers/cuda/registration/driver_module.c
index d3ecb55..dfd625f 100644
--- a/runtime/src/iree/hal/drivers/cuda/registration/driver_module.c
+++ b/runtime/src/iree/hal/drivers/cuda/registration/driver_module.c
@@ -36,44 +36,6 @@
 IREE_FLAG(int32_t, cuda_nccl_default_count, 0,
           "Participant count of the default collective group");
 
-// Default implementation of the collective channel provider that just uses the
-// NCCL_COMM_ID environment variable for configuration. Hosting layers would
-// want to use their own implementation to exchange IDs.
-static iree_status_t iree_hal_cuda_nccl_query_group_params(
-    void* self, iree_hal_device_t* device,
-    iree_hal_queue_affinity_t queue_affinity, iree_byte_span_t id_storage,
-    iree_hal_channel_params_t* params) {
-  IREE_ASSERT_EQ(id_storage.data_length, sizeof(iree_hal_cuda_nccl_id_t));
-
-  // Users can either specify a specific rank or allow this device
-  // implementation to decide. This allows us to run the same programs acting as
-  // different ranks by setting flags/environment variables/API options/etc.
-  if (params->rank == IREE_HAL_CHANNEL_RANK_DEFAULT) {
-    params->rank = FLAG_cuda_nccl_default_rank;
-  }
-  if (params->count == IREE_HAL_CHANNEL_COUNT_DEFAULT) {
-    params->count = FLAG_cuda_nccl_default_count;
-  }
-
-  // Let NCCL configure itself and return the ID to use.
-  //
-  // HACK: this may not be correct and should only be used for testing.
-  // TODO(benvanik): a string form we can use and a flag.
-  if (iree_const_byte_span_is_empty(params->id)) {
-    if (!getenv("NCCL_COMM_ID")) {
-      return iree_make_status(
-          IREE_STATUS_INVALID_ARGUMENT,
-          "the NCCL_COMM_ID environment variable must be set "
-          "when using the default NCCL configuration");
-    }
-    iree_hal_cuda_nccl_id_t* id = (iree_hal_cuda_nccl_id_t*)id_storage.data;
-    IREE_RETURN_IF_ERROR(iree_hal_cuda_nccl_get_unique_id(device, id));
-    params->id = iree_const_cast_byte_span(id_storage);
-  }
-
-  return iree_ok_status();
-}
-
 static iree_status_t iree_hal_cuda_driver_factory_enumerate(
     void* self, iree_host_size_t* out_driver_info_count,
     const iree_hal_driver_info_t** out_driver_infos) {
@@ -107,19 +69,16 @@
   }
   default_params.allow_inline_execution = FLAG_cuda_allow_inline_execution;
   default_params.stream_tracing = FLAG_cuda_tracing;
-
-  // Only setup channels if we're running collectives. Setting this will require
-  // NCCL to be available at runtime.
-  if (FLAG_cuda_nccl_default_count != 0) {
-    default_params.channel_provider = (iree_hal_channel_provider_t){
-        .self = NULL,
-        .query_group_params = iree_hal_cuda_nccl_query_group_params,
-    };
-  }
+  default_params.channel_provider = (iree_hal_channel_provider_t){
+      .self = NULL,
+      .query_group_params = iree_hal_cuda_nccl_query_group_params,
+  };
 
   iree_hal_cuda_driver_options_t driver_options;
   iree_hal_cuda_driver_options_initialize(&driver_options);
   driver_options.default_device_index = FLAG_cuda_default_index;
+  driver_options.default_rank = FLAG_cuda_nccl_default_rank;
+  driver_options.default_count = FLAG_cuda_nccl_default_count;
 
   iree_status_t status =
       iree_hal_cuda_driver_create(driver_name, &default_params, &driver_options,
diff --git a/runtime/src/iree/hal/drivers/cuda/status_util.c b/runtime/src/iree/hal/drivers/cuda/status_util.c
index 2e5eb1b..1e2cd66 100644
--- a/runtime/src/iree/hal/drivers/cuda/status_util.c
+++ b/runtime/src/iree/hal/drivers/cuda/status_util.c
@@ -68,3 +68,19 @@
                                         result,
                                         syms->ncclGetErrorString(result));
 }
+
+iree_status_t iree_hal_mpi_result_to_status(
+    iree_hal_cuda_dynamic_symbols_t* syms, int result, const char* file,
+    uint32_t line) {
+  iree_status_code_t code;
+
+  switch (result) {
+    case 0:  // MPI_SUCCESS
+      return iree_ok_status();
+    default:
+      code = IREE_STATUS_INTERNAL;
+      break;
+  }
+  return iree_make_status_with_location(file, line, code, "MPI error %d",
+                                        result);
+}
diff --git a/runtime/src/iree/hal/drivers/cuda/status_util.h b/runtime/src/iree/hal/drivers/cuda/status_util.h
index d860c43..b81a826 100644
--- a/runtime/src/iree/hal/drivers/cuda/status_util.h
+++ b/runtime/src/iree/hal/drivers/cuda/status_util.h
@@ -78,6 +78,37 @@
   IREE_IGNORE_ERROR(iree_hal_nccl_result_to_status((syms), ((syms)->expr), \
                                                    __FILE__, __LINE__))
 
+// Converts a mpi result to an iree_status_t.
+//
+// Usage:
+//   iree_status_t status = MPI_RESULT_TO_STATUS(mpiDoThing(...));
+#define MPI_RESULT_TO_STATUS(syms, expr, ...) \
+  iree_hal_mpi_result_to_status((syms), ((syms)->expr), __FILE__, __LINE__)
+
+// Converts a mpi result to a Status object.
+iree_status_t iree_hal_mpi_result_to_status(
+    iree_hal_cuda_dynamic_symbols_t* syms, int result, const char* file,
+    uint32_t line);
+
+// IREE_RETURN_IF_ERROR but implicitly converts the mpi return value to
+// a Status.
+//
+// Usage:
+//   MPI_RETURN_IF_ERROR(mpiDoThing(...), "message");
+#define MPI_RETURN_IF_ERROR(syms, expr, ...)                                 \
+  IREE_RETURN_IF_ERROR(iree_hal_mpi_result_to_status((syms), ((syms)->expr), \
+                                                     __FILE__, __LINE__),    \
+                       __VA_ARGS__)
+
+// IREE_IGNORE_ERROR but implicitly converts the mpi return value to a
+// Status.
+//
+// Usage:
+//   MPI_IGNORE_ERROR(mpiDoThing(...));
+#define MPI_IGNORE_ERROR(syms, expr)                                      \
+  IREE_IGNORE_ERROR(iree_hal_mpi_result_to_status((syms), ((syms)->expr), \
+                                                  __FILE__, __LINE__))
+
 #ifdef __cplusplus
 }  // extern "C"
 #endif  // __cplusplus