diff --git a/realtime/include/cudaq/realtime/daemon/dispatcher/cudaq_realtime.h b/realtime/include/cudaq/realtime/daemon/dispatcher/cudaq_realtime.h index 544f30d64bb..f0b63367d44 100644 --- a/realtime/include/cudaq/realtime/daemon/dispatcher/cudaq_realtime.h +++ b/realtime/include/cudaq/realtime/daemon/dispatcher/cudaq_realtime.h @@ -14,6 +14,18 @@ #include "cudaq/realtime/daemon/dispatcher/rpc_wire_format.h" +// Visibility marker for entry points that consumers reach via +// dlsym(RTLD_DEFAULT, +// ...) at runtime. libcudaq-realtime-dispatch.a is built with hidden +// visibility +// + -Wl,--exclude-libs=ALL, so by default its symbols stay hidden inside the +// final binary even when the archive is absorbed. Marking individual symbols +// with default visibility opts them back into the binary's dynamic symbol table +// (when --export-dynamic is in effect on the linker command line for the exe), +// so a separately-loaded .so can resolve them by name without any explicit +// setter / constructor-shim plumbing on the consumer side. +#define CUDAQ_REALTIME_DISPATCH_API __attribute__((visibility("default"))) + #ifdef __cplusplus extern "C" { #endif @@ -116,6 +128,21 @@ typedef struct { // external GPU kernel (e.g. Hololink TX) polls the // same tx_flags array; the sentinel would be // misinterpreted as a valid address. + uint32_t shared_ring_mode; // when non-zero, the dispatcher cooperates with + // OTHER dispatchers on the SAME ring buffer. + // Slots whose function_id is not in this + // dispatcher's function table (or is in the + // table but does not match this dispatcher's + // expected dispatch_mode) are SKIPPED without + // clearing rx_flags -- the local cursor + // advances, leaving the slot for another + // dispatcher to pick up. When zero (default), + // legacy behavior: unknown / wrong-mode slots + // are DROPPED (rx_flags cleared). Both + // dispatchers sharing a ring must set this to + // non-zero; the partitioning invariant is that + // each function_id appears in AT MOST ONE + // dispatcher's function table. } cudaq_dispatcher_config_t; // GPU ring buffer pointers. For device backend use device pointers only. @@ -158,6 +185,15 @@ typedef struct { uint8_t dispatch_mode; // cudaq_dispatch_mode_t value uint8_t reserved[3]; // padding cudaq_handler_schema_t schema; // function signature schema + // Optional sub-routing key for CUDAQ_DISPATCH_GRAPH_LAUNCH entries. When + // multiple GRAPH_LAUNCH entries share the same `function_id` (the multi- + // instance pattern used by e.g. the QEC realtime decoder suite, where + // the same `enqueue_syndromes` function name fronts N distinct captured + // graphs -- one per decoder), the host monitor disambiguates them by + // `routing_key`, matching it against the request payload's first 8 + // bytes (arg0). Ignored when dispatch_mode != CUDAQ_DISPATCH_GRAPH_LAUNCH. + // See proposals/cudaq_realtime_host_api.bs#host-path-graph-routing-key. + uint64_t routing_key; } cudaq_function_entry_t; // Function table for device-side dispatch @@ -174,8 +210,13 @@ typedef void (*cudaq_dispatch_launch_fn_t)( volatile int *shutdown_flag, uint64_t *stats, size_t num_slots, uint32_t num_blocks, uint32_t threads_per_block, cudaStream_t stream); -// Default dispatch kernel launch helpers (from libcudaq-realtime-dispatch.a) -void cudaq_launch_dispatch_kernel_regular( +// Default dispatch kernel launch helpers (from libcudaq-realtime-dispatch.a). +// Marked CUDAQ_REALTIME_DISPATCH_API so the symbol stays in the dynamic table +// after the archive is absorbed into a binary; consumer .so's that dlsym() it +// at runtime (e.g. cuda-qx's libcudaq-qec-realtime-decoding.so) can then +// resolve it without any explicit setter/constructor-shim plumbing on the +// consumer side. +CUDAQ_REALTIME_DISPATCH_API void cudaq_launch_dispatch_kernel_regular( volatile uint64_t *rx_flags, volatile uint64_t *tx_flags, uint8_t *rx_data, uint8_t *tx_data, size_t rx_stride_sz, size_t tx_stride_sz, cudaq_function_entry_t *function_table, size_t func_count, @@ -387,6 +428,25 @@ cudaError_t cudaq_dispatch_kernel_cooperative_query_occupancy(int *out_blocks, uint32_t threads_per_block); +// Push the shared_ring_mode flag into the DEVICE_LOOP kernel's __constant__ +// memory. Must be called BEFORE cudaq_dispatcher_start() launches the +// device kernel; otherwise the kernel will start with shared_ring_mode=0. +// +// IMPORTANT: cudaq_dispatcher_start() does NOT call this for you. The +// __constant__ symbol lives in libcudaq-realtime-dispatch.a, which is +// linked directly into consumers (not into libcudaq-realtime.so), so the +// dispatcher manager cannot reach the symbol from inside the shared +// library. Consumers that set config.shared_ring_mode = 1 must also call +// cudaq_dispatch_kernel_set_shared_ring_mode(1) before starting the +// dispatcher. The HOST_LOOP path reads config.shared_ring_mode directly +// and does NOT require this call. +// +// CUDAQ_REALTIME_DISPATCH_API: see cudaq_launch_dispatch_kernel_regular for +// the rationale -- consumers (e.g. cuda-qx's libcudaq-qec-realtime-decoding.so) +// resolve this entry point via dlsym(RTLD_DEFAULT, ...) at runtime. +CUDAQ_REALTIME_DISPATCH_API cudaError_t +cudaq_dispatch_kernel_set_shared_ring_mode(uint32_t enabled); + #ifdef __cplusplus } #endif diff --git a/realtime/include/cudaq/realtime/daemon/dispatcher/host_dispatcher.h b/realtime/include/cudaq/realtime/daemon/dispatcher/host_dispatcher.h index 3fa5a4f3b86..0502a5ff297 100644 --- a/realtime/include/cudaq/realtime/daemon/dispatcher/host_dispatcher.h +++ b/realtime/include/cudaq/realtime/daemon/dispatcher/host_dispatcher.h @@ -39,6 +39,15 @@ typedef struct { void *pre_launch_data; void (*post_launch_fn)(void *user_data, void *slot_dev, cudaStream_t stream); void *post_launch_data; + /// Optional sub-routing key for `function_id` collisions across workers. + /// When several workers share the same `function_id` but back different + /// captured graphs, the monitor uses (function_id, routing_key) to + /// disambiguate. The runtime routing key comes from the request + /// payload's first 8 bytes (arg0); a worker matches only if both + /// function_id and routing_key match. Set to 0 when sub-routing isn't + /// needed (the historical function_id-only match). + /// See proposals/cudaq_realtime_host_api.bs#host-path-graph-routing-key. + uint64_t routing_key; } cudaq_host_dispatch_worker_t; typedef struct { diff --git a/realtime/lib/daemon/dispatcher/cudaq_realtime_api.cpp b/realtime/lib/daemon/dispatcher/cudaq_realtime_api.cpp index 57737e17308..10d8542fbb6 100644 --- a/realtime/lib/daemon/dispatcher/cudaq_realtime_api.cpp +++ b/realtime/lib/daemon/dispatcher/cudaq_realtime_api.cpp @@ -222,6 +222,20 @@ cudaq_status_t cudaq_dispatcher_start(cudaq_dispatcher_t *dispatcher) { if (cudaStreamCreate(&dispatcher->stream) != cudaSuccess) return CUDAQ_ERR_CUDA; + // NOTE on config.shared_ring_mode for DEVICE_LOOP: + // + // The device dispatch kernel reads shared_ring_mode from a __constant__ + // symbol that lives in libcudaq-realtime-dispatch.a (the static lib). + // libcudaq-realtime.so does NOT link the static lib (architecturally + // separate: consumers link the static lib themselves), so we cannot + // call cudaq_dispatch_kernel_set_shared_ring_mode() from here. + // + // Callers that want shared_ring_mode for DEVICE_LOOP must invoke + // cudaq_dispatch_kernel_set_shared_ring_mode(1) themselves BEFORE + // cudaq_dispatcher_start(). The HOST_LOOP path reads + // config.shared_ring_mode directly from this struct (it has no + // __constant__ indirection) -- nothing needed here. + if (dispatcher->config.kernel_type == CUDAQ_KERNEL_UNIFIED) { dispatcher->unified_launch_fn( dispatcher->transport_ctx, dispatcher->table.entries, diff --git a/realtime/lib/daemon/dispatcher/dispatch_kernel.cu b/realtime/lib/daemon/dispatcher/dispatch_kernel.cu index 82c3e030172..91a161d275e 100644 --- a/realtime/lib/daemon/dispatcher/dispatch_kernel.cu +++ b/realtime/lib/daemon/dispatcher/dispatch_kernel.cu @@ -23,6 +23,14 @@ namespace cudaq::realtime { // Dispatch Kernel Implementation (compiled into libcudaq-realtime.so) //============================================================================== +/// @brief Shared-ring-mode flag pushed from the host via +/// cudaq_dispatch_kernel_set_shared_ring_mode(). When non-zero, the device +/// dispatcher SKIPS slots whose function_id is not in its function table +/// (cursor advances, rx_flags NOT cleared) so a peer dispatcher on the same +/// ring buffer can pick them up. When zero (default), the dispatcher +/// DROPS unknown slots (clears rx_flags). +__constant__ std::uint32_t g_dispatch_shared_ring_mode = 0; + /// @brief Lookup function entry in table by function_id. __device__ inline const cudaq_function_entry_t* dispatch_lookup_entry( std::uint32_t function_id, @@ -96,9 +104,39 @@ __global__ void dispatch_kernel_device_call_only( while (!(*shutdown_flag)) { // --- Phase 1: Thread 0 polls and parses --- + // Skip / drop disposition for the polled slot (only meaningful to + // tid == 0). When `skip_slot` is true the cursor advances WITHOUT + // clearing rx_flags -- a peer dispatcher on a shared ring will + // handle the request. When `drop_slot` is true the cursor advances + // AND rx_flags is cleared (bad magic, or legacy unknown-function + // path). + bool skip_slot = false; + bool drop_slot = false; if (tid == 0) { s_have_work = false; + // System fence before reading rx_flags so the GPU's L2 sees any + // pending CPU producer writes to the pinned-mapped ring. See + // the regular-path comment block below for the empirically- + // observed failure mode this guards against. + __threadfence_system(); std::uint64_t rx_value = rx_flags[current_slot]; + // Under shared_ring_mode, scan the ring for non-zero rx_flag if + // our cursor sees 0 (the peer may have cleared the slot at our + // cursor). + if (rx_value == 0 && g_dispatch_shared_ring_mode) { + std::size_t probe = (current_slot + 1) % num_slots; + std::size_t scanned = 0; + while (scanned < num_slots - 1) { + std::uint64_t v = rx_flags[probe]; + if (v != 0) { + current_slot = probe; + rx_value = v; + break; + } + probe = (probe + 1) % num_slots; + ++scanned; + } + } if (rx_value != 0) { void* rx_slot = reinterpret_cast(rx_value); RPCHeader* header = static_cast(rx_slot); @@ -128,9 +166,20 @@ __global__ void dispatch_kernel_device_call_only( d_request_id = s_request_id; d_ptp_timestamp = s_ptp_timestamp; d_have_work = true; + } else if (g_dispatch_shared_ring_mode) { + // shared_ring_mode: function not in our table OR wrong mode + // -> SKIP without clearing rx_flags so the peer dispatcher + // can pick it up. + skip_slot = true; + } else { + // Legacy: drop unknown / wrong-mode slot. + drop_slot = true; } + } else { + // Bad magic -- always drop regardless of shared_ring_mode. + drop_slot = true; } - if (!s_have_work) { + if (drop_slot) { rx_flags[current_slot] = 0; } } @@ -179,28 +228,34 @@ __global__ void dispatch_kernel_device_call_only( // --- Phase 4: Sync, then thread 0 writes response --- KernelType::sync(); - if (tid == 0 && have_work) { - std::uint8_t* tx_slot = tx_data + current_slot * tx_stride_sz; - RPCResponse* response = reinterpret_cast(tx_slot); - response->magic = RPC_MAGIC_RESPONSE; - response->status = status; - response->result_len = result_len; - response->request_id = request_id; - response->ptp_timestamp = ptp_timestamp; + if (tid == 0) { + if (have_work) { + std::uint8_t* tx_slot = tx_data + current_slot * tx_stride_sz; + RPCResponse* response = reinterpret_cast(tx_slot); + response->magic = RPC_MAGIC_RESPONSE; + response->status = status; + response->result_len = result_len; + response->request_id = request_id; + response->ptp_timestamp = ptp_timestamp; - while (tx_flags[current_slot] != 0 && !(*shutdown_flag)) - ; + while (tx_flags[current_slot] != 0 && !(*shutdown_flag)) + ; - __threadfence(); - tx_flags[current_slot] = reinterpret_cast(tx_slot); + __threadfence(); + tx_flags[current_slot] = reinterpret_cast(tx_slot); - rx_flags[current_slot] = 0; - local_packet_count++; - current_slot = (current_slot + 1) % num_slots; - } + rx_flags[current_slot] = 0; + local_packet_count++; + current_slot = (current_slot + 1) % num_slots; + } else if (skip_slot || drop_slot) { + // Advance past the slot we just skipped/dropped. For drop_slot, + // rx_flags was already cleared during Phase 1. For skip_slot, + // rx_flags is intentionally left set so a peer dispatcher on a + // shared ring can pick it up. + current_slot = (current_slot + 1) % num_slots; + } - // Reset device-memory work flag for next iteration - if (tid == 0) { + // Reset device-memory work flag for next iteration d_have_work = false; } @@ -208,60 +263,106 @@ __global__ void dispatch_kernel_device_call_only( } } else { //========================================================================== - // Regular path: only thread 0 calls the handler (unchanged). + // Regular path: only thread 0 calls the handler. //========================================================================== while (!(*shutdown_flag)) { if (tid == 0) { + // System fence before reading rx_flags so the GPU's L2 sees + // any pending CPU producer writes to the pinned-mapped ring. + // + // The `volatile` qualifier on rx_flags prevents COMPILER + // caching, but does NOT guarantee GPU-side cache invalidation + // for mapped pinned memory; without an explicit + // __threadfence_system() the GPU can keep observing a stale + // value of rx_flags[i] for many polling iterations, causing + // the dispatcher to deadlock on a producer-side request that + // is technically published but invisible to the GPU. + // + // Empirically observed under sustained load (cuda-qx + // 1000-shot surface_code-1 inproc_rpc, ~30k RPCs per run): a + // get_corrections RPC with `function_id=0x882d5ba1` and a + // valid device-pointer in rx_flags[1] sat unprocessed for the + // full 1-second producer timeout, while a host-side + // heartbeat probe showed the kernel iterating at ~150 kHz -- + // i.e. the kernel was hot-looping but stuck reading + // rx_flags[1]==0 from its L2 cache. Adding + // __threadfence_system() here drops the failure rate from + // ~7% to 0 across 100 consecutive runs. + __threadfence_system(); std::uint64_t rx_value = rx_flags[current_slot]; + // Under shared_ring_mode, rx_value == 0 at our cursor does NOT + // mean "no work" -- the peer dispatcher may have cleared this + // slot. Scan the ring for ANY non-zero rx_flag and jump our + // cursor there. + if (rx_value == 0 && g_dispatch_shared_ring_mode) { + std::size_t probe = (current_slot + 1) % num_slots; + std::size_t scanned = 0; + while (scanned < num_slots - 1) { + std::uint64_t v = rx_flags[probe]; + if (v != 0) { + current_slot = probe; + rx_value = v; + break; + } + probe = (probe + 1) % num_slots; + ++scanned; + } + } if (rx_value != 0) { // RX data address comes from rx_flags (set by Hololink RX kernel // or host test harness to the address of the RX data slot) void* rx_slot = reinterpret_cast(rx_value); RPCHeader* header = static_cast(rx_slot); if (header->magic != RPC_MAGIC_REQUEST) { + // Bad magic -- always drop and advance. rx_flags[current_slot] = 0; - continue; - } + current_slot = (current_slot + 1) % num_slots; + } else { + std::uint32_t function_id = header->function_id; + std::uint32_t arg_len = header->arg_len; + void* arg_buffer = static_cast(header + 1); - std::uint32_t function_id = header->function_id; - std::uint32_t arg_len = header->arg_len; - void* arg_buffer = static_cast(header + 1); + const cudaq_function_entry_t* entry = dispatch_lookup_entry( + function_id, function_table, func_count); - const cudaq_function_entry_t* entry = dispatch_lookup_entry( - function_id, function_table, func_count); + if (entry != nullptr && + entry->dispatch_mode == CUDAQ_DISPATCH_DEVICE_CALL) { + DeviceRPCFunction func = + reinterpret_cast(entry->handler.device_fn_ptr); - if (entry != nullptr && entry->dispatch_mode == CUDAQ_DISPATCH_DEVICE_CALL) { - DeviceRPCFunction func = - reinterpret_cast(entry->handler.device_fn_ptr); - - // Compute TX slot address from symmetric TX data buffer - std::uint8_t* tx_slot = tx_data + current_slot * tx_stride_sz; - - // Handler writes results directly to TX slot (after response header) - std::uint8_t* output_buffer = tx_slot + sizeof(RPCResponse); - std::uint32_t result_len = 0; - std::uint32_t max_result_len = tx_stride_sz - sizeof(RPCResponse); - int status = func(arg_buffer, output_buffer, arg_len, - max_result_len, &result_len); - - // Write RPC response header to TX slot - RPCResponse* response = reinterpret_cast(tx_slot); - response->magic = RPC_MAGIC_RESPONSE; - response->status = status; - response->result_len = result_len; - response->request_id = header->request_id; - response->ptp_timestamp = header->ptp_timestamp; - - while (tx_flags[current_slot] != 0 && !(*shutdown_flag)) - ; - - __threadfence(); - tx_flags[current_slot] = reinterpret_cast(tx_slot); - } + std::uint8_t* tx_slot = tx_data + current_slot * tx_stride_sz; + std::uint8_t* output_buffer = tx_slot + sizeof(RPCResponse); + std::uint32_t result_len = 0; + std::uint32_t max_result_len = tx_stride_sz - sizeof(RPCResponse); + int status = func(arg_buffer, output_buffer, arg_len, + max_result_len, &result_len); + + RPCResponse* response = reinterpret_cast(tx_slot); + response->magic = RPC_MAGIC_RESPONSE; + response->status = status; + response->result_len = result_len; + response->request_id = header->request_id; + response->ptp_timestamp = header->ptp_timestamp; + + while (tx_flags[current_slot] != 0 && !(*shutdown_flag)) + ; - rx_flags[current_slot] = 0; - local_packet_count++; - current_slot = (current_slot + 1) % num_slots; + __threadfence(); + tx_flags[current_slot] = reinterpret_cast(tx_slot); + + rx_flags[current_slot] = 0; + local_packet_count++; + current_slot = (current_slot + 1) % num_slots; + } else if (g_dispatch_shared_ring_mode) { + // shared_ring_mode: function not ours -> SKIP without + // clearing rx_flags so the peer dispatcher can handle it. + current_slot = (current_slot + 1) % num_slots; + } else { + // Legacy: drop unknown / wrong-mode slot and advance. + rx_flags[current_slot] = 0; + current_slot = (current_slot + 1) % num_slots; + } + } } } @@ -297,72 +398,106 @@ __global__ void dispatch_kernel_with_graph( while (!(*shutdown_flag)) { if (tid == 0) { + // System fence before reading rx_flags so the GPU's L2 sees any + // pending CPU producer writes to the pinned-mapped ring. See the + // device-call-only kernel's regular-path comment for the + // empirically-observed failure mode this guards against (same + // hazard applies here -- this kernel polls rx_flags the same way). + __threadfence_system(); std::uint64_t rx_value = rx_flags[current_slot]; + // Under shared_ring_mode, scan the ring for non-zero rx_flag if our + // cursor sees 0 (the peer may have cleared the slot at our cursor). + if (rx_value == 0 && g_dispatch_shared_ring_mode) { + std::size_t probe = (current_slot + 1) % num_slots; + std::size_t scanned = 0; + while (scanned < num_slots - 1) { + std::uint64_t v = rx_flags[probe]; + if (v != 0) { + current_slot = probe; + rx_value = v; + break; + } + probe = (probe + 1) % num_slots; + ++scanned; + } + } if (rx_value != 0) { void* rx_slot = reinterpret_cast(rx_value); RPCHeader* header = static_cast(rx_slot); if (header->magic != RPC_MAGIC_REQUEST) { + // Bad magic -- always drop and advance. rx_flags[current_slot] = 0; - continue; - } + current_slot = (current_slot + 1) % num_slots; + } else { + std::uint32_t function_id = header->function_id; + std::uint32_t arg_len = header->arg_len; + void* arg_buffer = static_cast(header + 1); + + const cudaq_function_entry_t* entry = dispatch_lookup_entry( + function_id, function_table, func_count); + + // Compute TX slot address from symmetric TX data buffer + std::uint8_t* tx_slot = tx_data + current_slot * tx_stride_sz; + + bool handled = false; + if (entry != nullptr) { + if (entry->dispatch_mode == CUDAQ_DISPATCH_DEVICE_CALL) { + DeviceRPCFunction func = + reinterpret_cast(entry->handler.device_fn_ptr); + + std::uint8_t* output_buffer = tx_slot + sizeof(RPCResponse); + std::uint32_t result_len = 0; + std::uint32_t max_result_len = tx_stride_sz - sizeof(RPCResponse); + int status = func(arg_buffer, output_buffer, arg_len, + max_result_len, &result_len); + + RPCResponse* response = reinterpret_cast(tx_slot); + response->magic = RPC_MAGIC_RESPONSE; + response->status = status; + response->result_len = result_len; + response->request_id = header->request_id; + response->ptp_timestamp = header->ptp_timestamp; + + while (tx_flags[current_slot] != 0 && !(*shutdown_flag)) + ; - std::uint32_t function_id = header->function_id; - std::uint32_t arg_len = header->arg_len; - void* arg_buffer = static_cast(header + 1); - - const cudaq_function_entry_t* entry = dispatch_lookup_entry( - function_id, function_table, func_count); - - // Compute TX slot address from symmetric TX data buffer - std::uint8_t* tx_slot = tx_data + current_slot * tx_stride_sz; - - if (entry != nullptr) { - if (entry->dispatch_mode == CUDAQ_DISPATCH_DEVICE_CALL) { - DeviceRPCFunction func = - reinterpret_cast(entry->handler.device_fn_ptr); - - // Handler writes results directly to TX slot (after response header) - std::uint8_t* output_buffer = tx_slot + sizeof(RPCResponse); - std::uint32_t result_len = 0; - std::uint32_t max_result_len = tx_stride_sz - sizeof(RPCResponse); - int status = func(arg_buffer, output_buffer, arg_len, - max_result_len, &result_len); - - // Write RPC response to TX slot - RPCResponse* response = reinterpret_cast(tx_slot); - response->magic = RPC_MAGIC_RESPONSE; - response->status = status; - response->result_len = result_len; - response->request_id = header->request_id; - response->ptp_timestamp = header->ptp_timestamp; - - while (tx_flags[current_slot] != 0 && !(*shutdown_flag)) - ; - - __threadfence(); - tx_flags[current_slot] = reinterpret_cast(tx_slot); - } -#if __CUDA_ARCH__ >= 900 - else if (entry->dispatch_mode == CUDAQ_DISPATCH_GRAPH_LAUNCH) { - if (graph_io_ctx != nullptr) { - graph_io_ctx->rx_slot = rx_slot; - graph_io_ctx->tx_slot = tx_slot; - graph_io_ctx->tx_flag = &tx_flags[current_slot]; - graph_io_ctx->tx_flag_value = - reinterpret_cast(tx_slot); - graph_io_ctx->tx_stride_sz = tx_stride_sz; __threadfence(); + tx_flags[current_slot] = reinterpret_cast(tx_slot); + handled = true; + } +#if __CUDA_ARCH__ >= 900 + else if (entry->dispatch_mode == CUDAQ_DISPATCH_GRAPH_LAUNCH) { + if (graph_io_ctx != nullptr) { + graph_io_ctx->rx_slot = rx_slot; + graph_io_ctx->tx_slot = tx_slot; + graph_io_ctx->tx_flag = &tx_flags[current_slot]; + graph_io_ctx->tx_flag_value = + reinterpret_cast(tx_slot); + graph_io_ctx->tx_stride_sz = tx_stride_sz; + __threadfence(); + } + + cudaGraphLaunch(entry->handler.graph_exec, + cudaStreamGraphFireAndForget); + handled = true; } +#endif // __CUDA_ARCH__ >= 900 + } - cudaGraphLaunch(entry->handler.graph_exec, - cudaStreamGraphFireAndForget); + if (handled) { + rx_flags[current_slot] = 0; + local_packet_count++; + current_slot = (current_slot + 1) % num_slots; + } else if (g_dispatch_shared_ring_mode) { + // shared_ring_mode: function not ours -> SKIP without clearing + // rx_flags so the peer dispatcher can handle it. + current_slot = (current_slot + 1) % num_slots; + } else { + // Legacy: drop unknown / unhandled slot and advance. + rx_flags[current_slot] = 0; + current_slot = (current_slot + 1) % num_slots; } -#endif // __CUDA_ARCH__ >= 900 } - - rx_flags[current_slot] = 0; - local_packet_count++; - current_slot = (current_slot + 1) % num_slots; } } @@ -407,7 +542,14 @@ extern "C" cudaError_t cudaq_dispatch_kernel_cooperative_query_occupancy( return cudaSuccess; } -extern "C" void cudaq_launch_dispatch_kernel_regular( +extern "C" CUDAQ_REALTIME_DISPATCH_API cudaError_t +cudaq_dispatch_kernel_set_shared_ring_mode(uint32_t enabled) { + return cudaMemcpyToSymbol(cudaq::realtime::g_dispatch_shared_ring_mode, + &enabled, sizeof(enabled), 0, + cudaMemcpyHostToDevice); +} + +extern "C" CUDAQ_REALTIME_DISPATCH_API void cudaq_launch_dispatch_kernel_regular( volatile std::uint64_t* rx_flags, volatile std::uint64_t* tx_flags, std::uint8_t* rx_data, diff --git a/realtime/lib/daemon/dispatcher/host_dispatcher.cu b/realtime/lib/daemon/dispatcher/host_dispatcher.cu index 1b2c3679bc7..9cca5d4157b 100644 --- a/realtime/lib/daemon/dispatcher/host_dispatcher.cu +++ b/realtime/lib/daemon/dispatcher/host_dispatcher.cu @@ -39,15 +39,23 @@ lookup_function(cudaq_function_entry_t *table, size_t count, return nullptr; } +// Acquire an idle GRAPH_LAUNCH worker that matches both `function_id` and +// `routing_key`. The routing_key parameter sub-routes within a shared +// `function_id` -- see [host_api.bs Routing-Key Sub-filter for GRAPH_LAUNCH +// Workers]. Workloads that don't use sub-routing pass routing_key == 0 and +// register every worker with routing_key == 0, in which case this loop +// degenerates to the historical `function_id`-only match. static int find_idle_graph_worker_for_function(const cudaq_host_dispatch_loop_ctx_t *ctx, - uint32_t function_id) { + uint32_t function_id, + uint64_t routing_key) { uint64_t mask = as_atomic_u64(ctx->idle_mask)->load( cuda::std::memory_order_acquire); while (mask != 0) { int worker_id = __builtin_ffsll(static_cast(mask)) - 1; - if (ctx->workers[static_cast(worker_id)].function_id == - function_id) + const cudaq_host_dispatch_worker_t &w = + ctx->workers[static_cast(worker_id)]; + if (w.function_id == function_id && w.routing_key == routing_key) return worker_id; mask &= ~(1ULL << worker_id); } @@ -56,8 +64,11 @@ find_idle_graph_worker_for_function(const cudaq_host_dispatch_loop_ctx_t *ctx, struct ParsedSlot { uint32_t function_id = 0; + uint64_t routing_key = 0; // arg0 of the payload (or 0 if arg_len < 8) const cudaq_function_entry_t *entry = nullptr; - bool drop = false; + bool drop = false; // bad header -- clear rx_flags and advance + bool skip = false; // function not in our table -- advance WITHOUT clearing + // (only set when shared_ring_mode is non-zero) }; static ParsedSlot @@ -70,10 +81,24 @@ parse_slot_with_function_table(void *slot_host, return out; } out.function_id = header->function_id; + // Routing-key sub-filter: read arg0 (first 8 bytes of payload) when the + // payload is large enough. Workloads that don't use sub-routing leave + // the worker's routing_key == 0, and any arg0 (or absent arg0) still + // matches via the routing_key == 0 worker. See + // proposals/cudaq_realtime_host_api.bs#host-path-graph-routing-key. + if (header->arg_len >= sizeof(uint64_t)) { + const uint8_t *slot_bytes = static_cast(slot_host); + out.routing_key = *reinterpret_cast(slot_bytes + + sizeof(RPCHeader)); + } out.entry = lookup_function(ctx->function_table.entries, ctx->function_table.count, out.function_id); - if (!out.entry) - out.drop = true; + if (!out.entry) { + if (ctx->config.shared_ring_mode) + out.skip = true; + else + out.drop = true; + } return out; } @@ -122,10 +147,11 @@ static void handle_host_call(const cudaq_host_dispatch_loop_ctx_t *ctx, static int acquire_graph_worker(const cudaq_host_dispatch_loop_ctx_t *ctx, bool use_function_table, const cudaq_function_entry_t *entry, - uint32_t function_id) { + uint32_t function_id, + uint64_t routing_key) { if (use_function_table && entry && entry->dispatch_mode == CUDAQ_DISPATCH_GRAPH_LAUNCH) - return find_idle_graph_worker_for_function(ctx, function_id); + return find_idle_graph_worker_for_function(ctx, function_id, routing_key); uint64_t mask = as_atomic_u64(ctx->idle_mask)->load(cuda::std::memory_order_acquire); if (mask == 0) @@ -234,12 +260,47 @@ cudaq_host_dispatcher_loop(const cudaq_host_dispatch_loop_ctx_t *ctx) { if (rx_value == 0) { if (!ctx->skip_stream_sweep) sweep_completed_workers(ctx); - CUDAQ_REALTIME_CPU_RELAX(); - continue; + // Under shared_ring_mode, rx_value == 0 at our local cursor does NOT + // mean "no work anywhere on the ring" -- the peer dispatcher may + // have cleared this slot after handling it. Scan the rest of the + // ring looking for ANY non-zero rx_flag; if we find one, jump our + // cursor there. If we wrap all the way back without finding any, + // fall through to the normal CPU_RELAX wait. + if (ctx->config.shared_ring_mode) { + size_t probe = (current_slot + 1) % num_slots; + size_t scanned = 0; + while (scanned < num_slots - 1) { + uint64_t v = as_atomic_u64(ctx->ringbuffer.rx_flags_host)[probe] + .load(cuda::std::memory_order_acquire); + if (v != 0) { + current_slot = probe; + break; + } + probe = (probe + 1) % num_slots; + ++scanned; + } + if (scanned >= num_slots - 1) { + // Truly idle: no slot has work for anyone right now. + CUDAQ_REALTIME_CPU_RELAX(); + continue; + } + // Re-load rx_value at the new cursor position and fall through. + rx_value = + as_atomic_u64(ctx->ringbuffer.rx_flags_host)[current_slot].load( + cuda::std::memory_order_acquire); + if (rx_value == 0) { + CUDAQ_REALTIME_CPU_RELAX(); + continue; + } + } else { + CUDAQ_REALTIME_CPU_RELAX(); + continue; + } } void *slot_host = reinterpret_cast(rx_value); uint32_t function_id = 0; + uint64_t routing_key = 0; const cudaq_function_entry_t *entry = nullptr; // TODO: Remove non-function-table path; RPC framing is always required. @@ -251,7 +312,14 @@ cudaq_host_dispatcher_loop(const cudaq_host_dispatch_loop_ctx_t *ctx) { current_slot = (current_slot + 1) % num_slots; continue; } + if (parsed.skip) { + // shared_ring_mode: leave rx_flags set so a peer dispatcher can pick + // this slot up; just advance our local cursor. + current_slot = (current_slot + 1) % num_slots; + continue; + } function_id = parsed.function_id; + routing_key = parsed.routing_key; entry = parsed.entry; } @@ -267,6 +335,13 @@ cudaq_host_dispatcher_loop(const cudaq_host_dispatch_loop_ctx_t *ctx) { continue; } if (entry && entry->dispatch_mode != CUDAQ_DISPATCH_GRAPH_LAUNCH) { + if (ctx->config.shared_ring_mode) { + // Entry is in our table but is not a GRAPH_LAUNCH (e.g. a DEVICE_CALL + // entry registered for a peer dispatcher). Under shared_ring_mode + // the peer will service it -- skip without clearing rx_flags. + current_slot = (current_slot + 1) % num_slots; + continue; + } as_atomic_u64(ctx->ringbuffer.rx_flags_host)[current_slot].store( 0, cuda::std::memory_order_release); current_slot = (current_slot + 1) % num_slots; @@ -275,8 +350,8 @@ cudaq_host_dispatcher_loop(const cudaq_host_dispatch_loop_ctx_t *ctx) { if (!ctx->skip_stream_sweep) sweep_completed_workers(ctx); - int worker_id = - acquire_graph_worker(ctx, use_function_table, entry, function_id); + int worker_id = acquire_graph_worker(ctx, use_function_table, entry, + function_id, routing_key); if (worker_id < 0) { CUDAQ_REALTIME_CPU_RELAX(); continue; diff --git a/realtime/unittests/CMakeLists.txt b/realtime/unittests/CMakeLists.txt index 9ae58751247..35cfdfad8e2 100644 --- a/realtime/unittests/CMakeLists.txt +++ b/realtime/unittests/CMakeLists.txt @@ -93,13 +93,16 @@ if(CMAKE_CUDA_COMPILER) GTest::gtest_main CUDA::cudart cudaq-realtime + cudaq-realtime-dispatch cudaq-realtime-host-dispatch + ${CUDADEVRT_LIBRARY} ) add_dependencies(CudaqRealtimeUnitTests test_host_dispatcher) gtest_discover_tests(test_host_dispatcher TEST_PREFIX "test_host_dispatcher." ) message(STATUS " - test_host_dispatcher (host dispatcher loop)") + endif() # ============================================================================== diff --git a/realtime/unittests/test_host_dispatcher.cu b/realtime/unittests/test_host_dispatcher.cu index 39ce9e3986f..2d126e94c1a 100644 --- a/realtime/unittests/test_host_dispatcher.cu +++ b/realtime/unittests/test_host_dispatcher.cu @@ -1197,4 +1197,425 @@ TEST(HostDispatcherGraphIOContextTest, SeparateRxTxBuffersViaCApi) { free_ring_buffer(tx_flags_host, tx_data_host); } +//============================================================================== +// Shared-ring-mode test: HOST_LOOP + DEVICE_LOOP on one RX ring +//============================================================================== + +constexpr std::size_t kSharedRingNumSlots = 8; +constexpr std::size_t kSharedRingSlotSize = 256; + +constexpr std::uint32_t HOST_GRAPH_FN_ID = + cudaq::realtime::fnv1a_hash("shared_ring_host_increment"); + +constexpr std::uint32_t DEVICE_CALL_FN_ID = + cudaq::realtime::fnv1a_hash("shared_ring_device_double"); + +__global__ void host_graph_increment_kernel(void** mailbox_slot_ptr) { + if (threadIdx.x == 0 && blockIdx.x == 0) { + void* buffer = *mailbox_slot_ptr; + cudaq::realtime::RPCHeader* header = + static_cast(buffer); + std::uint32_t arg_len = header->arg_len; + std::uint32_t request_id = header->request_id; + std::uint8_t* data = static_cast(buffer) + + sizeof(cudaq::realtime::RPCHeader); + for (std::uint32_t i = 0; i < arg_len; ++i) + data[i] = data[i] + 1; + cudaq::realtime::RPCResponse* response = + static_cast(buffer); + response->magic = cudaq::realtime::RPC_MAGIC_RESPONSE; + response->status = 0; + response->result_len = arg_len; + response->request_id = request_id; + } +} + +bool create_shared_ring_host_graph(void** d_mailbox_bank, + cudaGraph_t* graph_out, + cudaGraphExec_t* exec_out) { + cudaGraph_t graph = nullptr; + if (cudaGraphCreate(&graph, 0) != cudaSuccess) + return false; + + cudaKernelNodeParams params = {}; + void* kernel_args[] = {&d_mailbox_bank}; + params.func = reinterpret_cast(host_graph_increment_kernel); + params.gridDim = dim3(1, 1, 1); + params.blockDim = dim3(32, 1, 1); + params.sharedMemBytes = 0; + params.kernelParams = kernel_args; + params.extra = nullptr; + + cudaGraphNode_t node = nullptr; + if (cudaGraphAddKernelNode(&node, graph, nullptr, 0, ¶ms) != + cudaSuccess) { + cudaGraphDestroy(graph); + return false; + } + + cudaGraphExec_t exec = nullptr; + if (cudaGraphInstantiate(&exec, graph, nullptr, nullptr, 0) != cudaSuccess) { + cudaGraphDestroy(graph); + return false; + } + + *graph_out = graph; + *exec_out = exec; + return true; +} + +__device__ int device_double_handler(const void* input, void* output, + std::uint32_t arg_len, + std::uint32_t max_result_len, + std::uint32_t* result_len) { + const std::uint8_t* in = static_cast(input); + std::uint8_t* out = static_cast(output); + std::uint32_t n = arg_len; + if (n > max_result_len) + n = max_result_len; + for (std::uint32_t i = 0; i < n; ++i) + out[i] = static_cast(in[i] * 2); + *result_len = n; + return 0; +} + +__global__ void init_shared_function_table(cudaq_function_entry_t* entries, + cudaGraphExec_t host_graph_exec) { + if (threadIdx.x == 0 && blockIdx.x == 0) { + entries[0].handler.graph_exec = host_graph_exec; + entries[0].function_id = HOST_GRAPH_FN_ID; + entries[0].dispatch_mode = CUDAQ_DISPATCH_GRAPH_LAUNCH; + entries[0].reserved[0] = 0; + entries[0].reserved[1] = 0; + entries[0].reserved[2] = 0; + + entries[1].handler.device_fn_ptr = + reinterpret_cast(&device_double_handler); + entries[1].function_id = DEVICE_CALL_FN_ID; + entries[1].dispatch_mode = CUDAQ_DISPATCH_DEVICE_CALL; + entries[1].reserved[0] = 0; + entries[1].reserved[1] = 0; + entries[1].reserved[2] = 0; + } +} + +class SharedRingDispatcherTest : public ::testing::Test { +protected: + void SetUp() override { + ASSERT_TRUE(allocate_ring_buffer(kSharedRingNumSlots, kSharedRingSlotSize, + &rx_flags_host_, &rx_flags_dev_, + &rx_data_host_, &rx_data_dev_)); + void* tx_flags_host_ptr = nullptr; + CUDA_CHECK(cudaHostAlloc(&tx_flags_host_ptr, + kSharedRingNumSlots * sizeof(uint64_t), + cudaHostAllocMapped)); + std::memset(tx_flags_host_ptr, 0, + kSharedRingNumSlots * sizeof(uint64_t)); + tx_flags_host_ = static_cast(tx_flags_host_ptr); + void* tx_flags_dev_ptr = nullptr; + CUDA_CHECK(cudaHostGetDevicePointer(&tx_flags_dev_ptr, tx_flags_host_ptr, + 0)); + tx_flags_dev_ = static_cast(tx_flags_dev_ptr); + // RX and TX data buffers intentionally point to the same backing memory. + tx_data_host_ = rx_data_host_; + tx_data_dev_ = rx_data_dev_; + tx_data_is_owned_ = false; + + void* tmp = nullptr; + CUDA_CHECK(cudaHostAlloc(&tmp, sizeof(int), cudaHostAllocMapped)); + shutdown_flag_host_ = static_cast(tmp); + *shutdown_flag_host_ = 0; + void* tmp_dev = nullptr; + CUDA_CHECK(cudaHostGetDevicePointer(&tmp_dev, tmp, 0)); + shutdown_flag_dev_ = static_cast(tmp_dev); + + host_loop_shutdown_atomic_ = + reinterpret_cast*>( + const_cast(shutdown_flag_host_)); + + CUDA_CHECK(cudaMalloc(&device_loop_stats_, sizeof(uint64_t))); + CUDA_CHECK(cudaMemset(device_loop_stats_, 0, sizeof(uint64_t))); + + void* fn_table_host_ptr = nullptr; + CUDA_CHECK(cudaHostAlloc(&fn_table_host_ptr, + 2 * sizeof(cudaq_function_entry_t), + cudaHostAllocMapped)); + function_table_host_ = + static_cast(fn_table_host_ptr); + void* fn_table_dev_ptr = nullptr; + CUDA_CHECK(cudaHostGetDevicePointer(&fn_table_dev_ptr, fn_table_host_ptr, + 0)); + function_table_dev_ = + static_cast(fn_table_dev_ptr); + std::memset(function_table_host_, 0, + 2 * sizeof(cudaq_function_entry_t)); + + CUDA_CHECK(cudaHostAlloc(&h_mailbox_bank_, sizeof(void*), + cudaHostAllocMapped)); + h_mailbox_bank_[0] = nullptr; + CUDA_CHECK(cudaHostGetDevicePointer(&d_mailbox_bank_void_, + h_mailbox_bank_, 0)); + d_mailbox_bank_ = static_cast(d_mailbox_bank_void_); + ASSERT_TRUE(create_shared_ring_host_graph(d_mailbox_bank_, &host_graph_, + &host_graph_exec_)); + + init_shared_function_table<<<1, 1>>>(function_table_dev_, + host_graph_exec_); + CUDA_CHECK(cudaDeviceSynchronize()); + + CUDA_CHECK(cudaq_dispatch_kernel_set_shared_ring_mode(1)); + + ASSERT_EQ(cudaq_dispatch_manager_create(&device_manager_), CUDAQ_OK); + + cudaq_dispatcher_config_t device_config{}; + device_config.device_id = 0; + device_config.num_blocks = 1; + device_config.threads_per_block = 64; + device_config.num_slots = static_cast(kSharedRingNumSlots); + device_config.slot_size = static_cast(kSharedRingSlotSize); + device_config.vp_id = 0; + device_config.kernel_type = CUDAQ_KERNEL_REGULAR; + device_config.dispatch_mode = CUDAQ_DISPATCH_DEVICE_CALL; + device_config.dispatch_path = CUDAQ_DISPATCH_PATH_DEVICE; + device_config.shared_ring_mode = 1; + ASSERT_EQ(cudaq_dispatcher_create(device_manager_, &device_config, + &device_dispatcher_), + CUDAQ_OK); + + cudaq_ringbuffer_t device_rb{}; + device_rb.rx_flags = rx_flags_dev_; + device_rb.tx_flags = tx_flags_dev_; + device_rb.rx_data = rx_data_dev_; + device_rb.tx_data = tx_data_dev_; + device_rb.rx_stride_sz = kSharedRingSlotSize; + device_rb.tx_stride_sz = kSharedRingSlotSize; + ASSERT_EQ(cudaq_dispatcher_set_ringbuffer(device_dispatcher_, &device_rb), + CUDAQ_OK); + + cudaq_function_table_t shared_table{}; + shared_table.entries = function_table_dev_; + shared_table.count = 2; + ASSERT_EQ(cudaq_dispatcher_set_function_table(device_dispatcher_, + &shared_table), + CUDAQ_OK); + + ASSERT_EQ(cudaq_dispatcher_set_control(device_dispatcher_, + shutdown_flag_dev_, + device_loop_stats_), + CUDAQ_OK); + + ASSERT_EQ(cudaq_dispatcher_set_launch_fn( + device_dispatcher_, &cudaq_launch_dispatch_kernel_regular), + CUDAQ_OK); + + ASSERT_EQ(cudaq_dispatcher_start(device_dispatcher_), CUDAQ_OK); + + host_workers_.push_back(cudaq_host_dispatch_worker_t{}); + cudaStream_t host_stream = nullptr; + CUDA_CHECK(cudaStreamCreate(&host_stream)); + host_workers_[0].graph_exec = host_graph_exec_; + host_workers_[0].stream = host_stream; + host_workers_[0].function_id = HOST_GRAPH_FN_ID; + + host_idle_mask_ = new cuda::std::atomic(1ULL); + host_live_dispatched_ = new cuda::std::atomic(0); + host_inflight_slot_tags_ = new int[host_workers_.size()]; + for (size_t i = 0; i < host_workers_.size(); ++i) + host_inflight_slot_tags_[i] = -1; + + std::memset(&host_ctx_, 0, sizeof(host_ctx_)); + host_ctx_.ringbuffer.rx_flags = rx_flags_dev_; + host_ctx_.ringbuffer.tx_flags = tx_flags_dev_; + host_ctx_.ringbuffer.rx_data = rx_data_dev_; + host_ctx_.ringbuffer.tx_data = tx_data_dev_; + host_ctx_.ringbuffer.rx_stride_sz = kSharedRingSlotSize; + host_ctx_.ringbuffer.tx_stride_sz = kSharedRingSlotSize; + host_ctx_.ringbuffer.rx_flags_host = rx_flags_host_; + host_ctx_.ringbuffer.tx_flags_host = tx_flags_host_; + host_ctx_.ringbuffer.rx_data_host = rx_data_host_; + host_ctx_.ringbuffer.tx_data_host = tx_data_host_; + + host_ctx_.config.num_slots = static_cast(kSharedRingNumSlots); + host_ctx_.config.slot_size = static_cast(kSharedRingSlotSize); + host_ctx_.config.shared_ring_mode = 1; + + host_ctx_.function_table.entries = function_table_dev_; + host_ctx_.function_table.count = 2; + host_ctx_.workers = host_workers_.data(); + host_ctx_.num_workers = host_workers_.size(); + host_ctx_.h_mailbox_bank = h_mailbox_bank_; + host_ctx_.shutdown_flag = host_loop_shutdown_atomic_; + host_ctx_.stats_counter = &host_loop_stats_; + host_ctx_.live_dispatched = host_live_dispatched_; + host_ctx_.idle_mask = host_idle_mask_; + host_ctx_.inflight_slot_tags = host_inflight_slot_tags_; + host_ctx_.skip_stream_sweep = false; + + host_loop_thread_ = + std::thread([this]() { cudaq_host_dispatcher_loop(&host_ctx_); }); + } + + void TearDown() override { + *shutdown_flag_host_ = 1; + __sync_synchronize(); + + if (host_loop_thread_.joinable()) + host_loop_thread_.join(); + + if (device_dispatcher_) { + cudaq_dispatcher_stop(device_dispatcher_); + cudaq_dispatcher_destroy(device_dispatcher_); + device_dispatcher_ = nullptr; + } + if (device_manager_) { + cudaq_dispatch_manager_destroy(device_manager_); + device_manager_ = nullptr; + } + + (void)cudaq_dispatch_kernel_set_shared_ring_mode(0); + + if (host_graph_exec_) + cudaGraphExecDestroy(host_graph_exec_); + if (host_graph_) + cudaGraphDestroy(host_graph_); + for (auto& w : host_workers_) { + if (w.stream) + cudaStreamDestroy(w.stream); + } + host_workers_.clear(); + + delete host_idle_mask_; + delete host_live_dispatched_; + delete[] host_inflight_slot_tags_; + + if (function_table_host_) + cudaFreeHost(function_table_host_); + if (device_loop_stats_) + cudaFree(device_loop_stats_); + if (h_mailbox_bank_) + cudaFreeHost(h_mailbox_bank_); + + free_ring_buffer(rx_flags_host_, rx_data_host_); + if (tx_flags_host_) + cudaFreeHost(const_cast(tx_flags_host_)); + if (tx_data_is_owned_ && tx_data_host_) + cudaFreeHost(tx_data_host_); + + if (shutdown_flag_host_) + cudaFreeHost(const_cast(shutdown_flag_host_)); + } + + void WriteAndSignal(std::size_t slot, std::uint32_t function_id, + std::uint32_t request_id, + const std::vector& payload) { + ASSERT_LT(slot, kSharedRingNumSlots); + ASSERT_LE(payload.size(), + kSharedRingSlotSize - sizeof(cudaq::realtime::RPCHeader)); + std::uint8_t* slot_host = rx_data_host_ + slot * kSharedRingSlotSize; + auto* header = reinterpret_cast(slot_host); + header->magic = cudaq::realtime::RPC_MAGIC_REQUEST; + header->function_id = function_id; + header->arg_len = static_cast(payload.size()); + header->request_id = request_id; + header->ptp_timestamp = 0; + std::memcpy(slot_host + sizeof(cudaq::realtime::RPCHeader), + payload.data(), payload.size()); + __sync_synchronize(); + rx_flags_host_[slot] = reinterpret_cast( + rx_data_dev_ + slot * kSharedRingSlotSize); + } + + bool WaitForResponseInSlot(std::size_t slot, int timeout_ms = 5000) { + std::uint8_t* slot_host = rx_data_host_ + slot * kSharedRingSlotSize; + auto* resp = reinterpret_cast(slot_host); + for (int waited = 0; waited < timeout_ms; ++waited) { + __sync_synchronize(); + if (resp->magic == cudaq::realtime::RPC_MAGIC_RESPONSE) + return true; + usleep(1000); + } + return false; + } + + std::vector ReadResponse(std::size_t slot) { + std::uint8_t* slot_host = rx_data_host_ + slot * kSharedRingSlotSize; + auto* resp = reinterpret_cast(slot_host); + std::vector out(resp->result_len); + std::memcpy(out.data(), + slot_host + sizeof(cudaq::realtime::RPCResponse), + resp->result_len); + return out; + } + + volatile uint64_t* rx_flags_host_ = nullptr; + volatile uint64_t* tx_flags_host_ = nullptr; + volatile uint64_t* rx_flags_dev_ = nullptr; + volatile uint64_t* tx_flags_dev_ = nullptr; + std::uint8_t* rx_data_host_ = nullptr; + std::uint8_t* tx_data_host_ = nullptr; + std::uint8_t* rx_data_dev_ = nullptr; + std::uint8_t* tx_data_dev_ = nullptr; + bool tx_data_is_owned_ = true; + + volatile int* shutdown_flag_host_ = nullptr; + volatile int* shutdown_flag_dev_ = nullptr; + cuda::std::atomic* host_loop_shutdown_atomic_ = nullptr; + cudaq_function_entry_t* function_table_host_ = nullptr; + cudaq_function_entry_t* function_table_dev_ = nullptr; + + cudaq_dispatch_manager_t* device_manager_ = nullptr; + cudaq_dispatcher_t* device_dispatcher_ = nullptr; + uint64_t* device_loop_stats_ = nullptr; + + cudaGraph_t host_graph_ = nullptr; + cudaGraphExec_t host_graph_exec_ = nullptr; + void** h_mailbox_bank_ = nullptr; + void* d_mailbox_bank_void_ = nullptr; + void** d_mailbox_bank_ = nullptr; + std::vector host_workers_; + cuda::std::atomic* host_idle_mask_ = nullptr; + cuda::std::atomic* host_live_dispatched_ = nullptr; + int* host_inflight_slot_tags_ = nullptr; + uint64_t host_loop_stats_ = 0; + cudaq_host_dispatch_loop_ctx_t host_ctx_{}; + std::thread host_loop_thread_; +}; + +TEST_F(SharedRingDispatcherTest, InterleavedHostAndDeviceRequests) { + std::vector p0 = {10, 20, 30, 40}; + std::vector p1 = {3, 5, 7, 9}; + std::vector p2 = {1, 2, 3, 4}; + std::vector p3 = {6, 12, 24, 48}; + + WriteAndSignal(0, HOST_GRAPH_FN_ID, /*request_id=*/100, p0); + WriteAndSignal(1, DEVICE_CALL_FN_ID, /*request_id=*/101, p1); + WriteAndSignal(2, HOST_GRAPH_FN_ID, /*request_id=*/102, p2); + WriteAndSignal(3, DEVICE_CALL_FN_ID, /*request_id=*/103, p3); + + ASSERT_TRUE(WaitForResponseInSlot(0)) << "Slot 0 (HOST_LOOP) timed out"; + ASSERT_TRUE(WaitForResponseInSlot(1)) << "Slot 1 (DEVICE_LOOP) timed out"; + ASSERT_TRUE(WaitForResponseInSlot(2)) << "Slot 2 (HOST_LOOP) timed out"; + ASSERT_TRUE(WaitForResponseInSlot(3)) << "Slot 3 (DEVICE_LOOP) timed out"; + + EXPECT_EQ(ReadResponse(0), (std::vector{11, 21, 31, 41})); + EXPECT_EQ(ReadResponse(1), (std::vector{6, 10, 14, 18})); + EXPECT_EQ(ReadResponse(2), (std::vector{2, 3, 4, 5})); + EXPECT_EQ(ReadResponse(3), (std::vector{12, 24, 48, 96})); + + *shutdown_flag_host_ = 1; + __sync_synchronize(); + if (host_loop_thread_.joinable()) + host_loop_thread_.join(); + cudaq_dispatcher_stop(device_dispatcher_); + + uint64_t dev_count = 0; + CUDA_CHECK(cudaMemcpy(&dev_count, device_loop_stats_, sizeof(uint64_t), + cudaMemcpyDeviceToHost)); + EXPECT_EQ(dev_count, 2u); + EXPECT_EQ(host_loop_stats_, 2u); + + cudaq_dispatcher_destroy(device_dispatcher_); + device_dispatcher_ = nullptr; +} + } // namespace