Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 6 additions & 24 deletions include/internal/comm_routines.h
Original file line number Diff line number Diff line change
Expand Up @@ -281,14 +281,8 @@ static void cudecompAlltoall(const cudecompHandle_t& handle, const cudecompGridD
}

#ifdef ENABLE_NVSHMEM
if (handle->rank == 0 && handle->nvshmem_runtime && !handle->nvshmem_mixed_buffer_warning_issued &&
transposeBackendRequiresMpi(grid_desc->config.transpose_comm_backend) &&
(nvshmem_ptr(send_buff, handle->rank) || nvshmem_ptr(recv_buff, handle->rank))) {
printf("CUDECOMP:WARN: A work buffer allocated with nvshmem_malloc (via cudecompMalloc) is "
"being used with an MPI communication backend. This may cause issues with some MPI "
"implementations. See the documentation for additional details and possible workarounds "
"if you encounter issues.\n");
handle->nvshmem_mixed_buffer_warning_issued = true;
if (transposeBackendRequiresMpi(grid_desc->config.transpose_comm_backend)) {
warnIfNvshmemBufferUsedWithMpi(handle, send_buff, recv_buff);
}
#endif

Expand Down Expand Up @@ -470,14 +464,8 @@ cudecompAlltoallPipelined(const cudecompHandle_t& handle, const cudecompGridDesc
}

#ifdef ENABLE_NVSHMEM
if (handle->rank == 0 && handle->nvshmem_runtime && !handle->nvshmem_mixed_buffer_warning_issued &&
transposeBackendRequiresMpi(grid_desc->config.transpose_comm_backend) &&
(nvshmem_ptr(send_buff, handle->rank) || nvshmem_ptr(recv_buff, handle->rank))) {
printf("CUDECOMP:WARN: A work buffer allocated with nvshmem_malloc (via cudecompMalloc) is "
"being used with an MPI communication backend. This may cause issues with some MPI "
"implementations. See the documentation for additional details and possible workarounds "
"if you encounter issues.\n");
handle->nvshmem_mixed_buffer_warning_issued = true;
if (transposeBackendRequiresMpi(grid_desc->config.transpose_comm_backend)) {
warnIfNvshmemBufferUsedWithMpi(handle, send_buff, recv_buff);
}
#endif

Expand Down Expand Up @@ -666,14 +654,8 @@ static void cudecompSendRecvPair(const cudecompHandle_t& handle, const cudecompG
}

#ifdef ENABLE_NVSHMEM
if (handle->rank == 0 && handle->nvshmem_runtime && !handle->nvshmem_mixed_buffer_warning_issued &&
haloBackendRequiresMpi(grid_desc->config.halo_comm_backend) &&
(nvshmem_ptr(send_buff, handle->rank) || nvshmem_ptr(recv_buff, handle->rank))) {
printf("CUDECOMP:WARN: A work buffer allocated with nvshmem_malloc (via cudecompMalloc) is "
"being used with an MPI communication backend. This may cause issues with some MPI "
"implementations. See the documentation for additional details and possible workarounds "
"if you encounter issues.\n");
handle->nvshmem_mixed_buffer_warning_issued = true;
if (haloBackendRequiresMpi(grid_desc->config.halo_comm_backend)) {
warnIfNvshmemBufferUsedWithMpi(handle, send_buff, recv_buff);
}
#endif

Expand Down
36 changes: 23 additions & 13 deletions include/internal/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,19 +57,36 @@ typedef std::pair<std::array<unsigned char, 1>, unsigned int> mnnvl_info;
typedef std::shared_ptr<ncclComm_t> ncclComm;
struct nvshmemRuntimeState {
#ifdef ENABLE_NVSHMEM
~nvshmemRuntimeState() { finalize(); }
~nvshmemRuntimeState() noexcept { finalize(); }

void finalize() {
if (!finalized) {
void finalize() noexcept {
if (initialized) {
nvshmem_finalize();
finalized = true;
initialized = false;
}
nvshmem_allocations.clear();
nvshmem_allocation_size = 0;
}

bool finalized = false;
bool initialized = false; // Flag to track NVSHMEM initialization
size_t nvshmem_symmetric_size = 0; // NVSHMEM symmetric size
bool nvshmem_vmm = true; // Flag to track if NVSHMEM is using VMM allocations
std::unordered_map<void*, size_t> nvshmem_allocations; // Table to record NVSHMEM allocations
size_t nvshmem_allocation_size = 0; // Total of NVSHMEM allocations
#endif
};
typedef std::shared_ptr<nvshmemRuntimeState> nvshmemRuntime;
#ifdef ENABLE_NVSHMEM
struct nvshmemProcessState {
// NVSHMEM fixes the PE mapping at initialization and does not support reinitializing
// on a different set or order of ranks, even after nvshmem_finalize().
std::vector<int> init_world_ranks;
std::weak_ptr<nvshmemRuntimeState> active_runtime;
bool mixed_buffer_warning_issued = false;
};

void warnIfNvshmemBufferUsedWithMpi(cudecompHandle_t handle, const void* send_buff, const void* recv_buff);
#endif
} // namespace cudecomp

// cuDecomp handle containing general information
Expand Down Expand Up @@ -116,14 +133,6 @@ struct cudecompHandle {

bool initialized = false;

// Entries for NVSHMEM management and warning generation
cudecomp::nvshmemRuntime nvshmem_runtime; // Shared reference to initialized NVSHMEM runtime
bool nvshmem_mixed_buffer_warning_issued = false; // Warn once if NVSHMEM buffer is used with MPI
size_t nvshmem_symmetric_size; // NVSHMEM symmetric size
bool nvshmem_vmm; // Flag to track if NVSHMEM is using VMM allocations
std::unordered_map<void*, size_t> nvshmem_allocations; // Table to record NVSHMEM allocations
size_t nvshmem_allocation_size = 0; // Total of NVSHMEM allocations

// Multi-node NVLINK (MNNVL)
bool cuda_cumem_enable = false; // Flag to control whether cuMem* APIs are used for cudecompMalloc/Free
std::vector<cudecomp::mnnvl_info> rank_to_mnnvl_info; // list of mnnvl information (clusterUuid, cliqueId) by rank
Expand Down Expand Up @@ -243,6 +252,7 @@ struct cudecompGridDesc {
#endif
}

cudecompHandle_t handle = nullptr; // owning cuDecomp handle
cudecompGridDescConfig_t config; // configuration struct
bool gdims_dist_set = false; // flag to record if gdims_dist was set to non-default values
bool transpose_mem_order_set = false; // flag to record if transpose_mem_order was set to non-default values
Expand Down
31 changes: 18 additions & 13 deletions src/cuda_wrap.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
* limitations under the License.
*/

#include <mutex>

#include <cuda_runtime.h>

#include "internal/checks.h"
Expand Down Expand Up @@ -48,21 +50,24 @@ namespace cudecomp {
cuFunctionTable cuFnTable; // global table of required CUDA driver functions

void initCuFunctionTable() {
static std::once_flag init_once;
std::call_once(init_once, []() {
#if CUDART_VERSION >= 11030
LOAD_SYM(cuDeviceGet, 2000);
LOAD_SYM(cuDeviceGetAttribute, 2000);
LOAD_SYM(cuGetErrorString, 6000);
LOAD_SYM(cuMemAddressFree, 10020);
LOAD_SYM(cuMemAddressReserve, 10020);
LOAD_SYM(cuMemCreate, 10020);
LOAD_SYM(cuMemGetAddressRange, 3020);
LOAD_SYM(cuMemGetAllocationGranularity, 10020);
LOAD_SYM(cuMemMap, 10020);
LOAD_SYM(cuMemRetainAllocationHandle, 11000);
LOAD_SYM(cuMemRelease, 10020);
LOAD_SYM(cuMemSetAccess, 10020);
LOAD_SYM(cuMemUnmap, 10020);
LOAD_SYM(cuDeviceGet, 2000);
LOAD_SYM(cuDeviceGetAttribute, 2000);
LOAD_SYM(cuGetErrorString, 6000);
LOAD_SYM(cuMemAddressFree, 10020);
LOAD_SYM(cuMemAddressReserve, 10020);
LOAD_SYM(cuMemCreate, 10020);
LOAD_SYM(cuMemGetAddressRange, 3020);
LOAD_SYM(cuMemGetAllocationGranularity, 10020);
LOAD_SYM(cuMemMap, 10020);
LOAD_SYM(cuMemRetainAllocationHandle, 11000);
LOAD_SYM(cuMemRelease, 10020);
LOAD_SYM(cuMemSetAccess, 10020);
LOAD_SYM(cuMemUnmap, 10020);
#endif
});
}

} // namespace cudecomp
Expand Down
Loading
Loading