diff --git a/include/internal/common.h b/include/internal/common.h index 694ad01..b4551d5 100644 --- a/include/internal/common.h +++ b/include/internal/common.h @@ -69,7 +69,7 @@ struct cudecompHandle { cudecomp::ncclComm nccl_comm; // NCCL communicator (global) cudecomp::ncclComm nccl_local_comm; // NCCL communicator (intra-node, or intra-clique on MNNVL systems) bool nccl_enable_ubr = false; // Flag to control NCCL user buffer registration usage - std::unordered_map>> + std::unordered_map>> nccl_ubr_handles; // map of allocated buffer address to NCCL registration handle(s) std::vector streams; // internal streams for concurrent scheduling diff --git a/src/cudecomp.cc b/src/cudecomp.cc index 61bbf7e..edfdfad 100644 --- a/src/cudecomp.cc +++ b/src/cudecomp.cc @@ -579,6 +579,15 @@ cudecompResult_t cudecompFinalize(cudecompHandle_t handle) { try { checkHandle(handle); +#if NCCL_VERSION_CODE >= NCCL_VERSION(2, 19, 0) + for (auto& entry : handle->nccl_ubr_handles) { + for (const auto& ubr_handle : entry.second) { + CHECK_NCCL(ncclCommDeregister(*ubr_handle.first, ubr_handle.second)); + } + } + handle->nccl_ubr_handles.clear(); +#endif + handle->nccl_comm.reset(); handle->nccl_local_comm.reset(); @@ -1195,6 +1204,8 @@ cudecompResult_t cudecompMalloc(cudecompHandle_t handle, cudecompGridDesc_t grid try { checkHandle(handle); checkGridDesc(grid_desc); + if (!buffer) { THROW_INVALID_USAGE("buffer argument cannot be null"); } + if (buffer_size_bytes == 0) { THROW_INVALID_USAGE("buffer size cannot be zero"); } if (transposeBackendRequiresNvshmem(grid_desc->config.transpose_comm_backend) || haloBackendRequiresNvshmem(grid_desc->config.halo_comm_backend)) { @@ -1266,14 +1277,20 @@ cudecompResult_t cudecompMalloc(cudecompHandle_t handle, cudecompGridDesc_t grid haloBackendRequiresNccl(grid_desc->config.halo_comm_backend)) { if (handle->nccl_enable_ubr) { - void* nccl_ubr_handle; - if (grid_desc->nccl_comm) { - CHECK_NCCL(ncclCommRegister(*grid_desc->nccl_comm, buffer, buffer_size_bytes, &nccl_ubr_handle)); - handle->nccl_ubr_handles[*buffer].push_back(std::make_pair(*grid_desc->nccl_comm, nccl_ubr_handle)); - } - if (grid_desc->nccl_local_comm) { - CHECK_NCCL(ncclCommRegister(*grid_desc->nccl_local_comm, buffer, buffer_size_bytes, &nccl_ubr_handle)); - handle->nccl_ubr_handles[*buffer].push_back(std::make_pair(*grid_desc->nccl_local_comm, nccl_ubr_handle)); + try { + void* nccl_ubr_handle; + if (grid_desc->nccl_comm) { + CHECK_NCCL(ncclCommRegister(*grid_desc->nccl_comm, *buffer, buffer_size_bytes, &nccl_ubr_handle)); + handle->nccl_ubr_handles[*buffer].push_back(std::make_pair(grid_desc->nccl_comm, nccl_ubr_handle)); + } + if (grid_desc->nccl_local_comm) { + CHECK_NCCL(ncclCommRegister(*grid_desc->nccl_local_comm, *buffer, buffer_size_bytes, &nccl_ubr_handle)); + handle->nccl_ubr_handles[*buffer].push_back(std::make_pair(grid_desc->nccl_local_comm, nccl_ubr_handle)); + } + } catch (...) { + cudecompFree(handle, grid_desc, *buffer); + *buffer = nullptr; + throw; } } } @@ -1294,15 +1311,11 @@ cudecompResult_t cudecompFree(cudecompHandle_t handle, cudecompGridDesc_t grid_d checkGridDesc(grid_desc); #if NCCL_VERSION_CODE >= NCCL_VERSION(2, 19, 0) - if (transposeBackendRequiresNccl(grid_desc->config.transpose_comm_backend) || - haloBackendRequiresNccl(grid_desc->config.halo_comm_backend)) { - - if (handle->nccl_ubr_handles.count(buffer) != 0) { - for (const auto& entry : handle->nccl_ubr_handles[buffer]) { - CHECK_NCCL(ncclCommDeregister(entry.first, entry.second)); - } - handle->nccl_ubr_handles.erase(buffer); + if (handle->nccl_ubr_handles.count(buffer) != 0) { + for (const auto& entry : handle->nccl_ubr_handles[buffer]) { + CHECK_NCCL(ncclCommDeregister(*entry.first, entry.second)); } + handle->nccl_ubr_handles.erase(buffer); } #endif