Use MPI for NCCL unique ID exchange by default (#12902)
Now the default implementation of the channel provider uses MPI to
exchange NCCL unique ID.
A few MPI functions are dynamically called just like how NCCL functions
are called. With the default implementation, OpenMPI must be available
in the system to run a module with collective operations. (Note that We
can't use MPICH-variants because of the difference in defining
MPI_Datatype and MPI_Comm.)
The device index is overridden by the MPI rank automatically.
Here is the command line before and after.
before
```
NCCL_COMM_ID=127.0.0.1:8000 iree-run-module --cuda_nccl_default_rank=0 --cuda_nccl_default_count=2 --device=cuda://0 --module=all_gather.vmfb --function=all_gather_dim_0 & \
NCCL_COMM_ID=127.0.0.1:8000 iree-run-module --cuda_nccl_default_rank=1 --cuda_nccl_default_count=2 --device=cuda://1 --module=all_gather.vmfb --function=all_gather_dim_0
```
after
```
mpirun -np 2 iree-run-module --device=cuda --module=all_gather.vmfb --function=all_gather_dim_0
```
diff --git a/runtime/src/iree/hal/drivers/cuda/api.h b/runtime/src/iree/hal/drivers/cuda/api.h
index ee8733d..c026eb8 100644
--- a/runtime/src/iree/hal/drivers/cuda/api.h
+++ b/runtime/src/iree/hal/drivers/cuda/api.h
@@ -79,6 +79,13 @@
IREE_API_EXPORT iree_status_t iree_hal_cuda_nccl_get_unique_id(
iree_hal_device_t* device, iree_hal_cuda_nccl_id_t* out_id);
+// Default implementation of the collective channel provider that uses MPI.
+// Hosting layers would want to use their own implementation to exchange IDs.
+IREE_API_EXPORT iree_status_t iree_hal_cuda_nccl_query_group_params(
+ void* self, iree_hal_device_t* device,
+ iree_hal_queue_affinity_t queue_affinity, iree_byte_span_t id_storage,
+ iree_hal_channel_params_t* out_params);
+
//===----------------------------------------------------------------------===//
// iree_hal_cuda_driver_t
//===----------------------------------------------------------------------===//
@@ -88,6 +95,10 @@
// Index of the default CUDA device to use within the list of available
// devices.
int default_device_index;
+ // The rank in the default collective channel.
+ int default_rank;
+ // The count of the default collective channel.
+ int default_count;
} iree_hal_cuda_driver_options_t;
IREE_API_EXPORT void iree_hal_cuda_driver_options_initialize(
diff --git a/runtime/src/iree/hal/drivers/cuda/cuda_device.c b/runtime/src/iree/hal/drivers/cuda/cuda_device.c
index bf3d4da..7c27c69 100644
--- a/runtime/src/iree/hal/drivers/cuda/cuda_device.c
+++ b/runtime/src/iree/hal/drivers/cuda/cuda_device.c
@@ -198,6 +198,13 @@
return iree_ok_status();
}
+iree_hal_cuda_dynamic_symbols_t* iree_hal_cuda_get_dynamic_symbols(
+ iree_hal_device_t* base_device) {
+ IREE_ASSERT_ARGUMENT(base_device);
+ iree_hal_cuda_device_t* device = iree_hal_cuda_device_cast(base_device);
+ return device->context_wrapper.syms;
+}
+
static void iree_hal_cuda_device_destroy(iree_hal_device_t* base_device) {
iree_hal_cuda_device_t* device = iree_hal_cuda_device_cast(base_device);
iree_allocator_t host_allocator = iree_hal_device_host_allocator(base_device);
@@ -294,6 +301,49 @@
out_id);
}
+IREE_API_EXPORT iree_status_t iree_hal_cuda_nccl_query_group_params(
+ void* self, iree_hal_device_t* device,
+ iree_hal_queue_affinity_t queue_affinity, iree_byte_span_t id_storage,
+ iree_hal_channel_params_t* out_params) {
+ IREE_ASSERT_EQ(id_storage.data_length, sizeof(iree_hal_cuda_nccl_id_t));
+
+ iree_hal_cuda_dynamic_symbols_t* syms =
+ iree_hal_cuda_get_dynamic_symbols(device);
+
+ if (!syms->mpi_library) {
+ return iree_make_status(
+ IREE_STATUS_UNIMPLEMENTED,
+ "MPI should be loaded to use NCCL collective operations.");
+ }
+
+ // Until we have multi channel support, we only create the default channel.
+ IREE_ASSERT_EQ(out_params->rank, IREE_HAL_CHANNEL_RANK_DEFAULT);
+ IREE_ASSERT_EQ(out_params->count, IREE_HAL_CHANNEL_COUNT_DEFAULT);
+
+ // Update the rank and count.
+ MPI_RETURN_IF_ERROR(
+ syms, MPI_Comm_rank(syms->ompi_mpi_comm_world, &out_params->rank),
+ "MPI_Comm_rank");
+ MPI_RETURN_IF_ERROR(
+ syms, MPI_Comm_size(syms->ompi_mpi_comm_world, &out_params->count),
+ "MPI_Comm_size");
+
+ iree_hal_cuda_nccl_id_t* id = (iree_hal_cuda_nccl_id_t*)id_storage.data;
+ if (out_params->rank == 0) {
+ // The root process of the group creates the unique ID and broadcasts it
+ // to the others.
+ IREE_RETURN_IF_ERROR(iree_hal_cuda_nccl_get_unique_id(device, id));
+ }
+
+ MPI_RETURN_IF_ERROR(syms,
+ MPI_Bcast(id, id_storage.data_length, syms->ompi_mpi_byte,
+ 0, syms->ompi_mpi_comm_world),
+ "MPI_Bcast");
+ out_params->id = iree_const_cast_byte_span(id_storage);
+
+ return iree_ok_status();
+}
+
static iree_status_t iree_hal_cuda_device_create_channel(
iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity,
iree_hal_channel_params_t params, iree_hal_channel_t** out_channel) {
diff --git a/runtime/src/iree/hal/drivers/cuda/cuda_device.h b/runtime/src/iree/hal/drivers/cuda/cuda_device.h
index 7b8ded4..2ce444f 100644
--- a/runtime/src/iree/hal/drivers/cuda/cuda_device.h
+++ b/runtime/src/iree/hal/drivers/cuda/cuda_device.h
@@ -28,6 +28,10 @@
iree_status_t iree_hal_cuda_device_get_context(iree_hal_device_t* base_device,
CUcontext* out_context);
+// Returns the dynamic symbol table from the device's context.
+iree_hal_cuda_dynamic_symbols_t* iree_hal_cuda_get_dynamic_symbols(
+ iree_hal_device_t* base_device);
+
#ifdef __cplusplus
} // extern "C"
#endif // __cplusplus
diff --git a/runtime/src/iree/hal/drivers/cuda/cuda_driver.c b/runtime/src/iree/hal/drivers/cuda/cuda_driver.c
index a41402c..1186bde 100644
--- a/runtime/src/iree/hal/drivers/cuda/cuda_driver.c
+++ b/runtime/src/iree/hal/drivers/cuda/cuda_driver.c
@@ -46,6 +46,27 @@
iree_hal_cuda_driver_options_t* out_options) {
memset(out_options, 0, sizeof(*out_options));
out_options->default_device_index = 0;
+ out_options->default_rank = 0;
+ out_options->default_count = 0;
+}
+
+// Gets the MPI world size from the environmental variable.
+// Returns 0 if the variable is not set.
+static iree_status_t iree_hal_cuda_get_mpi_comm_world_size_from_env(
+ int32_t* size) {
+ *size = 0;
+
+ const char* comm_world_size_str = getenv("OMPI_COMM_WORLD_SIZE");
+ if (!comm_world_size_str) {
+ return iree_ok_status();
+ }
+
+ if (!iree_string_view_atoi_uint32(iree_make_cstring_view(comm_world_size_str),
+ size)) {
+ return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+ "OMPI_COMM_WORLD_SIZE=%s", comm_world_size_str);
+ }
+ return iree_ok_status();
}
static iree_status_t iree_hal_cuda_driver_create_internal(
@@ -70,12 +91,31 @@
iree_status_t status =
iree_hal_cuda_dynamic_symbols_initialize(host_allocator, &driver->syms);
- // Initialize NCCL too if a channel provider is defined or any default
- // collective group values.
+ int comm_world_size = 0;
+ if (iree_status_is_ok(status)) {
+ status = iree_hal_cuda_get_mpi_comm_world_size_from_env(&comm_world_size);
+ }
+
+ // Initialize NCCL too if MPI is used or default_count is set.
if (iree_status_is_ok(status) &&
- default_params->channel_provider.query_group_params) {
+ (comm_world_size > 0 || options->default_count > 0)) {
status = iree_hal_cuda_nccl_dynamic_symbols_initialize(host_allocator,
&driver->syms);
+ if (iree_status_is_ok(status) && comm_world_size > 0) {
+ status = iree_hal_mpi_dynamic_symbols_initialize(host_allocator,
+ &driver->syms);
+ if (iree_status_is_ok(status)) {
+ MPI_RETURN_IF_ERROR(&driver->syms, MPI_Init(NULL, NULL), "MPI_Init");
+
+ // Override the device index with the MPI rank.
+ int rank = 0;
+ MPI_RETURN_IF_ERROR(
+ &driver->syms,
+ MPI_Comm_rank(driver->syms.ompi_mpi_comm_world, &rank),
+ "MPI_Comm_rank");
+ driver->default_device_index = rank;
+ }
+ }
}
if (iree_status_is_ok(status)) {
@@ -90,7 +130,9 @@
iree_hal_cuda_driver_t* driver = iree_hal_cuda_driver_cast(base_driver);
iree_allocator_t host_allocator = driver->host_allocator;
IREE_TRACE_ZONE_BEGIN(z0);
-
+ if (driver->syms.mpi_library) {
+ MPI_IGNORE_ERROR(&driver->syms, MPI_Finalize());
+ }
iree_hal_cuda_dynamic_symbols_deinitialize(&driver->syms);
iree_allocator_free(host_allocator, driver);
diff --git a/runtime/src/iree/hal/drivers/cuda/cuda_headers.h b/runtime/src/iree/hal/drivers/cuda/cuda_headers.h
index f2775a0..c2bfd54 100644
--- a/runtime/src/iree/hal/drivers/cuda/cuda_headers.h
+++ b/runtime/src/iree/hal/drivers/cuda/cuda_headers.h
@@ -7,6 +7,7 @@
#ifndef IREE_HAL_DRIVERS_CUDA_CUDA_HEADERS_H_
#define IREE_HAL_DRIVERS_CUDA_CUDA_HEADERS_H_
-#include "cuda.h" // IWYU pragma: export
+#include "cuda.h" // IWYU pragma: export
#include "third_party/nccl/nccl.h" // IWYU pragma: export
+
#endif // IREE_HAL_DRIVERS_CUDA_CUDA_HEADERS_H_
diff --git a/runtime/src/iree/hal/drivers/cuda/dynamic_symbol_tables.h b/runtime/src/iree/hal/drivers/cuda/dynamic_symbol_tables.h
index e85ec7a..8c69823 100644
--- a/runtime/src/iree/hal/drivers/cuda/dynamic_symbol_tables.h
+++ b/runtime/src/iree/hal/drivers/cuda/dynamic_symbol_tables.h
@@ -106,3 +106,15 @@
cudaStream_t)
NCCL_PFN_DECL(ncclGroupStart)
NCCL_PFN_DECL(ncclGroupEnd)
+
+// MPI
+
+MPI_PFN_DECL(MPI_Init, int*, char***)
+MPI_PFN_DECL(MPI_Finalize)
+MPI_PFN_DECL(MPI_Bcast, void* buffer, int count, void* datatype, int root,
+ void* comm)
+MPI_PFN_DECL(MPI_Comm_rank, void* comm, int* rank)
+MPI_PFN_DECL(MPI_Comm_size, void* comm, int* size)
+MPI_PFN_DECL(MPI_Comm_split, void* comm, int color, int key, void** newcomm)
+MPI_PFN_DECL(ompi_mpi_byte)
+MPI_PFN_DECL(ompi_mpi_comm_world)
diff --git a/runtime/src/iree/hal/drivers/cuda/dynamic_symbols.c b/runtime/src/iree/hal/drivers/cuda/dynamic_symbols.c
index 44d1244..88a74c3 100644
--- a/runtime/src/iree/hal/drivers/cuda/dynamic_symbols.c
+++ b/runtime/src/iree/hal/drivers/cuda/dynamic_symbols.c
@@ -29,6 +29,15 @@
#endif // IREE_PLATFORM_WINDOWS
};
+// TODO(okkwon): move this to a place that can be used by other drivers.
+static const char* kMPILoaderSearchNames[] = {
+#if defined(IREE_PLATFORM_WINDOWS)
+ "mpi.dll",
+#else
+ "libmpi.so",
+#endif // IREE_PLATFORM_WINDOWS
+};
+
#define concat(A, B) A B
// Load CUDA entry points, prefer _v2 version if it exists.
@@ -46,10 +55,12 @@
}
#define NCCL_PFN_DECL(ncclSymbolName, ...)
#define NCCL_PFN_DECL_STR_RETURN(ncclSymbolName, ...)
+#define MPI_PFN_DECL(mpiSymbolName, ...)
#include "iree/hal/drivers/cuda/dynamic_symbol_tables.h" // IWYU pragma: keep
#undef CU_PFN_DECL
#undef NCCL_PFN_DECL
#undef NCCL_PFN_DECL_STR_RETURN
+#undef MPI_PFN_DECL
return iree_ok_status();
}
@@ -69,10 +80,32 @@
IREE_RETURN_IF_ERROR(iree_dynamic_library_lookup_symbol( \
syms->nccl_library, kName, (void**)&syms->ncclSymbolName)); \
}
+#define MPI_PFN_DECL(mpiSymbolName, ...)
#include "iree/hal/drivers/cuda/dynamic_symbol_tables.h" // IWYU pragma: keep
#undef CU_PFN_DECL
#undef NCCL_PFN_DECL
#undef NCCL_PFN_DECL_STR_RETURN
+#undef MPI_PFN_DECL
+ return iree_ok_status();
+}
+
+// Load MPI entry points.
+static iree_status_t iree_hal_mpi_dynamic_symbols_resolve_all(
+ iree_hal_cuda_dynamic_symbols_t* syms) {
+#define CU_PFN_DECL(cudaSymbolName, ...)
+#define NCCL_PFN_DECL(ncclSymbolName, ...)
+#define NCCL_PFN_DECL_STR_RETURN(ncclSymbolName, ...)
+#define MPI_PFN_DECL(mpiSymbolName, ...) \
+ { \
+ static const char* kName = #mpiSymbolName; \
+ IREE_RETURN_IF_ERROR(iree_dynamic_library_lookup_symbol( \
+ syms->mpi_library, kName, (void**)&syms->mpiSymbolName)); \
+ }
+#include "iree/hal/drivers/cuda/dynamic_symbol_tables.h" // IWYU pragma: keep
+#undef CU_PFN_DECL
+#undef NCCL_PFN_DECL
+#undef NCCL_PFN_DECL_STR_RETURN
+#undef MPI_PFN_DECL
return iree_ok_status();
}
@@ -163,11 +196,39 @@
return status;
}
+// FIXME(okkwon): it is unrelated to CUDA, but iree_hal_cuda_dynamic_symbols_t
+// is used.
+iree_status_t iree_hal_mpi_dynamic_symbols_initialize(
+ iree_allocator_t host_allocator,
+ iree_hal_cuda_dynamic_symbols_t* out_syms) {
+ IREE_TRACE_ZONE_BEGIN(z0);
+ out_syms->mpi_library = NULL;
+ iree_status_t status = iree_dynamic_library_load_from_files(
+ IREE_ARRAYSIZE(kMPILoaderSearchNames), kMPILoaderSearchNames,
+ IREE_DYNAMIC_LIBRARY_FLAG_NONE, host_allocator, &out_syms->mpi_library);
+ if (iree_status_is_not_found(status)) {
+ iree_status_ignore(status);
+ IREE_TRACE_ZONE_END(z0);
+ return iree_make_status(IREE_STATUS_UNAVAILABLE,
+ "MPI runtime library not available; ensure "
+ "installed and on path");
+ }
+ if (iree_status_is_ok(status)) {
+ status = iree_hal_mpi_dynamic_symbols_resolve_all(out_syms);
+ }
+ if (!iree_status_is_ok(status)) {
+ iree_hal_cuda_dynamic_symbols_deinitialize(out_syms);
+ }
+ IREE_TRACE_ZONE_END(z0);
+ return status;
+}
+
void iree_hal_cuda_dynamic_symbols_deinitialize(
iree_hal_cuda_dynamic_symbols_t* syms) {
IREE_TRACE_ZONE_BEGIN(z0);
iree_dynamic_library_release(syms->cuda_library);
iree_dynamic_library_release(syms->nccl_library);
+ iree_dynamic_library_release(syms->mpi_library);
memset(syms, 0, sizeof(*syms));
IREE_TRACE_ZONE_END(z0);
}
diff --git a/runtime/src/iree/hal/drivers/cuda/dynamic_symbols.h b/runtime/src/iree/hal/drivers/cuda/dynamic_symbols.h
index 6fc8e98..8c95867 100644
--- a/runtime/src/iree/hal/drivers/cuda/dynamic_symbols.h
+++ b/runtime/src/iree/hal/drivers/cuda/dynamic_symbols.h
@@ -10,7 +10,6 @@
#include "iree/base/api.h"
#include "iree/base/internal/dynamic_library.h"
#include "iree/hal/drivers/cuda/cuda_headers.h"
-#include "third_party/nccl/nccl.h"
#ifdef __cplusplus
extern "C" {
@@ -23,6 +22,7 @@
typedef struct iree_hal_cuda_dynamic_symbols_t {
iree_dynamic_library_t* cuda_library;
iree_dynamic_library_t* nccl_library;
+ iree_dynamic_library_t* mpi_library;
#define CU_PFN_DECL(cudaSymbolName, ...) \
CUresult (*cudaSymbolName)(__VA_ARGS__);
@@ -30,10 +30,12 @@
ncclResult_t (*ncclSymbolName)(__VA_ARGS__);
#define NCCL_PFN_DECL_STR_RETURN(ncclSymbolName, ...) \
const char* (*ncclSymbolName)(__VA_ARGS__);
+#define MPI_PFN_DECL(mpiSymbolName, ...) int (*mpiSymbolName)(__VA_ARGS__);
#include "iree/hal/drivers/cuda/dynamic_symbol_tables.h" // IWYU pragma: export
#undef CU_PFN_DECL
#undef NCCL_PFN_DECL
#undef NCCL_PFN_DECL_STR_RETURN
+#undef MPI_PFN_DECL
} iree_hal_cuda_dynamic_symbols_t;
// Initializes |out_syms| in-place with dynamically loaded CUDA symbols.
@@ -48,6 +50,12 @@
iree_status_t iree_hal_cuda_nccl_dynamic_symbols_initialize(
iree_allocator_t host_allocator, iree_hal_cuda_dynamic_symbols_t* out_syms);
+// Initializes |out_syms| in-place with dynamically loaded MPI symbols.
+// iree_hal_cuda_dynamic_symbols_deinitialize must be used to release the
+// library resources.
+iree_status_t iree_hal_mpi_dynamic_symbols_initialize(
+ iree_allocator_t host_allocator, iree_hal_cuda_dynamic_symbols_t* out_syms);
+
// Deinitializes |syms| by unloading the backing library. All function pointers
// will be invalidated. They _may_ still work if there are other reasons the
// library remains loaded so be careful.
diff --git a/runtime/src/iree/hal/drivers/cuda/registration/driver_module.c b/runtime/src/iree/hal/drivers/cuda/registration/driver_module.c
index d3ecb55..dfd625f 100644
--- a/runtime/src/iree/hal/drivers/cuda/registration/driver_module.c
+++ b/runtime/src/iree/hal/drivers/cuda/registration/driver_module.c
@@ -36,44 +36,6 @@
IREE_FLAG(int32_t, cuda_nccl_default_count, 0,
"Participant count of the default collective group");
-// Default implementation of the collective channel provider that just uses the
-// NCCL_COMM_ID environment variable for configuration. Hosting layers would
-// want to use their own implementation to exchange IDs.
-static iree_status_t iree_hal_cuda_nccl_query_group_params(
- void* self, iree_hal_device_t* device,
- iree_hal_queue_affinity_t queue_affinity, iree_byte_span_t id_storage,
- iree_hal_channel_params_t* params) {
- IREE_ASSERT_EQ(id_storage.data_length, sizeof(iree_hal_cuda_nccl_id_t));
-
- // Users can either specify a specific rank or allow this device
- // implementation to decide. This allows us to run the same programs acting as
- // different ranks by setting flags/environment variables/API options/etc.
- if (params->rank == IREE_HAL_CHANNEL_RANK_DEFAULT) {
- params->rank = FLAG_cuda_nccl_default_rank;
- }
- if (params->count == IREE_HAL_CHANNEL_COUNT_DEFAULT) {
- params->count = FLAG_cuda_nccl_default_count;
- }
-
- // Let NCCL configure itself and return the ID to use.
- //
- // HACK: this may not be correct and should only be used for testing.
- // TODO(benvanik): a string form we can use and a flag.
- if (iree_const_byte_span_is_empty(params->id)) {
- if (!getenv("NCCL_COMM_ID")) {
- return iree_make_status(
- IREE_STATUS_INVALID_ARGUMENT,
- "the NCCL_COMM_ID environment variable must be set "
- "when using the default NCCL configuration");
- }
- iree_hal_cuda_nccl_id_t* id = (iree_hal_cuda_nccl_id_t*)id_storage.data;
- IREE_RETURN_IF_ERROR(iree_hal_cuda_nccl_get_unique_id(device, id));
- params->id = iree_const_cast_byte_span(id_storage);
- }
-
- return iree_ok_status();
-}
-
static iree_status_t iree_hal_cuda_driver_factory_enumerate(
void* self, iree_host_size_t* out_driver_info_count,
const iree_hal_driver_info_t** out_driver_infos) {
@@ -107,19 +69,16 @@
}
default_params.allow_inline_execution = FLAG_cuda_allow_inline_execution;
default_params.stream_tracing = FLAG_cuda_tracing;
-
- // Only setup channels if we're running collectives. Setting this will require
- // NCCL to be available at runtime.
- if (FLAG_cuda_nccl_default_count != 0) {
- default_params.channel_provider = (iree_hal_channel_provider_t){
- .self = NULL,
- .query_group_params = iree_hal_cuda_nccl_query_group_params,
- };
- }
+ default_params.channel_provider = (iree_hal_channel_provider_t){
+ .self = NULL,
+ .query_group_params = iree_hal_cuda_nccl_query_group_params,
+ };
iree_hal_cuda_driver_options_t driver_options;
iree_hal_cuda_driver_options_initialize(&driver_options);
driver_options.default_device_index = FLAG_cuda_default_index;
+ driver_options.default_rank = FLAG_cuda_nccl_default_rank;
+ driver_options.default_count = FLAG_cuda_nccl_default_count;
iree_status_t status =
iree_hal_cuda_driver_create(driver_name, &default_params, &driver_options,
diff --git a/runtime/src/iree/hal/drivers/cuda/status_util.c b/runtime/src/iree/hal/drivers/cuda/status_util.c
index 2e5eb1b..1e2cd66 100644
--- a/runtime/src/iree/hal/drivers/cuda/status_util.c
+++ b/runtime/src/iree/hal/drivers/cuda/status_util.c
@@ -68,3 +68,19 @@
result,
syms->ncclGetErrorString(result));
}
+
+iree_status_t iree_hal_mpi_result_to_status(
+ iree_hal_cuda_dynamic_symbols_t* syms, int result, const char* file,
+ uint32_t line) {
+ iree_status_code_t code;
+
+ switch (result) {
+ case 0: // MPI_SUCCESS
+ return iree_ok_status();
+ default:
+ code = IREE_STATUS_INTERNAL;
+ break;
+ }
+ return iree_make_status_with_location(file, line, code, "MPI error %d",
+ result);
+}
diff --git a/runtime/src/iree/hal/drivers/cuda/status_util.h b/runtime/src/iree/hal/drivers/cuda/status_util.h
index d860c43..b81a826 100644
--- a/runtime/src/iree/hal/drivers/cuda/status_util.h
+++ b/runtime/src/iree/hal/drivers/cuda/status_util.h
@@ -78,6 +78,37 @@
IREE_IGNORE_ERROR(iree_hal_nccl_result_to_status((syms), ((syms)->expr), \
__FILE__, __LINE__))
+// Converts a mpi result to an iree_status_t.
+//
+// Usage:
+// iree_status_t status = MPI_RESULT_TO_STATUS(mpiDoThing(...));
+#define MPI_RESULT_TO_STATUS(syms, expr, ...) \
+ iree_hal_mpi_result_to_status((syms), ((syms)->expr), __FILE__, __LINE__)
+
+// Converts a mpi result to a Status object.
+iree_status_t iree_hal_mpi_result_to_status(
+ iree_hal_cuda_dynamic_symbols_t* syms, int result, const char* file,
+ uint32_t line);
+
+// IREE_RETURN_IF_ERROR but implicitly converts the mpi return value to
+// a Status.
+//
+// Usage:
+// MPI_RETURN_IF_ERROR(mpiDoThing(...), "message");
+#define MPI_RETURN_IF_ERROR(syms, expr, ...) \
+ IREE_RETURN_IF_ERROR(iree_hal_mpi_result_to_status((syms), ((syms)->expr), \
+ __FILE__, __LINE__), \
+ __VA_ARGS__)
+
+// IREE_IGNORE_ERROR but implicitly converts the mpi return value to a
+// Status.
+//
+// Usage:
+// MPI_IGNORE_ERROR(mpiDoThing(...));
+#define MPI_IGNORE_ERROR(syms, expr) \
+ IREE_IGNORE_ERROR(iree_hal_mpi_result_to_status((syms), ((syms)->expr), \
+ __FILE__, __LINE__))
+
#ifdef __cplusplus
} // extern "C"
#endif // __cplusplus