diff --git a/include/internal/comm_routines.h b/include/internal/comm_routines.h index f613f55..b132066 100644 --- a/include/internal/comm_routines.h +++ b/include/internal/comm_routines.h @@ -511,7 +511,7 @@ cudecompAlltoallPipelined(const cudecompHandle_t& handle, const cudecompGridDesc #endif } case CUDECOMP_TRANSPOSE_COMM_NCCL_PL: { - auto comm_info = (comm_axis == CUDECOMP_COMM_ROW) ? grid_desc->row_comm_info : grid_desc->col_comm_info; + const auto& comm_info = (comm_axis == CUDECOMP_COMM_ROW) ? grid_desc->row_comm_info : grid_desc->col_comm_info; // For fully intra-group alltoall, use distinct NCCL local comm instead of global comm as it is faster. auto comm = (comm_info.ngroups == 1) ? *grid_desc->nccl_local_comm : *grid_desc->nccl_comm; auto pl_stream = handle->streams[0].get(); diff --git a/include/internal/common.h b/include/internal/common.h index 62ac55e..675864f 100644 --- a/include/internal/common.h +++ b/include/internal/common.h @@ -74,6 +74,13 @@ typedef std::shared_ptr nvshmemRuntime; // cuDecomp handle containing general information struct cudecompHandle { + cudecompHandle() = default; + ~cudecompHandle() noexcept; + + cudecompHandle(const cudecompHandle&) = delete; + cudecompHandle& operator=(const cudecompHandle&) = delete; + cudecompHandle(cudecompHandle&&) = delete; + cudecompHandle& operator=(cudecompHandle&&) = delete; MPI_Comm mpi_comm = MPI_COMM_NULL; // MPI communicator int32_t rank; // MPI rank @@ -96,10 +103,12 @@ struct cudecompHandle { std::vector streams; // internal streams for concurrent scheduling - cutensorHandle_t cutensor_handle; // cuTENSOR handle; #if CUTENSOR_MAJOR >= 2 - cutensorPlanPreference_t cutensor_plan_pref; // cuTENSOR plan preference; - bool cutensor_needs_permute_chunking = false; // Flag to enable large tensor workaround + cutensorHandle_t cutensor_handle = nullptr; // cuTENSOR handle; + cutensorPlanPreference_t cutensor_plan_pref = nullptr; // cuTENSOR plan preference; + bool cutensor_needs_permute_chunking = false; // Flag to enable large tensor workaround +#else + cutensorHandle_t cutensor_handle; // cuTENSOR handle; #endif std::vector> hostnames; // list of hostnames by rank @@ -134,6 +143,7 @@ struct cudecompHandle { ""; // directory to write CSV performance reports, empty means no file writing // Miscellaneous + bool nvml_initialized = false; // Flag to track NVML initialization int32_t device_p2p_ce_count = 0; // number of P2P CEs available int32_t device_num_sms = 0; // number of SMs on the device int32_t device_max_threads_per_sm = 0; // maximum threads per SM @@ -142,9 +152,40 @@ struct cudecompHandle { // Structure with information about row/column communicator struct cudecompCommInfo { + cudecompCommInfo() = default; + ~cudecompCommInfo() noexcept { reset(); } + + cudecompCommInfo(const cudecompCommInfo&) = delete; + cudecompCommInfo& operator=(const cudecompCommInfo&) = delete; + cudecompCommInfo(cudecompCommInfo&&) = delete; + cudecompCommInfo& operator=(cudecompCommInfo&&) = delete; + + void reset() noexcept { + if (mpi_comm != MPI_COMM_NULL) { + MPI_Comm comm = mpi_comm; + mpi_comm = MPI_COMM_NULL; + MPI_Comm_free(&comm); + } +#ifdef ENABLE_NVSHMEM + if (nvshmem_team != NVSHMEM_TEAM_INVALID) { + nvshmem_team_destroy(nvshmem_team); + nvshmem_team = NVSHMEM_TEAM_INVALID; + } + if (nvshmem_signals) { + nvshmem_free(nvshmem_signals); + nvshmem_signals = nullptr; + } +#endif + rank = 0; + nranks = 0; + ngroups = 0; + npergroup = 0; + mnnvl_active = false; + } + MPI_Comm mpi_comm = MPI_COMM_NULL; - int32_t rank; - int32_t nranks; + int32_t rank = 0; + int32_t nranks = 0; int32_t ngroups = 0; // number of p2p groups (i.e. grouping of ranks with fast interconnect) in communicator int32_t npergroup = 0; // number of ranks per p2p group @@ -194,6 +235,14 @@ struct cudecompHaloPerformanceSampleCollection { // cuDecomp grid descriptor containing grid-specific information struct cudecompGridDesc { + ~cudecompGridDesc() noexcept { + row_comm_info.reset(); + col_comm_info.reset(); +#ifdef ENABLE_NVSHMEM + if (nvshmem_block_counters) { cudaFree(nvshmem_block_counters); } +#endif + } + cudecompGridDescConfig_t config; // configuration struct bool gdims_dist_set = false; // flag to record if gdims_dist was set to non-default values bool transpose_mem_order_set = false; // flag to record if transpose_mem_order was set to non-default values @@ -355,6 +404,7 @@ static void setCommInfo(cudecompHandle_t& handle, cudecompGridDesc_t& grid_desc, cudecompCommAxis comm_axis) { auto& info = (comm_axis == CUDECOMP_COMM_ROW) ? grid_desc->row_comm_info : grid_desc->col_comm_info; + info.reset(); info.mpi_comm = mpi_comm; CHECK_MPI(MPI_Comm_rank(info.mpi_comm, &info.rank)); CHECK_MPI(MPI_Comm_size(info.mpi_comm, &info.nranks)); @@ -420,6 +470,43 @@ static void setCommInfo(cudecompHandle_t& handle, cudecompGridDesc_t& grid_desc, info.ngroups = info.nranks / info.npergroup; } +static void createCommInfo(cudecompHandle_t& handle, cudecompGridDesc_t& grid_desc, bool need_nvshmem = false) { + grid_desc->row_comm_info.reset(); + grid_desc->col_comm_info.reset(); + + setProcessGridIndex(handle, grid_desc); + + MPI_Comm row_comm; + CHECK_MPI(MPI_Comm_split(handle->mpi_comm, grid_desc->pidx[0], handle->rank, &row_comm)); + setCommInfo(handle, grid_desc, row_comm, CUDECOMP_COMM_ROW); + + MPI_Comm col_comm; + CHECK_MPI(MPI_Comm_split(handle->mpi_comm, grid_desc->pidx[1], handle->rank, &col_comm)); + setCommInfo(handle, grid_desc, col_comm, CUDECOMP_COMM_COL); + +#ifdef ENABLE_NVSHMEM + if (need_nvshmem) { + nvshmem_team_config_t tmp; + nvshmem_team_split_2d(NVSHMEM_TEAM_WORLD, grid_desc->config.pdims[1], &tmp, 0, + &grid_desc->row_comm_info.nvshmem_team, &tmp, 0, &grid_desc->col_comm_info.nvshmem_team); + + grid_desc->row_comm_info.nvshmem_signals = + (uint64_t*)nvshmem_malloc(grid_desc->row_comm_info.nranks * sizeof(uint64_t)); + if (!grid_desc->row_comm_info.nvshmem_signals) { THROW_NVSHMEM_ERROR("nvshmem_malloc failed"); } + CHECK_CUDA( + cudaMemset(grid_desc->row_comm_info.nvshmem_signals, 0, grid_desc->row_comm_info.nranks * sizeof(uint64_t))); + + grid_desc->col_comm_info.nvshmem_signals = + (uint64_t*)nvshmem_malloc(grid_desc->col_comm_info.nranks * sizeof(uint64_t)); + if (!grid_desc->col_comm_info.nvshmem_signals) { THROW_NVSHMEM_ERROR("nvshmem_malloc failed"); } + CHECK_CUDA( + cudaMemset(grid_desc->col_comm_info.nvshmem_signals, 0, grid_desc->col_comm_info.nranks * sizeof(uint64_t))); + } +#else + if (need_nvshmem) { THROW_NOT_SUPPORTED("build does not support NVSHMEM communication backends."); } +#endif +} + static inline void getAlltoallPeerRanks(cudecompGridDesc_t grid_desc, cudecompCommAxis comm_axis, int iter, int& src_rank, int& dst_rank) { diff --git a/src/autotune.cc b/src/autotune.cc index 03b7f70..e2649f3 100644 --- a/src/autotune.cc +++ b/src/autotune.cc @@ -274,28 +274,9 @@ void autotuneTransposeBackend(cudecompHandle_t handle, cudecompGridDesc_t grid_d } // Create test row and column communicators - int color_row = grid_desc->pidx[0]; - MPI_Comm row_comm; - CHECK_MPI(MPI_Comm_split(handle->mpi_comm, color_row, handle->rank, &row_comm)); - setCommInfo(handle, grid_desc, row_comm, CUDECOMP_COMM_ROW); - - int color_col = grid_desc->pidx[1]; - MPI_Comm col_comm; - CHECK_MPI(MPI_Comm_split(handle->mpi_comm, color_col, handle->rank, &col_comm)); - setCommInfo(handle, grid_desc, col_comm, CUDECOMP_COMM_COL); + createCommInfo(handle, grid_desc, need_nvshmem); if (need_nvshmem) { #ifdef ENABLE_NVSHMEM - nvshmem_team_config_t tmp; - nvshmem_team_split_2d(NVSHMEM_TEAM_WORLD, grid_desc->config.pdims[1], &tmp, 0, - &grid_desc->row_comm_info.nvshmem_team, &tmp, 0, &grid_desc->col_comm_info.nvshmem_team); - grid_desc->row_comm_info.nvshmem_signals = - (uint64_t*)nvshmem_malloc(grid_desc->row_comm_info.nranks * sizeof(uint64_t)); - CHECK_CUDA( - cudaMemset(grid_desc->row_comm_info.nvshmem_signals, 0, grid_desc->row_comm_info.nranks * sizeof(uint64_t))); - grid_desc->col_comm_info.nvshmem_signals = - (uint64_t*)nvshmem_malloc(grid_desc->col_comm_info.nranks * sizeof(uint64_t)); - CHECK_CUDA( - cudaMemset(grid_desc->col_comm_info.nvshmem_signals, 0, grid_desc->col_comm_info.nranks * sizeof(uint64_t))); CHECK_CUDA(cudaMalloc(&grid_desc->nvshmem_block_counters, handle->nranks * sizeof(int))); CHECK_CUDA(cudaMemset(grid_desc->nvshmem_block_counters, 0, handle->nranks * sizeof(int))); #endif @@ -490,18 +471,15 @@ void autotuneTransposeBackend(cudecompHandle_t handle, cudecompGridDesc_t grid_d } } - // Destroy test communicators - CHECK_MPI(MPI_Comm_free(&grid_desc->row_comm_info.mpi_comm)); - CHECK_MPI(MPI_Comm_free(&grid_desc->col_comm_info.mpi_comm)); + // Destroy test communicator resources + grid_desc->row_comm_info.reset(); + grid_desc->col_comm_info.reset(); if (need_nvshmem) { #ifdef ENABLE_NVSHMEM - nvshmem_team_destroy(grid_desc->row_comm_info.nvshmem_team); - nvshmem_team_destroy(grid_desc->col_comm_info.nvshmem_team); - nvshmem_free(grid_desc->row_comm_info.nvshmem_signals); - nvshmem_free(grid_desc->col_comm_info.nvshmem_signals); - CHECK_CUDA(cudaFree(grid_desc->nvshmem_block_counters)); - grid_desc->row_comm_info.nvshmem_team = NVSHMEM_TEAM_INVALID; - grid_desc->col_comm_info.nvshmem_team = NVSHMEM_TEAM_INVALID; + if (grid_desc->nvshmem_block_counters) { + CHECK_CUDA(cudaFree(grid_desc->nvshmem_block_counters)); + grid_desc->nvshmem_block_counters = nullptr; + } #endif } } @@ -720,28 +698,9 @@ void autotuneHaloBackend(cudecompHandle_t handle, cudecompGridDesc_t grid_desc, } // Create test row and column communicators - int color_row = grid_desc->pidx[0]; - MPI_Comm row_comm; - CHECK_MPI(MPI_Comm_split(handle->mpi_comm, color_row, handle->rank, &row_comm)); - setCommInfo(handle, grid_desc, row_comm, CUDECOMP_COMM_ROW); - - int color_col = grid_desc->pidx[1]; - MPI_Comm col_comm; - CHECK_MPI(MPI_Comm_split(handle->mpi_comm, color_col, handle->rank, &col_comm)); - setCommInfo(handle, grid_desc, col_comm, CUDECOMP_COMM_COL); + createCommInfo(handle, grid_desc, need_nvshmem); if (need_nvshmem) { #ifdef ENABLE_NVSHMEM - nvshmem_team_config_t tmp; - nvshmem_team_split_2d(NVSHMEM_TEAM_WORLD, grid_desc->config.pdims[1], &tmp, 0, - &grid_desc->row_comm_info.nvshmem_team, &tmp, 0, &grid_desc->col_comm_info.nvshmem_team); - grid_desc->row_comm_info.nvshmem_signals = - (uint64_t*)nvshmem_malloc(grid_desc->row_comm_info.nranks * sizeof(uint64_t)); - CHECK_CUDA( - cudaMemset(grid_desc->row_comm_info.nvshmem_signals, 0, grid_desc->row_comm_info.nranks * sizeof(uint64_t))); - grid_desc->col_comm_info.nvshmem_signals = - (uint64_t*)nvshmem_malloc(grid_desc->col_comm_info.nranks * sizeof(uint64_t)); - CHECK_CUDA( - cudaMemset(grid_desc->col_comm_info.nvshmem_signals, 0, grid_desc->col_comm_info.nranks * sizeof(uint64_t))); CHECK_CUDA(cudaMalloc(&grid_desc->nvshmem_block_counters, handle->nranks * sizeof(int))); CHECK_CUDA(cudaMemset(grid_desc->nvshmem_block_counters, 0, handle->nranks * sizeof(int))); #endif @@ -850,18 +809,15 @@ void autotuneHaloBackend(cudecompHandle_t handle, cudecompGridDesc_t grid_desc, } } - // Destroy test communicators - CHECK_MPI(MPI_Comm_free(&grid_desc->row_comm_info.mpi_comm)); - CHECK_MPI(MPI_Comm_free(&grid_desc->col_comm_info.mpi_comm)); + // Destroy test communicator resources + grid_desc->row_comm_info.reset(); + grid_desc->col_comm_info.reset(); if (need_nvshmem) { #ifdef ENABLE_NVSHMEM - nvshmem_team_destroy(grid_desc->row_comm_info.nvshmem_team); - nvshmem_team_destroy(grid_desc->col_comm_info.nvshmem_team); - nvshmem_free(grid_desc->row_comm_info.nvshmem_signals); - nvshmem_free(grid_desc->col_comm_info.nvshmem_signals); - CHECK_CUDA(cudaFree(grid_desc->nvshmem_block_counters)); - grid_desc->row_comm_info.nvshmem_team = NVSHMEM_TEAM_INVALID; - grid_desc->col_comm_info.nvshmem_team = NVSHMEM_TEAM_INVALID; + if (grid_desc->nvshmem_block_counters) { + CHECK_CUDA(cudaFree(grid_desc->nvshmem_block_counters)); + grid_desc->nvshmem_block_counters = nullptr; + } #endif } } diff --git a/src/cudecomp.cc b/src/cudecomp.cc index 532c669..175fe34 100644 --- a/src/cudecomp.cc +++ b/src/cudecomp.cc @@ -501,9 +501,54 @@ static nvshmemRuntime createNvshmemRuntime(cudecompHandle_t& handle) { } #endif +static void releaseUnusedHandleResources(cudecompHandle_t handle, bool release_streams = false) noexcept { + if (!handle || !handle->initialized) { return; } + + // Destroy NCCL communicators to reclaim resources if not used by other grid descriptors + if (handle->nccl_comm && handle->nccl_comm.use_count() == 1) { handle->nccl_comm.reset(); } + if (handle->nccl_local_comm && handle->nccl_local_comm.use_count() == 1) { handle->nccl_local_comm.reset(); } + if (release_streams) { handle->streams.clear(); } + +#ifdef ENABLE_NVSHMEM + // If handle has the only remaining reference to the NVSHMEM runtime, destroy it to reclaim resources + if (handle->nvshmem_runtime && handle->nvshmem_runtime.use_count() == 1) { + handle->nvshmem_allocations.clear(); + handle->nvshmem_runtime.reset(); + handle->nvshmem_allocation_size = 0; + } +#endif +} + +static void cleanupFailedGridDescCreate(cudecompHandle_t handle, cudecompGridDesc_t grid_desc, + bool release_streams) noexcept { + if (grid_desc) { delete grid_desc; } + releaseUnusedHandleResources(handle, release_streams); +} + } // namespace } // namespace cudecomp +cudecompHandle::~cudecompHandle() noexcept { +#if NCCL_VERSION_CODE >= NCCL_VERSION(2, 19, 0) + for (auto& entry : nccl_ubr_handles) { + for (const auto& ubr_handle : entry.second) { + ncclCommDeregister(*ubr_handle.first, ubr_handle.second); + } + } +#endif + + if (mpi_clique_comm != MPI_COMM_NULL) { MPI_Comm_free(&mpi_clique_comm); } + + if (mpi_local_comm != MPI_COMM_NULL) { MPI_Comm_free(&mpi_local_comm); } + +#if CUTENSOR_MAJOR >= 2 + if (cutensor_handle) { cutensorDestroy(cutensor_handle); } + if (cutensor_plan_pref) { cutensorDestroyPlanPreference(cutensor_plan_pref); } +#endif + + if (nvml_initialized) { cudecomp::nvmlFnTable.pfn_nvmlShutdown(); } +} + cudecompResult_t cudecompInit(cudecompHandle_t* handle_in, MPI_Comm mpi_comm) { using namespace cudecomp; cudecompHandle_t handle = nullptr; @@ -528,6 +573,7 @@ cudecompResult_t cudecompInit(cudecompHandle_t* handle_in, MPI_Comm mpi_comm) { initNvmlFunctionTable(); CHECK_NVML(nvmlInit()); + handle->nvml_initialized = true; // Initialize cuTENSOR library #if CUTENSOR_MAJOR >= 2 @@ -628,24 +674,6 @@ cudecompResult_t cudecompFinalize(cudecompHandle_t handle) { handle->nccl_ubr_handles.clear(); #endif - handle->nccl_comm.reset(); - handle->nccl_local_comm.reset(); - -#ifdef ENABLE_NVSHMEM - if (handle->nvshmem_runtime) { handle->nvshmem_runtime->finalize(); } - handle->nvshmem_allocations.clear(); - handle->nvshmem_allocation_size = 0; - handle->nvshmem_runtime.reset(); -#endif - CHECK_MPI(MPI_Comm_free(&handle->mpi_local_comm)); - -#if CUTENSOR_MAJOR >= 2 - CHECK_CUTENSOR(cutensorDestroy(handle->cutensor_handle)); - CHECK_CUTENSOR(cutensorDestroyPlanPreference(handle->cutensor_plan_pref)); -#endif - - CHECK_NVML(nvmlShutdown()); - delete handle; cudecomp_initialized = false; @@ -669,6 +697,7 @@ cudecompResult_t cudecompGridDescCreate(cudecompHandle_t handle, cudecompGridDes using namespace cudecomp; cudecompGridDesc_t grid_desc = nullptr; + bool created_streams = false; try { checkHandle(handle); if (!grid_desc_in) { THROW_INVALID_USAGE("grid_desc argument cannot be null"); } @@ -740,16 +769,7 @@ cudecompResult_t cudecompGridDescCreate(cudecompHandle_t handle, cudecompGridDes if (grid_desc->config.pdims[0] > 0 && grid_desc->config.pdims[1] > 0) { // If pdims are set, temporarily set up comm info stuctures to determine if we need to create a local NCCL // communicator - setProcessGridIndex(handle, grid_desc); - int color_row = grid_desc->pidx[0]; - MPI_Comm row_comm; - CHECK_MPI(MPI_Comm_split(handle->mpi_comm, color_row, handle->rank, &row_comm)); - setCommInfo(handle, grid_desc, row_comm, CUDECOMP_COMM_ROW); - - int color_col = grid_desc->pidx[1]; - MPI_Comm col_comm; - CHECK_MPI(MPI_Comm_split(handle->mpi_comm, color_col, handle->rank, &col_comm)); - setCommInfo(handle, grid_desc, col_comm, CUDECOMP_COMM_COL); + createCommInfo(handle, grid_desc); // Create local NCCL communicator if row or column communicator uses it int need_local_nccl_comm = @@ -765,8 +785,8 @@ cudecompResult_t cudecompGridDescCreate(cudecompHandle_t handle, cudecompGridDes handle->nccl_local_comm = ncclCommFromMPIComm( handle->mpi_clique_comm != MPI_COMM_NULL ? handle->mpi_clique_comm : handle->mpi_local_comm); } - CHECK_MPI(MPI_Comm_free(&row_comm)); - CHECK_MPI(MPI_Comm_free(&col_comm)); + grid_desc->row_comm_info.reset(); + grid_desc->col_comm_info.reset(); } else { // If pdims are not set, set up local NCCL communicator for use during autotuning handle->nccl_local_comm = ncclCommFromMPIComm( @@ -790,7 +810,10 @@ cudecompResult_t cudecompGridDescCreate(cudecompHandle_t handle, cudecompGridDes #endif } - if (handle->streams.empty()) { handle->streams.resize(handle->device_p2p_ce_count + 1); } + if (handle->streams.empty()) { + handle->streams.resize(handle->device_p2p_ce_count + 1); + created_streams = true; + } // Create CUDA events for scheduling grid_desc->events.resize(handle->nranks); @@ -808,45 +831,19 @@ cudecompResult_t cudecompGridDescCreate(cudecompHandle_t handle, cudecompGridDes } } - setProcessGridIndex(handle, grid_desc); - // Setup final row and column communicators - int color_row = grid_desc->pidx[0]; - MPI_Comm row_comm; - CHECK_MPI(MPI_Comm_split(handle->mpi_comm, color_row, handle->rank, &row_comm)); - setCommInfo(handle, grid_desc, row_comm, CUDECOMP_COMM_ROW); - - int color_col = grid_desc->pidx[1]; - MPI_Comm col_comm; - CHECK_MPI(MPI_Comm_split(handle->mpi_comm, color_col, handle->rank, &col_comm)); - setCommInfo(handle, grid_desc, col_comm, CUDECOMP_COMM_COL); + bool need_nvshmem = transposeBackendRequiresNvshmem(grid_desc->config.transpose_comm_backend) || + haloBackendRequiresNvshmem(grid_desc->config.halo_comm_backend); + createCommInfo(handle, grid_desc, need_nvshmem); #ifdef ENABLE_NVSHMEM - if (transposeBackendRequiresNvshmem(grid_desc->config.transpose_comm_backend) || - haloBackendRequiresNvshmem(grid_desc->config.halo_comm_backend)) { - nvshmem_team_config_t tmp; - nvshmem_team_split_2d(NVSHMEM_TEAM_WORLD, grid_desc->config.pdims[1], &tmp, 0, - &grid_desc->row_comm_info.nvshmem_team, &tmp, 0, &grid_desc->col_comm_info.nvshmem_team); - grid_desc->row_comm_info.nvshmem_signals = - (uint64_t*)nvshmem_malloc(grid_desc->row_comm_info.nranks * sizeof(uint64_t)); - CHECK_CUDA( - cudaMemset(grid_desc->row_comm_info.nvshmem_signals, 0, grid_desc->row_comm_info.nranks * sizeof(uint64_t))); - grid_desc->col_comm_info.nvshmem_signals = - (uint64_t*)nvshmem_malloc(grid_desc->col_comm_info.nranks * sizeof(uint64_t)); - CHECK_CUDA( - cudaMemset(grid_desc->col_comm_info.nvshmem_signals, 0, grid_desc->col_comm_info.nranks * sizeof(uint64_t))); + if (need_nvshmem) { CHECK_CUDA(cudaMalloc(&grid_desc->nvshmem_block_counters, handle->nranks * sizeof(int))); CHECK_CUDA(cudaMemset(grid_desc->nvshmem_block_counters, 0, handle->nranks * sizeof(int))); // Set grid descriptor reference to NVSHMEM runtime grid_desc->nvshmem_runtime = handle->nvshmem_runtime; - } else { - // If handle has the only remaining reference to the NVSHMEM runtime, destroy it to reclaim resources - if (handle->nvshmem_runtime && handle->nvshmem_runtime.use_count() == 1) { - handle->nvshmem_allocations.clear(); - handle->nvshmem_runtime.reset(); - handle->nvshmem_allocation_size = 0; - } } + #endif if (transposeBackendRequiresNccl(grid_desc->config.transpose_comm_backend) || haloBackendRequiresNccl(grid_desc->config.halo_comm_backend)) { @@ -856,23 +853,19 @@ cudecompResult_t cudecompGridDescCreate(cudecompHandle_t handle, cudecompGridDes if ((grid_desc->row_comm_info.ngroups > 1 || grid_desc->row_comm_info.nranks == 1) && (grid_desc->col_comm_info.ngroups > 1 || grid_desc->col_comm_info.nranks == 1)) { grid_desc->nccl_local_comm.reset(); - - // If handle has the only remaining reference to the local NCCL communicator, destroy it to reclaim resources - if (handle->nccl_local_comm.use_count() == 1) { handle->nccl_local_comm.reset(); } } } } else { // Release grid descriptor references to NCCL communicators grid_desc->nccl_comm.reset(); grid_desc->nccl_local_comm.reset(); - - // Destroy NCCL communicators to reclaim resources if not used by other grid descriptors - if (handle->nccl_comm && handle->nccl_comm.use_count() == 1) { handle->nccl_comm.reset(); } - if (handle->nccl_local_comm && handle->nccl_local_comm.use_count() == 1) { handle->nccl_local_comm.reset(); } } + releaseUnusedHandleResources(handle); + *grid_desc_in = grid_desc; *config = grid_desc->config; + // If gdims_dist was not set, return config with default values if (!grid_desc->gdims_dist_set) { config->gdims_dist[0] = 0; @@ -889,7 +882,7 @@ cudecompResult_t cudecompGridDescCreate(cudecompHandle_t handle, cudecompGridDes } } } - CUDECOMP_CATCH_C_API_ERRORS(if (grid_desc) { delete grid_desc; }) + CUDECOMP_CATCH_C_API_ERRORS(cleanupFailedGridDescCreate(handle, grid_desc, created_streams)) return CUDECOMP_RESULT_SUCCESS; } @@ -899,53 +892,13 @@ cudecompResult_t cudecompGridDescDestroy(cudecompHandle_t handle, cudecompGridDe checkHandle(handle); checkGridDesc(grid_desc); - if (grid_desc->row_comm_info.mpi_comm != MPI_COMM_NULL) { - CHECK_MPI(MPI_Comm_free(&grid_desc->row_comm_info.mpi_comm)); - } - if (grid_desc->col_comm_info.mpi_comm != MPI_COMM_NULL) { - CHECK_MPI(MPI_Comm_free(&grid_desc->col_comm_info.mpi_comm)); - } - // Print performance report if enabled if (handle->performance_report_enable) { printPerformanceReport(handle, grid_desc); } - if (transposeBackendRequiresNccl(grid_desc->config.transpose_comm_backend) || - haloBackendRequiresNccl(grid_desc->config.halo_comm_backend)) { - // Release grid descriptor references to NCCL communicators - grid_desc->nccl_comm.reset(); - grid_desc->nccl_local_comm.reset(); - - // Destroy NCCL communicators to reclaim resources if not used by other grid descriptors - if (handle->nccl_comm && handle->nccl_comm.use_count() == 1) { handle->nccl_comm.reset(); } - if (handle->nccl_local_comm && handle->nccl_local_comm.use_count() == 1) { handle->nccl_local_comm.reset(); } - } - -#ifdef ENABLE_NVSHMEM - if (transposeBackendRequiresNvshmem(grid_desc->config.transpose_comm_backend) || - haloBackendRequiresNvshmem(grid_desc->config.halo_comm_backend)) { - if (grid_desc->row_comm_info.nvshmem_team != NVSHMEM_TEAM_INVALID) { - nvshmem_team_destroy(grid_desc->row_comm_info.nvshmem_team); - nvshmem_free(grid_desc->row_comm_info.nvshmem_signals); - } - if (grid_desc->col_comm_info.nvshmem_team != NVSHMEM_TEAM_INVALID) { - nvshmem_team_destroy(grid_desc->col_comm_info.nvshmem_team); - nvshmem_free(grid_desc->col_comm_info.nvshmem_signals); - } - CHECK_CUDA(cudaFree(grid_desc->nvshmem_block_counters)); - // Release grid descriptor reference to NVSHMEM runtime - grid_desc->nvshmem_runtime.reset(); - - // If handle has the only remaining reference to the NVSHMEM runtime, destroy it to reclaim resources - if (handle->nvshmem_runtime && handle->nvshmem_runtime.use_count() == 1) { - handle->nvshmem_allocations.clear(); - handle->nvshmem_runtime.reset(); - handle->nvshmem_allocation_size = 0; - } - } -#endif - delete grid_desc; grid_desc = nullptr; + + releaseUnusedHandleResources(handle); } CUDECOMP_CATCH_C_API_ERRORS() return CUDECOMP_RESULT_SUCCESS;