diff --git a/README.md b/README.md index 785695284..187f0bdae 100644 --- a/README.md +++ b/README.md @@ -385,6 +385,40 @@ slow or metered links, `--layers 20:42` is also supported: the coordinator will load the output head and compute logits locally, trading extra coordinator work for smaller per-token replies. +Reverse topology is also supported when the coordinator owns a final suffix +through the output head. In that layout workers cover the lower layers, +return hidden state upstream, and the coordinator runs the higher layers plus +output locally. For example: + +```sh +# Machine A: worker, owns lower layers. +./ds4 \ + -m gguf/DeepSeek-V4-Pro-Q4K-Layers00-30.gguf \ + --role worker \ + --layers 0:30 \ + --coordinator 169.254.43.68 1234 + +# Machine B: coordinator, owns the upper suffix and output. +./ds4 \ + -m gguf/DeepSeek-V4-Pro-Q4K-Layers-31-output.gguf \ + --role coordinator \ + --layers 31:output \ + --listen 169.254.43.68 1234 +``` + +Reverse `K:42` is intentionally unsupported. Reverse mode only supports +`K:output`, because the coordinator must own the output head. + +You can also start the coordinator owning `K:output` layers with `--local-decode`. +In that mode the route still does distributed prefill, but after prefill the worker +pushes its KV shard to the coordinator, the coordinator finishes generation locally +using full model residency. This hand-off keeps the distributed prefill speedup while +moving decode back onto one machine and gaining decode speed. + +For example, using M5 Max 128GB as the coordinator with `--layers 22:output --local-decode` +and running DGX Spark as a worker with `--layers 0:21` over 2.5GbE direct link provides +`prefill: 602.78 t/s, generation: 30.10 t/s`. + ### Network Link Comparison The table below shows the same two M5 Max hosts, the same 91 GB Flash quant, @@ -468,21 +502,26 @@ control TCP connection open to the coordinator and send a `HELLO` with their model ID, model family, quant profile, layer slice, context capacity, and data port. The coordinator uses these registrations to build a route that covers all layers. Work then moves over low-latency TCP data connections: the coordinator -computes the first slice, sends a `WORK` frame with session ID, token positions, -rolling token-prefix hashes before and after the span, route information, and -hidden-state payload, and each worker computes its slice. Middle workers can -forward directly to the next worker. The final worker returns logits to the -coordinator, or ACKs for non-final prefill chunks so the prefill pipeline can -stay full. `RESULT` frames echo the request ID and the post-span hash. A worker -status error is handled differently from a socket failure: KV/hash mismatch can -be recovered by replaying the token history on the same route, while transport -failure drops the route and waits for a replacement worker. For persistent KV, -the coordinator opens worker data connections and sends snapshot save/load -messages for each worker-owned layer range; the disk payload remains a single -agent/server cache file. The protocol has no -encryption or authentication, and is not release-stable yet; coordinator and -workers should be built from the same commit and used on trusted machines and -trusted networks. +computes the local prefix first in forward topology, sends a `WORK` frame with +session ID, token positions, rolling token-prefix hashes before and after the +span, route information, and hidden-state payload, and each worker computes its +slice. In reverse topology the first worker starts from layer 0 with token +input only, returns hidden state upstream, and the coordinator finishes the +higher layers plus output locally. Middle workers can forward directly to the +next worker. The final worker returns logits in the usual forward path, or +returns hidden state when the coordinator owns the output path. Forward +non-final prefill chunks may use ACK-only replies so the prefill pipeline can +stay full; reverse pipelined prefill returns hidden state for every chunk +because the coordinator must finish each chunk locally. `RESULT` frames echo +the request ID and the post-span hash. A worker status error is handled +differently from a socket failure: KV/hash mismatch can be recovered by +replaying the token history on the same route, while transport failure drops +the route and waits for a replacement worker. For persistent KV, the +coordinator opens worker data connections and sends snapshot save/load messages +for each worker-owned layer range; the disk payload remains a single +agent/server cache file. The protocol has no encryption or authentication, and +is not release-stable yet; coordinator and workers should be built from the +same commit and used on trusted machines and trusted networks. ## Reducing heat, power usage and fan noise diff --git a/ds4.c b/ds4.c index 640511eb0..d91c4277f 100644 --- a/ds4.c +++ b/ds4.c @@ -25554,6 +25554,15 @@ int ds4_engine_open(ds4_engine **out, const ds4_engine_options *opt) { e->distributed = opt->distributed; e->power_percent = opt->power_percent > 0 ? opt->power_percent : 100; e->prefill_chunk = opt->prefill_chunk; + const bool distributed_reverse_coordinator = + e->distributed.role == DS4_DISTRIBUTED_COORDINATOR && + e->distributed.layers.set && + e->distributed.layers.has_output && + e->distributed.layers.start > 0u; + if (e->prefill_chunk == 0 && distributed_reverse_coordinator) { + e->prefill_chunk = 2048u; + e->distributed.prefill_chunk = 2048u; + } e->ssd_streaming_cache_experts = opt->ssd_streaming_cache_experts; e->ssd_streaming_cache_bytes = opt->ssd_streaming_cache_bytes; e->ssd_streaming_preload_experts = opt->ssd_streaming_preload_experts; @@ -25590,8 +25599,15 @@ int ds4_engine_open(ds4_engine **out, const ds4_engine_options *opt) { uint32_t load_layer_end = opt->load_layer_end; bool load_output = opt->load_output; bool load_output_optional = false; + const bool distributed_reverse_coordinator_full_resident = + opt->distributed.role == DS4_DISTRIBUTED_COORDINATOR && + opt->distributed.local_decode && + opt->distributed.layers.set && + opt->distributed.layers.has_output && + opt->distributed.layers.start > 0u; if (opt->distributed.role != DS4_DISTRIBUTED_NONE && - opt->distributed.layers.set) + opt->distributed.layers.set && + !distributed_reverse_coordinator_full_resident) { load_slice = true; load_layer_start = opt->distributed.layers.start; @@ -27153,6 +27169,18 @@ static int ds4_session_eval_internal(ds4_session *s, int token, bool probe_mtp, #endif } +int ds4_session_eval_local_only(ds4_session *s, int token, char *err, size_t errlen) { + if (!s) { + if (errlen) snprintf(err, errlen, "invalid session"); + return 1; + } + ds4_dist_session *saved = s->distributed; + s->distributed = NULL; + const int rc = ds4_session_eval_internal(s, token, true, err, errlen); + s->distributed = saved; + return rc; +} + int ds4_session_eval(ds4_session *s, int token, char *err, size_t errlen) { return ds4_session_eval_internal(s, token, true, err, errlen); } diff --git a/ds4.h b/ds4.h index 9d040c92b..f6d1a235d 100644 --- a/ds4.h +++ b/ds4.h @@ -87,6 +87,7 @@ typedef struct { uint32_t prefill_chunk; uint32_t prefill_window; uint32_t activation_bits; + bool local_decode; bool replay_check; bool debug; } ds4_distributed_options; @@ -261,6 +262,11 @@ int ds4_session_top_logprobs(ds4_session *s, ds4_token_score *out, int k); int ds4_session_token_logprob(ds4_session *s, int token, ds4_token_score *out); int ds4_session_copy_logits(ds4_session *s, float *out, int cap); int ds4_session_set_logits(ds4_session *s, const float *logits, int n); +/* Internal runtime helper: run the normal local eval path even if the session + * still carries a distributed coordinator attachment. Frontends should keep + * using ds4_session_eval(); ds4_distributed.c uses this helper once + * coordinator-owned reverse-topology local decode is active. */ +int ds4_session_eval_local_only(ds4_session *s, int token, char *err, size_t errlen); int ds4_session_eval(ds4_session *s, int token, char *err, size_t errlen); int ds4_session_eval_speculative_argmax(ds4_session *s, int first_token, int max_tokens, int eos_token, diff --git a/ds4_distributed.c b/ds4_distributed.c index d31c8e2a6..10b94c047 100644 --- a/ds4_distributed.c +++ b/ds4_distributed.c @@ -222,6 +222,11 @@ typedef struct ds4_dist_worker_entry { struct ds4_dist_worker_entry *next; } ds4_dist_worker_entry; +typedef enum { + DS4_DIST_TOPOLOGY_FORWARD = 0, + DS4_DIST_TOPOLOGY_REVERSE = 1, +} ds4_dist_topology; + typedef struct { ds4_engine *engine; uint32_t model_id; @@ -229,8 +234,10 @@ typedef struct { uint32_t local_start; uint32_t local_end; uint32_t ctx_size; + ds4_dist_topology topology; bool local_has_output; bool local_can_output_head; + bool local_decode_requested; bool replay_check; bool debug; bool use_control_for_work; @@ -383,6 +390,13 @@ typedef struct { uint64_t tensor_bytes; } ds4_dist_kv_shard_file; +typedef struct { + uint32_t layer_start; + uint32_t layer_end; + const ds4_dist_route_entry *entry; + bool is_local; +} ds4_dist_kv_route_owner; + struct ds4_dist_session { ds4_dist_coordinator_state state; int listen_fd; @@ -395,6 +409,9 @@ struct ds4_dist_session { uint64_t session_id; uint64_t request_id; uint64_t snapshot_request_id; + bool local_decode_active; + bool local_decode_remote_flushable; + uint32_t local_decode_remote_pos; }; typedef struct { @@ -403,10 +420,23 @@ typedef struct { float logprob; } ds4_dist_logprob; +typedef struct { + uint32_t chunk_index; + uint32_t pos; + uint32_t n_tokens; + uint32_t payload_bytes; + bool reset_session; + bool output_logits; + void *payload; +} ds4_dist_prefill_reverse_slot; + typedef struct { ds4_dist_coordinator_state *state; + ds4_session *session; + const ds4_tokens *prompt; int fd; ds4_session *progress_session; + float *logits; uint64_t first_request_id; uint64_t *expected_hashes; uint32_t count; @@ -414,17 +444,30 @@ typedef struct { uint32_t chunk_cap; uint32_t progress_base; uint32_t progress_total; + uint32_t progress_received; uint32_t progress_completed; bool progress_done; uint64_t hc_values; bool allow_hidden; + bool reset_first_chunk; + bool reverse_apply_local_suffix; uint32_t final_kind; void *final_payload; uint32_t final_payload_bytes; int rc; + double local_eval_sec; char err[256]; pthread_mutex_t progress_mu; pthread_cond_t progress_cv; + ds4_dist_prefill_reverse_slot *reverse_slots; + uint32_t reverse_slot_count; + uint32_t reverse_head; + uint32_t reverse_tail; + uint32_t reverse_queued; + bool reverse_producer_done; + pthread_mutex_t reverse_mu; + pthread_cond_t reverse_can_enqueue; + pthread_cond_t reverse_can_dequeue; } ds4_dist_prefill_result_reader; typedef struct { @@ -465,6 +508,17 @@ typedef struct { * Small Utilities And Forward Declarations * ========================================================================= */ +static int dist_session_ensure_route(ds4_dist_session *d, char *err, size_t errlen); +static int dist_save_remote_shard_to_file( + ds4_dist_session *d, + const ds4_dist_route_entry *entry, + const ds4_tokens *tokens, + uint64_t token_hash, + FILE *fp, + uint64_t *payload_bytes_out, + char *err, + size_t errlen); + static uint32_t dist_prefill_send_depth(uint32_t chunk_count) { uint32_t depth = 2; const char *env = getenv("DS4_DIST_PREFILL_SEND_DEPTH"); @@ -525,12 +579,63 @@ static int dist_coordinator_prefill_prompt( char *err, size_t errlen); static int dist_validate_options(const ds4_dist_options *opt, char *err, size_t errlen); +static int dist_validate_layers_for_model(const ds4_dist_options *opt, uint32_t n_layers, char *err, size_t errlen); static uint32_t dist_resolved_layer_end(const ds4_dist_options *opt, uint32_t n_layers) { if (opt->layers.has_output) return n_layers - 1u; return opt->layers.end; } +static const char *dist_topology_name(ds4_dist_topology topology) { + switch (topology) { + case DS4_DIST_TOPOLOGY_FORWARD: return "forward"; + case DS4_DIST_TOPOLOGY_REVERSE: return "reverse"; + } + return "unknown"; +} + +static int dist_infer_coordinator_topology( + const ds4_dist_options *opt, + uint32_t n_layers, + ds4_dist_topology *out, + char *err, + size_t errlen) { + if (!opt || opt->role != DS4_DISTRIBUTED_COORDINATOR || !opt->layers.set) return 0; + if (n_layers == 0) { + if (errlen) snprintf(err, errlen, "model reports no layers"); + return 1; + } + + const uint32_t local_end = dist_resolved_layer_end(opt, n_layers); + if (opt->layers.start == 0u) { + if (out) *out = DS4_DIST_TOPOLOGY_FORWARD; + return 0; + } + if (opt->layers.has_output && local_end + 1u == n_layers) { + if (out) *out = DS4_DIST_TOPOLOGY_REVERSE; + return 0; + } + if (!opt->layers.has_output && local_end + 1u == n_layers) { + if (errlen) { + snprintf(err, + errlen, + "coordinator reverse suffix %u:%u is unsupported; use %u:output", + opt->layers.start, + local_end, + opt->layers.start); + } + return 1; + } + if (errlen) { + snprintf(err, + errlen, + "coordinator middle-only layer range %u:%u is unsupported; use 0:K or K:output", + opt->layers.start, + local_end); + } + return 1; +} + static const char *dist_role_name(ds4_distributed_role role) { switch (role) { case DS4_DISTRIBUTED_NONE: return "none"; @@ -1848,9 +1953,13 @@ static bool dist_coordinator_debug_enabled(const ds4_dist_coordinator_state *sta * Coordinator Worker Registry And Route Planning * ========================================================================= * - * A route is a contiguous chain that starts after the coordinator's local - * slice. The last hop can either return logits directly or return the final - * hidden state so the coordinator can run its local output head. + * A route is a contiguous worker chain covering the coordinator's missing + * side of the model: + * - forward topology covers layers after the coordinator slice + * - reverse topology covers layers before the coordinator slice + * + * The wire format stays the same in both cases: ordered worker entries plus + * a final upstream return target. */ static void dist_coordinator_add_worker( @@ -1941,13 +2050,50 @@ static int dist_worker_route_cmp(const void *a, const void *b) { return 0; } +static bool dist_coordinator_route_span( + const ds4_dist_coordinator_state *state, + uint32_t *start, + uint32_t *end) { + if (!state || state->n_layers == 0) return false; + switch (state->topology) { + case DS4_DIST_TOPOLOGY_FORWARD: + if (state->local_end + 1u >= state->n_layers) return false; + if (start) *start = state->local_end + 1u; + if (end) *end = state->n_layers - 1u; + return true; + case DS4_DIST_TOPOLOGY_REVERSE: + if (state->local_start == 0u) return false; + if (start) *start = 0u; + if (end) *end = state->local_start - 1u; + return true; + } + return false; +} + +static bool dist_coordinator_route_final_worker_may_output_logits( + const ds4_dist_coordinator_state *state) { + return state && state->topology == DS4_DIST_TOPOLOGY_FORWARD; +} + static bool dist_worker_route_candidate_ok( const ds4_dist_coordinator_state *state, const ds4_dist_worker_entry *w, - uint32_t last) { - const bool needs_hidden = w->layer_end < last || !w->has_output; + uint32_t required_end, + bool final_worker_may_output_logits) { + if (!state || !w) return false; + if (w->layer_end > required_end) return false; + + const bool is_final = w->layer_end == required_end; + const bool final_outputs_logits = is_final && + final_worker_may_output_logits && + w->has_output && + required_end + 1u == state->n_layers; + const bool needs_hidden = !is_final || !final_outputs_logits; if (needs_hidden && !w->has_hidden) return false; - if (w->layer_end >= last && !w->has_output && !state->local_can_output_head) return false; + if (is_final && + required_end + 1u == state->n_layers && + !final_outputs_logits && + !state->local_can_output_head) return false; return true; } @@ -1956,7 +2102,8 @@ static bool dist_route_search_workers( ds4_dist_worker_entry **workers, uint32_t n, uint32_t next, - uint32_t last, + uint32_t required_end, + bool final_worker_may_output_logits, ds4_dist_worker_entry **path, uint32_t *path_len, uint32_t *missing_layer) { @@ -1966,16 +2113,17 @@ static bool dist_route_search_workers( if (w->layer_start < next) continue; if (w->layer_start > next) break; saw_start = true; - if (!dist_worker_route_candidate_ok(state, w, last)) continue; + if (!dist_worker_route_candidate_ok(state, w, required_end, final_worker_may_output_logits)) continue; path[(*path_len)++] = w; - if (w->layer_end >= last) return true; + if (w->layer_end == required_end) return true; uint32_t child_missing = w->layer_end + 1u; if (dist_route_search_workers(state, workers, n, child_missing, - last, + required_end, + final_worker_may_output_logits, path, path_len, &child_missing)) { @@ -1988,6 +2136,108 @@ static bool dist_route_search_workers( return false; } +static void dist_format_range_end( + uint32_t layer_end, + bool has_output, + char *buf, + size_t buflen) { + if (!buf || buflen == 0) return; + if (has_output) snprintf(buf, buflen, "output"); + else snprintf(buf, buflen, "%u", layer_end); +} + +static bool dist_route_diag_local_overlap( + const ds4_dist_coordinator_state *state, + const ds4_dist_worker_entry *w, + char *err, + size_t errlen) { + if (!state || !w) return false; + if (w->layer_end < state->local_start || w->layer_start > state->local_end) return false; + + uint32_t overlap_start = w->layer_start > state->local_start ? w->layer_start : state->local_start; + uint32_t overlap_end = w->layer_end < state->local_end ? w->layer_end : state->local_end; + char worker_end[32]; + char local_end[32]; + dist_format_range_end(w->layer_end, w->has_output != 0, worker_end, sizeof(worker_end)); + dist_format_range_end(state->local_end, state->local_has_output, local_end, sizeof(local_end)); + if (errlen) { + snprintf(err, + errlen, + "distributed %s route invalid: worker %s:%u layers=%u:%s overlap coordinator local range %u:%s at layers %u:%u", + dist_topology_name(state->topology), + w->peer_host, + w->listen_port, + w->layer_start, + worker_end, + state->local_start, + local_end, + overlap_start, + overlap_end); + } + return true; +} + +static bool dist_route_diag_worker_overlap( + const ds4_dist_coordinator_state *state, + const ds4_dist_worker_entry *prev, + const ds4_dist_worker_entry *cur, + char *err, + size_t errlen) { + if (!state || !prev || !cur) return false; + if (cur->layer_start > prev->layer_end) return false; + + uint32_t overlap_start = cur->layer_start; + uint32_t overlap_end = prev->layer_end < cur->layer_end ? prev->layer_end : cur->layer_end; + char prev_end[32]; + char cur_end[32]; + dist_format_range_end(prev->layer_end, prev->has_output != 0, prev_end, sizeof(prev_end)); + dist_format_range_end(cur->layer_end, cur->has_output != 0, cur_end, sizeof(cur_end)); + if (errlen) { + snprintf(err, + errlen, + "distributed %s route invalid: worker %s:%u layers=%u:%s overlap worker %s:%u layers=%u:%s at layers %u:%u", + dist_topology_name(state->topology), + prev->peer_host, + prev->listen_port, + prev->layer_start, + prev_end, + cur->peer_host, + cur->listen_port, + cur->layer_start, + cur_end, + overlap_start, + overlap_end); + } + return true; +} + +static void dist_route_diagnose_failure( + const ds4_dist_coordinator_state *state, + ds4_dist_worker_entry **workers, + uint32_t n, + uint32_t required_start, + uint32_t required_end, + uint32_t missing_layer, + char *err, + size_t errlen) { + if (!state || !err || errlen == 0) return; + + ds4_dist_worker_entry *prev = NULL; + for (uint32_t i = 0; i < n; i++) { + ds4_dist_worker_entry *w = workers[i]; + if (w->layer_end < required_start || w->layer_start > required_end) continue; + if (dist_route_diag_local_overlap(state, w, err, errlen)) return; + if (prev && dist_route_diag_worker_overlap(state, prev, w, err, errlen)) return; + prev = w; + } + + snprintf(err, + errlen, + "distributed %s route incomplete: missing layer %u", + dist_topology_name(state->topology), + missing_layer); +} + static void dist_coordinator_report_plan(ds4_dist_coordinator_state *state) { if (!dist_coordinator_debug_enabled(state)) return; pthread_mutex_lock(&state->mu); @@ -2007,26 +2257,34 @@ static void dist_coordinator_report_plan(ds4_dist_coordinator_state *state) { qsort(workers, n, sizeof(workers[0]), dist_worker_route_cmp); const uint32_t last = state->n_layers - 1u; - bool complete = state->local_start == 0; - bool has_output = state->local_end == last && - (state->local_has_output || state->local_can_output_head); - uint32_t next = state->local_end + 1u; - if (state->local_end >= last) next = state->n_layers; + const char *topology_name = dist_topology_name(state->topology); + uint32_t required_start = 0, required_end = 0; + const bool needs_remote = + dist_coordinator_route_span(state, &required_start, &required_end); + const bool final_worker_may_output_logits = + dist_coordinator_route_final_worker_may_output_logits(state); + bool complete = !needs_remote; + bool has_output = false; + uint32_t missing = required_start; uint32_t path_len = 0; - uint32_t missing = next; - if (complete && !has_output) { + if (!needs_remote) { + has_output = state->local_end == last && + (state->local_has_output || state->local_can_output_head); + } else { complete = dist_route_search_workers(state, workers, n, - next, - last, + required_start, + required_end, + final_worker_may_output_logits, path, &path_len, &missing); - if (complete && path_len != 0) { + if (complete && state->topology == DS4_DIST_TOPOLOGY_FORWARD && path_len != 0) { ds4_dist_worker_entry *final = path[path_len - 1u]; has_output = final->has_output || state->local_can_output_head; - next = state->n_layers; + } else if (complete && state->topology == DS4_DIST_TOPOLOGY_REVERSE) { + has_output = state->local_has_output; } } @@ -2035,44 +2293,76 @@ static void dist_coordinator_report_plan(ds4_dist_coordinator_state *state) { char local_end[32]; if (state->local_has_output) snprintf(local_end, sizeof(local_end), "output"); else snprintf(local_end, sizeof(local_end), "%u", state->local_end); - used += (size_t)snprintf(plan + used, used < sizeof(plan) ? sizeof(plan) - used : 0, - "local %u:%s", - state->local_start, - local_end); - for (i = 0; i < path_len; i++) { - ds4_dist_worker_entry *w = path[i]; + if (state->topology == DS4_DIST_TOPOLOGY_FORWARD) { + used += (size_t)snprintf(plan + used, used < sizeof(plan) ? sizeof(plan) - used : 0, + "local %u:%s", + state->local_start, + local_end); + for (i = 0; i < path_len; i++) { + ds4_dist_worker_entry *w = path[i]; + if (used < sizeof(plan)) { + char end[32]; + if (w->has_output) snprintf(end, sizeof(end), "output"); + else snprintf(end, sizeof(end), "%u", w->layer_end); + used += (size_t)snprintf(plan + used, sizeof(plan) - used, + " -> %s:%u Q%u %u:%s", + w->peer_host, + w->listen_port, + w->quant_bits, + w->layer_start, + end); + } + } + if (complete && path_len != 0 && + !path[path_len - 1u]->has_output && state->local_can_output_head && + used < sizeof(plan)) { + used += (size_t)snprintf(plan + used, sizeof(plan) - used, + " -> local output"); + } + if (complete && path_len == 0 && + state->local_end == last && !state->local_has_output && + state->local_can_output_head && used < sizeof(plan)) { + used += (size_t)snprintf(plan + used, sizeof(plan) - used, + " -> local output"); + } + } else { + for (i = 0; i < path_len; i++) { + ds4_dist_worker_entry *w = path[i]; + if (used < sizeof(plan)) { + char end[32]; + if (w->has_output) snprintf(end, sizeof(end), "output"); + else snprintf(end, sizeof(end), "%u", w->layer_end); + used += (size_t)snprintf(plan + used, sizeof(plan) - used, + "%s%s:%u Q%u %u:%s", + i == 0 ? "" : " -> ", + w->peer_host, + w->listen_port, + w->quant_bits, + w->layer_start, + end); + } + } if (used < sizeof(plan)) { - char end[32]; - if (w->has_output) snprintf(end, sizeof(end), "output"); - else snprintf(end, sizeof(end), "%u", w->layer_end); used += (size_t)snprintf(plan + used, sizeof(plan) - used, - " -> %s:%u Q%u %u:%s", - w->peer_host, - w->listen_port, - w->quant_bits, - w->layer_start, - end); - } - } - if (complete && path_len != 0 && - !path[path_len - 1u]->has_output && state->local_can_output_head && - used < sizeof(plan)) { - used += (size_t)snprintf(plan + used, sizeof(plan) - used, - " -> local output"); - } - if (complete && path_len == 0 && - state->local_end == last && !state->local_has_output && - state->local_can_output_head && used < sizeof(plan)) { - used += (size_t)snprintf(plan + used, sizeof(plan) - used, - " -> local output"); - } - complete = complete && has_output && next == state->n_layers; + "%slocal %u:%s", + path_len != 0 ? " -> " : "", + state->local_start, + local_end); + } + } + complete = complete && has_output; pthread_mutex_unlock(&state->mu); if (complete) { - fprintf(stderr, "ds4: distributed coordinator: complete route ready: %s\n", plan); + fprintf(stderr, + "ds4: distributed coordinator: complete %s route ready: %s\n", + topology_name, + plan); } else { - fprintf(stderr, "ds4: distributed coordinator: route incomplete; next needed layer %u\n", missing); + fprintf(stderr, + "ds4: distributed coordinator: %s route incomplete; missing layer %u\n", + topology_name, + missing); } free(path); free(workers); @@ -2224,37 +2514,51 @@ static bool dist_coordinator_build_route_plan( qsort(workers, n, sizeof(workers[0]), dist_worker_route_cmp); const uint32_t last = state->n_layers - 1u; - if (state->local_start != 0) { + uint32_t required_start = 0, required_end = 0; + const bool needs_remote = + dist_coordinator_route_span(state, &required_start, &required_end); + const bool final_worker_may_output_logits = + dist_coordinator_route_final_worker_may_output_logits(state); + if (!needs_remote && + state->local_end == last && + (state->local_has_output || state->local_can_output_head)) { + if (generation) *generation = state->generation; pthread_mutex_unlock(&state->mu); free(workers); free(path); - if (errlen) snprintf(err, errlen, "coordinator route does not start at layer 0"); - return false; + return true; } - if (state->local_end == last && - (state->local_has_output || state->local_can_output_head)) { - if (generation) *generation = state->generation; + if (!needs_remote) { + const char *topology_name = dist_topology_name(state->topology); pthread_mutex_unlock(&state->mu); free(workers); free(path); - return true; + if (errlen) snprintf(err, errlen, "distributed %s route has no output owner", topology_name); + return false; } - uint32_t next = state->local_end + 1u; uint32_t path_len = 0; - uint32_t missing = next; + uint32_t missing = required_start; if (!dist_route_search_workers(state, workers, n, - next, - last, + required_start, + required_end, + final_worker_may_output_logits, path, &path_len, &missing)) { pthread_mutex_unlock(&state->mu); + dist_route_diagnose_failure(state, + workers, + n, + required_start, + required_end, + missing, + err, + errlen); free(workers); free(path); - if (errlen) snprintf(err, errlen, "distributed route incomplete: missing layer %u", missing); return false; } @@ -2267,7 +2571,11 @@ static bool dist_coordinator_build_route_plan( entry.port = w->listen_port; entry.layer_start = w->layer_start; entry.layer_end = w->layer_end; - entry.flags = w->has_output ? DS4_DIST_ROUTE_F_OUTPUT_LOGITS : 0u; + entry.flags = (i + 1u == path_len && + final_worker_may_output_logits && + w->has_output) + ? DS4_DIST_ROUTE_F_OUTPUT_LOGITS + : 0u; if (state->use_control_for_work && plan->count == 0) { entry.fd = dup(w->fd); if (entry.fd < 0) { @@ -2524,14 +2832,20 @@ static int dist_coordinator_send_remote_work_on_fd( work.n_tokens = n_tokens; work.layer_start = first->layer_start; work.layer_end = first->layer_end; - work.flags = DS4_DIST_WORK_F_INPUT_HC; + const bool input_hc_present = hidden_hc != NULL || hidden_hc_bytes != 0u; + if (input_hc_present && (!hidden_hc || hidden_hc_bytes == 0u)) { + if (errlen) snprintf(err, errlen, "distributed input hidden-state metadata mismatch"); + return 1; + } + work.flags = input_hc_present ? DS4_DIST_WORK_F_INPUT_HC : 0u; if (reset_session) work.flags |= DS4_DIST_WORK_F_RESET_SESSION; if (ack_only) work.flags |= DS4_DIST_WORK_F_ACK_ONLY; if ((first->flags & DS4_DIST_ROUTE_F_OUTPUT_LOGITS) != 0) { work.flags |= DS4_DIST_WORK_F_OUTPUT_LOGITS; } uint32_t wire_hidden_hc_bytes = 0; - if (!dist_activation_wire_bytes_from_f32_bytes(state->activation_bits, + if (input_hc_present && + !dist_activation_wire_bytes_from_f32_bytes(state->activation_bits, hidden_hc_bytes, &wire_hidden_hc_bytes)) { if (errlen) snprintf(err, errlen, "invalid distributed hidden-state size"); @@ -2539,7 +2853,7 @@ static int dist_coordinator_send_remote_work_on_fd( } work.token_bytes = n_tokens * sizeof(uint32_t); work.input_hc_bytes = wire_hidden_hc_bytes; - work.input_hc_bits = state->activation_bits; + work.input_hc_bits = input_hc_present ? state->activation_bits : 0u; work.route_count = plan->count; work.route_index = 0; work.route_bytes = plan->blob_bytes; @@ -2551,9 +2865,8 @@ static int dist_coordinator_send_remote_work_on_fd( return 0; } -static int dist_coordinator_eval_remote_on_fd( +static int dist_coordinator_request_remote_on_fd( ds4_dist_coordinator_state *state, - ds4_session *session, const ds4_dist_route_plan *plan, int fd, const int *tokens, @@ -2564,9 +2877,12 @@ static int dist_coordinator_eval_remote_on_fd( uint64_t prefix_hash, uint64_t expected_result_hash, bool reset_session, + bool ack_only, const float *hidden_hc, uint32_t hidden_hc_bytes, - float *logits, + uint32_t *kind, + void **payload, + uint32_t *payload_bytes, char *err, size_t errlen) { const bool profile = dist_decode_profile_enabled() && n_tokens == 1; @@ -2583,24 +2899,24 @@ static int dist_coordinator_eval_remote_on_fd( prefix_hash, expected_result_hash, reset_session, - false, + ack_only, hidden_hc, hidden_hc_bytes, err, errlen); const double send_t1 = profile ? dist_now_sec() : 0.0; - uint32_t kind = 0, payload_bytes = 0; + uint32_t result_kind = 0, result_payload_bytes = 0; uint64_t result_hash = 0; - void *payload = NULL; + void *result_payload = NULL; if (rc == 0) { const double recv_t0 = profile ? dist_now_sec() : 0.0; rc = dist_recv_result_alloc(fd, state, request_id, - &kind, + &result_kind, &result_hash, - &payload, - &payload_bytes, + &result_payload, + &result_payload_bytes, err, errlen); const double recv_t1 = profile ? dist_now_sec() : 0.0; @@ -2611,16 +2927,72 @@ static int dist_coordinator_eval_remote_on_fd( pos0, (send_t1 - send_t0) * 1000.0, (recv_t1 - recv_t0) * 1000.0, - kind, - (double)payload_bytes / (1024.0 * 1024.0)); + result_kind, + (double)result_payload_bytes / (1024.0 * 1024.0)); } } if (rc != 0) return rc; if (result_hash != expected_result_hash) { - free(payload); + free(result_payload); if (errlen) snprintf(err, errlen, "distributed result prefix hash mismatch"); return 1; } + if (kind) *kind = result_kind; + if (payload) *payload = result_payload; + else free(result_payload); + if (payload_bytes) *payload_bytes = result_payload_bytes; + if (profile) { + const double total_t1 = dist_now_sec(); + fprintf(stderr, + "ds4: dist decode profile: remote request=%llu total=%.3fms\n", + (unsigned long long)request_id, + (total_t1 - total_t0) * 1000.0); + } + return 0; +} + +static int dist_coordinator_eval_remote_on_fd( + ds4_dist_coordinator_state *state, + ds4_session *session, + const ds4_dist_route_plan *plan, + int fd, + const int *tokens, + uint32_t n_tokens, + uint32_t pos0, + uint64_t session_id, + uint64_t request_id, + uint64_t prefix_hash, + uint64_t expected_result_hash, + bool reset_session, + const float *hidden_hc, + uint32_t hidden_hc_bytes, + float *logits, + char *err, + size_t errlen) { + const bool profile = dist_decode_profile_enabled() && n_tokens == 1; + const double total_t0 = profile ? dist_now_sec() : 0.0; + uint32_t kind = 0, payload_bytes = 0; + void *payload = NULL; + int rc = dist_coordinator_request_remote_on_fd(state, + plan, + fd, + tokens, + n_tokens, + pos0, + session_id, + request_id, + prefix_hash, + expected_result_hash, + reset_session, + false, + hidden_hc, + hidden_hc_bytes, + &kind, + &payload, + &payload_bytes, + err, + errlen); + if (rc != 0) return rc; const uint32_t logits_bytes = (uint32_t)((uint64_t)ds4_engine_vocab_size(state->engine) * sizeof(float)); if (kind == DS4_DIST_RESULT_LOGITS && payload_bytes == logits_bytes) { @@ -2667,6 +3039,47 @@ static int dist_coordinator_eval_remote_on_fd( return 1; } +static int dist_coordinator_eval_local_suffix( + ds4_dist_coordinator_state *state, + ds4_session *session, + const int *tokens, + uint32_t n_tokens, + uint32_t pos0, + bool reset_session, + const float *input_hc, + uint32_t input_hc_bytes, + bool output_logits, + float *logits, + char *err, + size_t errlen) { + const uint64_t hc_values = ds4_engine_hidden_f32_values(state->engine); + const uint64_t expected_bytes64 = (uint64_t)n_tokens * hc_values * sizeof(float); + if (expected_bytes64 > UINT32_MAX) { + if (errlen) snprintf(err, errlen, "distributed coordinator hidden-state chunk is too large"); + return 1; + } + if (!input_hc || input_hc_bytes != (uint32_t)expected_bytes64) { + if (errlen) snprintf(err, errlen, "distributed route returned invalid hidden-state size"); + return 1; + } + if (reset_session && + ds4_session_layer_slice_reset(session, err, errlen) != 0) { + return 1; + } + return ds4_session_eval_layer_slice(session, + tokens, + n_tokens, + pos0, + state->local_start, + state->local_end, + input_hc, + NULL, + output_logits, + logits, + err, + errlen); +} + static int dist_coordinator_eval_span( ds4_dist_coordinator_state *state, ds4_session *session, @@ -2703,34 +3116,88 @@ static int dist_coordinator_eval_span( } const uint64_t result_hash = dist_token_hash_update_span(prefix_hash, tokens, n_tokens); const uint32_t hidden_bytes = (uint32_t)hidden_bytes64; - float *hidden = NULL; - if (plan->count != 0) { - hidden = malloc(hidden_bytes); - if (!hidden) { - if (errlen) snprintf(err, errlen, "out of memory allocating coordinator hidden-state"); - return 1; - } - } - if (reset_session && - ds4_session_layer_slice_reset(session, err, errlen) != 0) { - free(hidden); - return 1; - } - - const bool local_logits = plan->count == 0; int remote_fd = -1; if (plan->count != 0) { const ds4_dist_route_entry *first = &plan->entry[0]; remote_fd = first->fd; if (remote_fd < 0) { if (errlen) snprintf(err, errlen, "distributed route has no live first-hop connection"); - free(hidden); return 1; } } - const double local_t0 = profile ? dist_now_sec() : 0.0; - int rc = ds4_session_eval_layer_slice(session, + int rc = 0; + double local_t0 = 0.0, local_t1 = 0.0; + double remote_t0 = 0.0, remote_t1 = 0.0; + if (state->topology == DS4_DIST_TOPOLOGY_REVERSE) { + if (plan->count == 0) { + if (errlen) snprintf(err, errlen, "reverse distributed route has no remote prefix worker"); + return 1; + } + uint32_t kind = 0, payload_bytes = 0; + void *payload = NULL; + remote_t0 = profile ? dist_now_sec() : 0.0; + rc = dist_coordinator_request_remote_on_fd(state, + plan, + remote_fd, + tokens, + n_tokens, + pos0, + session_id, + request_id, + prefix_hash, + result_hash, + reset_session, + false, + NULL, + 0u, + &kind, + &payload, + &payload_bytes, + err, + errlen); + remote_t1 = profile ? dist_now_sec() : 0.0; + if (rc == 0) { + if (kind != DS4_DIST_RESULT_HIDDEN_STATE) { + free(payload); + if (errlen) snprintf(err, errlen, "reverse distributed decode requires final hidden-state"); + rc = 1; + } else { + local_t0 = profile ? dist_now_sec() : 0.0; + rc = dist_coordinator_eval_local_suffix(state, + session, + tokens, + n_tokens, + pos0, + reset_session, + payload, + payload_bytes, + true, + logits, + err, + errlen); + local_t1 = profile ? dist_now_sec() : 0.0; + free(payload); + } + } + } else { + float *hidden = NULL; + if (plan->count != 0) { + hidden = malloc(hidden_bytes); + if (!hidden) { + if (errlen) snprintf(err, errlen, "out of memory allocating coordinator hidden-state"); + return 1; + } + } + if (reset_session && + ds4_session_layer_slice_reset(session, err, errlen) != 0) { + free(hidden); + return 1; + } + + const bool local_logits = plan->count == 0; + local_t0 = profile ? dist_now_sec() : 0.0; + rc = ds4_session_eval_layer_slice(session, tokens, n_tokens, pos0, @@ -2742,28 +3209,29 @@ static int dist_coordinator_eval_span( local_logits ? logits : NULL, err, errlen); - const double local_t1 = profile ? dist_now_sec() : 0.0; - double remote_t0 = 0.0, remote_t1 = 0.0; - if (rc == 0 && plan->count != 0) { - remote_t0 = profile ? dist_now_sec() : 0.0; - rc = dist_coordinator_eval_remote_on_fd(state, - session, - plan, - remote_fd, - tokens, - n_tokens, - pos0, - session_id, - request_id, - prefix_hash, - result_hash, - reset_session, - hidden, - hidden_bytes, - logits, - err, - errlen); - remote_t1 = profile ? dist_now_sec() : 0.0; + local_t1 = profile ? dist_now_sec() : 0.0; + if (rc == 0 && plan->count != 0) { + remote_t0 = profile ? dist_now_sec() : 0.0; + rc = dist_coordinator_eval_remote_on_fd(state, + session, + plan, + remote_fd, + tokens, + n_tokens, + pos0, + session_id, + request_id, + prefix_hash, + result_hash, + reset_session, + hidden, + hidden_bytes, + logits, + err, + errlen); + remote_t1 = profile ? dist_now_sec() : 0.0; + } + free(hidden); } if (profile) { const double span_t1 = dist_now_sec(); @@ -2779,7 +3247,6 @@ static int dist_coordinator_eval_span( (double)hidden_bytes / (1024.0 * 1024.0), rc); } - free(hidden); return rc; } @@ -3231,7 +3698,7 @@ static void *dist_prefill_sender_main(void *arg) { slot->result_hash, slot->reset_session, slot->ack_only, - slot->hidden, + slot->hidden_bytes ? slot->hidden : NULL, slot->hidden_bytes, send_err, sizeof(send_err)); @@ -3276,6 +3743,15 @@ static void dist_prefill_reader_signal_progress( pthread_mutex_unlock(&reader->progress_mu); } +static void dist_prefill_reader_note_received( + ds4_dist_prefill_result_reader *reader, + uint32_t received) { + pthread_mutex_lock(&reader->progress_mu); + if (received > reader->progress_received) reader->progress_received = received; + pthread_cond_broadcast(&reader->progress_cv); + pthread_mutex_unlock(&reader->progress_mu); +} + static void dist_prefill_reader_emit_progress( ds4_dist_prefill_result_reader *reader, uint32_t *reported) { @@ -3326,7 +3802,7 @@ static bool dist_prefill_reader_wait_flow_window( for (;;) { pthread_mutex_lock(&reader->progress_mu); - const uint32_t completed = reader->progress_completed; + const uint32_t completed = reader->progress_received; const bool done = reader->progress_done; const bool has_room = submitted < completed + window; if (done || has_room) { @@ -3334,10 +3810,146 @@ static bool dist_prefill_reader_wait_flow_window( dist_prefill_reader_emit_progress(reader, reported); return !done && has_room; } - pthread_cond_wait(&reader->progress_cv, &reader->progress_mu); - pthread_mutex_unlock(&reader->progress_mu); - dist_prefill_reader_emit_progress(reader, reported); + pthread_cond_wait(&reader->progress_cv, &reader->progress_mu); + pthread_mutex_unlock(&reader->progress_mu); + dist_prefill_reader_emit_progress(reader, reported); + } +} + +static int dist_prefill_reverse_queue_init( + ds4_dist_prefill_result_reader *reader, + uint32_t slot_count, + char *err, + size_t errlen) { + if (!reader) return 1; + if (slot_count == 0) slot_count = 1; + reader->reverse_slots = calloc(slot_count, sizeof(reader->reverse_slots[0])); + if (!reader->reverse_slots) { + if (errlen) snprintf(err, errlen, "out of memory allocating reverse prefill queue"); + return 1; + } + reader->reverse_slot_count = slot_count; + pthread_mutex_init(&reader->reverse_mu, NULL); + pthread_cond_init(&reader->reverse_can_enqueue, NULL); + pthread_cond_init(&reader->reverse_can_dequeue, NULL); + return 0; +} + +static void dist_prefill_reverse_queue_destroy(ds4_dist_prefill_result_reader *reader) { + if (!reader) return; + if (reader->reverse_slots) { + for (uint32_t i = 0; i < reader->reverse_slot_count; i++) { + free(reader->reverse_slots[i].payload); + } + free(reader->reverse_slots); + reader->reverse_slots = NULL; + } + if (reader->reverse_slot_count != 0) { + pthread_cond_destroy(&reader->reverse_can_dequeue); + pthread_cond_destroy(&reader->reverse_can_enqueue); + pthread_mutex_destroy(&reader->reverse_mu); + reader->reverse_slot_count = 0; + } +} + +static int dist_prefill_reverse_enqueue( + ds4_dist_prefill_result_reader *reader, + uint32_t chunk_index, + uint32_t pos, + uint32_t n_tokens, + bool reset_session, + bool output_logits, + void *payload, + uint32_t payload_bytes, + char *err, + size_t errlen) { + pthread_mutex_lock(&reader->reverse_mu); + while (reader->reverse_queued == reader->reverse_slot_count && reader->rc == 0) { + pthread_cond_wait(&reader->reverse_can_enqueue, &reader->reverse_mu); + } + if (reader->rc != 0) { + if (errlen) snprintf(err, errlen, "%s", + reader->err[0] ? reader->err : "reverse prefill apply queue stopped"); + pthread_mutex_unlock(&reader->reverse_mu); + return 1; + } + ds4_dist_prefill_reverse_slot *slot = &reader->reverse_slots[reader->reverse_tail]; + slot->chunk_index = chunk_index; + slot->pos = pos; + slot->n_tokens = n_tokens; + slot->payload_bytes = payload_bytes; + slot->reset_session = reset_session; + slot->output_logits = output_logits; + slot->payload = payload; + reader->reverse_tail = (reader->reverse_tail + 1u) % reader->reverse_slot_count; + reader->reverse_queued++; + pthread_cond_signal(&reader->reverse_can_dequeue); + pthread_mutex_unlock(&reader->reverse_mu); + return 0; +} + +static void dist_prefill_reverse_finish(ds4_dist_prefill_result_reader *reader) { + pthread_mutex_lock(&reader->reverse_mu); + reader->reverse_producer_done = true; + pthread_cond_broadcast(&reader->reverse_can_dequeue); + pthread_mutex_unlock(&reader->reverse_mu); +} + +static void dist_prefill_reverse_cancel(ds4_dist_prefill_result_reader *reader) { + pthread_mutex_lock(&reader->reverse_mu); + reader->reverse_producer_done = true; + pthread_cond_broadcast(&reader->reverse_can_enqueue); + pthread_cond_broadcast(&reader->reverse_can_dequeue); + pthread_mutex_unlock(&reader->reverse_mu); +} + +static void *dist_prefill_reverse_apply_main(void *arg) { + ds4_dist_prefill_result_reader *reader = arg; + for (;;) { + pthread_mutex_lock(&reader->reverse_mu); + while (reader->reverse_queued == 0 && !reader->reverse_producer_done && reader->rc == 0) { + pthread_cond_wait(&reader->reverse_can_dequeue, &reader->reverse_mu); + } + if (reader->rc != 0 || (reader->reverse_queued == 0 && reader->reverse_producer_done)) { + pthread_mutex_unlock(&reader->reverse_mu); + break; + } + ds4_dist_prefill_reverse_slot slot = reader->reverse_slots[reader->reverse_head]; + memset(&reader->reverse_slots[reader->reverse_head], 0, sizeof(reader->reverse_slots[reader->reverse_head])); + reader->reverse_head = (reader->reverse_head + 1u) % reader->reverse_slot_count; + reader->reverse_queued--; + pthread_cond_signal(&reader->reverse_can_enqueue); + pthread_mutex_unlock(&reader->reverse_mu); + + const double local_t0 = dist_now_sec(); + int local_rc = dist_coordinator_eval_local_suffix(reader->state, + reader->session, + reader->prompt->v + slot.pos, + slot.n_tokens, + reader->progress_base + slot.pos, + slot.reset_session, + slot.payload, + slot.payload_bytes, + slot.output_logits, + slot.output_logits ? reader->logits : NULL, + reader->err, + sizeof(reader->err)); + const double local_t1 = dist_now_sec(); + reader->local_eval_sec += local_t1 - local_t0; + free(slot.payload); + if (local_rc != 0) { + reader->rc = local_rc; + shutdown(reader->fd, SHUT_RDWR); + dist_prefill_reverse_cancel(reader); + dist_prefill_reader_signal_progress(reader, slot.chunk_index, true); + return NULL; + } + dist_prefill_reader_signal_progress(reader, slot.chunk_index + 1u, false); } + if (reader->rc == 0) { + dist_prefill_reader_signal_progress(reader, reader->count, true); + } + return NULL; } static void *dist_prefill_result_reader_main(void *arg) { @@ -3347,6 +3959,20 @@ static void *dist_prefill_result_reader_main(void *arg) { reader->final_kind = 0; reader->final_payload = NULL; reader->final_payload_bytes = 0; + reader->local_eval_sec = 0.0; + reader->progress_received = 0; + + pthread_t reverse_tid; + bool reverse_started = false; + if (reader->reverse_apply_local_suffix) { + if (pthread_create(&reverse_tid, NULL, dist_prefill_reverse_apply_main, reader) != 0) { + reader->rc = 1; + snprintf(reader->err, sizeof(reader->err), "failed to start reverse prefill apply worker"); + dist_prefill_reader_signal_progress(reader, 0, true); + return NULL; + } + reverse_started = true; + } const uint32_t logits_bytes = (uint32_t)((uint64_t)ds4_engine_vocab_size(reader->state->engine) * sizeof(float)); @@ -3387,10 +4013,13 @@ static void *dist_prefill_result_reader_main(void *arg) { const uint32_t chunk = remaining < reader->chunk_cap ? remaining : reader->chunk_cap; const uint64_t hidden_bytes64 = (uint64_t)chunk * reader->hc_values * sizeof(float); const bool final_chunk = i + 1u == reader->count; - const bool valid_ack = !final_chunk && + const bool valid_ack = !reader->reverse_apply_local_suffix && + !final_chunk && kind == DS4_DIST_RESULT_ACK && payload_bytes == 0; - const bool valid_logits = kind == DS4_DIST_RESULT_LOGITS && payload_bytes == logits_bytes; + const bool valid_logits = !reader->reverse_apply_local_suffix && + kind == DS4_DIST_RESULT_LOGITS && + payload_bytes == logits_bytes; const bool valid_hidden = reader->allow_hidden && hidden_bytes64 <= UINT32_MAX && kind == DS4_DIST_RESULT_HIDDEN_STATE && @@ -3405,14 +4034,47 @@ static void *dist_prefill_result_reader_main(void *arg) { dist_prefill_reader_signal_progress(reader, i, true); return NULL; } - if (final_chunk) { + if (reader->reverse_apply_local_suffix) { + const bool reset_session = reader->reset_first_chunk && i == 0; + const bool output_logits = final_chunk; + if (dist_prefill_reverse_enqueue(reader, + i, + pos0, + chunk, + reset_session, + output_logits, + payload, + payload_bytes, + reader->err, + sizeof(reader->err)) != 0) { + reader->rc = 1; + free(payload); + shutdown(reader->fd, SHUT_RDWR); + dist_prefill_reader_signal_progress(reader, i, true); + break; + } + payload = NULL; + dist_prefill_reader_note_received(reader, i + 1u); + } else if (final_chunk) { reader->final_kind = kind; reader->final_payload = payload; reader->final_payload_bytes = payload_bytes; payload = NULL; } free(payload); - dist_prefill_reader_signal_progress(reader, i + 1u, final_chunk); + if (!reader->reverse_apply_local_suffix) { + dist_prefill_reader_note_received(reader, i + 1u); + dist_prefill_reader_signal_progress(reader, i + 1u, final_chunk); + } + } + if (reader->reverse_apply_local_suffix) { + dist_prefill_reverse_finish(reader); + if (reverse_started) pthread_join(reverse_tid, NULL); + if (reader->rc != 0) { + dist_prefill_reverse_cancel(reader); + dist_prefill_reader_signal_progress(reader, reader->progress_completed, true); + } + return NULL; } dist_prefill_reader_signal_progress(reader, reader->count, true); return NULL; @@ -3430,6 +4092,9 @@ static bool dist_coordinator_can_pipeline_prefill( if (chunk_cap == 0 || n_tokens <= chunk_cap) return false; if (plan->count == 0) return false; if (plan->entry[0].fd < 0) return false; + if (state->topology == DS4_DIST_TOPOLOGY_REVERSE) { + return state->local_can_output_head; + } const ds4_dist_route_entry *final = &plan->entry[plan->count - 1u]; if ((final->flags & DS4_DIST_ROUTE_F_OUTPUT_LOGITS) == 0) { return final->layer_end + 1u == state->n_layers && @@ -3458,7 +4123,20 @@ static int dist_coordinator_prefill_chunk_cap( return 1; } } - if (requested == 0) requested = prefill_cap; + if (requested == 0) { + requested = prefill_cap; + if (state && + state->topology == DS4_DIST_TOPOLOGY_REVERSE && + requested > 2048u) { + /* Reverse prefill has two serialized GPU stages: worker prefix and + * coordinator suffix. Extremely large chunks leave too much + * fill/drain overhead on the table, while tiny chunks drown in + * per-chunk launch cost. Keep the CLI/env override, but use a + * smaller default chunk on reverse routes so the pipeline has + * enough chunks to overlap. */ + requested = 2048u; + } + } if (requested > prefill_cap) { if (errlen) { snprintf(err, @@ -3569,8 +4247,11 @@ static int dist_coordinator_prefill_prompt_pipelined( ds4_dist_prefill_result_reader reader; memset(&reader, 0, sizeof(reader)); reader.state = state; + reader.session = session; + reader.prompt = prompt; reader.fd = plan->entry[0].fd; reader.progress_session = session; + reader.logits = logits; reader.first_request_id = *request_id; reader.count = chunk_count; reader.total_tokens = total; @@ -3578,12 +4259,22 @@ static int dist_coordinator_prefill_prompt_pipelined( reader.progress_base = span_start; reader.progress_total = (uint32_t)prompt->len; reader.hc_values = hc_values; - reader.allow_hidden = + reader.reverse_apply_local_suffix = state->topology == DS4_DIST_TOPOLOGY_REVERSE; + reader.reset_first_chunk = reset_first_chunk; + reader.allow_hidden = reader.reverse_apply_local_suffix || (plan->entry[plan->count - 1u].flags & DS4_DIST_ROUTE_F_OUTPUT_LOGITS) == 0; pthread_mutex_init(&reader.progress_mu, NULL); pthread_cond_init(&reader.progress_cv, NULL); + if (reader.reverse_apply_local_suffix && + dist_prefill_reverse_queue_init(&reader, flow_window, err, errlen) != 0) { + pthread_cond_destroy(&reader.progress_cv); + pthread_mutex_destroy(&reader.progress_mu); + dist_prefill_sender_destroy(&sender); + return 1; + } reader.expected_hashes = calloc(chunk_count, sizeof(reader.expected_hashes[0])); if (!reader.expected_hashes) { + dist_prefill_reverse_queue_destroy(&reader); pthread_cond_destroy(&reader.progress_cv); pthread_mutex_destroy(&reader.progress_mu); dist_prefill_sender_destroy(&sender); @@ -3604,6 +4295,7 @@ static int dist_coordinator_prefill_prompt_pipelined( pthread_t reader_tid; if (pthread_create(&reader_tid, NULL, dist_prefill_result_reader_main, &reader) != 0) { free(reader.expected_hashes); + dist_prefill_reverse_queue_destroy(&reader); pthread_cond_destroy(&reader.progress_cv); pthread_mutex_destroy(&reader.progress_mu); dist_prefill_sender_destroy(&sender); @@ -3615,6 +4307,7 @@ static int dist_coordinator_prefill_prompt_pipelined( dist_prefill_sender_cancel(&sender); pthread_join(reader_tid, NULL); free(reader.expected_hashes); + dist_prefill_reverse_queue_destroy(&reader); pthread_cond_destroy(&reader.progress_cv); pthread_mutex_destroy(&reader.progress_mu); dist_prefill_sender_destroy(&sender); @@ -3634,6 +4327,7 @@ static int dist_coordinator_prefill_prompt_pipelined( flow_window); int rc = 0; + const bool reverse_topology = state->topology == DS4_DIST_TOPOLOGY_REVERSE; double local_eval_sec = 0.0; const double pipeline_t0 = dist_now_sec(); uint32_t pos = span_start; @@ -3659,37 +4353,40 @@ static int dist_coordinator_prefill_prompt_pipelined( rc = 1; break; } - if (pos == span_start && - reset_first_chunk && - ds4_session_layer_slice_reset(session, err, errlen) != 0) { - rc = 1; - break; + if (!reverse_topology) { + if (pos == span_start && + reset_first_chunk && + ds4_session_layer_slice_reset(session, err, errlen) != 0) { + rc = 1; + break; + } + const double local_t0 = dist_now_sec(); + rc = ds4_session_eval_layer_slice(session, + prompt->v + pos, + chunk, + pos, + state->local_start, + state->local_end, + NULL, + slot->hidden, + false, + NULL, + err, + errlen); + const double local_t1 = dist_now_sec(); + local_eval_sec += local_t1 - local_t0; + if (rc != 0) break; } - const double local_t0 = dist_now_sec(); - rc = ds4_session_eval_layer_slice(session, - prompt->v + pos, - chunk, - pos, - state->local_start, - state->local_end, - NULL, - slot->hidden, - false, - NULL, - err, - errlen); - const double local_t1 = dist_now_sec(); - local_eval_sec += local_t1 - local_t0; - if (rc != 0) break; slot->pos = pos; slot->n_tokens = chunk; - slot->hidden_bytes = hidden_bytes; + slot->hidden_bytes = reverse_topology ? 0u : hidden_bytes; slot->request_id = *request_id; slot->prefix_hash = next_prefix_hash; slot->result_hash = reader.expected_hashes[submitted_chunks]; slot->reset_session = reset_first_chunk && pos == span_start; - slot->ack_only = !getenv("DS4_DIST_DISABLE_PREFILL_ACK_ONLY") && + slot->ack_only = !reverse_topology && + !getenv("DS4_DIST_DISABLE_PREFILL_ACK_ONLY") && pos + chunk < span_end; rc = dist_prefill_sender_enqueue_slot(&sender, err, errlen); if (rc != 0) break; @@ -3722,12 +4419,13 @@ static int dist_coordinator_prefill_prompt_pipelined( if (rc == 0 && reader.rc == 0) { const double total_sec = pipeline_t1 - pipeline_t0; DIST_COORD_DEBUG(state, - "ds4: distributed coordinator: pipelined prefill done tokens=%u chunks=%u total=%.3fs %.2f t/s local=%.3fs send=%.3fs %.2f MiB/s\n", + "ds4: distributed coordinator: pipelined prefill done topology=%s tokens=%u chunks=%u total=%.3fs %.2f t/s local=%.3fs send=%.3fs %.2f MiB/s\n", + dist_topology_name(state->topology), total, chunk_count, total_sec, total_sec > 0.0 ? (double)total / total_sec : 0.0, - local_eval_sec, + local_eval_sec + reader.local_eval_sec, sender.send_sec, sender.send_sec > 0.0 ? ((double)sender.send_bytes / (1024.0 * 1024.0)) / sender.send_sec @@ -3738,6 +4436,7 @@ static int dist_coordinator_prefill_prompt_pipelined( int reader_rc = reader.rc; free(reader.final_payload); free(reader.expected_hashes); + dist_prefill_reverse_queue_destroy(&reader); dist_prefill_sender_destroy(&sender); pthread_cond_destroy(&reader.progress_cv); pthread_mutex_destroy(&reader.progress_mu); @@ -3745,12 +4444,17 @@ static int dist_coordinator_prefill_prompt_pipelined( } dist_prefill_sender_destroy(&sender); free(reader.expected_hashes); + dist_prefill_reverse_queue_destroy(&reader); pthread_cond_destroy(&reader.progress_cv); pthread_mutex_destroy(&reader.progress_mu); if (rc != 0) { free(reader.final_payload); return 1; } + if (reader.reverse_apply_local_suffix) { + free(reader.final_payload); + return 0; + } const uint32_t logits_bytes = (uint32_t)((uint64_t)ds4_engine_vocab_size(state->engine) * sizeof(float)); if (reader.final_kind == DS4_DIST_RESULT_LOGITS && @@ -4853,57 +5557,309 @@ static int dist_kv_write_layer_header( return 0; } -static uint32_t dist_kv_route_shard_count(const ds4_dist_session *d) { +static uint32_t dist_kv_route_owner_count(const ds4_dist_session *d) { return d ? 1u + d->plan.count : 0u; } -static void dist_kv_route_shard( - const ds4_dist_session *d, - uint32_t shard, - uint32_t *layer_start, - uint32_t *layer_end, - const ds4_dist_route_entry **entry) { - if (entry) *entry = NULL; - if (shard == 0) { - if (layer_start) *layer_start = d->state.local_start; - if (layer_end) *layer_end = d->state.local_end; - return; - } - const ds4_dist_route_entry *e = &d->plan.entry[shard - 1u]; - if (layer_start) *layer_start = e->layer_start; - if (layer_end) *layer_end = e->layer_end; - if (entry) *entry = e; +static int dist_kv_route_owner_cmp(const void *ap, const void *bp) { + const ds4_dist_kv_route_owner *a = ap; + const ds4_dist_kv_route_owner *b = bp; + if (a->layer_start < b->layer_start) return -1; + if (a->layer_start > b->layer_start) return 1; + if (a->layer_end < b->layer_end) return -1; + if (a->layer_end > b->layer_end) return 1; + if (a->is_local != b->is_local) return a->is_local ? -1 : 1; + return 0; } -static int dist_kv_route_validate( +static int dist_kv_route_build_owners( const ds4_dist_session *d, + ds4_dist_kv_route_owner *owners, + uint32_t owner_count, char *err, size_t errlen) { - if (!d || d->state.n_layers == 0 || - d->state.local_start != 0 || - d->state.local_end >= d->state.n_layers) { - if (errlen) snprintf(err, errlen, "distributed KV route does not start at layer 0"); + if (!d || !owners || d->state.n_layers == 0 || + d->state.local_start > d->state.local_end || + d->state.local_end >= d->state.n_layers || + owner_count != dist_kv_route_owner_count(d)) { + if (errlen) snprintf(err, errlen, "invalid distributed KV route"); return 1; } - uint32_t prev = d->state.local_end; + + owners[0].layer_start = d->state.local_start; + owners[0].layer_end = d->state.local_end; + owners[0].entry = NULL; + owners[0].is_local = true; for (uint32_t i = 0; i < d->plan.count; i++) { - const ds4_dist_route_entry *e = &d->plan.entry[i]; - if (prev == UINT32_MAX || - e->layer_start != prev + 1u || - e->layer_end < e->layer_start || - e->layer_end >= d->state.n_layers) { + owners[i + 1u].layer_start = d->plan.entry[i].layer_start; + owners[i + 1u].layer_end = d->plan.entry[i].layer_end; + owners[i + 1u].entry = &d->plan.entry[i]; + owners[i + 1u].is_local = false; + } + qsort(owners, owner_count, sizeof(owners[0]), dist_kv_route_owner_cmp); + + uint32_t prev_end = UINT32_MAX; + for (uint32_t i = 0; i < owner_count; i++) { + const ds4_dist_kv_route_owner *owner = &owners[i]; + if (owner->layer_end < owner->layer_start || + owner->layer_end >= d->state.n_layers) { + if (errlen) snprintf(err, errlen, "distributed KV route has invalid layer range"); + return 1; + } + if (i == 0) { + if (owner->layer_start != 0) { + if (errlen) snprintf(err, errlen, "distributed KV route does not cover layer 0"); + return 1; + } + } else if (owner->layer_start != prev_end + 1u) { if (errlen) snprintf(err, errlen, "distributed KV route is not contiguous"); return 1; } - prev = e->layer_end; + prev_end = owner->layer_end; } - if (prev + 1u != d->state.n_layers) { + if (prev_end + 1u != d->state.n_layers) { if (errlen) snprintf(err, errlen, "distributed KV route does not cover all layers"); return 1; } return 0; } +static bool dist_session_supports_local_decode(const ds4_dist_session *d) { + return d && + d->state.local_decode_requested && + d->state.topology == DS4_DIST_TOPOLOGY_REVERSE && + d->state.local_has_output && + d->state.local_start > 0u; +} + +static int dist_session_activate_local_decode( + ds4_dist_session *d, + ds4_session *owner, + const ds4_tokens *tokens, + char *err, + size_t errlen) { + if (!d || !owner || !tokens || tokens->len <= 0) { + if (errlen) snprintf(err, errlen, "invalid local decode activation request"); + return 1; + } + if (d->local_decode_active) return 0; + if (!dist_session_supports_local_decode(d)) { + if (errlen) snprintf(err, errlen, "distributed route does not support coordinator local decode"); + return 1; + } + if (!d->plan_ready && dist_session_ensure_route(d, err, errlen) != 0) return 1; + + const uint32_t owner_count = dist_kv_route_owner_count(d); + ds4_dist_kv_route_owner *owners = calloc(owner_count, sizeof(owners[0])); + if (!owners) { + if (errlen) snprintf(err, errlen, "out of memory activating coordinator local decode"); + return 1; + } + if (dist_kv_route_build_owners(d, owners, owner_count, err, errlen) != 0) { + free(owners); + return 1; + } + + const uint64_t token_hash = dist_token_hash_prefix(tokens->v, (uint32_t)tokens->len); + const double t0 = dist_now_sec(); + for (uint32_t i = 0; i < owner_count; i++) { + const ds4_dist_kv_route_owner *owner_desc = &owners[i]; + if (owner_desc->is_local) continue; + FILE *tmp = dist_tmpfile_or_err("coordinator local-decode shard", err, errlen); + if (!tmp) { + free(owners); + return 1; + } + uint64_t shard_bytes = 0; + int rc = dist_save_remote_shard_to_file(d, + owner_desc->entry, + tokens, + token_hash, + tmp, + &shard_bytes, + err, + errlen); + if (rc == 0 && shard_bytes != 0) { + rc = dist_rewind_file(tmp, "coordinator local-decode shard", err, errlen); + } + if (rc == 0) { + rc = ds4_session_load_layer_payload(owner, + tmp, + shard_bytes, + tokens->v, + (uint32_t)tokens->len, + owner_desc->layer_start, + owner_desc->layer_end, + err, + errlen); + } + fclose(tmp); + if (rc != 0) { + free(owners); + return 1; + } + } + free(owners); + d->local_decode_active = true; + d->local_decode_remote_flushable = true; + d->local_decode_remote_pos = (uint32_t)tokens->len; + DIST_COORD_DEBUG(&d->state, + "ds4: distributed coordinator: activated reverse-topology local decode tokens=%d local=%u:%u total=%.3fs\n", + tokens->len, + d->state.local_start, + d->state.local_end, + dist_now_sec() - t0); + return 0; +} + +static int dist_session_restore_distributed_checkpoint( + ds4_dist_session *d, + ds4_session *owner, + const ds4_tokens *checkpoint, + float *logits, + char *err, + size_t errlen) { + if (!d || !owner || !checkpoint || checkpoint->len <= 0 || !logits) { + if (errlen) snprintf(err, errlen, "invalid distributed checkpoint restore request"); + return 1; + } + if (!d->local_decode_active) return 0; + if (dist_coordinator_rebuild_from_transcript(&d->state, + owner, + &d->plan, + checkpoint, + d->session_id, + &d->request_id, + logits, + &d->plan_generation, + false, + err, + errlen) != 0) { + char rebuild_err[256]; + snprintf(rebuild_err, + sizeof(rebuild_err), + "%s", + err && err[0] ? err : "distributed checkpoint replay failed"); + if (dist_coordinator_rebuild_from_transcript(&d->state, + owner, + &d->plan, + checkpoint, + d->session_id, + &d->request_id, + logits, + &d->plan_generation, + true, + err, + errlen) != 0) { + if (errlen && (!err || !err[0])) { + snprintf(err, errlen, "%s", rebuild_err); + } + d->plan_ready = false; + d->plan_generation = 0; + return 1; + } + } + d->plan_ready = true; + d->local_decode_active = false; + d->local_decode_remote_flushable = false; + d->local_decode_remote_pos = 0; + DIST_COORD_DEBUG(&d->state, + "ds4: distributed coordinator: restored distributed checkpoint from local decode transcript tokens=%d\n", + checkpoint->len); + return 0; +} + +static int dist_session_flush_local_decode_remote( + ds4_dist_session *d, + ds4_session *owner, + char *err, + size_t errlen) { + if (!d || !owner) { + if (errlen) snprintf(err, errlen, "invalid local decode remote flush request"); + return 1; + } + if (!d->local_decode_active || !d->local_decode_remote_flushable) { + if (errlen) snprintf(err, errlen, "local decode remote flush is inactive"); + return 1; + } + if (d->state.topology != DS4_DIST_TOPOLOGY_REVERSE || d->plan.count == 0) { + if (errlen) snprintf(err, errlen, "local decode remote flush requires a reverse remote prefix route"); + return 1; + } + const int remote_fd = d->plan.entry[0].fd; + if (remote_fd < 0) { + if (errlen) snprintf(err, errlen, "reverse distributed route has no live remote prefix worker"); + return 1; + } + + const ds4_tokens *timeline = ds4_session_tokens(owner); + if (!timeline || timeline->len < 0 || (uint64_t)timeline->len > UINT32_MAX) { + if (errlen) snprintf(err, errlen, "local decode remote flush has no valid token timeline"); + return 1; + } + const uint32_t token_count = (uint32_t)timeline->len; + if (d->local_decode_remote_pos > token_count) { + if (errlen) snprintf(err, errlen, "local decode remote flush position exceeds current transcript"); + return 1; + } + + uint32_t chunk_cap = 0; + if (dist_coordinator_prefill_chunk_cap(&d->state, owner, &chunk_cap, err, errlen) != 0) { + return 1; + } + + uint32_t pos = d->local_decode_remote_pos; + while (pos < token_count) { + const uint32_t remaining = token_count - pos; + const uint32_t chunk = remaining < chunk_cap ? remaining : chunk_cap; + uint64_t prefix_hash = 0; + if (dist_session_token_hash_prefix(owner, pos, &prefix_hash, err, errlen) != 0) { + return 1; + } + const uint64_t result_hash = + dist_token_hash_update_span(prefix_hash, timeline->v + pos, chunk); + uint32_t kind = 0, payload_bytes = 0; + int rc = dist_coordinator_request_remote_on_fd(&d->state, + &d->plan, + remote_fd, + timeline->v + pos, + chunk, + pos, + d->session_id, + d->request_id++, + prefix_hash, + result_hash, + false, + true, + NULL, + 0u, + &kind, + NULL, + &payload_bytes, + err, + errlen); + if (rc != 0) return rc; + if (kind != DS4_DIST_RESULT_ACK || payload_bytes != 0) { + if (errlen) snprintf(err, errlen, "unexpected local decode remote flush result"); + return 1; + } + pos += chunk; + } + d->local_decode_remote_pos = token_count; + return 0; +} + +static int dist_session_eval_local_decode_token( + ds4_session *owner, + int token, + char *err, + size_t errlen) { + if (ds4_session_eval_local_only(owner, token, err, errlen) != 0) { + return 1; + } + return 0; +} + static void dist_kv_shards_close(ds4_dist_kv_shard_file *shards, uint32_t count) { if (!shards) return; for (uint32_t i = 0; i < count; i++) { @@ -5086,11 +6042,22 @@ int ds4_dist_session_save_payload( if (errlen) snprintf(err, errlen, "invalid distributed payload save"); return 1; } - if (dist_session_ensure_route(d, err, errlen) != 0) return 1; - if (dist_kv_route_validate(d, err, errlen) != 0) return 1; + if (!d->local_decode_active && + dist_session_ensure_route(d, err, errlen) != 0) return 1; + const uint32_t shard_count = dist_kv_route_owner_count(d); + ds4_dist_kv_route_owner *owners = calloc(shard_count, sizeof(owners[0])); + if (!owners) { + if (errlen) snprintf(err, errlen, "out of memory saving distributed KV route"); + return 1; + } + if (dist_kv_route_build_owners(d, owners, shard_count, err, errlen) != 0) { + free(owners); + return 1; + } const ds4_tokens *tokens = ds4_session_tokens(owner); if (!tokens || tokens->len < 0 || (uint64_t)tokens->len > UINT32_MAX) { + free(owners); if (errlen) snprintf(err, errlen, "distributed session has no valid token timeline"); return 1; } @@ -5099,20 +6066,22 @@ int ds4_dist_session_save_payload( const uint32_t vocab = (uint32_t)ds4_engine_vocab_size(d->state.engine); float *logits = malloc((size_t)vocab * sizeof(logits[0])); if (!logits) { + free(owners); if (errlen) snprintf(err, errlen, "out of memory saving distributed logits"); return 1; } if (ds4_session_copy_logits(owner, logits, (int)vocab) != (int)vocab) { + free(owners); free(logits); if (errlen) snprintf(err, errlen, "failed to copy distributed logits"); return 1; } - const uint32_t shard_count = dist_kv_route_shard_count(d); ds4_dist_kv_shard_file *shards = calloc(shard_count, sizeof(shards[0])); uint32_t *n_comp = calloc(d->state.n_layers, sizeof(n_comp[0])); uint32_t *n_index_comp = calloc(d->state.n_layers, sizeof(n_index_comp[0])); if (!shards || !n_comp || !n_index_comp) { + free(owners); free(logits); free(shards); free(n_comp); @@ -5126,12 +6095,12 @@ int ds4_dist_session_save_payload( bool layout_set = false; for (uint32_t shard = 0; shard < shard_count; shard++) { - uint32_t layer_start = 0, layer_end = 0; - const ds4_dist_route_entry *entry = NULL; - dist_kv_route_shard(d, shard, &layer_start, &layer_end, &entry); + const ds4_dist_kv_route_owner *owner_desc = &owners[shard]; + const uint32_t layer_start = owner_desc->layer_start; + const uint32_t layer_end = owner_desc->layer_end; shards[shard].fp = dist_tmpfile_or_err("distributed KV shard", err, errlen); if (!shards[shard].fp) goto cleanup; - if (shard == 0) { + if (owner_desc->is_local || d->local_decode_active) { if (ds4_session_save_layer_payload(owner, shards[shard].fp, layer_start, layer_end, err, errlen) != 0) @@ -5140,7 +6109,7 @@ int ds4_dist_session_save_payload( "distributed local KV shard", err, errlen) != 0) goto cleanup; } else { - if (dist_save_remote_shard_to_file(d, entry, tokens, token_hash, + if (dist_save_remote_shard_to_file(d, owner_desc->entry, tokens, token_hash, shards[shard].fp, &shards[shard].bytes, err, errlen) != 0) @@ -5202,6 +6171,7 @@ int ds4_dist_session_save_payload( cleanup: dist_kv_shards_close(shards, shard_count); + free(owners); free(shards); free(n_comp); free(n_index_comp); @@ -5221,16 +6191,28 @@ int ds4_dist_session_load_payload( return 1; } if (dist_session_ensure_route(d, err, errlen) != 0) return 1; - if (dist_kv_route_validate(d, err, errlen) != 0) return 1; + const uint32_t shard_count = dist_kv_route_owner_count(d); + ds4_dist_kv_route_owner *owners = calloc(shard_count, sizeof(owners[0])); + if (!owners) { + if (errlen) snprintf(err, errlen, "out of memory loading distributed KV route"); + return 1; + } + if (dist_kv_route_build_owners(d, owners, shard_count, err, errlen) != 0) { + free(owners); + return 1; + } uint64_t remaining = payload_bytes; uint32_t h[DS4_SESSION_PAYLOAD_U32_FIELDS]; for (uint32_t i = 0; i < DS4_SESSION_PAYLOAD_U32_FIELDS; i++) { - if (dist_payload_read_u32(fp, &h[i], &remaining, err, errlen) != 0) + if (dist_payload_read_u32(fp, &h[i], &remaining, err, errlen) != 0) { + free(owners); return 1; + } } if (h[0] != DS4_SESSION_PAYLOAD_MAGIC || h[1] != DS4_SESSION_PAYLOAD_VERSION) { + free(owners); if (errlen) snprintf(err, errlen, "unsupported DS4 KV payload version"); return 1; } @@ -5252,6 +6234,7 @@ int ds4_dist_session_load_payload( layout.token_count >= (uint32_t)ds4_session_ctx(owner) || layout.vocab != (uint32_t)ds4_engine_vocab_size(d->state.engine) || !dist_kv_raw_live_valid(&layout)) { + free(owners); if (errlen) snprintf(err, errlen, "DS4 KV payload does not match current distributed runtime"); return 1; } @@ -5262,6 +6245,7 @@ int ds4_dist_session_load_payload( uint32_t *n_comp = calloc(layout.n_layers, sizeof(n_comp[0])); uint32_t *n_index_comp = calloc(layout.n_layers, sizeof(n_index_comp[0])); if ((layout.token_count && !tokens) || !logits || !n_comp || !n_index_comp) { + free(owners); free(tokens); free(logits); free(n_comp); @@ -5272,6 +6256,7 @@ int ds4_dist_session_load_payload( for (uint32_t i = 0; i < layout.token_count; i++) { uint32_t tok = 0; if (dist_payload_read_u32(fp, &tok, &remaining, err, errlen) != 0) { + free(owners); free(tokens); free(logits); free(n_comp); @@ -5280,6 +6265,7 @@ int ds4_dist_session_load_payload( } if (tok > (uint32_t)INT_MAX || tok >= (uint32_t)ds4_engine_vocab_size(d->state.engine)) { + free(owners); free(tokens); free(logits); free(n_comp); @@ -5295,6 +6281,7 @@ int ds4_dist_session_load_payload( if (dist_payload_read_bytes(fp, logits, (uint64_t)layout.vocab * sizeof(logits[0]), &remaining, err, errlen) != 0) { + free(owners); free(tokens); free(logits); free(n_comp); @@ -5319,36 +6306,34 @@ int ds4_dist_session_load_payload( } } - const uint32_t shard_count = dist_kv_route_shard_count(d); for (uint32_t shard = 0; shard < shard_count; shard++) { - uint32_t layer_start = 0, layer_end = 0; - const ds4_dist_route_entry *entry = NULL; + const ds4_dist_kv_route_owner *owner_desc = &owners[shard]; FILE *tmp = NULL; uint64_t shard_bytes = 0; - dist_kv_route_shard(d, shard, &layer_start, &layer_end, &entry); if (dist_prepare_shard_from_session_payload(d, fp, &remaining, &layout, n_comp, n_index_comp, - layer_start, - layer_end, + owner_desc->layer_start, + owner_desc->layer_end, &tmp, &shard_bytes, err, errlen) != 0) goto cleanup; - if (shard == 0) { + if (owner_desc->is_local) { if (ds4_session_load_layer_payload(owner, tmp, shard_bytes, tokens_arg, layout.token_count, - layer_start, layer_end, + owner_desc->layer_start, + owner_desc->layer_end, err, errlen) != 0) { fclose(tmp); goto cleanup; } } else { - if (dist_load_remote_shard_from_payload(d, entry, + if (dist_load_remote_shard_from_payload(d, owner_desc->entry, tokens_arg, layout.token_count, token_hash, tmp, shard_bytes, @@ -5370,6 +6355,7 @@ int ds4_dist_session_load_payload( rc = 0; cleanup: + free(owners); free(tokens); free(logits); free(n_comp); @@ -5405,6 +6391,10 @@ int ds4_dist_session_create( return 1; } if (dist_validate_options(opt, err, errlen) != 0) return 1; + const uint32_t n_layers = (uint32_t)ds4_engine_layer_count(engine); + if (dist_validate_layers_for_model(opt, n_layers, err, errlen) != 0) return 1; + ds4_dist_topology topology = DS4_DIST_TOPOLOGY_FORWARD; + if (dist_infer_coordinator_topology(opt, n_layers, &topology, err, errlen) != 0) return 1; int listen_fd = dist_open_listener(opt->listen_host, opt->listen_port, err, errlen); if (listen_fd < 0) return 1; @@ -5419,12 +6409,14 @@ int ds4_dist_session_create( d->listen_fd = listen_fd; d->state.engine = engine; d->state.model_id = (uint32_t)ds4_engine_model_id(engine); - d->state.n_layers = (uint32_t)ds4_engine_layer_count(engine); + d->state.n_layers = n_layers; d->state.local_start = opt->layers.start; d->state.local_end = dist_resolved_layer_end(opt, d->state.n_layers); d->state.ctx_size = ctx_size > 0 ? (uint32_t)ctx_size : 0u; + d->state.topology = topology; d->state.local_has_output = opt->layers.has_output; d->state.local_can_output_head = ds4_engine_has_output_head(engine); + d->state.local_decode_requested = opt->local_decode; d->state.replay_check = opt->replay_check; d->state.debug = opt->debug; d->state.use_control_for_work = true; @@ -5438,16 +6430,20 @@ int ds4_dist_session_create( * WORK results are outstanding. Keep them out of the WORK request-id stream * so progress callbacks cannot perturb the reader's contiguous expectations. */ d->snapshot_request_id = UINT64_C(1) << 63; + d->local_decode_active = false; + d->local_decode_remote_flushable = false; + d->local_decode_remote_pos = 0; char local_end[32]; if (opt->layers.has_output) snprintf(local_end, sizeof(local_end), "output"); else snprintf(local_end, sizeof(local_end), "%u", opt->layers.end); DIST_COORD_DEBUG(&d->state, - "ds4: distributed coordinator API: listening on %s:%d model_id=%u layers=%u local=%u:%s activation_bits=%u\n", + "ds4: distributed coordinator API: listening on %s:%d model_id=%u layers=%u topology=%s local=%u:%s activation_bits=%u\n", opt->listen_host, opt->listen_port, d->state.model_id, d->state.n_layers, + dist_topology_name(d->state.topology), opt->layers.start, local_end, d->state.activation_bits); @@ -5513,6 +6509,44 @@ int ds4_dist_session_sync( if (errlen) snprintf(err, errlen, "invalid distributed sync request"); return 1; } + if (d->local_decode_active) { + if (checkpoint && + checkpoint->len >= 0 && + checkpoint->len <= prompt->len && + ds4_tokens_starts_with(prompt, checkpoint)) { + if (checkpoint->len == prompt->len) return 0; + const uint64_t plan_generation = d->plan_generation; + if (dist_session_ensure_route(d, err, errlen) != 0) return 1; + if (d->local_decode_remote_flushable && + d->plan_generation == plan_generation && + dist_session_flush_local_decode_remote(d, owner, err, errlen) == 0) { + d->local_decode_active = false; + d->local_decode_remote_flushable = false; + d->local_decode_remote_pos = 0; + } else { + if (d->local_decode_remote_flushable) { + DIST_COORD_DEBUG(&d->state, + "ds4: distributed coordinator: deferred local decode flush failed; rebuilding worker KV from transcript: %s\n", + err && err[0] ? err : "route changed"); + } + d->local_decode_remote_flushable = false; + d->local_decode_remote_pos = 0; + if (errlen) err[0] = '\0'; + if (dist_session_restore_distributed_checkpoint(d, + owner, + checkpoint, + logits, + err, + errlen) != 0) { + return 1; + } + } + } else { + d->local_decode_active = false; + d->local_decode_remote_flushable = false; + d->local_decode_remote_pos = 0; + } + } if (dist_session_ensure_route(d, err, errlen) != 0) return 1; if (checkpoint && @@ -5646,6 +6680,15 @@ int ds4_dist_session_eval( if (errlen) snprintf(err, errlen, "invalid distributed decode request"); return 1; } + if (d->local_decode_active) { + return dist_session_eval_local_decode_token(owner, token, err, errlen); + } + if (dist_session_supports_local_decode(d)) { + if (dist_session_activate_local_decode(d, owner, checkpoint, err, errlen) != 0) { + return 1; + } + return dist_session_eval_local_decode_token(owner, token, err, errlen); + } if (dist_session_ensure_route(d, err, errlen) != 0) return 1; ds4_tokens transcript = {0}; @@ -5694,6 +6737,12 @@ int ds4_dist_session_eval( static int dist_run_coordinator(ds4_engine *engine, const ds4_dist_options *opt, const ds4_dist_generation_options *gen) { char err[256]; + const uint32_t n_layers = (uint32_t)ds4_engine_layer_count(engine); + ds4_dist_topology topology = DS4_DIST_TOPOLOGY_FORWARD; + if (dist_infer_coordinator_topology(opt, n_layers, &topology, err, sizeof(err)) != 0) { + fprintf(stderr, "ds4: distributed coordinator: %s\n", err); + return 1; + } int listen_fd = dist_open_listener(opt->listen_host, opt->listen_port, err, sizeof(err)); if (listen_fd < 0) { fprintf(stderr, "ds4: distributed coordinator: %s\n", err); @@ -5704,10 +6753,11 @@ static int dist_run_coordinator(ds4_engine *engine, const ds4_dist_options *opt, memset(&state, 0, sizeof(state)); state.engine = engine; state.model_id = (uint32_t)ds4_engine_model_id(engine); - state.n_layers = (uint32_t)ds4_engine_layer_count(engine); + state.n_layers = n_layers; state.local_start = opt->layers.start; state.local_end = dist_resolved_layer_end(opt, state.n_layers); state.ctx_size = gen && gen->ctx_size > 0 ? (uint32_t)gen->ctx_size : 0u; + state.topology = topology; state.local_has_output = opt->layers.has_output; state.local_can_output_head = ds4_engine_has_output_head(engine); state.replay_check = opt->replay_check; @@ -5722,11 +6772,12 @@ static int dist_run_coordinator(ds4_engine *engine, const ds4_dist_options *opt, if (opt->layers.has_output) snprintf(local_end, sizeof(local_end), "output"); else snprintf(local_end, sizeof(local_end), "%u", opt->layers.end); DIST_COORD_DEBUG(&state, - "ds4: distributed coordinator: listening on %s:%d model_id=%u layers=%u local=%u:%s activation_bits=%u\n", + "ds4: distributed coordinator: listening on %s:%d model_id=%u layers=%u topology=%s local=%u:%s activation_bits=%u\n", opt->listen_host, opt->listen_port, state.model_id, state.n_layers, + dist_topology_name(state.topology), opt->layers.start, local_end, state.activation_bits); @@ -7329,6 +8380,7 @@ static int dist_worker_process_work_payload( ds4_dist_route_entry current_route; ds4_dist_route_entry next_route; + ds4_dist_route_entry final_route; const bool has_route = work.route_count != 0; const bool has_next = has_route && work.route_index + 1u < work.route_count; if (has_route) { @@ -7364,6 +8416,20 @@ static int dist_worker_process_work_payload( free(tokens); return dist_worker_upstream_send_work_error(upstream, request_id, err); } + if (!dist_route_get_entry(route_blob, work.route_bytes, work.route_count, + work.route_count - 1u, &final_route, err, sizeof(err))) { + free(route_blob); + free(tokens); + return dist_worker_upstream_send_work_error(upstream, request_id, err); + } + if (!input_hc_present && + (final_route.flags & DS4_DIST_ROUTE_F_OUTPUT_LOGITS) != 0) { + free(route_blob); + free(tokens); + return dist_worker_upstream_send_work_error(upstream, + request_id, + "remote-first distributed route must return hidden-state upstream"); + } } if (has_next && output_logits) { free(route_blob); @@ -8127,6 +9193,9 @@ void ds4_dist_usage(FILE *fp) { " Coordinator TCP listen address. Workers may later use it to force their data listener.\n" " --coordinator HOST PORT\n" " Coordinator TCP address for --role worker.\n" + " --local-decode\n" + " Coordinator-only opt-in: keep a reverse --layers N:output coordinator\n" + " fully resident and switch to local decode after distributed prefill.\n" " --dist-prefill-chunk N\n" " Coordinator prefill pipeline chunk size. Default: session cap, normally 4096.\n" " Non-default values are experimental and can change logits unless validated.\n" @@ -8212,6 +9281,14 @@ ds4_dist_cli_parse_result ds4_dist_parse_cli_arg( opt->coordinator_host = host; return DS4_DIST_CLI_MATCHED; } + if (!strcmp(arg, "--local-decode")) { + if (!opt) { + if (errlen) snprintf(err, errlen, "missing distributed options"); + return DS4_DIST_CLI_ERROR; + } + opt->local_decode = true; + return DS4_DIST_CLI_MATCHED; + } if (!strcmp(arg, "--dist-prefill-chunk")) { if (!opt) { if (errlen) snprintf(err, errlen, "missing distributed options"); @@ -8287,6 +9364,7 @@ static int dist_validate_options(const ds4_dist_options *opt, char *err, size_t if (opt->layers.set || opt->listen_host || opt->listen_port || opt->coordinator_host || opt->coordinator_port || opt->prefill_chunk != 0 || opt->prefill_window != 0 || + opt->local_decode || opt->activation_bits != 0) { if (errlen) snprintf(err, errlen, "distributed options require --role coordinator or --role worker"); return 1; @@ -8308,6 +9386,15 @@ static int dist_validate_options(const ds4_dist_options *opt, char *err, size_t } if (opt->role == DS4_DISTRIBUTED_COORDINATOR) { + if (opt->local_decode && + (!opt->layers.has_output || opt->layers.start == 0u)) { + if (errlen) { + snprintf(err, + errlen, + "--local-decode requires reverse --role coordinator --layers N:output"); + } + return 1; + } if (!opt->listen_host || opt->listen_port <= 0) { if (errlen) snprintf(err, errlen, "--role coordinator requires --listen HOST PORT"); return 1; @@ -8320,6 +9407,14 @@ static int dist_validate_options(const ds4_dist_options *opt, char *err, size_t } if (opt->role == DS4_DISTRIBUTED_WORKER) { + if (opt->local_decode) { + if (errlen) { + snprintf(err, + errlen, + "--local-decode requires --role coordinator"); + } + return 1; + } if (!opt->coordinator_host || opt->coordinator_port <= 0) { if (errlen) snprintf(err, errlen, "--role worker requires --coordinator HOST PORT"); return 1; @@ -8356,10 +9451,27 @@ int ds4_dist_prepare_engine_options( if (engine && opt) { engine->distributed = *opt; if (ds4_dist_enabled(opt)) { - engine->load_slice = true; - engine->load_layer_start = opt->layers.start; - engine->load_layer_end = opt->layers.has_output ? UINT32_MAX : opt->layers.end; - engine->load_output = opt->layers.has_output; + const bool reverse_coordinator = + opt->role == DS4_DISTRIBUTED_COORDINATOR && + opt->layers.set && + opt->layers.has_output && + opt->layers.start > 0u; + if (engine->prefill_chunk == 0 && reverse_coordinator) { + engine->prefill_chunk = 2048u; + engine->distributed.prefill_chunk = 2048u; + } + const bool reverse_coordinator_full_resident = + opt->role == DS4_DISTRIBUTED_COORDINATOR && + opt->local_decode && + opt->layers.set && + opt->layers.has_output && + opt->layers.start > 0u; + if (!reverse_coordinator_full_resident) { + engine->load_slice = true; + engine->load_layer_start = opt->layers.start; + engine->load_layer_end = opt->layers.has_output ? UINT32_MAX : opt->layers.end; + engine->load_output = opt->layers.has_output; + } } } return 0; @@ -8381,13 +9493,61 @@ static int dist_validate_layers_for_model(const ds4_dist_options *opt, uint32_t if (errlen) snprintf(err, errlen, "layer range ends past final model layer %u", last); return 1; } - if (opt->role == DS4_DISTRIBUTED_COORDINATOR && opt->layers.start != 0) { - if (errlen) snprintf(err, errlen, "coordinator layer range must start at layer 0"); - return 1; + if (opt->role == DS4_DISTRIBUTED_COORDINATOR) { + return dist_infer_coordinator_topology(opt, n_layers, NULL, err, errlen); } return 0; } +int ds4_dist_test_validate_coordinator_layers( + const ds4_dist_options *opt, + uint32_t n_layers, + char *err, + size_t errlen) { + return dist_validate_layers_for_model(opt, n_layers, err, errlen); +} + +int ds4_dist_test_infer_coordinator_topology( + const ds4_dist_options *opt, + uint32_t n_layers, + int *topology_out, + char *err, + size_t errlen) { + ds4_dist_topology topology = DS4_DIST_TOPOLOGY_FORWARD; + int rc = dist_infer_coordinator_topology(opt, n_layers, &topology, err, errlen); + if (rc == 0 && topology_out) *topology_out = (int)topology; + return rc; +} + +bool ds4_dist_test_build_route_plan( + ds4_dist_coordinator_state *state, + ds4_dist_route_plan *plan, + char *err, + size_t errlen) { + if (!state || !plan) { + if (errlen) snprintf(err, errlen, "invalid distributed test route request"); + return false; + } + return dist_coordinator_build_route_plan(state, plan, NULL, err, errlen); +} + +void ds4_dist_test_route_plan_free(ds4_dist_route_plan *plan) { + dist_route_plan_free(plan); +} + +uint32_t ds4_dist_test_kv_route_owner_count(const ds4_dist_session *d) { + return dist_kv_route_owner_count(d); +} + +int ds4_dist_test_kv_route_build_owners( + const ds4_dist_session *d, + ds4_dist_kv_route_owner *owners, + uint32_t owner_count, + char *err, + size_t errlen) { + return dist_kv_route_build_owners(d, owners, owner_count, err, errlen); +} + int ds4_dist_run(ds4_engine *engine, const ds4_dist_options *opt, const ds4_dist_generation_options *gen) { if (!engine || !opt) { fprintf(stderr, "ds4: distributed runtime requires an open engine and options\n"); diff --git a/ds4_help.c b/ds4_help.c index d32e088cf..0378b8c7f 100644 --- a/ds4_help.c +++ b/ds4_help.c @@ -167,7 +167,7 @@ static void print_model_runtime(FILE *fp, const help_colors *c, opt(fp, c, "--ssd-streaming-cache-experts N|NGB", "SSD streaming: routed expert cache as expert count or GiB, e.g. 32GB. Metal/ROCm default: 80% working set minus non-routed weights; CUDA default: backend fixed cache."); opt(fp, c, "--ssd-streaming-preload-experts N", "SSD streaming: upfront popularity preload count. Default: auto hot seed capped at 4096; use --ssd-streaming-cold to skip."); opt(fp, c, "--simulate-used-memory NGB", "Diagnostic: lock N GiB before model load to simulate a smaller-memory machine."); - opt(fp, c, "--prefill-chunk N", "Metal graph prefill chunk size. Default: auto (PRO long prompts use 8192; others use 4096)."); + opt(fp, c, "--prefill-chunk N", "Metal graph prefill chunk size. Default: auto (reverse distributed coordinators use 2048; PRO long prompts use 8192; others use 4096)."); if (full) { if (tool != DS4_HELP_BENCH) { opt(fp, c, "--mtp FILE", "Optional MTP support GGUF used for draft-token probes."); @@ -214,12 +214,13 @@ static void print_steering(FILE *fp, const help_colors *c) { static void print_distributed(FILE *fp, const help_colors *c) { title(fp, c, "Distributed Inference"); fputc('\n', fp); - para(fp, c, "Distributed mode runs one logical session across several machines by assigning contiguous model layer ranges to workers. Workers own their layer slice and KV-cache shard; the coordinator owns the prompt, sampling loop, and client/API flow. Start workers first, then start the coordinator. The coordinator waits for a complete route and streams hidden states through the workers."); + para(fp, c, "Distributed mode runs one logical session across several machines by assigning contiguous model layer ranges to workers. The coordinator may own either a leading prefix (0:K) or a final suffix through output (K:output). Workers own their layer slice and KV-cache shard; the coordinator owns the prompt, sampling loop, and client/API flow. Start workers first, then start the coordinator. Reverse K:42 is unsupported."); fputc('\n', fp); opt(fp, c, "--role ROLE", "Distributed role: coordinator or worker."); - opt(fp, c, "--layers A:B", "Inclusive layer slice, e.g. 0:20 or 21:output."); + opt(fp, c, "--layers A:B", "Inclusive layer slice, e.g. 0:20, 21:42, or 21:output."); opt(fp, c, "--listen HOST PORT", "Coordinator listen address; workers may use it for their data listener."); opt(fp, c, "--coordinator HOST PORT", "Coordinator address for --role worker."); + opt(fp, c, "--local-decode", "Coordinator-only opt-in for reverse N:output local decode with full local model residency."); opt(fp, c, "--dist-prefill-chunk N", "Coordinator prefill pipeline chunk size. Default: session cap."); opt(fp, c, "--dist-prefill-window N", "Max prefill chunks in flight. Default: workers+2, capped at 8."); opt(fp, c, "--dist-activation-bits N", "Hidden-state transport width: 32, 16, or 8. Default: 32"); diff --git a/tests/ds4_distributed_test_internal.h b/tests/ds4_distributed_test_internal.h new file mode 100644 index 000000000..fc69c4097 --- /dev/null +++ b/tests/ds4_distributed_test_internal.h @@ -0,0 +1,145 @@ +#ifndef DS4_DISTRIBUTED_TEST_INTERNAL_H +#define DS4_DISTRIBUTED_TEST_INTERNAL_H + +/* Test-only view of selected distributed internals. + * + * Why this exists: + * - The distributed topology/route/KV-owner tests need to exercise the real + * planner and validator logic from ds4_distributed.c. + * - Those code paths operate on private runtime structs such as + * ds4_dist_coordinator_state, ds4_dist_worker_entry, ds4_dist_route_plan, + * and ds4_dist_session. + * - We intentionally do not expose those internals, or richer test-adapter + * APIs, through ds4_distributed.h because that would turn test scaffolding + * into public production surface area. + * + * This header therefore provides a narrow test-only bridge: + * - enough struct definitions for tests/ds4_test.c to assemble synthetic + * coordinator/worker/session state + * - thin declarations for internal logic entry points that remain implemented + * in production code + * + * It should only be included by unit tests. Production code should include + * ds4_distributed.h instead. + */ + +#include "../ds4_distributed.h" + +#include +#include +#include +#include +#include + +typedef struct ds4_dist_worker_entry { + int fd; + char peer_host[NI_MAXHOST]; + char peer_port[NI_MAXSERV]; + char model_name[128]; + uint32_t model_id; + uint32_t quant_bits; + uint32_t layer_start; + uint32_t layer_end; + uint32_t has_output; + uint32_t has_hidden; + uint32_t ctx_size; + uint32_t n_layers; + uint32_t listen_port; + struct ds4_dist_worker_entry *next; +} ds4_dist_worker_entry; + +typedef enum { + DS4_DIST_TOPOLOGY_FORWARD = 0, + DS4_DIST_TOPOLOGY_REVERSE = 1, +} ds4_dist_topology; + +typedef struct { + ds4_engine *engine; + uint32_t model_id; + uint32_t n_layers; + uint32_t local_start; + uint32_t local_end; + uint32_t ctx_size; + ds4_dist_topology topology; + bool local_has_output; + bool local_can_output_head; + bool replay_check; + bool debug; + bool use_control_for_work; + uint32_t prefill_chunk; + uint32_t prefill_window; + uint32_t activation_bits; + uint64_t generation; + pthread_mutex_t mu; + ds4_dist_worker_entry *workers; + bool shutting_down; +} ds4_dist_coordinator_state; + +typedef struct { + char host[NI_MAXHOST]; + uint32_t port; + uint32_t layer_start; + uint32_t layer_end; + uint32_t flags; + int fd; +} ds4_dist_route_entry; + +typedef struct { + ds4_dist_route_entry *entry; + uint32_t count; + void *blob; + uint32_t blob_bytes; +} ds4_dist_route_plan; + +typedef struct { + uint32_t layer_start; + uint32_t layer_end; + const ds4_dist_route_entry *entry; + bool is_local; +} ds4_dist_kv_route_owner; + +struct ds4_dist_session { + ds4_dist_coordinator_state state; + int listen_fd; + pthread_t accept_tid; + bool accept_started; + struct { + ds4_dist_coordinator_state *state; + int listen_fd; + } accept_ctx; + ds4_dist_route_plan plan; + bool plan_ready; + uint64_t plan_generation; + uint64_t session_id; + uint64_t request_id; + uint64_t snapshot_request_id; +}; + +#define DS4_DIST_ROUTE_F_OUTPUT_LOGITS 0x00000001u + +int ds4_dist_test_validate_coordinator_layers( + const ds4_dist_options *opt, + uint32_t n_layers, + char *err, + size_t errlen); +int ds4_dist_test_infer_coordinator_topology( + const ds4_dist_options *opt, + uint32_t n_layers, + int *topology_out, + char *err, + size_t errlen); +bool ds4_dist_test_build_route_plan( + ds4_dist_coordinator_state *state, + ds4_dist_route_plan *plan, + char *err, + size_t errlen); +void ds4_dist_test_route_plan_free(ds4_dist_route_plan *plan); +uint32_t ds4_dist_test_kv_route_owner_count(const ds4_dist_session *d); +int ds4_dist_test_kv_route_build_owners( + const ds4_dist_session *d, + ds4_dist_kv_route_owner *owners, + uint32_t owner_count, + char *err, + size_t errlen); + +#endif diff --git a/tests/ds4_test.c b/tests/ds4_test.c index ea1e52487..ecc231cd8 100644 --- a/tests/ds4_test.c +++ b/tests/ds4_test.c @@ -1,6 +1,7 @@ #define DS4_SERVER_TEST #define DS4_SERVER_TEST_NO_MAIN #include "../ds4_server.c" +#include "ds4_distributed_test_internal.h" #ifndef DS4_NO_GPU #include "../ds4_gpu.h" #include @@ -2176,6 +2177,237 @@ static void test_mtp_verify_depth(void) { } #endif +static ds4_dist_options test_dist_options( + ds4_distributed_role role, + uint32_t start, + uint32_t end, + bool has_output) { + ds4_dist_options opt; + memset(&opt, 0, sizeof(opt)); + opt.role = role; + opt.layers.start = start; + opt.layers.end = end; + opt.layers.has_output = has_output; + opt.layers.set = true; + if (role == DS4_DISTRIBUTED_COORDINATOR) { + opt.listen_host = "127.0.0.1"; + opt.listen_port = 1234; + } else if (role == DS4_DISTRIBUTED_WORKER) { + opt.coordinator_host = "127.0.0.1"; + opt.coordinator_port = 1234; + } + return opt; +} + +static void test_dist_free_workers(ds4_dist_worker_entry *workers) { + while (workers) { + ds4_dist_worker_entry *next = workers->next; + free(workers); + workers = next; + } +} + +static ds4_dist_worker_entry *test_dist_worker( + const char *host, + uint32_t port, + uint32_t quant_bits, + uint32_t layer_start, + uint32_t layer_end, + bool has_output, + bool has_hidden, + uint32_t ctx_size) { + ds4_dist_worker_entry *w = calloc(1, sizeof(*w)); + TEST_ASSERT(w != NULL); + if (!w) return NULL; + w->fd = -1; + snprintf(w->peer_host, sizeof(w->peer_host), "%s", + host && host[0] ? host : "127.0.0.1"); + snprintf(w->peer_port, sizeof(w->peer_port), "%u", port); + w->model_id = 1; + w->quant_bits = quant_bits; + w->layer_start = layer_start; + w->layer_end = layer_end; + w->has_output = has_output; + w->has_hidden = has_hidden; + w->ctx_size = ctx_size; + w->n_layers = 43; + w->listen_port = port; + return w; +} + +static void test_distributed_topology_logic_group(void) { + char err[256]; + int topology = -1; + ds4_engine_options engine_opt; + + ds4_dist_options forward = test_dist_options(DS4_DISTRIBUTED_COORDINATOR, 0, 19, false); + TEST_ASSERT(ds4_dist_test_validate_coordinator_layers(&forward, 43, err, sizeof(err)) == 0); + TEST_ASSERT(ds4_dist_test_infer_coordinator_topology(&forward, 43, &topology, err, sizeof(err)) == 0); + TEST_ASSERT(topology == 0); + memset(&engine_opt, 0, sizeof(engine_opt)); + TEST_ASSERT(ds4_dist_prepare_engine_options(&forward, &engine_opt, err, sizeof(err)) == 0); + TEST_ASSERT(engine_opt.load_slice); + TEST_ASSERT(engine_opt.load_layer_start == 0u); + TEST_ASSERT(engine_opt.load_layer_end == 19u); + TEST_ASSERT(!engine_opt.load_output); + + ds4_dist_options reverse = test_dist_options(DS4_DISTRIBUTED_COORDINATOR, 20, 0, true); + TEST_ASSERT(ds4_dist_test_validate_coordinator_layers(&reverse, 43, err, sizeof(err)) == 0); + TEST_ASSERT(ds4_dist_test_infer_coordinator_topology(&reverse, 43, &topology, err, sizeof(err)) == 0); + TEST_ASSERT(topology == 1); + memset(&engine_opt, 0, sizeof(engine_opt)); + TEST_ASSERT(ds4_dist_prepare_engine_options(&reverse, &engine_opt, err, sizeof(err)) == 0); + TEST_ASSERT(engine_opt.load_slice); + TEST_ASSERT(engine_opt.load_layer_start == 20u); + TEST_ASSERT(engine_opt.load_layer_end == UINT32_MAX); + TEST_ASSERT(engine_opt.load_output); + + reverse.local_decode = true; + memset(&engine_opt, 0, sizeof(engine_opt)); + TEST_ASSERT(ds4_dist_prepare_engine_options(&reverse, &engine_opt, err, sizeof(err)) == 0); + TEST_ASSERT(!engine_opt.load_slice); + TEST_ASSERT(!engine_opt.load_output); + + ds4_dist_options reverse_partial = test_dist_options(DS4_DISTRIBUTED_COORDINATOR, 20, 42, false); + TEST_ASSERT(ds4_dist_test_validate_coordinator_layers(&reverse_partial, 43, err, sizeof(err)) != 0); + + ds4_dist_options middle = test_dist_options(DS4_DISTRIBUTED_COORDINATOR, 5, 10, false); + TEST_ASSERT(ds4_dist_test_validate_coordinator_layers(&middle, 43, err, sizeof(err)) != 0); + + ds4_dist_options worker = test_dist_options(DS4_DISTRIBUTED_WORKER, 0, 19, false); + memset(&engine_opt, 0, sizeof(engine_opt)); + TEST_ASSERT(ds4_dist_prepare_engine_options(&worker, &engine_opt, err, sizeof(err)) == 0); + TEST_ASSERT(engine_opt.load_slice); + TEST_ASSERT(engine_opt.load_layer_start == 0u); + TEST_ASSERT(engine_opt.load_layer_end == 19u); + + worker.local_decode = true; + memset(&engine_opt, 0, sizeof(engine_opt)); + TEST_ASSERT(ds4_dist_prepare_engine_options(&worker, &engine_opt, err, sizeof(err)) != 0); + + ds4_dist_coordinator_state state; + memset(&state, 0, sizeof(state)); + state.n_layers = 43; + state.local_start = 0; + state.local_end = 19; + state.topology = DS4_DIST_TOPOLOGY_FORWARD; + state.local_can_output_head = true; + state.use_control_for_work = false; + pthread_mutex_init(&state.mu, NULL); + state.workers = test_dist_worker("127.0.0.1", 20001, 2, 20, 42, true, true, 1024); + ds4_dist_route_plan route = {0}; + TEST_ASSERT(ds4_dist_test_build_route_plan(&state, &route, err, sizeof(err))); + TEST_ASSERT(route.count == 1); + TEST_ASSERT(route.entry[0].layer_start == 20); + TEST_ASSERT(route.entry[0].layer_end == 42); + TEST_ASSERT((route.entry[0].flags & DS4_DIST_ROUTE_F_OUTPUT_LOGITS) != 0); + ds4_dist_test_route_plan_free(&route); + test_dist_free_workers(state.workers); + pthread_mutex_destroy(&state.mu); + + memset(&state, 0, sizeof(state)); + state.n_layers = 43; + state.local_start = 20; + state.local_end = 42; + state.topology = DS4_DIST_TOPOLOGY_REVERSE; + state.local_has_output = true; + state.local_can_output_head = true; + state.use_control_for_work = false; + pthread_mutex_init(&state.mu, NULL); + state.workers = test_dist_worker("127.0.0.1", 20011, 2, 0, 19, false, true, 1024); + memset(&route, 0, sizeof(route)); + TEST_ASSERT(ds4_dist_test_build_route_plan(&state, &route, err, sizeof(err))); + TEST_ASSERT(route.count == 1); + TEST_ASSERT(route.entry[0].layer_start == 0); + TEST_ASSERT(route.entry[0].layer_end == 19); + TEST_ASSERT((route.entry[0].flags & DS4_DIST_ROUTE_F_OUTPUT_LOGITS) == 0); + + ds4_dist_session session; + memset(&session, 0, sizeof(session)); + session.state.n_layers = 43; + session.state.local_start = 20; + session.state.local_end = 42; + session.plan = route; + ds4_dist_kv_route_owner owners[4]; + uint32_t owner_count = ds4_dist_test_kv_route_owner_count(&session); + TEST_ASSERT(owner_count == 2); + TEST_ASSERT(ds4_dist_test_kv_route_build_owners(&session, + owners, + owner_count, + err, + sizeof(err)) == 0); + TEST_ASSERT(owner_count == 2); + TEST_ASSERT(!owners[0].is_local); + TEST_ASSERT(owners[0].layer_start == 0); + TEST_ASSERT(owners[0].layer_end == 19); + TEST_ASSERT(owners[1].is_local); + TEST_ASSERT(owners[1].layer_start == 20); + TEST_ASSERT(owners[1].layer_end == 42); + route.entry[0].layer_end = 18; + TEST_ASSERT(ds4_dist_test_kv_route_build_owners(&session, + owners, + owner_count, + err, + sizeof(err)) != 0); + route.entry[0].layer_end = 19; + ds4_dist_test_route_plan_free(&route); + test_dist_free_workers(state.workers); + pthread_mutex_destroy(&state.mu); + + memset(&state, 0, sizeof(state)); + state.n_layers = 43; + state.local_start = 20; + state.local_end = 42; + state.topology = DS4_DIST_TOPOLOGY_REVERSE; + state.local_has_output = true; + state.local_can_output_head = true; + state.use_control_for_work = false; + pthread_mutex_init(&state.mu, NULL); + state.workers = test_dist_worker("127.0.0.1", 20012, 2, 1, 19, false, true, 1024); + memset(&route, 0, sizeof(route)); + TEST_ASSERT(!ds4_dist_test_build_route_plan(&state, &route, err, sizeof(err))); + TEST_ASSERT(strstr(err, "missing layer 0") != NULL); + test_dist_free_workers(state.workers); + pthread_mutex_destroy(&state.mu); + + memset(&state, 0, sizeof(state)); + state.n_layers = 43; + state.local_start = 21; + state.local_end = 42; + state.topology = DS4_DIST_TOPOLOGY_REVERSE; + state.local_has_output = true; + state.local_can_output_head = true; + state.use_control_for_work = false; + pthread_mutex_init(&state.mu, NULL); + state.workers = test_dist_worker("127.0.0.1", 20013, 2, 0, 21, false, true, 1024); + memset(&route, 0, sizeof(route)); + err[0] = '\0'; + TEST_ASSERT(!ds4_dist_test_build_route_plan(&state, &route, err, sizeof(err))); + TEST_ASSERT(strstr(err, "overlap coordinator local range 21:output") != NULL); + TEST_ASSERT(strstr(err, "layers 21:21") != NULL); + test_dist_free_workers(state.workers); + pthread_mutex_destroy(&state.mu); + + memset(&state, 0, sizeof(state)); + state.n_layers = 43; + state.local_start = 21; + state.local_end = 42; + state.topology = DS4_DIST_TOPOLOGY_REVERSE; + state.local_has_output = true; + state.local_can_output_head = true; + state.use_control_for_work = false; + pthread_mutex_init(&state.mu, NULL); + state.workers = test_dist_worker("127.0.0.1", 20014, 2, 0, 10, false, true, 1024); + state.workers->next = test_dist_worker("127.0.0.1", 20015, 2, 10, 20, false, true, 1024); + memset(&route, 0, sizeof(route)); + err[0] = '\0'; + TEST_ASSERT(!ds4_dist_test_build_route_plan(&state, &route, err, sizeof(err))); + TEST_ASSERT(strstr(err, "overlap worker") != NULL); + TEST_ASSERT(strstr(err, "layers 10:10") != NULL); + test_dist_free_workers(state.workers); + pthread_mutex_destroy(&state.mu); +} + static void test_server_unit_group(void) { ds4_server_unit_tests_run(); } @@ -2203,6 +2435,7 @@ static const ds4_test_entry test_entries[] = { {"--streaming-decode-prefill-correctness", "streaming-decode-prefill-correctness", "streaming decode-style cold prefill drift and repeatability", test_streaming_decode_prefill_correctness}, {"--mtp-verify-depth", "mtp-verify-depth", "MTP speculative verify commits autoregressive-identical tokens at draft depth > 2", test_mtp_verify_depth}, #endif + {"--distributed-topology-logic", "distributed-topology-logic", "distributed coordinator topology, route-planning, and KV-owner logic", test_distributed_topology_logic_group}, {"--server", "server", "server parser/rendering/cache unit tests", test_server_unit_group}, };