From d5d62bf6e579d3c1183fb03cdc1bb5969c22d42e Mon Sep 17 00:00:00 2001 From: Josh Romero Date: Mon, 4 May 2026 11:31:30 -0700 Subject: [PATCH] Use RAII cudaStream wrapper to simplify stream creation/cleanup. Signed-off-by: Josh Romero --- include/internal/comm_routines.h | 6 +-- include/internal/common.h | 4 +- include/internal/graph.h | 4 +- .../{cuda_event.h => raii_wrappers.h} | 44 +++++++++++++++++-- include/internal/transpose.h | 2 +- src/cudecomp.cc | 12 +---- src/graph.cc | 7 +-- 7 files changed, 51 insertions(+), 28 deletions(-) rename include/internal/{cuda_event.h => raii_wrappers.h} (60%) diff --git a/include/internal/comm_routines.h b/include/internal/comm_routines.h index 36388ea..f613f55 100644 --- a/include/internal/comm_routines.h +++ b/include/internal/comm_routines.h @@ -107,7 +107,7 @@ static void nvshmemAlltoallV(const cudecompHandle_t& handle, const cudecompGridD auto& comm_info = (comm_axis == CUDECOMP_COMM_ROW) ? grid_desc->row_comm_info : grid_desc->col_comm_info; auto team = comm_info.nvshmem_team; int self_rank = comm_info.rank; - auto aux_stream = handle->streams[handle->device_p2p_ce_count]; + auto aux_stream = handle->streams[handle->device_p2p_ce_count].get(); // Enforce sync dependency between transpose operations CHECK_CUDA(cudaStreamWaitEvent(stream, grid_desc->nvshmem_sync_event)); @@ -445,7 +445,7 @@ cudecompAlltoallPipelined(const cudecompHandle_t& handle, const cudecompGridDesc #ifdef ENABLE_NVSHMEM if (nvshmem_ptr(send_buff, handle->rank) && nvshmem_ptr(recv_buff, handle->rank)) { auto& comm_info = (comm_axis == CUDECOMP_COMM_ROW) ? grid_desc->row_comm_info : grid_desc->col_comm_info; - auto pl_stream = handle->streams[0]; + auto pl_stream = handle->streams[0].get(); int self_rank = (comm_axis == CUDECOMP_COMM_ROW) ? grid_desc->row_comm_info.rank : grid_desc->col_comm_info.rank; // Enforce sync dependency between transpose operations @@ -514,7 +514,7 @@ cudecompAlltoallPipelined(const cudecompHandle_t& handle, const cudecompGridDesc auto comm_info = (comm_axis == CUDECOMP_COMM_ROW) ? grid_desc->row_comm_info : grid_desc->col_comm_info; // For fully intra-group alltoall, use distinct NCCL local comm instead of global comm as it is faster. auto comm = (comm_info.ngroups == 1) ? *grid_desc->nccl_local_comm : *grid_desc->nccl_comm; - auto pl_stream = handle->streams[0]; + auto pl_stream = handle->streams[0].get(); int self_rank = comm_info.rank; bool group_started = false; diff --git a/include/internal/common.h b/include/internal/common.h index 514cf0a..62ac55e 100644 --- a/include/internal/common.h +++ b/include/internal/common.h @@ -45,8 +45,8 @@ #include "cudecomp.h" #include "internal/checks.h" -#include "internal/cuda_event.h" #include "internal/graph.h" +#include "internal/raii_wrappers.h" namespace cudecomp { #if NVML_API_VERSION >= 12 && CUDART_VERSION >= 12040 @@ -94,7 +94,7 @@ struct cudecompHandle { std::unordered_map>> nccl_ubr_handles; // map of allocated buffer address to NCCL registration handle(s) - std::vector streams; // internal streams for concurrent scheduling + std::vector streams; // internal streams for concurrent scheduling cutensorHandle_t cutensor_handle; // cuTENSOR handle; #if CUTENSOR_MAJOR >= 2 diff --git a/include/internal/graph.h b/include/internal/graph.h index d13d9e0..578b420 100644 --- a/include/internal/graph.h +++ b/include/internal/graph.h @@ -26,6 +26,7 @@ #include "cudecomp.h" #include "internal/checks.h" #include "internal/hashes.h" +#include "internal/raii_wrappers.h" #include "internal/utils.h" namespace cudecomp { @@ -34,7 +35,6 @@ class graphCache { using key_type = std::tuple; public: - graphCache(); ~graphCache(); void replay(const key_type& key, cudaStream_t stream) const; cudaStream_t startCapture(const key_type& key, cudaStream_t stream) const; @@ -44,7 +44,7 @@ class graphCache { private: std::unordered_map graph_cache_; - cudaStream_t graph_stream_; + cudaStream graph_stream_; }; } // namespace cudecomp diff --git a/include/internal/cuda_event.h b/include/internal/raii_wrappers.h similarity index 60% rename from include/internal/cuda_event.h rename to include/internal/raii_wrappers.h index 27e7100..94c22e5 100644 --- a/include/internal/cuda_event.h +++ b/include/internal/raii_wrappers.h @@ -15,8 +15,8 @@ * limitations under the License. */ -#ifndef CUDECOMP_CUDA_EVENT_H -#define CUDECOMP_CUDA_EVENT_H +#ifndef CUDECOMP_RAII_WRAPPERS_H +#define CUDECOMP_RAII_WRAPPERS_H #include @@ -61,6 +61,44 @@ template class cudaEventBase { using cudaEvent = cudaEventBase; using cudaEventTimed = cudaEventBase; +template class cudaStreamBase { +public: + cudaStreamBase() { + int greatest_priority; + CHECK_CUDA(cudaDeviceGetStreamPriorityRange(nullptr, &greatest_priority)); + CHECK_CUDA(cudaStreamCreateWithPriority(&stream_, flags, greatest_priority)); + } + ~cudaStreamBase() noexcept { resetNoThrow(); } + + cudaStreamBase(const cudaStreamBase&) = delete; + cudaStreamBase& operator=(const cudaStreamBase&) = delete; + + cudaStreamBase(cudaStreamBase&& other) noexcept : stream_(std::exchange(other.stream_, nullptr)) {} + + cudaStreamBase& operator=(cudaStreamBase&& other) noexcept { + if (this != &other) { + resetNoThrow(); + stream_ = std::exchange(other.stream_, nullptr); + } + return *this; + } + + cudaStream_t get() const noexcept { return stream_; } + operator cudaStream_t() const noexcept { return stream_; } + +private: + void resetNoThrow() noexcept { + if (stream_) { + cudaStreamDestroy(stream_); + stream_ = nullptr; + } + } + + cudaStream_t stream_ = nullptr; +}; + +using cudaStream = cudaStreamBase; + } // namespace cudecomp -#endif // CUDECOMP_CUDA_EVENT_H +#endif // CUDECOMP_RAII_WRAPPERS_H diff --git a/include/internal/transpose.h b/include/internal/transpose.h index 745ddad..b53dcd1 100644 --- a/include/internal/transpose.h +++ b/include/internal/transpose.h @@ -292,7 +292,7 @@ static void cudecompTranspose_(int ax, int dir, const cudecompHandle_t handle, c if (splits_a.size() != 1) { auto& comm_info = (comm_axis == CUDECOMP_COMM_ROW) ? grid_desc->row_comm_info : grid_desc->col_comm_info; auto team = comm_info.nvshmem_team; - auto aux_stream = handle->streams[handle->device_p2p_ce_count]; + auto aux_stream = handle->streams[handle->device_p2p_ce_count].get(); CHECK_CUDA(cudaEventRecord(grid_desc->nvshmem_sync_event, stream)); CHECK_CUDA(cudaStreamWaitEvent(aux_stream, grid_desc->nvshmem_sync_event)); // Zero out signal buffer for this team here. diff --git a/src/cudecomp.cc b/src/cudecomp.cc index a10a32e..532c669 100644 --- a/src/cudecomp.cc +++ b/src/cudecomp.cc @@ -631,9 +631,6 @@ cudecompResult_t cudecompFinalize(cudecompHandle_t handle) { handle->nccl_comm.reset(); handle->nccl_local_comm.reset(); - for (auto& stream : handle->streams) { - CHECK_CUDA(cudaStreamDestroy(stream)); - } #ifdef ENABLE_NVSHMEM if (handle->nvshmem_runtime) { handle->nvshmem_runtime->finalize(); } handle->nvshmem_allocations.clear(); @@ -793,14 +790,7 @@ cudecompResult_t cudecompGridDescCreate(cudecompHandle_t handle, cudecompGridDes #endif } - if (handle->streams.empty()) { - handle->streams.resize(handle->device_p2p_ce_count + 1); - int greatest_priority; - CHECK_CUDA(cudaDeviceGetStreamPriorityRange(nullptr, &greatest_priority)); - for (auto& stream : handle->streams) { - CHECK_CUDA(cudaStreamCreateWithPriority(&stream, cudaStreamNonBlocking, greatest_priority)); - } - } + if (handle->streams.empty()) { handle->streams.resize(handle->device_p2p_ce_count + 1); } // Create CUDA events for scheduling grid_desc->events.resize(handle->nranks); diff --git a/src/graph.cc b/src/graph.cc index 88b62a1..2d6dc86 100644 --- a/src/graph.cc +++ b/src/graph.cc @@ -27,12 +27,7 @@ namespace cudecomp { -graphCache::graphCache() { CHECK_CUDA(cudaStreamCreateWithFlags(&graph_stream_, cudaStreamNonBlocking)); } - -graphCache::~graphCache() { - CHECK_CUDA(cudaStreamDestroy(graph_stream_)); - this->clear(); -} +graphCache::~graphCache() { this->clear(); } void graphCache::replay(const graphCache::key_type& key, cudaStream_t stream) const { CHECK_CUDA(cudaGraphLaunch(graph_cache_.at(key), stream));