From 2a7fea33de6b8a3640c718200888d62932de0db9 Mon Sep 17 00:00:00 2001 From: Josh Romero Date: Thu, 30 Apr 2026 15:11:13 -0700 Subject: [PATCH 1/2] Use ref-count based handling of NVSHMEM initialization state. Signed-off-by: Josh Romero --- include/internal/comm_routines.h | 6 ++--- include/internal/common.h | 12 ++++++--- src/cudecomp.cc | 44 +++++++++++++++++--------------- 3 files changed, 35 insertions(+), 27 deletions(-) diff --git a/include/internal/comm_routines.h b/include/internal/comm_routines.h index e56d531..36388ea 100644 --- a/include/internal/comm_routines.h +++ b/include/internal/comm_routines.h @@ -248,7 +248,7 @@ static void cudecompAlltoall(const cudecompHandle_t& handle, const cudecompGridD } #ifdef ENABLE_NVSHMEM - if (handle->rank == 0 && handle->nvshmem_initialized && !handle->nvshmem_mixed_buffer_warning_issued && + if (handle->rank == 0 && handle->nvshmem_runtime && !handle->nvshmem_mixed_buffer_warning_issued && transposeBackendRequiresMpi(grid_desc->config.transpose_comm_backend) && (nvshmem_ptr(send_buff, handle->rank) || nvshmem_ptr(recv_buff, handle->rank))) { printf("CUDECOMP:WARN: A work buffer allocated with nvshmem_malloc (via cudecompMalloc) is " @@ -429,7 +429,7 @@ cudecompAlltoallPipelined(const cudecompHandle_t& handle, const cudecompGridDesc } #ifdef ENABLE_NVSHMEM - if (handle->rank == 0 && handle->nvshmem_initialized && !handle->nvshmem_mixed_buffer_warning_issued && + if (handle->rank == 0 && handle->nvshmem_runtime && !handle->nvshmem_mixed_buffer_warning_issued && transposeBackendRequiresMpi(grid_desc->config.transpose_comm_backend) && (nvshmem_ptr(send_buff, handle->rank) || nvshmem_ptr(recv_buff, handle->rank))) { printf("CUDECOMP:WARN: A work buffer allocated with nvshmem_malloc (via cudecompMalloc) is " @@ -625,7 +625,7 @@ static void cudecompSendRecvPair(const cudecompHandle_t& handle, const cudecompG } #ifdef ENABLE_NVSHMEM - if (handle->rank == 0 && handle->nvshmem_initialized && !handle->nvshmem_mixed_buffer_warning_issued && + if (handle->rank == 0 && handle->nvshmem_runtime && !handle->nvshmem_mixed_buffer_warning_issued && haloBackendRequiresMpi(grid_desc->config.halo_comm_backend) && (nvshmem_ptr(send_buff, handle->rank) || nvshmem_ptr(recv_buff, handle->rank))) { printf("CUDECOMP:WARN: A work buffer allocated with nvshmem_malloc (via cudecompMalloc) is " diff --git a/include/internal/common.h b/include/internal/common.h index 091e900..cce69a7 100644 --- a/include/internal/common.h +++ b/include/internal/common.h @@ -54,6 +54,12 @@ typedef std::pair, unsigned typedef std::pair, unsigned int> mnnvl_info; #endif typedef std::shared_ptr ncclComm; +struct nvshmemRuntimeState { +#ifdef ENABLE_NVSHMEM + ~nvshmemRuntimeState() { nvshmem_finalize(); } +#endif +}; +typedef std::shared_ptr nvshmemRuntime; } // namespace cudecomp // cuDecomp handle containing general information @@ -92,8 +98,7 @@ struct cudecompHandle { bool initialized = false; // Entries for NVSHMEM management and warning generation - bool nvshmem_initialized = false; // Flag to track NVSHMEM initialization - int n_grid_descs_using_nvshmem = 0; // Count of grid descriptors using NVSHMEM + cudecomp::nvshmemRuntime nvshmem_runtime; // Shared reference to initialized NVSHMEM runtime bool nvshmem_mixed_buffer_warning_issued = false; // Warn once if NVSHMEM buffer is used with MPI size_t nvshmem_symmetric_size; // NVSHMEM symmetric size bool nvshmem_vmm; // Flag to track if NVSHMEM is using VMM allocations @@ -192,7 +197,8 @@ struct cudecompGridDesc { cudaEvent_t nvshmem_sync_event = nullptr; // NVSHMEM event used for synchronization #ifdef ENABLE_NVSHMEM - int* nvshmem_block_counters = nullptr; // device memory counters for SM alltoallv last-block detection + int* nvshmem_block_counters = nullptr; // device memory counters for SM alltoallv last-block detection + cudecomp::nvshmemRuntime nvshmem_runtime; // Shared reference to initialized NVSHMEM runtime #endif cudecomp::graphCache graph_cache; // CUDA graph cache diff --git a/src/cudecomp.cc b/src/cudecomp.cc index b2283d2..f8e19bf 100644 --- a/src/cudecomp.cc +++ b/src/cudecomp.cc @@ -492,6 +492,13 @@ static void checkNvshmemVersion(cudecompHandle_t& handle) { setenv("NVSHMEM_CUMEM_GRANULARITY", "2147483648", 1); } } + +static nvshmemRuntime createNvshmemRuntime(cudecompHandle_t& handle) { + checkNvshmemVersion(handle); + inspectNvshmemEnvVars(handle); + initNvshmemFromMPIComm(handle->mpi_comm); + return std::make_shared(); +} #endif } // namespace @@ -628,12 +635,9 @@ cudecompResult_t cudecompFinalize(cudecompHandle_t handle) { CHECK_CUDA(cudaStreamDestroy(stream)); } #ifdef ENABLE_NVSHMEM - if (handle->nvshmem_initialized) { - nvshmem_finalize(); - handle->nvshmem_initialized = false; - handle->nvshmem_allocations.clear(); - handle->nvshmem_allocation_size = 0; - } + handle->nvshmem_allocations.clear(); + handle->nvshmem_allocation_size = 0; + handle->nvshmem_runtime.reset(); #endif CHECK_MPI(MPI_Comm_free(&handle->mpi_local_comm)); @@ -781,11 +785,8 @@ cudecompResult_t cudecompGridDescCreate(cudecompHandle_t handle, cudecompGridDes if (transposeBackendRequiresNvshmem(comm_backend) || haloBackendRequiresNvshmem(halo_comm_backend) || ((autotune_transpose_backend || autotune_halo_backend) && !autotune_disable_nvshmem_backends)) { #ifdef ENABLE_NVSHMEM - if (!handle->nvshmem_initialized) { - checkNvshmemVersion(handle); - inspectNvshmemEnvVars(handle); - initNvshmemFromMPIComm(handle->mpi_comm); - handle->nvshmem_initialized = true; + if (!handle->nvshmem_runtime) { + handle->nvshmem_runtime = createNvshmemRuntime(handle); handle->nvshmem_allocation_size = 0; } #endif @@ -850,13 +851,14 @@ cudecompResult_t cudecompGridDescCreate(cudecompHandle_t handle, cudecompGridDes cudaMemset(grid_desc->col_comm_info.nvshmem_signals, 0, grid_desc->col_comm_info.nranks * sizeof(uint64_t))); CHECK_CUDA(cudaMalloc(&grid_desc->nvshmem_block_counters, handle->nranks * sizeof(int))); CHECK_CUDA(cudaMemset(grid_desc->nvshmem_block_counters, 0, handle->nranks * sizeof(int))); - handle->n_grid_descs_using_nvshmem++; + + // Set grid descriptor reference to NVSHMEM runtime + grid_desc->nvshmem_runtime = handle->nvshmem_runtime; } else { - // Finalize nvshmem to reclaim symmetric heap memory if not used - if (handle->nvshmem_initialized && handle->n_grid_descs_using_nvshmem == 0) { - nvshmem_finalize(); - handle->nvshmem_initialized = false; + // If handle has the only remaining reference to the NVSHMEM runtime, destroy it to reclaim resources + if (handle->nvshmem_runtime && handle->nvshmem_runtime.use_count() == 1) { handle->nvshmem_allocations.clear(); + handle->nvshmem_runtime.reset(); handle->nvshmem_allocation_size = 0; } } @@ -981,13 +983,13 @@ cudecompResult_t cudecompGridDescDestroy(cudecompHandle_t handle, cudecompGridDe nvshmem_free(grid_desc->col_comm_info.nvshmem_signals); } CHECK_CUDA(cudaFree(grid_desc->nvshmem_block_counters)); - handle->n_grid_descs_using_nvshmem--; + // Release grid descriptor reference to NVSHMEM runtime + grid_desc->nvshmem_runtime.reset(); - // Finalize nvshmem to reclaim symmetric heap memory if not used - if (handle->nvshmem_initialized && handle->n_grid_descs_using_nvshmem == 0) { - nvshmem_finalize(); - handle->nvshmem_initialized = false; + // If handle has the only remaining reference to the NVSHMEM runtime, destroy it to reclaim resources + if (handle->nvshmem_runtime && handle->nvshmem_runtime.use_count() == 1) { handle->nvshmem_allocations.clear(); + handle->nvshmem_runtime.reset(); handle->nvshmem_allocation_size = 0; } } From c224515ba0730f54e97cea64abdc1d86aa312241 Mon Sep 17 00:00:00 2001 From: Josh Romero Date: Thu, 30 Apr 2026 16:08:18 -0700 Subject: [PATCH 2/2] Add explicit NVSHMEM finalization call in cudecompFree to override dangling references. Signed-off-by: Josh Romero --- include/internal/common.h | 11 ++++++++++- src/cudecomp.cc | 1 + 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/include/internal/common.h b/include/internal/common.h index cce69a7..d763cdd 100644 --- a/include/internal/common.h +++ b/include/internal/common.h @@ -56,7 +56,16 @@ typedef std::pair, unsigned int> mnnvl_info; typedef std::shared_ptr ncclComm; struct nvshmemRuntimeState { #ifdef ENABLE_NVSHMEM - ~nvshmemRuntimeState() { nvshmem_finalize(); } + ~nvshmemRuntimeState() { finalize(); } + + void finalize() { + if (!finalized) { + nvshmem_finalize(); + finalized = true; + } + } + + bool finalized = false; #endif }; typedef std::shared_ptr nvshmemRuntime; diff --git a/src/cudecomp.cc b/src/cudecomp.cc index f8e19bf..53b71b2 100644 --- a/src/cudecomp.cc +++ b/src/cudecomp.cc @@ -635,6 +635,7 @@ cudecompResult_t cudecompFinalize(cudecompHandle_t handle) { CHECK_CUDA(cudaStreamDestroy(stream)); } #ifdef ENABLE_NVSHMEM + if (handle->nvshmem_runtime) { handle->nvshmem_runtime->finalize(); } handle->nvshmem_allocations.clear(); handle->nvshmem_allocation_size = 0; handle->nvshmem_runtime.reset();