diff --git a/include/internal/comm_routines.h b/include/internal/comm_routines.h index cc372db..e645872 100644 --- a/include/internal/comm_routines.h +++ b/include/internal/comm_routines.h @@ -281,14 +281,8 @@ static void cudecompAlltoall(const cudecompHandle_t& handle, const cudecompGridD } #ifdef ENABLE_NVSHMEM - if (handle->rank == 0 && handle->nvshmem_runtime && !handle->nvshmem_mixed_buffer_warning_issued && - transposeBackendRequiresMpi(grid_desc->config.transpose_comm_backend) && - (nvshmem_ptr(send_buff, handle->rank) || nvshmem_ptr(recv_buff, handle->rank))) { - printf("CUDECOMP:WARN: A work buffer allocated with nvshmem_malloc (via cudecompMalloc) is " - "being used with an MPI communication backend. This may cause issues with some MPI " - "implementations. See the documentation for additional details and possible workarounds " - "if you encounter issues.\n"); - handle->nvshmem_mixed_buffer_warning_issued = true; + if (transposeBackendRequiresMpi(grid_desc->config.transpose_comm_backend)) { + warnIfNvshmemBufferUsedWithMpi(handle, send_buff, recv_buff); } #endif @@ -470,14 +464,8 @@ cudecompAlltoallPipelined(const cudecompHandle_t& handle, const cudecompGridDesc } #ifdef ENABLE_NVSHMEM - if (handle->rank == 0 && handle->nvshmem_runtime && !handle->nvshmem_mixed_buffer_warning_issued && - transposeBackendRequiresMpi(grid_desc->config.transpose_comm_backend) && - (nvshmem_ptr(send_buff, handle->rank) || nvshmem_ptr(recv_buff, handle->rank))) { - printf("CUDECOMP:WARN: A work buffer allocated with nvshmem_malloc (via cudecompMalloc) is " - "being used with an MPI communication backend. This may cause issues with some MPI " - "implementations. See the documentation for additional details and possible workarounds " - "if you encounter issues.\n"); - handle->nvshmem_mixed_buffer_warning_issued = true; + if (transposeBackendRequiresMpi(grid_desc->config.transpose_comm_backend)) { + warnIfNvshmemBufferUsedWithMpi(handle, send_buff, recv_buff); } #endif @@ -666,14 +654,8 @@ static void cudecompSendRecvPair(const cudecompHandle_t& handle, const cudecompG } #ifdef ENABLE_NVSHMEM - if (handle->rank == 0 && handle->nvshmem_runtime && !handle->nvshmem_mixed_buffer_warning_issued && - haloBackendRequiresMpi(grid_desc->config.halo_comm_backend) && - (nvshmem_ptr(send_buff, handle->rank) || nvshmem_ptr(recv_buff, handle->rank))) { - printf("CUDECOMP:WARN: A work buffer allocated with nvshmem_malloc (via cudecompMalloc) is " - "being used with an MPI communication backend. This may cause issues with some MPI " - "implementations. See the documentation for additional details and possible workarounds " - "if you encounter issues.\n"); - handle->nvshmem_mixed_buffer_warning_issued = true; + if (haloBackendRequiresMpi(grid_desc->config.halo_comm_backend)) { + warnIfNvshmemBufferUsedWithMpi(handle, send_buff, recv_buff); } #endif diff --git a/include/internal/common.h b/include/internal/common.h index 675864f..0ab7ee5 100644 --- a/include/internal/common.h +++ b/include/internal/common.h @@ -57,19 +57,36 @@ typedef std::pair, unsigned int> mnnvl_info; typedef std::shared_ptr ncclComm; struct nvshmemRuntimeState { #ifdef ENABLE_NVSHMEM - ~nvshmemRuntimeState() { finalize(); } + ~nvshmemRuntimeState() noexcept { finalize(); } - void finalize() { - if (!finalized) { + void finalize() noexcept { + if (initialized) { nvshmem_finalize(); - finalized = true; + initialized = false; } + nvshmem_allocations.clear(); + nvshmem_allocation_size = 0; } - bool finalized = false; + bool initialized = false; // Flag to track NVSHMEM initialization + size_t nvshmem_symmetric_size = 0; // NVSHMEM symmetric size + bool nvshmem_vmm = true; // Flag to track if NVSHMEM is using VMM allocations + std::unordered_map nvshmem_allocations; // Table to record NVSHMEM allocations + size_t nvshmem_allocation_size = 0; // Total of NVSHMEM allocations #endif }; typedef std::shared_ptr nvshmemRuntime; +#ifdef ENABLE_NVSHMEM +struct nvshmemProcessState { + // NVSHMEM fixes the PE mapping at initialization and does not support reinitializing + // on a different set or order of ranks, even after nvshmem_finalize(). + std::vector init_world_ranks; + std::weak_ptr active_runtime; + bool mixed_buffer_warning_issued = false; +}; + +void warnIfNvshmemBufferUsedWithMpi(cudecompHandle_t handle, const void* send_buff, const void* recv_buff); +#endif } // namespace cudecomp // cuDecomp handle containing general information @@ -116,14 +133,6 @@ struct cudecompHandle { bool initialized = false; - // Entries for NVSHMEM management and warning generation - cudecomp::nvshmemRuntime nvshmem_runtime; // Shared reference to initialized NVSHMEM runtime - bool nvshmem_mixed_buffer_warning_issued = false; // Warn once if NVSHMEM buffer is used with MPI - size_t nvshmem_symmetric_size; // NVSHMEM symmetric size - bool nvshmem_vmm; // Flag to track if NVSHMEM is using VMM allocations - std::unordered_map nvshmem_allocations; // Table to record NVSHMEM allocations - size_t nvshmem_allocation_size = 0; // Total of NVSHMEM allocations - // Multi-node NVLINK (MNNVL) bool cuda_cumem_enable = false; // Flag to control whether cuMem* APIs are used for cudecompMalloc/Free std::vector rank_to_mnnvl_info; // list of mnnvl information (clusterUuid, cliqueId) by rank @@ -243,6 +252,7 @@ struct cudecompGridDesc { #endif } + cudecompHandle_t handle = nullptr; // owning cuDecomp handle cudecompGridDescConfig_t config; // configuration struct bool gdims_dist_set = false; // flag to record if gdims_dist was set to non-default values bool transpose_mem_order_set = false; // flag to record if transpose_mem_order was set to non-default values diff --git a/src/cuda_wrap.cc b/src/cuda_wrap.cc index a24dca5..25da7b2 100644 --- a/src/cuda_wrap.cc +++ b/src/cuda_wrap.cc @@ -15,6 +15,8 @@ * limitations under the License. */ +#include + #include #include "internal/checks.h" @@ -48,21 +50,24 @@ namespace cudecomp { cuFunctionTable cuFnTable; // global table of required CUDA driver functions void initCuFunctionTable() { + static std::once_flag init_once; + std::call_once(init_once, []() { #if CUDART_VERSION >= 11030 - LOAD_SYM(cuDeviceGet, 2000); - LOAD_SYM(cuDeviceGetAttribute, 2000); - LOAD_SYM(cuGetErrorString, 6000); - LOAD_SYM(cuMemAddressFree, 10020); - LOAD_SYM(cuMemAddressReserve, 10020); - LOAD_SYM(cuMemCreate, 10020); - LOAD_SYM(cuMemGetAddressRange, 3020); - LOAD_SYM(cuMemGetAllocationGranularity, 10020); - LOAD_SYM(cuMemMap, 10020); - LOAD_SYM(cuMemRetainAllocationHandle, 11000); - LOAD_SYM(cuMemRelease, 10020); - LOAD_SYM(cuMemSetAccess, 10020); - LOAD_SYM(cuMemUnmap, 10020); + LOAD_SYM(cuDeviceGet, 2000); + LOAD_SYM(cuDeviceGetAttribute, 2000); + LOAD_SYM(cuGetErrorString, 6000); + LOAD_SYM(cuMemAddressFree, 10020); + LOAD_SYM(cuMemAddressReserve, 10020); + LOAD_SYM(cuMemCreate, 10020); + LOAD_SYM(cuMemGetAddressRange, 3020); + LOAD_SYM(cuMemGetAllocationGranularity, 10020); + LOAD_SYM(cuMemMap, 10020); + LOAD_SYM(cuMemRetainAllocationHandle, 11000); + LOAD_SYM(cuMemRelease, 10020); + LOAD_SYM(cuMemSetAccess, 10020); + LOAD_SYM(cuMemUnmap, 10020); #endif + }); } } // namespace cudecomp diff --git a/src/cudecomp.cc b/src/cudecomp.cc index b15ddf3..1da8bf8 100644 --- a/src/cudecomp.cc +++ b/src/cudecomp.cc @@ -25,7 +25,7 @@ #include #include #include -#include +#include #include #include #include @@ -55,9 +55,6 @@ namespace cudecomp { namespace { -// Static flag to disable multiple handle creation -static bool cudecomp_initialized = false; - static cudecomp::ncclComm ncclCommFromMPIComm(MPI_Comm mpi_comm) { int rank, nranks; CHECK_MPI(MPI_Comm_rank(mpi_comm, &rank)); @@ -74,6 +71,46 @@ static cudecomp::ncclComm ncclCommFromMPIComm(MPI_Comm mpi_comm) { } #ifdef ENABLE_NVSHMEM +static nvshmemProcessState process_nvshmem_state; + +static std::vector getCommWorldRanks(MPI_Comm mpi_comm) { + int nranks; + CHECK_MPI(MPI_Comm_size(mpi_comm, &nranks)); + + std::vector comm_ranks(nranks); + std::vector world_ranks(nranks); + for (int i = 0; i < nranks; ++i) { + comm_ranks[i] = i; + } + + MPI_Group comm_group = MPI_GROUP_NULL; + MPI_Group world_group = MPI_GROUP_NULL; + CHECK_MPI(MPI_Comm_group(mpi_comm, &comm_group)); + CHECK_MPI(MPI_Comm_group(MPI_COMM_WORLD, &world_group)); + CHECK_MPI(MPI_Group_translate_ranks(comm_group, nranks, comm_ranks.data(), world_group, world_ranks.data())); + CHECK_MPI(MPI_Group_free(&world_group)); + CHECK_MPI(MPI_Group_free(&comm_group)); + + return world_ranks; +} + +static std::vector checkNvshmemCommCongruent(MPI_Comm mpi_comm) { + auto world_ranks = getCommWorldRanks(mpi_comm); + + int local_noncongruent = 0; + if (!process_nvshmem_state.init_world_ranks.empty()) { + local_noncongruent = process_nvshmem_state.init_world_ranks != world_ranks; + } + + // In overlapping communicator cases, some ranks may not have prior process-local NVSHMEM state. + // Agree on incompatibility before any rank reuses or initializes the process-global runtime. + int any_noncongruent = 0; + CHECK_MPI(MPI_Allreduce(&local_noncongruent, &any_noncongruent, 1, MPI_INT, MPI_MAX, mpi_comm)); + + if (any_noncongruent) { THROW_INVALID_USAGE("Multiple NVSHMEM-backed handles require congruent MPI communicators"); } + return world_ranks; +} + static void initNvshmemFromMPIComm(MPI_Comm mpi_comm) { int rank, nranks; CHECK_MPI(MPI_Comm_rank(mpi_comm, &rank)); @@ -163,8 +200,9 @@ static void checkHandle(cudecompHandle_t handle) { if (!handle || !handle->initialized) { THROW_INVALID_USAGE("invalid handle"); } } -static void checkGridDesc(cudecompGridDesc_t grid_desc) { +static void checkGridDesc(cudecompHandle_t handle, cudecompGridDesc_t grid_desc) { if (!grid_desc || !grid_desc->initialized) { THROW_INVALID_USAGE("invalid grid descriptor"); } + if (grid_desc->handle != handle) { THROW_INVALID_USAGE("grid descriptor belongs to a different handle"); } } static cudecompResult_t handleException(const BaseException& e) { @@ -435,11 +473,11 @@ static void resolveRankOrder(cudecompHandle_t handle, cudecompGridDesc_t grid_de } #ifdef ENABLE_NVSHMEM -static void inspectNvshmemEnvVars(cudecompHandle_t& handle) { +static void inspectNvshmemEnvVars(nvshmemRuntimeState& runtime) { // Check NVSHMEM_DISABLE_CUDA_VMM - handle->nvshmem_vmm = true; + runtime.nvshmem_vmm = true; char* vmm_str = std::getenv("NVSHMEM_DISABLE_CUDA_VMM"); - if (vmm_str) { handle->nvshmem_vmm = std::strtol(vmm_str, nullptr, 10) == 0; } + if (vmm_str) { runtime.nvshmem_vmm = std::strtol(vmm_str, nullptr, 10) == 0; } // Check NVSHMEM_SYMMETRIC_SIZE char* symmetric_size_str = std::getenv("NVSHMEM_SYMMETRIC_SIZE"); @@ -456,10 +494,10 @@ static void inspectNvshmemEnvVars(cudecompHandle_t& handle) { } else { scale = 0; } - handle->nvshmem_symmetric_size = std::ceil(std::strtod(symmetric_size_str, nullptr) * (1ull << scale)); + runtime.nvshmem_symmetric_size = std::ceil(std::strtod(symmetric_size_str, nullptr) * (1ull << scale)); } else { // NVSHMEM symmetric size defaults to 1 GiB - handle->nvshmem_symmetric_size = 1ull << 30; + runtime.nvshmem_symmetric_size = 1ull << 30; } } @@ -493,11 +531,30 @@ static void checkNvshmemVersion(cudecompHandle_t& handle) { } } -static nvshmemRuntime createNvshmemRuntime(cudecompHandle_t& handle) { - checkNvshmemVersion(handle); - inspectNvshmemEnvVars(handle); - initNvshmemFromMPIComm(handle->mpi_comm); - return std::make_shared(); +static nvshmemRuntime acquireNvshmemRuntime(cudecompHandle_t& handle) { + auto world_ranks = checkNvshmemCommCongruent(handle->mpi_comm); + auto runtime = process_nvshmem_state.active_runtime.lock(); + if (!runtime) { + runtime = std::make_shared(); + inspectNvshmemEnvVars(*runtime); + if (process_nvshmem_state.init_world_ranks.empty()) { + process_nvshmem_state.init_world_ranks = std::move(world_ranks); + } + checkNvshmemVersion(handle); + initNvshmemFromMPIComm(handle->mpi_comm); + runtime->initialized = true; + process_nvshmem_state.active_runtime = runtime; + return runtime; + } + + if (!runtime->initialized) { + runtime->finalize(); + inspectNvshmemEnvVars(*runtime); + checkNvshmemVersion(handle); + initNvshmemFromMPIComm(handle->mpi_comm); + runtime->initialized = true; + } + return runtime; } #endif @@ -508,15 +565,6 @@ static void releaseUnusedHandleResources(cudecompHandle_t handle, bool release_s if (handle->nccl_comm && handle->nccl_comm.use_count() == 1) { handle->nccl_comm.reset(); } if (handle->nccl_local_comm && handle->nccl_local_comm.use_count() == 1) { handle->nccl_local_comm.reset(); } if (release_streams) { handle->streams.clear(); } - -#ifdef ENABLE_NVSHMEM - // If handle has the only remaining reference to the NVSHMEM runtime, destroy it to reclaim resources - if (handle->nvshmem_runtime && handle->nvshmem_runtime.use_count() == 1) { - handle->nvshmem_allocations.clear(); - handle->nvshmem_runtime.reset(); - handle->nvshmem_allocation_size = 0; - } -#endif } static void cleanupFailedGridDescCreate(cudecompHandle_t handle, cudecompGridDesc_t grid_desc, @@ -553,6 +601,26 @@ struct cuMemAllocationGuard { } // namespace } // namespace cudecomp +#ifdef ENABLE_NVSHMEM +namespace cudecomp { +void warnIfNvshmemBufferUsedWithMpi(cudecompHandle_t handle, const void* send_buff, const void* recv_buff) { + auto runtime = process_nvshmem_state.active_runtime.lock(); + if (!handle || handle->rank != 0 || !runtime || !runtime->initialized || + process_nvshmem_state.mixed_buffer_warning_issued) { + return; + } + + if (nvshmem_ptr(send_buff, handle->rank) || nvshmem_ptr(recv_buff, handle->rank)) { + printf("CUDECOMP:WARN: A work buffer allocated with nvshmem_malloc (via cudecompMalloc) is " + "being used with an MPI communication backend. This may cause issues with some MPI " + "implementations. See the documentation for additional details and possible workarounds " + "if you encounter issues.\n"); + process_nvshmem_state.mixed_buffer_warning_issued = true; + } +} +} // namespace cudecomp +#endif + cudecompHandle::~cudecompHandle() noexcept { #if NCCL_VERSION_CODE >= NCCL_VERSION(2, 19, 0) for (auto& entry : nccl_ubr_handles) { @@ -579,9 +647,6 @@ cudecompResult_t cudecompInit(cudecompHandle_t* handle_in, MPI_Comm mpi_comm) { cudecompHandle_t handle = nullptr; try { if (!handle_in) { THROW_INVALID_USAGE("handle argument cannot be null"); } - if (cudecomp_initialized) { - THROW_INVALID_USAGE("cuDecomp already initialized and multiple handles are not supported."); - } handle = new cudecompHandle; handle->mpi_comm = mpi_comm; CHECK_MPI(MPI_Comm_rank(mpi_comm, &handle->rank)); @@ -677,7 +742,6 @@ cudecompResult_t cudecompInit(cudecompHandle_t* handle_in, MPI_Comm mpi_comm) { CHECK_CUDA(cudaDeviceGetAttribute(&handle->device_max_threads_per_sm, cudaDevAttrMaxThreadsPerMultiProcessor, dev)); handle->initialized = true; - cudecomp_initialized = true; *handle_in = handle; } @@ -700,8 +764,6 @@ cudecompResult_t cudecompFinalize(cudecompHandle_t handle) { #endif delete handle; - - cudecomp_initialized = false; } CUDECOMP_CATCH_C_API_ERRORS() return CUDECOMP_RESULT_SUCCESS; @@ -723,6 +785,9 @@ cudecompResult_t cudecompGridDescCreate(cudecompHandle_t handle, cudecompGridDes using namespace cudecomp; cudecompGridDesc_t grid_desc = nullptr; bool created_streams = false; +#ifdef ENABLE_NVSHMEM + nvshmemRuntime nvshmem_runtime; +#endif try { checkHandle(handle); if (!grid_desc_in) { THROW_INVALID_USAGE("grid_desc argument cannot be null"); } @@ -747,6 +812,7 @@ cudecompResult_t cudecompGridDescCreate(cudecompHandle_t handle, cudecompGridDes grid_desc = new cudecompGridDesc; grid_desc->initialized = true; + grid_desc->handle = handle; grid_desc->config = *config; resolveRankOrder(handle, grid_desc); auto comm_backend = grid_desc->config.transpose_comm_backend; @@ -826,10 +892,8 @@ cudecompResult_t cudecompGridDescCreate(cudecompHandle_t handle, cudecompGridDes if (transposeBackendRequiresNvshmem(comm_backend) || haloBackendRequiresNvshmem(halo_comm_backend) || ((autotune_transpose_backend || autotune_halo_backend) && !autotune_disable_nvshmem_backends)) { #ifdef ENABLE_NVSHMEM - if (!handle->nvshmem_runtime) { - handle->nvshmem_runtime = createNvshmemRuntime(handle); - handle->nvshmem_allocation_size = 0; - } + nvshmem_runtime = acquireNvshmemRuntime(handle); + grid_desc->nvshmem_runtime = nvshmem_runtime; #endif } @@ -860,11 +924,14 @@ cudecompResult_t cudecompGridDescCreate(cudecompHandle_t handle, cudecompGridDes createCommInfo(handle, grid_desc, need_nvshmem); #ifdef ENABLE_NVSHMEM if (need_nvshmem) { + if (!nvshmem_runtime || !nvshmem_runtime->initialized) { nvshmem_runtime = acquireNvshmemRuntime(handle); } CHECK_CUDA(cudaMalloc(&grid_desc->nvshmem_block_counters, handle->nranks * sizeof(int))); CHECK_CUDA(cudaMemset(grid_desc->nvshmem_block_counters, 0, handle->nranks * sizeof(int))); // Set grid descriptor reference to NVSHMEM runtime - grid_desc->nvshmem_runtime = handle->nvshmem_runtime; + grid_desc->nvshmem_runtime = nvshmem_runtime; + } else { + grid_desc->nvshmem_runtime.reset(); } #endif @@ -913,7 +980,7 @@ cudecompResult_t cudecompGridDescDestroy(cudecompHandle_t handle, cudecompGridDe using namespace cudecomp; try { checkHandle(handle); - checkGridDesc(grid_desc); + checkGridDesc(handle, grid_desc); // Print performance report if enabled if (handle->performance_report_enable) { printPerformanceReport(handle, grid_desc); } @@ -1005,7 +1072,7 @@ cudecompResult_t cudecompGetPencilInfo(cudecompHandle_t handle, cudecompGridDesc using namespace cudecomp; try { checkHandle(handle); - checkGridDesc(grid_desc); + checkGridDesc(handle, grid_desc); if (!pencil_info) { THROW_INVALID_USAGE("pencil_info argument cannot be null."); } if (axis < 0 || axis > 2) { THROW_INVALID_USAGE("axis argument out of range"); } @@ -1063,7 +1130,7 @@ cudecompResult_t cudecompGetGridDescConfig(cudecompHandle_t handle, cudecompGrid using namespace cudecomp; try { checkHandle(handle); - checkGridDesc(grid_desc); + checkGridDesc(handle, grid_desc); if (!config) { THROW_INVALID_USAGE("config argument cannot be null."); } *config = grid_desc->config; @@ -1092,7 +1159,7 @@ cudecompResult_t cudecompGetTransposeWorkspaceSize(cudecompHandle_t handle, cude using namespace cudecomp; try { checkHandle(handle); - checkGridDesc(grid_desc); + checkGridDesc(handle, grid_desc); if (!workspace_size) { THROW_INVALID_USAGE("workspace_size argument cannot be null."); } int64_t max_pencil_size_x = getGlobalMaxPencilSize(handle, grid_desc, 0); int64_t max_pencil_size_y = getGlobalMaxPencilSize(handle, grid_desc, 1); @@ -1115,7 +1182,7 @@ cudecompResult_t cudecompGetHaloWorkspaceSize(cudecompHandle_t handle, cudecompG using namespace cudecomp; try { checkHandle(handle); - checkGridDesc(grid_desc); + checkGridDesc(handle, grid_desc); if (axis < 0 || axis > 2) { THROW_INVALID_USAGE("axis argument out of range"); } if (!halo_extents) { THROW_INVALID_USAGE("halo_extents argument cannot be null."); } if (!workspace_size) { THROW_INVALID_USAGE("workspace_size argument cannot be null."); } @@ -1142,7 +1209,7 @@ cudecompResult_t cudecompMalloc(cudecompHandle_t handle, cudecompGridDesc_t grid using namespace cudecomp; try { checkHandle(handle); - checkGridDesc(grid_desc); + checkGridDesc(handle, grid_desc); if (!buffer) { THROW_INVALID_USAGE("buffer argument cannot be null"); } if (buffer_size_bytes == 0) { THROW_INVALID_USAGE("buffer size cannot be zero"); } @@ -1152,21 +1219,27 @@ cudecompResult_t cudecompMalloc(cudecompHandle_t handle, cudecompGridDesc_t grid // NVSHMEM requires allocations to be the same size for all ranks. Find maximum. CHECK_MPI(MPI_Allreduce(MPI_IN_PLACE, &buffer_size_bytes, 1, MPI_LONG_LONG_INT, MPI_MAX, handle->mpi_comm)); - size_t nvshmem_free_size = handle->nvshmem_symmetric_size - handle->nvshmem_allocation_size; - if (!handle->nvshmem_vmm && handle->rank == 0 && buffer_size_bytes > nvshmem_free_size) { + auto nvshmem_runtime = grid_desc->nvshmem_runtime; + if (!nvshmem_runtime || !nvshmem_runtime->initialized) { THROW_INVALID_USAGE("NVSHMEM runtime is unavailable"); } + + size_t nvshmem_free_size = 0; + if (nvshmem_runtime->nvshmem_symmetric_size > nvshmem_runtime->nvshmem_allocation_size) { + nvshmem_free_size = nvshmem_runtime->nvshmem_symmetric_size - nvshmem_runtime->nvshmem_allocation_size; + } + if (!nvshmem_runtime->nvshmem_vmm && handle->rank == 0 && buffer_size_bytes > nvshmem_free_size) { fprintf(stderr, "CUDECOMP:WARN: Attempting an NVSHMEM allocation of %lld bytes but *approximately* " "%zu free bytes of %zu total bytes of symmetric heap space available. If the allocation fails, " "set NVSHMEM_SYMMETRIC_SIZE >= %zu and try again.\n", - buffer_size_bytes, nvshmem_free_size, handle->nvshmem_symmetric_size, - handle->nvshmem_symmetric_size + (buffer_size_bytes - nvshmem_free_size)); + buffer_size_bytes, nvshmem_free_size, nvshmem_runtime->nvshmem_symmetric_size, + nvshmem_runtime->nvshmem_symmetric_size + (buffer_size_bytes - nvshmem_free_size)); } *buffer = nvshmem_malloc(buffer_size_bytes); if (buffer_size_bytes != 0 && *buffer == nullptr) { THROW_NVSHMEM_ERROR("nvshmem_malloc failed"); } // Record NVSHMEM allocation details - handle->nvshmem_allocations[*buffer] = buffer_size_bytes; - handle->nvshmem_allocation_size += buffer_size_bytes; + nvshmem_runtime->nvshmem_allocations[*buffer] = buffer_size_bytes; + nvshmem_runtime->nvshmem_allocation_size += buffer_size_bytes; #else THROW_NOT_SUPPORTED("build does not support NVSHMEM communication backends."); #endif @@ -1251,7 +1324,7 @@ cudecompResult_t cudecompFree(cudecompHandle_t handle, cudecompGridDesc_t grid_d using namespace cudecomp; try { checkHandle(handle); - checkGridDesc(grid_desc); + checkGridDesc(handle, grid_desc); #if NCCL_VERSION_CODE >= NCCL_VERSION(2, 19, 0) if (handle->nccl_ubr_handles.count(buffer) != 0) { @@ -1266,12 +1339,19 @@ cudecompResult_t cudecompFree(cudecompHandle_t handle, cudecompGridDesc_t grid_d haloBackendRequiresNvshmem(grid_desc->config.halo_comm_backend)) { #ifdef ENABLE_NVSHMEM if (buffer) { + auto nvshmem_runtime = grid_desc->nvshmem_runtime; + if (!nvshmem_runtime || !nvshmem_runtime->initialized) { + THROW_INVALID_USAGE("NVSHMEM runtime is unavailable"); + } + nvshmem_free(buffer); // Record NVSHMEM deallocation details - size_t buffer_size_bytes = handle->nvshmem_allocations[buffer]; - handle->nvshmem_allocation_size -= buffer_size_bytes; - handle->nvshmem_allocations.erase(buffer); + auto entry = nvshmem_runtime->nvshmem_allocations.find(buffer); + if (entry != nvshmem_runtime->nvshmem_allocations.end()) { + nvshmem_runtime->nvshmem_allocation_size -= entry->second; + nvshmem_runtime->nvshmem_allocations.erase(entry); + } } #else THROW_NOT_SUPPORTED("build does not support NVSHMEM communication backends."); @@ -1347,7 +1427,7 @@ cudecompResult_t cudecompGetShiftedRank(cudecompHandle_t handle, cudecompGridDes try { checkHandle(handle); - checkGridDesc(grid_desc); + checkGridDesc(handle, grid_desc); if (axis < 0 || axis > 2) { THROW_INVALID_USAGE("axis argument out of range"); } if (dim < 0 || dim > 2) { THROW_INVALID_USAGE("dim argument out of range"); } if (!shifted_rank) { THROW_INVALID_USAGE("shifted_rank argument cannot be null."); } @@ -1395,7 +1475,7 @@ cudecompResult_t cudecompTransposeXToY(cudecompHandle_t handle, cudecompGridDesc using namespace cudecomp; try { checkHandle(handle); - checkGridDesc(grid_desc); + checkGridDesc(handle, grid_desc); checkDataType(dtype); if (!input) { THROW_INVALID_USAGE("input argument cannot be null"); } if (!output) { THROW_INVALID_USAGE("output argument cannot be null"); } @@ -1436,7 +1516,7 @@ cudecompResult_t cudecompTransposeYToZ(cudecompHandle_t handle, cudecompGridDesc using namespace cudecomp; try { checkHandle(handle); - checkGridDesc(grid_desc); + checkGridDesc(handle, grid_desc); checkDataType(dtype); if (!input) { THROW_INVALID_USAGE("input argument cannot be null"); } if (!output) { THROW_INVALID_USAGE("output argument cannot be null"); } @@ -1477,7 +1557,7 @@ cudecompResult_t cudecompTransposeZToY(cudecompHandle_t handle, cudecompGridDesc using namespace cudecomp; try { checkHandle(handle); - checkGridDesc(grid_desc); + checkGridDesc(handle, grid_desc); checkDataType(dtype); if (!input) { THROW_INVALID_USAGE("input argument cannot be null"); } if (!output) { THROW_INVALID_USAGE("output argument cannot be null"); } @@ -1518,7 +1598,7 @@ cudecompResult_t cudecompTransposeYToX(cudecompHandle_t handle, cudecompGridDesc using namespace cudecomp; try { checkHandle(handle); - checkGridDesc(grid_desc); + checkGridDesc(handle, grid_desc); checkDataType(dtype); if (!input) { THROW_INVALID_USAGE("input argument cannot be null"); } if (!output) { THROW_INVALID_USAGE("output argument cannot be null"); } @@ -1558,7 +1638,7 @@ cudecompResult_t cudecompUpdateHalosX(cudecompHandle_t handle, cudecompGridDesc_ using namespace cudecomp; try { checkHandle(handle); - checkGridDesc(grid_desc); + checkGridDesc(handle, grid_desc); checkDataType(dtype); if (!halo_extents) { THROW_INVALID_USAGE("halo_extents argument cannot be null"); } if (halo_extents[0] == 0 && halo_extents[1] == 0 && halo_extents[2] == 0) { @@ -1601,7 +1681,7 @@ cudecompResult_t cudecompUpdateHalosY(cudecompHandle_t handle, cudecompGridDesc_ using namespace cudecomp; try { checkHandle(handle); - checkGridDesc(grid_desc); + checkGridDesc(handle, grid_desc); checkDataType(dtype); if (!halo_extents) { THROW_INVALID_USAGE("halo_extents argument cannot be null"); } if (halo_extents[0] == 0 && halo_extents[1] == 0 && halo_extents[2] == 0) { @@ -1644,7 +1724,7 @@ cudecompResult_t cudecompUpdateHalosZ(cudecompHandle_t handle, cudecompGridDesc_ using namespace cudecomp; try { checkHandle(handle); - checkGridDesc(grid_desc); + checkGridDesc(handle, grid_desc); checkDataType(dtype); if (!halo_extents) { THROW_INVALID_USAGE("halo_extents argument cannot be null"); } if (halo_extents[0] == 0 && halo_extents[1] == 0 && halo_extents[2] == 0) { diff --git a/src/nvml_wrap.cc b/src/nvml_wrap.cc index 9e0a715..2a70b42 100644 --- a/src/nvml_wrap.cc +++ b/src/nvml_wrap.cc @@ -15,8 +15,10 @@ * limitations under the License. */ -#include #include +#include + +#include #include "internal/checks.h" #include "internal/exceptions.h" @@ -34,19 +36,22 @@ namespace cudecomp { nvmlFunctionTable nvmlFnTable; // global table of required NVML functions void initNvmlFunctionTable() { - void* nvml_handle = dlopen("libnvidia-ml.so.1", RTLD_NOW); - if (!nvml_handle) { THROW_INVALID_USAGE("Could not dlopen libnvidia-ml.so.1"); } - LOAD_SYM(nvmlInit); - LOAD_SYM(nvmlShutdown); - LOAD_SYM(nvmlErrorString); - LOAD_SYM(nvmlDeviceGetFieldValues); - LOAD_SYM(nvmlDeviceGetHandleByPciBusId); + static std::once_flag init_once; + std::call_once(init_once, []() { + void* nvml_handle = dlopen("libnvidia-ml.so.1", RTLD_NOW); + if (!nvml_handle) { THROW_INVALID_USAGE("Could not dlopen libnvidia-ml.so.1"); } + LOAD_SYM(nvmlInit); + LOAD_SYM(nvmlShutdown); + LOAD_SYM(nvmlErrorString); + LOAD_SYM(nvmlDeviceGetFieldValues); + LOAD_SYM(nvmlDeviceGetHandleByPciBusId); #if NVML_API_VERSION >= 12 && CUDART_VERSION >= 12040 - LOAD_SYM(nvmlDeviceGetGpuFabricInfoV); + LOAD_SYM(nvmlDeviceGetGpuFabricInfoV); #endif - LOAD_SYM(nvmlDeviceGetNvLinkCapability); - LOAD_SYM(nvmlDeviceGetNvLinkState); - LOAD_SYM(nvmlDeviceGetNvLinkRemotePciInfo); + LOAD_SYM(nvmlDeviceGetNvLinkCapability); + LOAD_SYM(nvmlDeviceGetNvLinkState); + LOAD_SYM(nvmlDeviceGetNvLinkRemotePciInfo); + }); } bool nvmlHasFabricSupport() { diff --git a/tests/ctest/CMakeLists.txt b/tests/ctest/CMakeLists.txt index 243c56b..7c1ba04 100644 --- a/tests/ctest/CMakeLists.txt +++ b/tests/ctest/CMakeLists.txt @@ -70,6 +70,10 @@ if (CUDECOMP_ENABLE_NVSHMEM) PUBLIC ${NVSHMEM_INCLUDE_DIR} ) + target_link_libraries(cudecomp_test_support + PUBLIC + ${NVSHMEM_LIBRARY_DIR}/libnvshmem_host.so + ) endif() target_link_libraries(cudecomp_test_support PUBLIC @@ -105,6 +109,11 @@ set_tests_properties(cudecomp_api PROPERTIES LABELS "api;mpi" TIMEOUT ${CUDECOMP_TEST_TIMEOUT} ) +if (CUDECOMP_ENABLE_NVSHMEM) + set_tests_properties(cudecomp_api PROPERTIES + ENVIRONMENT "NVSHMEM_DISABLE_NCCL=1" + ) +endif() add_executable(cudecomp_test_transpose) target_sources(cudecomp_test_transpose diff --git a/tests/ctest/api_tests.cc b/tests/ctest/api_tests.cc index 9418ba2..3f04e8d 100644 --- a/tests/ctest/api_tests.cc +++ b/tests/ctest/api_tests.cc @@ -11,6 +11,9 @@ #include #include +#ifdef ENABLE_NVSHMEM +#include +#endif #include "cudecomp.h" @@ -27,6 +30,9 @@ constexpr std::array kPdims{2, 2}; constexpr std::array kHaloExtents{1, 2, 1}; constexpr std::array kPadding{1, 0, 2}; constexpr std::array kHaloPeriods{false, true, false}; +#ifdef ENABLE_NVSHMEM +constexpr int kNvshmemTestRanks = 2; +#endif struct ExpectedPencilInfo { std::array shape; @@ -109,6 +115,34 @@ void setDistributedConfig(cudecompGridDescConfig_t& config) { config.pdims[1] = kPdims[1]; } +#ifdef ENABLE_NVSHMEM +void setNvshmemTestConfig(cudecompGridDescConfig_t& config) { + config.gdims[0] = kGdims[0]; + config.gdims[1] = kGdims[1]; + config.gdims[2] = kGdims[2]; + config.pdims[0] = 1; + config.pdims[1] = kNvshmemTestRanks; + config.transpose_comm_backend = CUDECOMP_TRANSPOSE_COMM_NVSHMEM; +} + +void expectNvshmemInitialized(const cudecomp_test::MpiTestComm& comm) { + const int local_status = nvshmemx_init_status(); + int all_ranks_match = + (local_status >= NVSHMEM_STATUS_IS_INITIALIZED && local_status < NVSHMEM_STATUS_INVALID) ? 1 : 0; + EXPECT_EQ(MPI_SUCCESS, MPI_Allreduce(MPI_IN_PLACE, &all_ranks_match, 1, MPI_INT, MPI_MIN, comm.mpiComm())); + EXPECT_EQ(1, all_ranks_match) << "local NVSHMEM init status is " << local_status << ", expected initialized"; +} + +void expectNvshmemNotInitialized(const cudecomp_test::MpiTestComm& comm) { + const int local_status = nvshmemx_init_status(); + // NVSHMEM may leave the bootstrap layer active after nvshmem_finalize(). + int all_ranks_match = local_status < NVSHMEM_STATUS_IS_INITIALIZED ? 1 : 0; + EXPECT_EQ(MPI_SUCCESS, MPI_Allreduce(MPI_IN_PLACE, &all_ranks_match, 1, MPI_INT, MPI_MIN, comm.mpiComm())); + EXPECT_EQ(1, all_ranks_match) << "local NVSHMEM init status is " << local_status + << ", expected below initialized"; +} +#endif + void setMemOrder(cudecompGridDescConfig_t& config, const std::array& order) { for (int axis = 0; axis < 3; ++axis) { for (int i = 0; i < 3; ++i) { @@ -355,12 +389,465 @@ class ApiTransposeTest : public ApiMpiTestBase {}; class ApiHaloTest : public ApiMpiTestBase {}; TEST_F(ApiInitTest, RejectsInvalidArguments) { - cudecompHandle_t second_handle = nullptr; - EXPECT_EQ(CUDECOMP_RESULT_INVALID_USAGE, cudecompInit(&second_handle, active_comm_.mpiComm())); - EXPECT_EQ(nullptr, second_handle); EXPECT_EQ(CUDECOMP_RESULT_INVALID_USAGE, cudecompInit(nullptr, active_comm_.mpiComm())); } +TEST_F(ApiInitTest, SupportsMultipleLiveHandlesWithIndependentResources) { + cudecompHandle_t second_handle = nullptr; + const cudecompResult_t second_init_result = cudecompInit(&second_handle, active_comm_.mpiComm()); + auto second_handle_guard = std::make_unique(second_handle); + CHECK_CUDECOMP_GLOBAL(active_comm_, second_init_result); + + auto config = distributedConfig(); + cudecompGridDesc_t grid_desc = nullptr; + CHECK_CUDECOMP_GLOBAL(active_comm_, cudecompGridDescCreate(handle_, &grid_desc, &config, nullptr)); + cudecomp_test::gridDescGuard grid_desc_guard(handle_, grid_desc); + + auto second_config = distributedConfig(); + second_config.rank_order = CUDECOMP_RANK_ORDER_COL_MAJOR; + cudecompGridDesc_t second_grid_desc = nullptr; + CHECK_CUDECOMP_GLOBAL(active_comm_, + cudecompGridDescCreate(second_handle, &second_grid_desc, &second_config, nullptr)); + cudecomp_test::gridDescGuard second_grid_desc_guard(second_handle, second_grid_desc); + + cudecompGridDescConfig_t queried_config; + EXPECT_EQ(CUDECOMP_RESULT_INVALID_USAGE, cudecompGetGridDescConfig(second_handle, grid_desc, &queried_config)); + EXPECT_EQ(CUDECOMP_RESULT_INVALID_USAGE, cudecompGridDescDestroy(second_handle, grid_desc)); + + void* buffer = nullptr; + CHECK_CUDECOMP_GLOBAL(active_comm_, cudecompMalloc(handle_, grid_desc, &buffer, 1024)); + cudecomp_test::cudecompBufferGuard buffer_guard(handle_, grid_desc, buffer); + + void* second_buffer = nullptr; + CHECK_CUDECOMP_GLOBAL(active_comm_, cudecompMalloc(second_handle, second_grid_desc, &second_buffer, 2048)); + cudecomp_test::cudecompBufferGuard second_buffer_guard(second_handle, second_grid_desc, second_buffer); + + void* unused_buffer = nullptr; + EXPECT_EQ(CUDECOMP_RESULT_INVALID_USAGE, cudecompMalloc(second_handle, grid_desc, &unused_buffer, 1024)); + EXPECT_EQ(nullptr, unused_buffer); +} + +TEST_F(ApiInitTest, FinalizesMultipleHandlesInCreationOrder) { + cudecompHandle_t second_handle = nullptr; + const cudecompResult_t second_init_result = cudecompInit(&second_handle, active_comm_.mpiComm()); + auto second_handle_guard = std::make_unique(second_handle); + CHECK_CUDECOMP_GLOBAL(active_comm_, second_init_result); + + { + auto config = distributedConfig(); + cudecompGridDesc_t grid_desc = nullptr; + CHECK_CUDECOMP_GLOBAL(active_comm_, cudecompGridDescCreate(handle_, &grid_desc, &config, nullptr)); + cudecomp_test::gridDescGuard grid_desc_guard(handle_, grid_desc); + + auto second_config = distributedConfig(); + cudecompGridDesc_t second_grid_desc = nullptr; + CHECK_CUDECOMP_GLOBAL(active_comm_, + cudecompGridDescCreate(second_handle, &second_grid_desc, &second_config, nullptr)); + cudecomp_test::gridDescGuard second_grid_desc_guard(second_handle, second_grid_desc); + } + + handle_guard_.reset(); + handle_ = nullptr; + second_handle_guard.reset(); +} + +TEST_F(ApiInitTest, SupportsMultipleNcclBackedHandles) { + const auto setup_decision = cudecomp_test::initializeGpuForTest(active_comm_, true); + ASSERT_FALSE(setup_decision.fail) << setup_decision.reason; + if (setup_decision.skip) { GTEST_SKIP() << setup_decision.reason; } + + cudecompHandle_t second_handle = nullptr; + const cudecompResult_t second_init_result = cudecompInit(&second_handle, active_comm_.mpiComm()); + auto second_handle_guard = std::make_unique(second_handle); + CHECK_CUDECOMP_GLOBAL(active_comm_, second_init_result); + + auto config = distributedConfig(); + config.transpose_comm_backend = CUDECOMP_TRANSPOSE_COMM_NCCL; + cudecompGridDesc_t grid_desc = nullptr; + CHECK_CUDECOMP_GLOBAL(active_comm_, cudecompGridDescCreate(handle_, &grid_desc, &config, nullptr)); + cudecomp_test::gridDescGuard grid_desc_guard(handle_, grid_desc); + + auto second_config = config; + cudecompGridDesc_t second_grid_desc = nullptr; + CHECK_CUDECOMP_GLOBAL(active_comm_, + cudecompGridDescCreate(second_handle, &second_grid_desc, &second_config, nullptr)); + cudecomp_test::gridDescGuard second_grid_desc_guard(second_handle, second_grid_desc); +} + +#ifdef ENABLE_NVSHMEM +TEST(ApiNvshmemInitTest, SupportsMultipleBackedHandlesWithCongruentCommunicators) { + auto world_comm = cudecomp_test::MpiTestComm::world(); + if (world_comm.size() < kApiTestRanks) { + GTEST_SKIP() << "NVSHMEM API tests require " << kApiTestRanks << " ranks, launched with " << world_comm.size(); + } + + const auto setup_decision = cudecomp_test::initializeGpuForTest(world_comm); + ASSERT_FALSE(setup_decision.fail) << setup_decision.reason; + if (setup_decision.skip) { GTEST_SKIP() << setup_decision.reason; } + + auto nvshmem_comm = cudecomp_test::MpiTestComm::splitRange(world_comm, 0, kNvshmemTestRanks); + if (!nvshmem_comm.valid()) return; + + cudecompHandle_t handle = nullptr; + const cudecompResult_t init_result = cudecompInit(&handle, nvshmem_comm.mpiComm()); + auto handle_guard = std::make_unique(handle); + CHECK_CUDECOMP_GLOBAL(nvshmem_comm, init_result); + + cudecompHandle_t second_handle = nullptr; + const cudecompResult_t second_init_result = cudecompInit(&second_handle, nvshmem_comm.mpiComm()); + auto second_handle_guard = std::make_unique(second_handle); + CHECK_CUDECOMP_GLOBAL(nvshmem_comm, second_init_result); + + cudecompGridDescConfig_t config; + EXPECT_EQ(CUDECOMP_RESULT_SUCCESS, cudecompGridDescConfigSetDefaults(&config)); + setNvshmemTestConfig(config); + cudecompGridDesc_t grid_desc = nullptr; + CHECK_CUDECOMP_GLOBAL(nvshmem_comm, cudecompGridDescCreate(handle, &grid_desc, &config, nullptr)); + cudecomp_test::gridDescGuard grid_desc_guard(handle, grid_desc); + + cudecompGridDescConfig_t second_config; + EXPECT_EQ(CUDECOMP_RESULT_SUCCESS, cudecompGridDescConfigSetDefaults(&second_config)); + setNvshmemTestConfig(second_config); + cudecompGridDesc_t second_grid_desc = nullptr; + CHECK_CUDECOMP_GLOBAL(nvshmem_comm, + cudecompGridDescCreate(second_handle, &second_grid_desc, &second_config, nullptr)); + cudecomp_test::gridDescGuard second_grid_desc_guard(second_handle, second_grid_desc); + + void* buffer = nullptr; + CHECK_CUDECOMP_GLOBAL(nvshmem_comm, cudecompMalloc(handle, grid_desc, &buffer, 1024)); + cudecomp_test::cudecompBufferGuard buffer_guard(handle, grid_desc, buffer); + + void* second_buffer = nullptr; + CHECK_CUDECOMP_GLOBAL(nvshmem_comm, cudecompMalloc(second_handle, second_grid_desc, &second_buffer, 2048)); + cudecomp_test::cudecompBufferGuard second_buffer_guard(second_handle, second_grid_desc, second_buffer); +} + +TEST(ApiNvshmemInitTest, FinalizesAfterLastGridDescAcrossHandles) { + auto world_comm = cudecomp_test::MpiTestComm::world(); + if (world_comm.size() < kApiTestRanks) { + GTEST_SKIP() << "NVSHMEM API tests require " << kApiTestRanks << " ranks, launched with " << world_comm.size(); + } + + const auto setup_decision = cudecomp_test::initializeGpuForTest(world_comm); + ASSERT_FALSE(setup_decision.fail) << setup_decision.reason; + if (setup_decision.skip) { GTEST_SKIP() << setup_decision.reason; } + + auto nvshmem_comm = cudecomp_test::MpiTestComm::splitRange(world_comm, 0, kNvshmemTestRanks); + if (!nvshmem_comm.valid()) return; + + cudecompHandle_t handle = nullptr; + const cudecompResult_t init_result = cudecompInit(&handle, nvshmem_comm.mpiComm()); + auto handle_guard = std::make_unique(handle); + CHECK_CUDECOMP_GLOBAL(nvshmem_comm, init_result); + + cudecompHandle_t second_handle = nullptr; + const cudecompResult_t second_init_result = cudecompInit(&second_handle, nvshmem_comm.mpiComm()); + auto second_handle_guard = std::make_unique(second_handle); + CHECK_CUDECOMP_GLOBAL(nvshmem_comm, second_init_result); + + cudecompGridDescConfig_t config; + EXPECT_EQ(CUDECOMP_RESULT_SUCCESS, cudecompGridDescConfigSetDefaults(&config)); + setNvshmemTestConfig(config); + cudecompGridDesc_t grid_desc = nullptr; + CHECK_CUDECOMP_GLOBAL(nvshmem_comm, cudecompGridDescCreate(handle, &grid_desc, &config, nullptr)); + expectNvshmemInitialized(nvshmem_comm); + + cudecompGridDescConfig_t second_config; + EXPECT_EQ(CUDECOMP_RESULT_SUCCESS, cudecompGridDescConfigSetDefaults(&second_config)); + setNvshmemTestConfig(second_config); + cudecompGridDesc_t second_grid_desc = nullptr; + CHECK_CUDECOMP_GLOBAL(nvshmem_comm, + cudecompGridDescCreate(second_handle, &second_grid_desc, &second_config, nullptr)); + expectNvshmemInitialized(nvshmem_comm); + + CHECK_CUDECOMP_GLOBAL(nvshmem_comm, cudecompGridDescDestroy(handle, grid_desc)); + grid_desc = nullptr; + expectNvshmemInitialized(nvshmem_comm); + + CHECK_CUDECOMP_GLOBAL(nvshmem_comm, cudecompGridDescDestroy(second_handle, second_grid_desc)); + second_grid_desc = nullptr; + expectNvshmemNotInitialized(nvshmem_comm); +} + +TEST(ApiNvshmemInitTest, FinalizesAfterLastGridDescAndReinitializes) { + auto world_comm = cudecomp_test::MpiTestComm::world(); + if (world_comm.size() < kApiTestRanks) { + GTEST_SKIP() << "NVSHMEM API tests require " << kApiTestRanks << " ranks, launched with " << world_comm.size(); + } + + const auto setup_decision = cudecomp_test::initializeGpuForTest(world_comm); + ASSERT_FALSE(setup_decision.fail) << setup_decision.reason; + if (setup_decision.skip) { GTEST_SKIP() << setup_decision.reason; } + + auto nvshmem_comm = cudecomp_test::MpiTestComm::splitRange(world_comm, 0, kNvshmemTestRanks); + if (!nvshmem_comm.valid()) return; + + { + cudecompHandle_t handle = nullptr; + const cudecompResult_t init_result = cudecompInit(&handle, nvshmem_comm.mpiComm()); + auto handle_guard = std::make_unique(handle); + CHECK_CUDECOMP_GLOBAL(nvshmem_comm, init_result); + + cudecompGridDescConfig_t config; + EXPECT_EQ(CUDECOMP_RESULT_SUCCESS, cudecompGridDescConfigSetDefaults(&config)); + setNvshmemTestConfig(config); + + cudecompGridDesc_t first_grid_desc = nullptr; + CHECK_CUDECOMP_GLOBAL(nvshmem_comm, cudecompGridDescCreate(handle, &first_grid_desc, &config, nullptr)); + expectNvshmemInitialized(nvshmem_comm); + + cudecompGridDesc_t second_grid_desc = nullptr; + CHECK_CUDECOMP_GLOBAL(nvshmem_comm, cudecompGridDescCreate(handle, &second_grid_desc, &config, nullptr)); + expectNvshmemInitialized(nvshmem_comm); + + CHECK_CUDECOMP_GLOBAL(nvshmem_comm, cudecompGridDescDestroy(handle, first_grid_desc)); + first_grid_desc = nullptr; + expectNvshmemInitialized(nvshmem_comm); + + CHECK_CUDECOMP_GLOBAL(nvshmem_comm, cudecompGridDescDestroy(handle, second_grid_desc)); + second_grid_desc = nullptr; + expectNvshmemNotInitialized(nvshmem_comm); + } + + expectNvshmemNotInitialized(nvshmem_comm); + + { + cudecompHandle_t handle = nullptr; + const cudecompResult_t init_result = cudecompInit(&handle, nvshmem_comm.mpiComm()); + auto handle_guard = std::make_unique(handle); + CHECK_CUDECOMP_GLOBAL(nvshmem_comm, init_result); + + cudecompGridDescConfig_t config; + EXPECT_EQ(CUDECOMP_RESULT_SUCCESS, cudecompGridDescConfigSetDefaults(&config)); + setNvshmemTestConfig(config); + + cudecompGridDesc_t grid_desc = nullptr; + CHECK_CUDECOMP_GLOBAL(nvshmem_comm, cudecompGridDescCreate(handle, &grid_desc, &config, nullptr)); + expectNvshmemInitialized(nvshmem_comm); + + CHECK_CUDECOMP_GLOBAL(nvshmem_comm, cudecompGridDescDestroy(handle, grid_desc)); + grid_desc = nullptr; + expectNvshmemNotInitialized(nvshmem_comm); + } +} + +TEST(ApiNvshmemInitTest, RejectsBackedHandleWithNonCongruentCommunicator) { + auto world_comm = cudecomp_test::MpiTestComm::world(); + if (world_comm.size() < kApiTestRanks) { + GTEST_SKIP() << "NVSHMEM API tests require " << kApiTestRanks << " ranks, launched with " << world_comm.size(); + } + + const auto setup_decision = cudecomp_test::initializeGpuForTest(world_comm); + ASSERT_FALSE(setup_decision.fail) << setup_decision.reason; + if (setup_decision.skip) { GTEST_SKIP() << setup_decision.reason; } + + auto nvshmem_comm = cudecomp_test::MpiTestComm::splitRange(world_comm, 0, kNvshmemTestRanks); + if (!nvshmem_comm.valid()) return; + + cudecompHandle_t handle = nullptr; + const cudecompResult_t init_result = cudecompInit(&handle, nvshmem_comm.mpiComm()); + auto handle_guard = std::make_unique(handle); + CHECK_CUDECOMP_GLOBAL(nvshmem_comm, init_result); + + cudecompGridDescConfig_t config; + EXPECT_EQ(CUDECOMP_RESULT_SUCCESS, cudecompGridDescConfigSetDefaults(&config)); + setNvshmemTestConfig(config); + cudecompGridDesc_t grid_desc = nullptr; + CHECK_CUDECOMP_GLOBAL(nvshmem_comm, cudecompGridDescCreate(handle, &grid_desc, &config, nullptr)); + cudecomp_test::gridDescGuard grid_desc_guard(handle, grid_desc); + + MPI_Comm reversed_raw_comm = MPI_COMM_NULL; + CHECK_MPI_GLOBAL(nvshmem_comm, MPI_Comm_split(nvshmem_comm.mpiComm(), 0, + nvshmem_comm.size() - 1 - nvshmem_comm.rank(), &reversed_raw_comm)); + auto reversed_comm = cudecomp_test::MpiTestComm::fromComm(reversed_raw_comm); + if (reversed_raw_comm != MPI_COMM_NULL) { CHECK_MPI_GLOBAL(nvshmem_comm, MPI_Comm_free(&reversed_raw_comm)); } + + cudecompHandle_t reversed_handle = nullptr; + const cudecompResult_t reversed_init_result = cudecompInit(&reversed_handle, reversed_comm.mpiComm()); + auto reversed_handle_guard = std::make_unique(reversed_handle); + CHECK_CUDECOMP_GLOBAL(reversed_comm, reversed_init_result); + + cudecompGridDescConfig_t reversed_config; + EXPECT_EQ(CUDECOMP_RESULT_SUCCESS, cudecompGridDescConfigSetDefaults(&reversed_config)); + setNvshmemTestConfig(reversed_config); + cudecompGridDesc_t reversed_grid_desc = nullptr; + EXPECT_EQ(CUDECOMP_RESULT_INVALID_USAGE, + cudecompGridDescCreate(reversed_handle, &reversed_grid_desc, &reversed_config, nullptr)); + if (reversed_grid_desc) { (void)cudecompGridDescDestroy(reversed_handle, reversed_grid_desc); } +} + +TEST(ApiNvshmemInitTest, RejectsOverlappingNonCongruentCommunicatorCollectively) { + auto world_comm = cudecomp_test::MpiTestComm::world(); + if (world_comm.size() < kApiTestRanks) { + GTEST_SKIP() << "NVSHMEM API tests require " << kApiTestRanks << " ranks, launched with " << world_comm.size(); + } + + const auto setup_decision = cudecomp_test::initializeGpuForTest(world_comm); + ASSERT_FALSE(setup_decision.fail) << setup_decision.reason; + if (setup_decision.skip) { GTEST_SKIP() << setup_decision.reason; } + + auto first_comm = cudecomp_test::MpiTestComm::splitRange(world_comm, 0, kNvshmemTestRanks); + + cudecompHandle_t first_handle = nullptr; + std::unique_ptr first_handle_guard; + std::unique_ptr first_grid_desc_guard; + int first_ready = 1; + if (first_comm.valid()) { + const cudecompResult_t first_init_result = cudecompInit(&first_handle, first_comm.mpiComm()); + first_handle_guard = std::make_unique(first_handle); + + int first_init_ok = first_init_result == CUDECOMP_RESULT_SUCCESS ? 1 : 0; + EXPECT_EQ(MPI_SUCCESS, MPI_Allreduce(MPI_IN_PLACE, &first_init_ok, 1, MPI_INT, MPI_MIN, first_comm.mpiComm())); + if (first_init_ok) { + cudecompGridDescConfig_t first_config; + EXPECT_EQ(CUDECOMP_RESULT_SUCCESS, cudecompGridDescConfigSetDefaults(&first_config)); + setNvshmemTestConfig(first_config); + + cudecompGridDesc_t first_grid_desc = nullptr; + const cudecompResult_t first_create_result = + cudecompGridDescCreate(first_handle, &first_grid_desc, &first_config, nullptr); + int first_create_ok = first_create_result == CUDECOMP_RESULT_SUCCESS ? 1 : 0; + EXPECT_EQ(MPI_SUCCESS, MPI_Allreduce(MPI_IN_PLACE, &first_create_ok, 1, MPI_INT, MPI_MIN, first_comm.mpiComm())); + if (first_create_ok) { + first_grid_desc_guard = std::make_unique(first_handle, first_grid_desc); + } + first_ready = first_create_ok; + } else { + first_ready = 0; + } + } + + EXPECT_EQ(MPI_SUCCESS, MPI_Allreduce(MPI_IN_PLACE, &first_ready, 1, MPI_INT, MPI_MIN, world_comm.mpiComm())); + ASSERT_EQ(1, first_ready) << "initial NVSHMEM communicator setup failed"; + + auto candidate_comm = cudecomp_test::MpiTestComm::splitRange(world_comm, 1, kNvshmemTestRanks); + + cudecompHandle_t candidate_handle = nullptr; + std::unique_ptr candidate_handle_guard; + int candidate_init_ready = 1; + int candidate_rejected = 1; + if (candidate_comm.valid()) { + const cudecompResult_t candidate_init_result = cudecompInit(&candidate_handle, candidate_comm.mpiComm()); + candidate_handle_guard = std::make_unique(candidate_handle); + + candidate_init_ready = candidate_init_result == CUDECOMP_RESULT_SUCCESS ? 1 : 0; + EXPECT_EQ(MPI_SUCCESS, + MPI_Allreduce(MPI_IN_PLACE, &candidate_init_ready, 1, MPI_INT, MPI_MIN, candidate_comm.mpiComm())); + if (candidate_init_ready) { + cudecompGridDescConfig_t candidate_config; + EXPECT_EQ(CUDECOMP_RESULT_SUCCESS, cudecompGridDescConfigSetDefaults(&candidate_config)); + setNvshmemTestConfig(candidate_config); + + cudecompGridDesc_t candidate_grid_desc = nullptr; + const cudecompResult_t candidate_create_result = + cudecompGridDescCreate(candidate_handle, &candidate_grid_desc, &candidate_config, nullptr); + candidate_rejected = candidate_create_result == CUDECOMP_RESULT_INVALID_USAGE ? 1 : 0; + + int candidate_create_succeeded = candidate_create_result == CUDECOMP_RESULT_SUCCESS ? 1 : 0; + EXPECT_EQ(MPI_SUCCESS, MPI_Allreduce(MPI_IN_PLACE, &candidate_create_succeeded, 1, MPI_INT, MPI_MIN, + candidate_comm.mpiComm())); + if (candidate_create_succeeded && candidate_grid_desc) { + EXPECT_EQ(CUDECOMP_RESULT_SUCCESS, cudecompGridDescDestroy(candidate_handle, candidate_grid_desc)); + } + } + } + + EXPECT_EQ(MPI_SUCCESS, MPI_Allreduce(MPI_IN_PLACE, &candidate_init_ready, 1, MPI_INT, MPI_MIN, world_comm.mpiComm())); + EXPECT_EQ(MPI_SUCCESS, MPI_Allreduce(MPI_IN_PLACE, &candidate_rejected, 1, MPI_INT, MPI_MIN, world_comm.mpiComm())); + ASSERT_EQ(1, candidate_init_ready) << "candidate communicator initialization failed before NVSHMEM check"; + EXPECT_EQ(1, candidate_rejected); +} + +TEST(ApiNvshmemInitTest, RejectsNonCongruentCommunicatorAfterRuntimeFinalize) { + auto world_comm = cudecomp_test::MpiTestComm::world(); + if (world_comm.size() < kApiTestRanks) { + GTEST_SKIP() << "NVSHMEM API tests require " << kApiTestRanks << " ranks, launched with " << world_comm.size(); + } + + const auto setup_decision = cudecomp_test::initializeGpuForTest(world_comm); + ASSERT_FALSE(setup_decision.fail) << setup_decision.reason; + if (setup_decision.skip) { GTEST_SKIP() << setup_decision.reason; } + + auto first_comm = cudecomp_test::MpiTestComm::splitRange(world_comm, 0, kNvshmemTestRanks); + + cudecompHandle_t first_handle = nullptr; + std::unique_ptr first_handle_guard; + int first_ready = 1; + int first_finalized = 1; + if (first_comm.valid()) { + const cudecompResult_t first_init_result = cudecompInit(&first_handle, first_comm.mpiComm()); + first_handle_guard = std::make_unique(first_handle); + + int first_init_ok = first_init_result == CUDECOMP_RESULT_SUCCESS ? 1 : 0; + EXPECT_EQ(MPI_SUCCESS, MPI_Allreduce(MPI_IN_PLACE, &first_init_ok, 1, MPI_INT, MPI_MIN, first_comm.mpiComm())); + if (first_init_ok) { + cudecompGridDescConfig_t first_config; + EXPECT_EQ(CUDECOMP_RESULT_SUCCESS, cudecompGridDescConfigSetDefaults(&first_config)); + setNvshmemTestConfig(first_config); + + cudecompGridDesc_t first_grid_desc = nullptr; + const cudecompResult_t first_create_result = + cudecompGridDescCreate(first_handle, &first_grid_desc, &first_config, nullptr); + int first_create_ok = first_create_result == CUDECOMP_RESULT_SUCCESS ? 1 : 0; + EXPECT_EQ(MPI_SUCCESS, MPI_Allreduce(MPI_IN_PLACE, &first_create_ok, 1, MPI_INT, MPI_MIN, first_comm.mpiComm())); + first_ready = first_create_ok; + if (first_create_ok) { + const cudecompResult_t first_destroy_result = cudecompGridDescDestroy(first_handle, first_grid_desc); + int first_destroy_ok = first_destroy_result == CUDECOMP_RESULT_SUCCESS ? 1 : 0; + EXPECT_EQ(MPI_SUCCESS, + MPI_Allreduce(MPI_IN_PLACE, &first_destroy_ok, 1, MPI_INT, MPI_MIN, first_comm.mpiComm())); + first_finalized = first_destroy_ok; + if (first_destroy_ok) { expectNvshmemNotInitialized(first_comm); } + } + } else { + first_ready = 0; + } + } + + EXPECT_EQ(MPI_SUCCESS, MPI_Allreduce(MPI_IN_PLACE, &first_ready, 1, MPI_INT, MPI_MIN, world_comm.mpiComm())); + EXPECT_EQ(MPI_SUCCESS, MPI_Allreduce(MPI_IN_PLACE, &first_finalized, 1, MPI_INT, MPI_MIN, world_comm.mpiComm())); + ASSERT_EQ(1, first_ready) << "initial NVSHMEM communicator setup failed"; + ASSERT_EQ(1, first_finalized) << "initial NVSHMEM runtime did not finalize cleanly"; + + auto candidate_comm = cudecomp_test::MpiTestComm::splitRange(world_comm, 1, kNvshmemTestRanks); + + cudecompHandle_t candidate_handle = nullptr; + std::unique_ptr candidate_handle_guard; + int candidate_init_ready = 1; + int candidate_rejected = 1; + if (candidate_comm.valid()) { + const cudecompResult_t candidate_init_result = cudecompInit(&candidate_handle, candidate_comm.mpiComm()); + candidate_handle_guard = std::make_unique(candidate_handle); + + candidate_init_ready = candidate_init_result == CUDECOMP_RESULT_SUCCESS ? 1 : 0; + EXPECT_EQ(MPI_SUCCESS, + MPI_Allreduce(MPI_IN_PLACE, &candidate_init_ready, 1, MPI_INT, MPI_MIN, candidate_comm.mpiComm())); + if (candidate_init_ready) { + cudecompGridDescConfig_t candidate_config; + EXPECT_EQ(CUDECOMP_RESULT_SUCCESS, cudecompGridDescConfigSetDefaults(&candidate_config)); + setNvshmemTestConfig(candidate_config); + + cudecompGridDesc_t candidate_grid_desc = nullptr; + const cudecompResult_t candidate_create_result = + cudecompGridDescCreate(candidate_handle, &candidate_grid_desc, &candidate_config, nullptr); + candidate_rejected = candidate_create_result == CUDECOMP_RESULT_INVALID_USAGE ? 1 : 0; + + int candidate_create_succeeded = candidate_create_result == CUDECOMP_RESULT_SUCCESS ? 1 : 0; + EXPECT_EQ(MPI_SUCCESS, MPI_Allreduce(MPI_IN_PLACE, &candidate_create_succeeded, 1, MPI_INT, MPI_MIN, + candidate_comm.mpiComm())); + if (candidate_create_succeeded && candidate_grid_desc) { + EXPECT_EQ(CUDECOMP_RESULT_SUCCESS, cudecompGridDescDestroy(candidate_handle, candidate_grid_desc)); + } + } + } + + EXPECT_EQ(MPI_SUCCESS, MPI_Allreduce(MPI_IN_PLACE, &candidate_init_ready, 1, MPI_INT, MPI_MIN, world_comm.mpiComm())); + EXPECT_EQ(MPI_SUCCESS, MPI_Allreduce(MPI_IN_PLACE, &candidate_rejected, 1, MPI_INT, MPI_MIN, world_comm.mpiComm())); + ASSERT_EQ(1, candidate_init_ready) << "candidate communicator initialization failed before NVSHMEM check"; + EXPECT_EQ(1, candidate_rejected); +} +#endif + TEST_F(ApiFinalizeTest, RejectsInvalidArguments) { EXPECT_EQ(CUDECOMP_RESULT_INVALID_USAGE, cudecompFinalize(nullptr)); } diff --git a/tests/ctest/mpi_test_utils.cc b/tests/ctest/mpi_test_utils.cc index af0fdcc..720c9f6 100644 --- a/tests/ctest/mpi_test_utils.cc +++ b/tests/ctest/mpi_test_utils.cc @@ -65,6 +65,20 @@ MpiTestComm MpiTestComm::split(const MpiTestComm& parent_comm, int requested_ran return result; } +MpiTestComm MpiTestComm::splitRange(const MpiTestComm& parent_comm, int first_rank, int requested_ranks) { + const bool valid_request = + requested_ranks > 0 && first_rank >= 0 && first_rank + requested_ranks <= parent_comm.size(); + const bool active = + valid_request && parent_comm.rank() >= first_rank && parent_comm.rank() < first_rank + requested_ranks; + + MPI_Comm comm = MPI_COMM_NULL; + MPI_Comm_split(parent_comm.mpiComm(), active ? 0 : MPI_UNDEFINED, parent_comm.rank(), &comm); + + MpiTestComm result = fromComm(comm); + if (comm != MPI_COMM_NULL) { MPI_Comm_free(&comm); } + return result; +} + MpiTestComm MpiTestComm::fromComm(MPI_Comm comm) { return MpiTestComm(comm); } MPI_Comm MpiTestComm::mpiComm() const { return comm_; } diff --git a/tests/ctest/mpi_test_utils.h b/tests/ctest/mpi_test_utils.h index 209a24b..cc13dd1 100644 --- a/tests/ctest/mpi_test_utils.h +++ b/tests/ctest/mpi_test_utils.h @@ -21,6 +21,7 @@ class MpiTestComm { static MpiTestComm world(); static MpiTestComm split(const MpiTestComm& parent_comm, int requested_ranks); + static MpiTestComm splitRange(const MpiTestComm& parent_comm, int first_rank, int requested_ranks); static MpiTestComm fromComm(MPI_Comm comm); MPI_Comm mpiComm() const;