Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions include/internal/comm_routines.h
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ static void cudecompAlltoall(const cudecompHandle_t& handle, const cudecompGridD
}

#ifdef ENABLE_NVSHMEM
if (handle->rank == 0 && handle->nvshmem_initialized && !handle->nvshmem_mixed_buffer_warning_issued &&
if (handle->rank == 0 && handle->nvshmem_runtime && !handle->nvshmem_mixed_buffer_warning_issued &&
transposeBackendRequiresMpi(grid_desc->config.transpose_comm_backend) &&
(nvshmem_ptr(send_buff, handle->rank) || nvshmem_ptr(recv_buff, handle->rank))) {
printf("CUDECOMP:WARN: A work buffer allocated with nvshmem_malloc (via cudecompMalloc) is "
Expand Down Expand Up @@ -429,7 +429,7 @@ cudecompAlltoallPipelined(const cudecompHandle_t& handle, const cudecompGridDesc
}

#ifdef ENABLE_NVSHMEM
if (handle->rank == 0 && handle->nvshmem_initialized && !handle->nvshmem_mixed_buffer_warning_issued &&
if (handle->rank == 0 && handle->nvshmem_runtime && !handle->nvshmem_mixed_buffer_warning_issued &&
transposeBackendRequiresMpi(grid_desc->config.transpose_comm_backend) &&
(nvshmem_ptr(send_buff, handle->rank) || nvshmem_ptr(recv_buff, handle->rank))) {
printf("CUDECOMP:WARN: A work buffer allocated with nvshmem_malloc (via cudecompMalloc) is "
Expand Down Expand Up @@ -625,7 +625,7 @@ static void cudecompSendRecvPair(const cudecompHandle_t& handle, const cudecompG
}

#ifdef ENABLE_NVSHMEM
if (handle->rank == 0 && handle->nvshmem_initialized && !handle->nvshmem_mixed_buffer_warning_issued &&
if (handle->rank == 0 && handle->nvshmem_runtime && !handle->nvshmem_mixed_buffer_warning_issued &&
haloBackendRequiresMpi(grid_desc->config.halo_comm_backend) &&
(nvshmem_ptr(send_buff, handle->rank) || nvshmem_ptr(recv_buff, handle->rank))) {
printf("CUDECOMP:WARN: A work buffer allocated with nvshmem_malloc (via cudecompMalloc) is "
Expand Down
21 changes: 18 additions & 3 deletions include/internal/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,21 @@ typedef std::pair<std::array<unsigned char, NVML_GPU_FABRIC_UUID_LEN>, unsigned
typedef std::pair<std::array<unsigned char, 1>, unsigned int> mnnvl_info;
#endif
typedef std::shared_ptr<ncclComm_t> ncclComm;
struct nvshmemRuntimeState {
#ifdef ENABLE_NVSHMEM
~nvshmemRuntimeState() { finalize(); }

void finalize() {
if (!finalized) {
nvshmem_finalize();
finalized = true;
}
}

bool finalized = false;
#endif
};
typedef std::shared_ptr<nvshmemRuntimeState> nvshmemRuntime;
} // namespace cudecomp

// cuDecomp handle containing general information
Expand Down Expand Up @@ -92,8 +107,7 @@ struct cudecompHandle {
bool initialized = false;

// Entries for NVSHMEM management and warning generation
bool nvshmem_initialized = false; // Flag to track NVSHMEM initialization
int n_grid_descs_using_nvshmem = 0; // Count of grid descriptors using NVSHMEM
cudecomp::nvshmemRuntime nvshmem_runtime; // Shared reference to initialized NVSHMEM runtime
bool nvshmem_mixed_buffer_warning_issued = false; // Warn once if NVSHMEM buffer is used with MPI
size_t nvshmem_symmetric_size; // NVSHMEM symmetric size
bool nvshmem_vmm; // Flag to track if NVSHMEM is using VMM allocations
Expand Down Expand Up @@ -192,7 +206,8 @@ struct cudecompGridDesc {
cudaEvent_t nvshmem_sync_event = nullptr; // NVSHMEM event used for synchronization

#ifdef ENABLE_NVSHMEM
int* nvshmem_block_counters = nullptr; // device memory counters for SM alltoallv last-block detection
int* nvshmem_block_counters = nullptr; // device memory counters for SM alltoallv last-block detection
cudecomp::nvshmemRuntime nvshmem_runtime; // Shared reference to initialized NVSHMEM runtime
#endif

cudecomp::graphCache graph_cache; // CUDA graph cache
Expand Down
45 changes: 24 additions & 21 deletions src/cudecomp.cc
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,13 @@ static void checkNvshmemVersion(cudecompHandle_t& handle) {
setenv("NVSHMEM_CUMEM_GRANULARITY", "2147483648", 1);
}
}

static nvshmemRuntime createNvshmemRuntime(cudecompHandle_t& handle) {
checkNvshmemVersion(handle);
inspectNvshmemEnvVars(handle);
initNvshmemFromMPIComm(handle->mpi_comm);
return std::make_shared<nvshmemRuntimeState>();
}
#endif

} // namespace
Expand Down Expand Up @@ -628,12 +635,10 @@ cudecompResult_t cudecompFinalize(cudecompHandle_t handle) {
CHECK_CUDA(cudaStreamDestroy(stream));
}
#ifdef ENABLE_NVSHMEM
if (handle->nvshmem_initialized) {
nvshmem_finalize();
handle->nvshmem_initialized = false;
handle->nvshmem_allocations.clear();
handle->nvshmem_allocation_size = 0;
}
if (handle->nvshmem_runtime) { handle->nvshmem_runtime->finalize(); }
handle->nvshmem_allocations.clear();
handle->nvshmem_allocation_size = 0;
handle->nvshmem_runtime.reset();
#endif
CHECK_MPI(MPI_Comm_free(&handle->mpi_local_comm));

Expand Down Expand Up @@ -781,11 +786,8 @@ cudecompResult_t cudecompGridDescCreate(cudecompHandle_t handle, cudecompGridDes
if (transposeBackendRequiresNvshmem(comm_backend) || haloBackendRequiresNvshmem(halo_comm_backend) ||
((autotune_transpose_backend || autotune_halo_backend) && !autotune_disable_nvshmem_backends)) {
#ifdef ENABLE_NVSHMEM
if (!handle->nvshmem_initialized) {
checkNvshmemVersion(handle);
inspectNvshmemEnvVars(handle);
initNvshmemFromMPIComm(handle->mpi_comm);
handle->nvshmem_initialized = true;
if (!handle->nvshmem_runtime) {
handle->nvshmem_runtime = createNvshmemRuntime(handle);
handle->nvshmem_allocation_size = 0;
}
#endif
Expand Down Expand Up @@ -850,13 +852,14 @@ cudecompResult_t cudecompGridDescCreate(cudecompHandle_t handle, cudecompGridDes
cudaMemset(grid_desc->col_comm_info.nvshmem_signals, 0, grid_desc->col_comm_info.nranks * sizeof(uint64_t)));
CHECK_CUDA(cudaMalloc(&grid_desc->nvshmem_block_counters, handle->nranks * sizeof(int)));
CHECK_CUDA(cudaMemset(grid_desc->nvshmem_block_counters, 0, handle->nranks * sizeof(int)));
handle->n_grid_descs_using_nvshmem++;

// Set grid descriptor reference to NVSHMEM runtime
grid_desc->nvshmem_runtime = handle->nvshmem_runtime;
} else {
// Finalize nvshmem to reclaim symmetric heap memory if not used
if (handle->nvshmem_initialized && handle->n_grid_descs_using_nvshmem == 0) {
nvshmem_finalize();
handle->nvshmem_initialized = false;
// If handle has the only remaining reference to the NVSHMEM runtime, destroy it to reclaim resources
if (handle->nvshmem_runtime && handle->nvshmem_runtime.use_count() == 1) {
handle->nvshmem_allocations.clear();
handle->nvshmem_runtime.reset();
handle->nvshmem_allocation_size = 0;
}
}
Expand Down Expand Up @@ -981,13 +984,13 @@ cudecompResult_t cudecompGridDescDestroy(cudecompHandle_t handle, cudecompGridDe
nvshmem_free(grid_desc->col_comm_info.nvshmem_signals);
}
CHECK_CUDA(cudaFree(grid_desc->nvshmem_block_counters));
handle->n_grid_descs_using_nvshmem--;
// Release grid descriptor reference to NVSHMEM runtime
grid_desc->nvshmem_runtime.reset();

// Finalize nvshmem to reclaim symmetric heap memory if not used
if (handle->nvshmem_initialized && handle->n_grid_descs_using_nvshmem == 0) {
nvshmem_finalize();
handle->nvshmem_initialized = false;
// If handle has the only remaining reference to the NVSHMEM runtime, destroy it to reclaim resources
if (handle->nvshmem_runtime && handle->nvshmem_runtime.use_count() == 1) {
handle->nvshmem_allocations.clear();
handle->nvshmem_runtime.reset();
handle->nvshmem_allocation_size = 0;
}
}
Expand Down
Loading