Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions include/internal/comm_routines.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ static void nvshmemAlltoallV(const cudecompHandle_t& handle, const cudecompGridD
auto& comm_info = (comm_axis == CUDECOMP_COMM_ROW) ? grid_desc->row_comm_info : grid_desc->col_comm_info;
auto team = comm_info.nvshmem_team;
int self_rank = comm_info.rank;
auto aux_stream = handle->streams[handle->device_p2p_ce_count];
auto aux_stream = handle->streams[handle->device_p2p_ce_count].get();

// Enforce sync dependency between transpose operations
CHECK_CUDA(cudaStreamWaitEvent(stream, grid_desc->nvshmem_sync_event));
Expand Down Expand Up @@ -445,7 +445,7 @@ cudecompAlltoallPipelined(const cudecompHandle_t& handle, const cudecompGridDesc
#ifdef ENABLE_NVSHMEM
if (nvshmem_ptr(send_buff, handle->rank) && nvshmem_ptr(recv_buff, handle->rank)) {
auto& comm_info = (comm_axis == CUDECOMP_COMM_ROW) ? grid_desc->row_comm_info : grid_desc->col_comm_info;
auto pl_stream = handle->streams[0];
auto pl_stream = handle->streams[0].get();
int self_rank = (comm_axis == CUDECOMP_COMM_ROW) ? grid_desc->row_comm_info.rank : grid_desc->col_comm_info.rank;

// Enforce sync dependency between transpose operations
Expand Down Expand Up @@ -514,7 +514,7 @@ cudecompAlltoallPipelined(const cudecompHandle_t& handle, const cudecompGridDesc
auto comm_info = (comm_axis == CUDECOMP_COMM_ROW) ? grid_desc->row_comm_info : grid_desc->col_comm_info;
// For fully intra-group alltoall, use distinct NCCL local comm instead of global comm as it is faster.
auto comm = (comm_info.ngroups == 1) ? *grid_desc->nccl_local_comm : *grid_desc->nccl_comm;
auto pl_stream = handle->streams[0];
auto pl_stream = handle->streams[0].get();
int self_rank = comm_info.rank;

bool group_started = false;
Expand Down
4 changes: 2 additions & 2 deletions include/internal/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@

#include "cudecomp.h"
#include "internal/checks.h"
#include "internal/cuda_event.h"
#include "internal/graph.h"
#include "internal/raii_wrappers.h"

namespace cudecomp {
#if NVML_API_VERSION >= 12 && CUDART_VERSION >= 12040
Expand Down Expand Up @@ -94,7 +94,7 @@ struct cudecompHandle {
std::unordered_map<void*, std::vector<std::pair<cudecomp::ncclComm, void*>>>
nccl_ubr_handles; // map of allocated buffer address to NCCL registration handle(s)

std::vector<cudaStream_t> streams; // internal streams for concurrent scheduling
std::vector<cudecomp::cudaStream> streams; // internal streams for concurrent scheduling

cutensorHandle_t cutensor_handle; // cuTENSOR handle;
#if CUTENSOR_MAJOR >= 2
Expand Down
4 changes: 2 additions & 2 deletions include/internal/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include "cudecomp.h"
#include "internal/checks.h"
#include "internal/hashes.h"
#include "internal/raii_wrappers.h"
#include "internal/utils.h"

namespace cudecomp {
Expand All @@ -34,7 +35,6 @@ class graphCache {
using key_type = std::tuple<void*, void*, int, int, cudecompPencilInfo_t, cudecompPencilInfo_t, cudecompDataType_t>;

public:
graphCache();
~graphCache();
void replay(const key_type& key, cudaStream_t stream) const;
cudaStream_t startCapture(const key_type& key, cudaStream_t stream) const;
Expand All @@ -44,7 +44,7 @@ class graphCache {

private:
std::unordered_map<key_type, cudaGraphExec_t> graph_cache_;
cudaStream_t graph_stream_;
cudaStream graph_stream_;
};

} // namespace cudecomp
Expand Down
44 changes: 41 additions & 3 deletions include/internal/cuda_event.h → include/internal/raii_wrappers.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
* limitations under the License.
*/

#ifndef CUDECOMP_CUDA_EVENT_H
#define CUDECOMP_CUDA_EVENT_H
#ifndef CUDECOMP_RAII_WRAPPERS_H
#define CUDECOMP_RAII_WRAPPERS_H

#include <utility>

Expand Down Expand Up @@ -61,6 +61,44 @@ template <unsigned int flags> class cudaEventBase {
using cudaEvent = cudaEventBase<cudaEventDisableTiming>;
using cudaEventTimed = cudaEventBase<cudaEventDefault>;

template <unsigned int flags> class cudaStreamBase {
public:
cudaStreamBase() {
int greatest_priority;
CHECK_CUDA(cudaDeviceGetStreamPriorityRange(nullptr, &greatest_priority));
CHECK_CUDA(cudaStreamCreateWithPriority(&stream_, flags, greatest_priority));
}
~cudaStreamBase() noexcept { resetNoThrow(); }

cudaStreamBase(const cudaStreamBase&) = delete;
cudaStreamBase& operator=(const cudaStreamBase&) = delete;

cudaStreamBase(cudaStreamBase&& other) noexcept : stream_(std::exchange(other.stream_, nullptr)) {}

cudaStreamBase& operator=(cudaStreamBase&& other) noexcept {
if (this != &other) {
resetNoThrow();
stream_ = std::exchange(other.stream_, nullptr);
}
return *this;
}

cudaStream_t get() const noexcept { return stream_; }
operator cudaStream_t() const noexcept { return stream_; }

private:
void resetNoThrow() noexcept {
if (stream_) {
cudaStreamDestroy(stream_);
stream_ = nullptr;
}
}

cudaStream_t stream_ = nullptr;
};

using cudaStream = cudaStreamBase<cudaStreamNonBlocking>;

} // namespace cudecomp

#endif // CUDECOMP_CUDA_EVENT_H
#endif // CUDECOMP_RAII_WRAPPERS_H
2 changes: 1 addition & 1 deletion include/internal/transpose.h
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ static void cudecompTranspose_(int ax, int dir, const cudecompHandle_t handle, c
if (splits_a.size() != 1) {
auto& comm_info = (comm_axis == CUDECOMP_COMM_ROW) ? grid_desc->row_comm_info : grid_desc->col_comm_info;
auto team = comm_info.nvshmem_team;
auto aux_stream = handle->streams[handle->device_p2p_ce_count];
auto aux_stream = handle->streams[handle->device_p2p_ce_count].get();
CHECK_CUDA(cudaEventRecord(grid_desc->nvshmem_sync_event, stream));
CHECK_CUDA(cudaStreamWaitEvent(aux_stream, grid_desc->nvshmem_sync_event));
// Zero out signal buffer for this team here.
Expand Down
12 changes: 1 addition & 11 deletions src/cudecomp.cc
Original file line number Diff line number Diff line change
Expand Up @@ -631,9 +631,6 @@ cudecompResult_t cudecompFinalize(cudecompHandle_t handle) {
handle->nccl_comm.reset();
handle->nccl_local_comm.reset();

for (auto& stream : handle->streams) {
CHECK_CUDA(cudaStreamDestroy(stream));
}
#ifdef ENABLE_NVSHMEM
if (handle->nvshmem_runtime) { handle->nvshmem_runtime->finalize(); }
handle->nvshmem_allocations.clear();
Expand Down Expand Up @@ -793,14 +790,7 @@ cudecompResult_t cudecompGridDescCreate(cudecompHandle_t handle, cudecompGridDes
#endif
}

if (handle->streams.empty()) {
handle->streams.resize(handle->device_p2p_ce_count + 1);
int greatest_priority;
CHECK_CUDA(cudaDeviceGetStreamPriorityRange(nullptr, &greatest_priority));
for (auto& stream : handle->streams) {
CHECK_CUDA(cudaStreamCreateWithPriority(&stream, cudaStreamNonBlocking, greatest_priority));
}
}
if (handle->streams.empty()) { handle->streams.resize(handle->device_p2p_ce_count + 1); }

// Create CUDA events for scheduling
grid_desc->events.resize(handle->nranks);
Expand Down
7 changes: 1 addition & 6 deletions src/graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,7 @@

namespace cudecomp {

graphCache::graphCache() { CHECK_CUDA(cudaStreamCreateWithFlags(&graph_stream_, cudaStreamNonBlocking)); }

graphCache::~graphCache() {
CHECK_CUDA(cudaStreamDestroy(graph_stream_));
this->clear();
}
graphCache::~graphCache() { this->clear(); }

void graphCache::replay(const graphCache::key_type& key, cudaStream_t stream) const {
CHECK_CUDA(cudaGraphLaunch(graph_cache_.at(key), stream));
Expand Down
Loading