| // Copyright 2023 The IREE Authors |
| // |
| // Licensed under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| |
| #include "iree/hal/utils/mpi_channel_provider.h" |
| |
| #include <stdlib.h> |
| |
| #include "iree/hal/utils/libmpi.h" |
| |
| // Returns true if |var_name| is set to a non-empty value in the environment. |
| static bool iree_hal_mpi_env_is_set(const char* var_name) { |
| const char* var_value = getenv(var_name); |
| return var_value && strlen(var_value) > 0; |
| } |
| |
| // For now this simply checks that the world size is set. We could verify |
| // more of the environment but if a user is partially configuring things |
| // manually YMMV. |
| IREE_API_EXPORT bool iree_hal_mpi_is_configured(void) { |
| // TODO: find a better approach that is more portable across implementations. |
| // PMI_RANK/PMI_SIZE seem common ones (MS/Intel/Slurm) but OpenMPI uses their |
| // own. |
| return iree_hal_mpi_env_is_set("PMI_SIZE") || |
| iree_hal_mpi_env_is_set("OMPI_COMM_WORLD_SIZE") || |
| iree_hal_mpi_env_is_set("MPIEXEC_HOSTNAME"); |
| } |
| |
| typedef struct iree_hal_mpi_channel_provider_t { |
| iree_hal_resource_t resource; |
| iree_allocator_t host_allocator; |
| |
| iree_dynamic_library_t* library; |
| iree_hal_mpi_dynamic_symbols_t symbols; |
| |
| // MPI_Init and MPI_Finalize must be called exactly once in an application. |
| // It may be the case that the user of IREE has already initialized MPI. |
| // Deciding if we are owners of the context based on whether MPI is already |
| // initialized or not is not ideal since other parts of the user's application |
| // that use MPI will depend on the channel provider lifespan. |
| // It may be better to let the user be the owner of MPI's context. |
| bool is_mpi_context_owner; |
| } iree_hal_mpi_channel_provider_t; |
| |
| static const iree_hal_channel_provider_vtable_t |
| iree_hal_mpi_channel_provider_vtable; |
| |
| static iree_hal_mpi_channel_provider_t* iree_hal_mpi_channel_provider_cast( |
| iree_hal_channel_provider_t* base_value) { |
| IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_mpi_channel_provider_vtable); |
| return (iree_hal_mpi_channel_provider_t*)base_value; |
| } |
| |
| IREE_API_EXPORT iree_status_t iree_hal_mpi_channel_provider_create( |
| iree_allocator_t host_allocator, |
| iree_hal_channel_provider_t** out_channel_provider) { |
| IREE_ASSERT_ARGUMENT(out_channel_provider); |
| IREE_TRACE_ZONE_BEGIN(z0); |
| |
| iree_hal_mpi_channel_provider_t* channel_provider = NULL; |
| IREE_RETURN_AND_END_ZONE_IF_ERROR( |
| z0, iree_allocator_malloc(host_allocator, sizeof(*channel_provider), |
| (void**)&channel_provider)); |
| iree_hal_resource_initialize(&iree_hal_mpi_channel_provider_vtable, |
| &channel_provider->resource); |
| channel_provider->host_allocator = host_allocator; |
| |
| // Attempt to load the shared library. This will fail if it's not found, |
| // not compatible with the process (wrong arch), or missing symbols (out of |
| // date). |
| iree_status_t status = iree_hal_mpi_library_load( |
| host_allocator, &channel_provider->library, &channel_provider->symbols); |
| |
| // If the library successfully loaded then try to initialize MPI. |
| int is_mpi_initialized_already; |
| if (iree_status_is_ok(status)) { |
| IREE_TRACE_ZONE_BEGIN_NAMED(z1, "MPI_Initialized"); |
| status = MPI_RESULT_TO_STATUS(&channel_provider->symbols, |
| MPI_Initialized(&is_mpi_initialized_already), |
| "MPI_Initialized"); |
| IREE_TRACE_ZONE_END(z1); |
| } |
| if (iree_status_is_ok(status)) { |
| if (!is_mpi_initialized_already) { |
| IREE_TRACE_ZONE_BEGIN_NAMED(z2, "MPI_Init"); |
| status = MPI_RESULT_TO_STATUS(&channel_provider->symbols, |
| MPI_Init(NULL, NULL), "MPI_Init"); |
| IREE_TRACE_ZONE_END(z2); |
| } |
| channel_provider->is_mpi_context_owner = !is_mpi_initialized_already; |
| } |
| |
| if (iree_status_is_ok(status)) { |
| *out_channel_provider = (iree_hal_channel_provider_t*)channel_provider; |
| } else { |
| iree_hal_channel_provider_release( |
| (iree_hal_channel_provider_t*)channel_provider); |
| } |
| IREE_TRACE_ZONE_END(z0); |
| return status; |
| } |
| |
| // Returns true if MPI has been initialized. |
| static bool iree_hal_mpi_channel_provider_is_initialized( |
| iree_hal_mpi_channel_provider_t* channel_provider) { |
| if (!channel_provider->library) return false; |
| int flag = 0; |
| MPI_IGNORE_ERROR(&channel_provider->symbols, MPI_Initialized(&flag)); |
| return flag ? true : false; |
| } |
| |
| static void iree_hal_mpi_channel_provider_destroy( |
| iree_hal_channel_provider_t* base_channel_provider) { |
| iree_hal_mpi_channel_provider_t* channel_provider = |
| iree_hal_mpi_channel_provider_cast(base_channel_provider); |
| iree_allocator_t host_allocator = channel_provider->host_allocator; |
| IREE_TRACE_ZONE_BEGIN(z0); |
| |
| // We must finalize MPI before unloading the library. |
| // NOTE: once finalized MPI can never be used in the process again! |
| if (channel_provider->is_mpi_context_owner && |
| iree_hal_mpi_channel_provider_is_initialized(channel_provider)) { |
| IREE_TRACE_ZONE_BEGIN_NAMED(z1, "MPI_Finalize"); |
| MPI_IGNORE_ERROR(&channel_provider->symbols, MPI_Finalize()); |
| IREE_TRACE_ZONE_END(z1); |
| } |
| |
| // Reset the symbols in case anyone is hanging on to them. ASAN should |
| // complain anyway but it's not available everywhere. |
| memset(&channel_provider->symbols, 0, sizeof(channel_provider->symbols)); |
| |
| iree_dynamic_library_release(channel_provider->library); |
| |
| iree_allocator_free(host_allocator, channel_provider); |
| |
| IREE_TRACE_ZONE_END(z0); |
| } |
| |
| IREE_API_EXPORT bool iree_hal_mpi_channel_provider_isa( |
| iree_hal_channel_provider_t* channel_provider) { |
| return iree_hal_resource_is(channel_provider, |
| &iree_hal_mpi_channel_provider_vtable); |
| } |
| |
| IREE_API_EXPORT iree_hal_mpi_dynamic_symbols_t* |
| iree_hal_mpi_channel_provider_symbols( |
| iree_hal_channel_provider_t* base_channel_provider) { |
| IREE_ASSERT_ARGUMENT(base_channel_provider); |
| if (!iree_hal_mpi_channel_provider_isa(base_channel_provider)) return NULL; |
| iree_hal_mpi_channel_provider_t* channel_provider = |
| iree_hal_mpi_channel_provider_cast(base_channel_provider); |
| return &channel_provider->symbols; |
| } |
| |
| static iree_status_t iree_hal_mpi_channel_provider_query_default_rank_and_count( |
| iree_hal_channel_provider_t* base_channel_provider, int32_t* out_rank, |
| int32_t* out_count) { |
| iree_hal_mpi_channel_provider_t* channel_provider = |
| iree_hal_mpi_channel_provider_cast(base_channel_provider); |
| |
| static_assert(sizeof(int32_t) == sizeof(int), "MPI uses int"); |
| MPI_RETURN_IF_ERROR( |
| &channel_provider->symbols, |
| MPI_Comm_rank(IREE_MPI_COMM_WORLD(&channel_provider->symbols), |
| (int*)out_rank), |
| "MPI_Comm_rank"); |
| MPI_RETURN_IF_ERROR( |
| &channel_provider->symbols, |
| MPI_Comm_size(IREE_MPI_COMM_WORLD(&channel_provider->symbols), |
| (int*)out_count), |
| "MPI_Comm_size"); |
| |
| return iree_ok_status(); |
| } |
| |
| IREE_API_EXPORT iree_status_t iree_hal_mpi_channel_provider_exchange_default_id( |
| iree_hal_channel_provider_t* base_channel_provider, iree_byte_span_t id) { |
| iree_hal_mpi_channel_provider_t* channel_provider = |
| iree_hal_mpi_channel_provider_cast(base_channel_provider); |
| |
| // Exchange the ID with all other participants. The root participant will |
| // send its ID while the others will receive it. |
| MPI_RETURN_IF_ERROR( |
| &channel_provider->symbols, |
| MPI_Bcast(id.data, id.data_length, |
| IREE_MPI_BYTE(&channel_provider->symbols), 0, |
| IREE_MPI_COMM_WORLD(&channel_provider->symbols)), |
| "MPI_Bcast"); |
| |
| return iree_ok_status(); |
| } |
| |
| static const iree_hal_channel_provider_vtable_t |
| iree_hal_mpi_channel_provider_vtable = { |
| .destroy = iree_hal_mpi_channel_provider_destroy, |
| .query_default_rank_and_count = |
| iree_hal_mpi_channel_provider_query_default_rank_and_count, |
| .exchange_default_id = |
| iree_hal_mpi_channel_provider_exchange_default_id, |
| }; |