diff --git a/README.md b/README.md index 785695284..fddebfc60 100644 --- a/README.md +++ b/README.md @@ -385,6 +385,30 @@ slow or metered links, `--layers 20:42` is also supported: the coordinator will load the output head and compute logits locally, trading extra coordinator work for smaller per-token replies. +Reverse topology is also supported when the coordinator owns a final suffix +through the output head. In that layout workers cover the lower layers, +return hidden state upstream, and the coordinator runs the higher layers plus +output locally. For example: + +```sh +# Machine A: worker, owns lower layers. +./ds4 \ + -m gguf/DeepSeek-V4-Pro-Q4K-Layers00-30.gguf \ + --role worker \ + --layers 0:30 \ + --coordinator 169.254.43.68 1234 + +# Machine B: coordinator, owns the upper suffix and output. +./ds4 \ + -m gguf/DeepSeek-V4-Pro-Q4K-Layers-31-output.gguf \ + --role coordinator \ + --layers 31:output \ + --listen 169.254.43.68 1234 +``` + +Reverse `K:42` is intentionally unsupported. Reverse mode only supports +`K:output`, because the coordinator must own the output head. + ### Network Link Comparison The table below shows the same two M5 Max hosts, the same 91 GB Flash quant, @@ -468,21 +492,26 @@ control TCP connection open to the coordinator and send a `HELLO` with their model ID, model family, quant profile, layer slice, context capacity, and data port. The coordinator uses these registrations to build a route that covers all layers. Work then moves over low-latency TCP data connections: the coordinator -computes the first slice, sends a `WORK` frame with session ID, token positions, -rolling token-prefix hashes before and after the span, route information, and -hidden-state payload, and each worker computes its slice. Middle workers can -forward directly to the next worker. The final worker returns logits to the -coordinator, or ACKs for non-final prefill chunks so the prefill pipeline can -stay full. `RESULT` frames echo the request ID and the post-span hash. A worker -status error is handled differently from a socket failure: KV/hash mismatch can -be recovered by replaying the token history on the same route, while transport -failure drops the route and waits for a replacement worker. For persistent KV, -the coordinator opens worker data connections and sends snapshot save/load -messages for each worker-owned layer range; the disk payload remains a single -agent/server cache file. The protocol has no -encryption or authentication, and is not release-stable yet; coordinator and -workers should be built from the same commit and used on trusted machines and -trusted networks. +computes the local prefix first in forward topology, sends a `WORK` frame with +session ID, token positions, rolling token-prefix hashes before and after the +span, route information, and hidden-state payload, and each worker computes its +slice. In reverse topology the first worker starts from layer 0 with token +input only, returns hidden state upstream, and the coordinator finishes the +higher layers plus output locally. Middle workers can forward directly to the +next worker. The final worker returns logits in the usual forward path, or +returns hidden state when the coordinator owns the output path. Forward +non-final prefill chunks may use ACK-only replies so the prefill pipeline can +stay full; reverse pipelined prefill returns hidden state for every chunk +because the coordinator must finish each chunk locally. `RESULT` frames echo +the request ID and the post-span hash. A worker status error is handled +differently from a socket failure: KV/hash mismatch can be recovered by +replaying the token history on the same route, while transport failure drops +the route and waits for a replacement worker. For persistent KV, the +coordinator opens worker data connections and sends snapshot save/load messages +for each worker-owned layer range; the disk payload remains a single +agent/server cache file. The protocol has no encryption or authentication, and +is not release-stable yet; coordinator and workers should be built from the +same commit and used on trusted machines and trusted networks. ## Reducing heat, power usage and fan noise diff --git a/ds4_distributed.c b/ds4_distributed.c index d31c8e2a6..f0b9ad3ce 100644 --- a/ds4_distributed.c +++ b/ds4_distributed.c @@ -222,6 +222,11 @@ typedef struct ds4_dist_worker_entry { struct ds4_dist_worker_entry *next; } ds4_dist_worker_entry; +typedef enum { + DS4_DIST_TOPOLOGY_FORWARD = 0, + DS4_DIST_TOPOLOGY_REVERSE = 1, +} ds4_dist_topology; + typedef struct { ds4_engine *engine; uint32_t model_id; @@ -229,6 +234,7 @@ typedef struct { uint32_t local_start; uint32_t local_end; uint32_t ctx_size; + ds4_dist_topology topology; bool local_has_output; bool local_can_output_head; bool replay_check; @@ -383,6 +389,13 @@ typedef struct { uint64_t tensor_bytes; } ds4_dist_kv_shard_file; +typedef struct { + uint32_t layer_start; + uint32_t layer_end; + const ds4_dist_route_entry *entry; + bool is_local; +} ds4_dist_kv_route_owner; + struct ds4_dist_session { ds4_dist_coordinator_state state; int listen_fd; @@ -405,8 +418,11 @@ typedef struct { typedef struct { ds4_dist_coordinator_state *state; + ds4_session *session; + const ds4_tokens *prompt; int fd; ds4_session *progress_session; + float *logits; uint64_t first_request_id; uint64_t *expected_hashes; uint32_t count; @@ -418,10 +434,13 @@ typedef struct { bool progress_done; uint64_t hc_values; bool allow_hidden; + bool reset_first_chunk; + bool reverse_apply_local_suffix; uint32_t final_kind; void *final_payload; uint32_t final_payload_bytes; int rc; + double local_eval_sec; char err[256]; pthread_mutex_t progress_mu; pthread_cond_t progress_cv; @@ -525,12 +544,63 @@ static int dist_coordinator_prefill_prompt( char *err, size_t errlen); static int dist_validate_options(const ds4_dist_options *opt, char *err, size_t errlen); +static int dist_validate_layers_for_model(const ds4_dist_options *opt, uint32_t n_layers, char *err, size_t errlen); static uint32_t dist_resolved_layer_end(const ds4_dist_options *opt, uint32_t n_layers) { if (opt->layers.has_output) return n_layers - 1u; return opt->layers.end; } +static const char *dist_topology_name(ds4_dist_topology topology) { + switch (topology) { + case DS4_DIST_TOPOLOGY_FORWARD: return "forward"; + case DS4_DIST_TOPOLOGY_REVERSE: return "reverse"; + } + return "unknown"; +} + +static int dist_infer_coordinator_topology( + const ds4_dist_options *opt, + uint32_t n_layers, + ds4_dist_topology *out, + char *err, + size_t errlen) { + if (!opt || opt->role != DS4_DISTRIBUTED_COORDINATOR || !opt->layers.set) return 0; + if (n_layers == 0) { + if (errlen) snprintf(err, errlen, "model reports no layers"); + return 1; + } + + const uint32_t local_end = dist_resolved_layer_end(opt, n_layers); + if (opt->layers.start == 0u) { + if (out) *out = DS4_DIST_TOPOLOGY_FORWARD; + return 0; + } + if (opt->layers.has_output && local_end + 1u == n_layers) { + if (out) *out = DS4_DIST_TOPOLOGY_REVERSE; + return 0; + } + if (!opt->layers.has_output && local_end + 1u == n_layers) { + if (errlen) { + snprintf(err, + errlen, + "coordinator reverse suffix %u:%u is unsupported; use %u:output", + opt->layers.start, + local_end, + opt->layers.start); + } + return 1; + } + if (errlen) { + snprintf(err, + errlen, + "coordinator middle-only layer range %u:%u is unsupported; use 0:K or K:output", + opt->layers.start, + local_end); + } + return 1; +} + static const char *dist_role_name(ds4_distributed_role role) { switch (role) { case DS4_DISTRIBUTED_NONE: return "none"; @@ -1848,9 +1918,13 @@ static bool dist_coordinator_debug_enabled(const ds4_dist_coordinator_state *sta * Coordinator Worker Registry And Route Planning * ========================================================================= * - * A route is a contiguous chain that starts after the coordinator's local - * slice. The last hop can either return logits directly or return the final - * hidden state so the coordinator can run its local output head. + * A route is a contiguous worker chain covering the coordinator's missing + * side of the model: + * - forward topology covers layers after the coordinator slice + * - reverse topology covers layers before the coordinator slice + * + * The wire format stays the same in both cases: ordered worker entries plus + * a final upstream return target. */ static void dist_coordinator_add_worker( @@ -1941,13 +2015,50 @@ static int dist_worker_route_cmp(const void *a, const void *b) { return 0; } +static bool dist_coordinator_route_span( + const ds4_dist_coordinator_state *state, + uint32_t *start, + uint32_t *end) { + if (!state || state->n_layers == 0) return false; + switch (state->topology) { + case DS4_DIST_TOPOLOGY_FORWARD: + if (state->local_end + 1u >= state->n_layers) return false; + if (start) *start = state->local_end + 1u; + if (end) *end = state->n_layers - 1u; + return true; + case DS4_DIST_TOPOLOGY_REVERSE: + if (state->local_start == 0u) return false; + if (start) *start = 0u; + if (end) *end = state->local_start - 1u; + return true; + } + return false; +} + +static bool dist_coordinator_route_final_worker_may_output_logits( + const ds4_dist_coordinator_state *state) { + return state && state->topology == DS4_DIST_TOPOLOGY_FORWARD; +} + static bool dist_worker_route_candidate_ok( const ds4_dist_coordinator_state *state, const ds4_dist_worker_entry *w, - uint32_t last) { - const bool needs_hidden = w->layer_end < last || !w->has_output; + uint32_t required_end, + bool final_worker_may_output_logits) { + if (!state || !w) return false; + if (w->layer_end > required_end) return false; + + const bool is_final = w->layer_end == required_end; + const bool final_outputs_logits = is_final && + final_worker_may_output_logits && + w->has_output && + required_end + 1u == state->n_layers; + const bool needs_hidden = !is_final || !final_outputs_logits; if (needs_hidden && !w->has_hidden) return false; - if (w->layer_end >= last && !w->has_output && !state->local_can_output_head) return false; + if (is_final && + required_end + 1u == state->n_layers && + !final_outputs_logits && + !state->local_can_output_head) return false; return true; } @@ -1956,7 +2067,8 @@ static bool dist_route_search_workers( ds4_dist_worker_entry **workers, uint32_t n, uint32_t next, - uint32_t last, + uint32_t required_end, + bool final_worker_may_output_logits, ds4_dist_worker_entry **path, uint32_t *path_len, uint32_t *missing_layer) { @@ -1966,16 +2078,17 @@ static bool dist_route_search_workers( if (w->layer_start < next) continue; if (w->layer_start > next) break; saw_start = true; - if (!dist_worker_route_candidate_ok(state, w, last)) continue; + if (!dist_worker_route_candidate_ok(state, w, required_end, final_worker_may_output_logits)) continue; path[(*path_len)++] = w; - if (w->layer_end >= last) return true; + if (w->layer_end == required_end) return true; uint32_t child_missing = w->layer_end + 1u; if (dist_route_search_workers(state, workers, n, child_missing, - last, + required_end, + final_worker_may_output_logits, path, path_len, &child_missing)) { @@ -1988,6 +2101,108 @@ static bool dist_route_search_workers( return false; } +static void dist_format_range_end( + uint32_t layer_end, + bool has_output, + char *buf, + size_t buflen) { + if (!buf || buflen == 0) return; + if (has_output) snprintf(buf, buflen, "output"); + else snprintf(buf, buflen, "%u", layer_end); +} + +static bool dist_route_diag_local_overlap( + const ds4_dist_coordinator_state *state, + const ds4_dist_worker_entry *w, + char *err, + size_t errlen) { + if (!state || !w) return false; + if (w->layer_end < state->local_start || w->layer_start > state->local_end) return false; + + uint32_t overlap_start = w->layer_start > state->local_start ? w->layer_start : state->local_start; + uint32_t overlap_end = w->layer_end < state->local_end ? w->layer_end : state->local_end; + char worker_end[32]; + char local_end[32]; + dist_format_range_end(w->layer_end, w->has_output != 0, worker_end, sizeof(worker_end)); + dist_format_range_end(state->local_end, state->local_has_output, local_end, sizeof(local_end)); + if (errlen) { + snprintf(err, + errlen, + "distributed %s route invalid: worker %s:%u layers=%u:%s overlap coordinator local range %u:%s at layers %u:%u", + dist_topology_name(state->topology), + w->peer_host, + w->listen_port, + w->layer_start, + worker_end, + state->local_start, + local_end, + overlap_start, + overlap_end); + } + return true; +} + +static bool dist_route_diag_worker_overlap( + const ds4_dist_coordinator_state *state, + const ds4_dist_worker_entry *prev, + const ds4_dist_worker_entry *cur, + char *err, + size_t errlen) { + if (!state || !prev || !cur) return false; + if (cur->layer_start > prev->layer_end) return false; + + uint32_t overlap_start = cur->layer_start; + uint32_t overlap_end = prev->layer_end < cur->layer_end ? prev->layer_end : cur->layer_end; + char prev_end[32]; + char cur_end[32]; + dist_format_range_end(prev->layer_end, prev->has_output != 0, prev_end, sizeof(prev_end)); + dist_format_range_end(cur->layer_end, cur->has_output != 0, cur_end, sizeof(cur_end)); + if (errlen) { + snprintf(err, + errlen, + "distributed %s route invalid: worker %s:%u layers=%u:%s overlap worker %s:%u layers=%u:%s at layers %u:%u", + dist_topology_name(state->topology), + prev->peer_host, + prev->listen_port, + prev->layer_start, + prev_end, + cur->peer_host, + cur->listen_port, + cur->layer_start, + cur_end, + overlap_start, + overlap_end); + } + return true; +} + +static void dist_route_diagnose_failure( + const ds4_dist_coordinator_state *state, + ds4_dist_worker_entry **workers, + uint32_t n, + uint32_t required_start, + uint32_t required_end, + uint32_t missing_layer, + char *err, + size_t errlen) { + if (!state || !err || errlen == 0) return; + + ds4_dist_worker_entry *prev = NULL; + for (uint32_t i = 0; i < n; i++) { + ds4_dist_worker_entry *w = workers[i]; + if (w->layer_end < required_start || w->layer_start > required_end) continue; + if (dist_route_diag_local_overlap(state, w, err, errlen)) return; + if (prev && dist_route_diag_worker_overlap(state, prev, w, err, errlen)) return; + prev = w; + } + + snprintf(err, + errlen, + "distributed %s route incomplete: missing layer %u", + dist_topology_name(state->topology), + missing_layer); +} + static void dist_coordinator_report_plan(ds4_dist_coordinator_state *state) { if (!dist_coordinator_debug_enabled(state)) return; pthread_mutex_lock(&state->mu); @@ -2007,26 +2222,34 @@ static void dist_coordinator_report_plan(ds4_dist_coordinator_state *state) { qsort(workers, n, sizeof(workers[0]), dist_worker_route_cmp); const uint32_t last = state->n_layers - 1u; - bool complete = state->local_start == 0; - bool has_output = state->local_end == last && - (state->local_has_output || state->local_can_output_head); - uint32_t next = state->local_end + 1u; - if (state->local_end >= last) next = state->n_layers; + const char *topology_name = dist_topology_name(state->topology); + uint32_t required_start = 0, required_end = 0; + const bool needs_remote = + dist_coordinator_route_span(state, &required_start, &required_end); + const bool final_worker_may_output_logits = + dist_coordinator_route_final_worker_may_output_logits(state); + bool complete = !needs_remote; + bool has_output = false; + uint32_t missing = required_start; uint32_t path_len = 0; - uint32_t missing = next; - if (complete && !has_output) { + if (!needs_remote) { + has_output = state->local_end == last && + (state->local_has_output || state->local_can_output_head); + } else { complete = dist_route_search_workers(state, workers, n, - next, - last, + required_start, + required_end, + final_worker_may_output_logits, path, &path_len, &missing); - if (complete && path_len != 0) { + if (complete && state->topology == DS4_DIST_TOPOLOGY_FORWARD && path_len != 0) { ds4_dist_worker_entry *final = path[path_len - 1u]; has_output = final->has_output || state->local_can_output_head; - next = state->n_layers; + } else if (complete && state->topology == DS4_DIST_TOPOLOGY_REVERSE) { + has_output = state->local_has_output; } } @@ -2035,44 +2258,76 @@ static void dist_coordinator_report_plan(ds4_dist_coordinator_state *state) { char local_end[32]; if (state->local_has_output) snprintf(local_end, sizeof(local_end), "output"); else snprintf(local_end, sizeof(local_end), "%u", state->local_end); - used += (size_t)snprintf(plan + used, used < sizeof(plan) ? sizeof(plan) - used : 0, - "local %u:%s", - state->local_start, - local_end); - for (i = 0; i < path_len; i++) { - ds4_dist_worker_entry *w = path[i]; + if (state->topology == DS4_DIST_TOPOLOGY_FORWARD) { + used += (size_t)snprintf(plan + used, used < sizeof(plan) ? sizeof(plan) - used : 0, + "local %u:%s", + state->local_start, + local_end); + for (i = 0; i < path_len; i++) { + ds4_dist_worker_entry *w = path[i]; + if (used < sizeof(plan)) { + char end[32]; + if (w->has_output) snprintf(end, sizeof(end), "output"); + else snprintf(end, sizeof(end), "%u", w->layer_end); + used += (size_t)snprintf(plan + used, sizeof(plan) - used, + " -> %s:%u Q%u %u:%s", + w->peer_host, + w->listen_port, + w->quant_bits, + w->layer_start, + end); + } + } + if (complete && path_len != 0 && + !path[path_len - 1u]->has_output && state->local_can_output_head && + used < sizeof(plan)) { + used += (size_t)snprintf(plan + used, sizeof(plan) - used, + " -> local output"); + } + if (complete && path_len == 0 && + state->local_end == last && !state->local_has_output && + state->local_can_output_head && used < sizeof(plan)) { + used += (size_t)snprintf(plan + used, sizeof(plan) - used, + " -> local output"); + } + } else { + for (i = 0; i < path_len; i++) { + ds4_dist_worker_entry *w = path[i]; + if (used < sizeof(plan)) { + char end[32]; + if (w->has_output) snprintf(end, sizeof(end), "output"); + else snprintf(end, sizeof(end), "%u", w->layer_end); + used += (size_t)snprintf(plan + used, sizeof(plan) - used, + "%s%s:%u Q%u %u:%s", + i == 0 ? "" : " -> ", + w->peer_host, + w->listen_port, + w->quant_bits, + w->layer_start, + end); + } + } if (used < sizeof(plan)) { - char end[32]; - if (w->has_output) snprintf(end, sizeof(end), "output"); - else snprintf(end, sizeof(end), "%u", w->layer_end); used += (size_t)snprintf(plan + used, sizeof(plan) - used, - " -> %s:%u Q%u %u:%s", - w->peer_host, - w->listen_port, - w->quant_bits, - w->layer_start, - end); - } - } - if (complete && path_len != 0 && - !path[path_len - 1u]->has_output && state->local_can_output_head && - used < sizeof(plan)) { - used += (size_t)snprintf(plan + used, sizeof(plan) - used, - " -> local output"); - } - if (complete && path_len == 0 && - state->local_end == last && !state->local_has_output && - state->local_can_output_head && used < sizeof(plan)) { - used += (size_t)snprintf(plan + used, sizeof(plan) - used, - " -> local output"); - } - complete = complete && has_output && next == state->n_layers; + "%slocal %u:%s", + path_len != 0 ? " -> " : "", + state->local_start, + local_end); + } + } + complete = complete && has_output; pthread_mutex_unlock(&state->mu); if (complete) { - fprintf(stderr, "ds4: distributed coordinator: complete route ready: %s\n", plan); + fprintf(stderr, + "ds4: distributed coordinator: complete %s route ready: %s\n", + topology_name, + plan); } else { - fprintf(stderr, "ds4: distributed coordinator: route incomplete; next needed layer %u\n", missing); + fprintf(stderr, + "ds4: distributed coordinator: %s route incomplete; missing layer %u\n", + topology_name, + missing); } free(path); free(workers); @@ -2224,37 +2479,51 @@ static bool dist_coordinator_build_route_plan( qsort(workers, n, sizeof(workers[0]), dist_worker_route_cmp); const uint32_t last = state->n_layers - 1u; - if (state->local_start != 0) { + uint32_t required_start = 0, required_end = 0; + const bool needs_remote = + dist_coordinator_route_span(state, &required_start, &required_end); + const bool final_worker_may_output_logits = + dist_coordinator_route_final_worker_may_output_logits(state); + if (!needs_remote && + state->local_end == last && + (state->local_has_output || state->local_can_output_head)) { + if (generation) *generation = state->generation; pthread_mutex_unlock(&state->mu); free(workers); free(path); - if (errlen) snprintf(err, errlen, "coordinator route does not start at layer 0"); - return false; + return true; } - if (state->local_end == last && - (state->local_has_output || state->local_can_output_head)) { - if (generation) *generation = state->generation; + if (!needs_remote) { + const char *topology_name = dist_topology_name(state->topology); pthread_mutex_unlock(&state->mu); free(workers); free(path); - return true; + if (errlen) snprintf(err, errlen, "distributed %s route has no output owner", topology_name); + return false; } - uint32_t next = state->local_end + 1u; uint32_t path_len = 0; - uint32_t missing = next; + uint32_t missing = required_start; if (!dist_route_search_workers(state, workers, n, - next, - last, + required_start, + required_end, + final_worker_may_output_logits, path, &path_len, &missing)) { pthread_mutex_unlock(&state->mu); + dist_route_diagnose_failure(state, + workers, + n, + required_start, + required_end, + missing, + err, + errlen); free(workers); free(path); - if (errlen) snprintf(err, errlen, "distributed route incomplete: missing layer %u", missing); return false; } @@ -2267,7 +2536,11 @@ static bool dist_coordinator_build_route_plan( entry.port = w->listen_port; entry.layer_start = w->layer_start; entry.layer_end = w->layer_end; - entry.flags = w->has_output ? DS4_DIST_ROUTE_F_OUTPUT_LOGITS : 0u; + entry.flags = (i + 1u == path_len && + final_worker_may_output_logits && + w->has_output) + ? DS4_DIST_ROUTE_F_OUTPUT_LOGITS + : 0u; if (state->use_control_for_work && plan->count == 0) { entry.fd = dup(w->fd); if (entry.fd < 0) { @@ -2524,14 +2797,20 @@ static int dist_coordinator_send_remote_work_on_fd( work.n_tokens = n_tokens; work.layer_start = first->layer_start; work.layer_end = first->layer_end; - work.flags = DS4_DIST_WORK_F_INPUT_HC; + const bool input_hc_present = hidden_hc != NULL || hidden_hc_bytes != 0u; + if (input_hc_present && (!hidden_hc || hidden_hc_bytes == 0u)) { + if (errlen) snprintf(err, errlen, "distributed input hidden-state metadata mismatch"); + return 1; + } + work.flags = input_hc_present ? DS4_DIST_WORK_F_INPUT_HC : 0u; if (reset_session) work.flags |= DS4_DIST_WORK_F_RESET_SESSION; if (ack_only) work.flags |= DS4_DIST_WORK_F_ACK_ONLY; if ((first->flags & DS4_DIST_ROUTE_F_OUTPUT_LOGITS) != 0) { work.flags |= DS4_DIST_WORK_F_OUTPUT_LOGITS; } uint32_t wire_hidden_hc_bytes = 0; - if (!dist_activation_wire_bytes_from_f32_bytes(state->activation_bits, + if (input_hc_present && + !dist_activation_wire_bytes_from_f32_bytes(state->activation_bits, hidden_hc_bytes, &wire_hidden_hc_bytes)) { if (errlen) snprintf(err, errlen, "invalid distributed hidden-state size"); @@ -2539,7 +2818,7 @@ static int dist_coordinator_send_remote_work_on_fd( } work.token_bytes = n_tokens * sizeof(uint32_t); work.input_hc_bytes = wire_hidden_hc_bytes; - work.input_hc_bits = state->activation_bits; + work.input_hc_bits = input_hc_present ? state->activation_bits : 0u; work.route_count = plan->count; work.route_index = 0; work.route_bytes = plan->blob_bytes; @@ -2551,9 +2830,8 @@ static int dist_coordinator_send_remote_work_on_fd( return 0; } -static int dist_coordinator_eval_remote_on_fd( +static int dist_coordinator_request_remote_on_fd( ds4_dist_coordinator_state *state, - ds4_session *session, const ds4_dist_route_plan *plan, int fd, const int *tokens, @@ -2566,7 +2844,9 @@ static int dist_coordinator_eval_remote_on_fd( bool reset_session, const float *hidden_hc, uint32_t hidden_hc_bytes, - float *logits, + uint32_t *kind, + void **payload, + uint32_t *payload_bytes, char *err, size_t errlen) { const bool profile = dist_decode_profile_enabled() && n_tokens == 1; @@ -2589,18 +2869,18 @@ static int dist_coordinator_eval_remote_on_fd( err, errlen); const double send_t1 = profile ? dist_now_sec() : 0.0; - uint32_t kind = 0, payload_bytes = 0; + uint32_t result_kind = 0, result_payload_bytes = 0; uint64_t result_hash = 0; - void *payload = NULL; + void *result_payload = NULL; if (rc == 0) { const double recv_t0 = profile ? dist_now_sec() : 0.0; rc = dist_recv_result_alloc(fd, state, request_id, - &kind, + &result_kind, &result_hash, - &payload, - &payload_bytes, + &result_payload, + &result_payload_bytes, err, errlen); const double recv_t1 = profile ? dist_now_sec() : 0.0; @@ -2611,16 +2891,71 @@ static int dist_coordinator_eval_remote_on_fd( pos0, (send_t1 - send_t0) * 1000.0, (recv_t1 - recv_t0) * 1000.0, - kind, - (double)payload_bytes / (1024.0 * 1024.0)); + result_kind, + (double)result_payload_bytes / (1024.0 * 1024.0)); } } if (rc != 0) return rc; if (result_hash != expected_result_hash) { - free(payload); + free(result_payload); if (errlen) snprintf(err, errlen, "distributed result prefix hash mismatch"); return 1; } + if (kind) *kind = result_kind; + if (payload) *payload = result_payload; + else free(result_payload); + if (payload_bytes) *payload_bytes = result_payload_bytes; + if (profile) { + const double total_t1 = dist_now_sec(); + fprintf(stderr, + "ds4: dist decode profile: remote request=%llu total=%.3fms\n", + (unsigned long long)request_id, + (total_t1 - total_t0) * 1000.0); + } + return 0; +} + +static int dist_coordinator_eval_remote_on_fd( + ds4_dist_coordinator_state *state, + ds4_session *session, + const ds4_dist_route_plan *plan, + int fd, + const int *tokens, + uint32_t n_tokens, + uint32_t pos0, + uint64_t session_id, + uint64_t request_id, + uint64_t prefix_hash, + uint64_t expected_result_hash, + bool reset_session, + const float *hidden_hc, + uint32_t hidden_hc_bytes, + float *logits, + char *err, + size_t errlen) { + const bool profile = dist_decode_profile_enabled() && n_tokens == 1; + const double total_t0 = profile ? dist_now_sec() : 0.0; + uint32_t kind = 0, payload_bytes = 0; + void *payload = NULL; + int rc = dist_coordinator_request_remote_on_fd(state, + plan, + fd, + tokens, + n_tokens, + pos0, + session_id, + request_id, + prefix_hash, + expected_result_hash, + reset_session, + hidden_hc, + hidden_hc_bytes, + &kind, + &payload, + &payload_bytes, + err, + errlen); + if (rc != 0) return rc; const uint32_t logits_bytes = (uint32_t)((uint64_t)ds4_engine_vocab_size(state->engine) * sizeof(float)); if (kind == DS4_DIST_RESULT_LOGITS && payload_bytes == logits_bytes) { @@ -2667,6 +3002,46 @@ static int dist_coordinator_eval_remote_on_fd( return 1; } +static int dist_coordinator_eval_local_suffix( + ds4_dist_coordinator_state *state, + ds4_session *session, + const int *tokens, + uint32_t n_tokens, + uint32_t pos0, + bool reset_session, + const float *input_hc, + uint32_t input_hc_bytes, + float *logits, + char *err, + size_t errlen) { + const uint64_t hc_values = ds4_engine_hidden_f32_values(state->engine); + const uint64_t expected_bytes64 = (uint64_t)n_tokens * hc_values * sizeof(float); + if (expected_bytes64 > UINT32_MAX) { + if (errlen) snprintf(err, errlen, "distributed coordinator hidden-state chunk is too large"); + return 1; + } + if (!input_hc || input_hc_bytes != (uint32_t)expected_bytes64) { + if (errlen) snprintf(err, errlen, "distributed route returned invalid hidden-state size"); + return 1; + } + if (reset_session && + ds4_session_layer_slice_reset(session, err, errlen) != 0) { + return 1; + } + return ds4_session_eval_layer_slice(session, + tokens, + n_tokens, + pos0, + state->local_start, + state->local_end, + input_hc, + NULL, + true, + logits, + err, + errlen); +} + static int dist_coordinator_eval_span( ds4_dist_coordinator_state *state, ds4_session *session, @@ -2703,34 +3078,86 @@ static int dist_coordinator_eval_span( } const uint64_t result_hash = dist_token_hash_update_span(prefix_hash, tokens, n_tokens); const uint32_t hidden_bytes = (uint32_t)hidden_bytes64; - float *hidden = NULL; - if (plan->count != 0) { - hidden = malloc(hidden_bytes); - if (!hidden) { - if (errlen) snprintf(err, errlen, "out of memory allocating coordinator hidden-state"); - return 1; - } - } - if (reset_session && - ds4_session_layer_slice_reset(session, err, errlen) != 0) { - free(hidden); - return 1; - } - - const bool local_logits = plan->count == 0; int remote_fd = -1; if (plan->count != 0) { const ds4_dist_route_entry *first = &plan->entry[0]; remote_fd = first->fd; if (remote_fd < 0) { if (errlen) snprintf(err, errlen, "distributed route has no live first-hop connection"); - free(hidden); return 1; } } - const double local_t0 = profile ? dist_now_sec() : 0.0; - int rc = ds4_session_eval_layer_slice(session, + int rc = 0; + double local_t0 = 0.0, local_t1 = 0.0; + double remote_t0 = 0.0, remote_t1 = 0.0; + if (state->topology == DS4_DIST_TOPOLOGY_REVERSE) { + if (plan->count == 0) { + if (errlen) snprintf(err, errlen, "reverse distributed route has no remote prefix worker"); + return 1; + } + uint32_t kind = 0, payload_bytes = 0; + void *payload = NULL; + remote_t0 = profile ? dist_now_sec() : 0.0; + rc = dist_coordinator_request_remote_on_fd(state, + plan, + remote_fd, + tokens, + n_tokens, + pos0, + session_id, + request_id, + prefix_hash, + result_hash, + reset_session, + NULL, + 0u, + &kind, + &payload, + &payload_bytes, + err, + errlen); + remote_t1 = profile ? dist_now_sec() : 0.0; + if (rc == 0) { + if (kind != DS4_DIST_RESULT_HIDDEN_STATE) { + free(payload); + if (errlen) snprintf(err, errlen, "reverse distributed decode requires final hidden-state"); + rc = 1; + } else { + local_t0 = profile ? dist_now_sec() : 0.0; + rc = dist_coordinator_eval_local_suffix(state, + session, + tokens, + n_tokens, + pos0, + reset_session, + payload, + payload_bytes, + logits, + err, + errlen); + local_t1 = profile ? dist_now_sec() : 0.0; + free(payload); + } + } + } else { + float *hidden = NULL; + if (plan->count != 0) { + hidden = malloc(hidden_bytes); + if (!hidden) { + if (errlen) snprintf(err, errlen, "out of memory allocating coordinator hidden-state"); + return 1; + } + } + if (reset_session && + ds4_session_layer_slice_reset(session, err, errlen) != 0) { + free(hidden); + return 1; + } + + const bool local_logits = plan->count == 0; + local_t0 = profile ? dist_now_sec() : 0.0; + rc = ds4_session_eval_layer_slice(session, tokens, n_tokens, pos0, @@ -2742,28 +3169,29 @@ static int dist_coordinator_eval_span( local_logits ? logits : NULL, err, errlen); - const double local_t1 = profile ? dist_now_sec() : 0.0; - double remote_t0 = 0.0, remote_t1 = 0.0; - if (rc == 0 && plan->count != 0) { - remote_t0 = profile ? dist_now_sec() : 0.0; - rc = dist_coordinator_eval_remote_on_fd(state, - session, - plan, - remote_fd, - tokens, - n_tokens, - pos0, - session_id, - request_id, - prefix_hash, - result_hash, - reset_session, - hidden, - hidden_bytes, - logits, - err, - errlen); - remote_t1 = profile ? dist_now_sec() : 0.0; + local_t1 = profile ? dist_now_sec() : 0.0; + if (rc == 0 && plan->count != 0) { + remote_t0 = profile ? dist_now_sec() : 0.0; + rc = dist_coordinator_eval_remote_on_fd(state, + session, + plan, + remote_fd, + tokens, + n_tokens, + pos0, + session_id, + request_id, + prefix_hash, + result_hash, + reset_session, + hidden, + hidden_bytes, + logits, + err, + errlen); + remote_t1 = profile ? dist_now_sec() : 0.0; + } + free(hidden); } if (profile) { const double span_t1 = dist_now_sec(); @@ -2779,7 +3207,6 @@ static int dist_coordinator_eval_span( (double)hidden_bytes / (1024.0 * 1024.0), rc); } - free(hidden); return rc; } @@ -3231,7 +3658,7 @@ static void *dist_prefill_sender_main(void *arg) { slot->result_hash, slot->reset_session, slot->ack_only, - slot->hidden, + slot->hidden_bytes ? slot->hidden : NULL, slot->hidden_bytes, send_err, sizeof(send_err)); @@ -3347,6 +3774,7 @@ static void *dist_prefill_result_reader_main(void *arg) { reader->final_kind = 0; reader->final_payload = NULL; reader->final_payload_bytes = 0; + reader->local_eval_sec = 0.0; const uint32_t logits_bytes = (uint32_t)((uint64_t)ds4_engine_vocab_size(reader->state->engine) * sizeof(float)); @@ -3387,10 +3815,13 @@ static void *dist_prefill_result_reader_main(void *arg) { const uint32_t chunk = remaining < reader->chunk_cap ? remaining : reader->chunk_cap; const uint64_t hidden_bytes64 = (uint64_t)chunk * reader->hc_values * sizeof(float); const bool final_chunk = i + 1u == reader->count; - const bool valid_ack = !final_chunk && + const bool valid_ack = !reader->reverse_apply_local_suffix && + !final_chunk && kind == DS4_DIST_RESULT_ACK && payload_bytes == 0; - const bool valid_logits = kind == DS4_DIST_RESULT_LOGITS && payload_bytes == logits_bytes; + const bool valid_logits = !reader->reverse_apply_local_suffix && + kind == DS4_DIST_RESULT_LOGITS && + payload_bytes == logits_bytes; const bool valid_hidden = reader->allow_hidden && hidden_bytes64 <= UINT32_MAX && kind == DS4_DIST_RESULT_HIDDEN_STATE && @@ -3405,7 +3836,31 @@ static void *dist_prefill_result_reader_main(void *arg) { dist_prefill_reader_signal_progress(reader, i, true); return NULL; } - if (final_chunk) { + if (reader->reverse_apply_local_suffix) { + const uint32_t chunk_pos = reader->progress_base + pos0; + const bool reset_session = reader->reset_first_chunk && i == 0; + const double local_t0 = dist_now_sec(); + int local_rc = dist_coordinator_eval_local_suffix(reader->state, + reader->session, + reader->prompt->v + chunk_pos, + chunk, + chunk_pos, + reset_session, + payload, + payload_bytes, + reader->logits, + reader->err, + sizeof(reader->err)); + const double local_t1 = dist_now_sec(); + reader->local_eval_sec += local_t1 - local_t0; + if (local_rc != 0) { + reader->rc = local_rc; + free(payload); + shutdown(reader->fd, SHUT_RDWR); + dist_prefill_reader_signal_progress(reader, i, true); + return NULL; + } + } else if (final_chunk) { reader->final_kind = kind; reader->final_payload = payload; reader->final_payload_bytes = payload_bytes; @@ -3430,6 +3885,9 @@ static bool dist_coordinator_can_pipeline_prefill( if (chunk_cap == 0 || n_tokens <= chunk_cap) return false; if (plan->count == 0) return false; if (plan->entry[0].fd < 0) return false; + if (state->topology == DS4_DIST_TOPOLOGY_REVERSE) { + return state->local_can_output_head; + } const ds4_dist_route_entry *final = &plan->entry[plan->count - 1u]; if ((final->flags & DS4_DIST_ROUTE_F_OUTPUT_LOGITS) == 0) { return final->layer_end + 1u == state->n_layers && @@ -3569,8 +4027,11 @@ static int dist_coordinator_prefill_prompt_pipelined( ds4_dist_prefill_result_reader reader; memset(&reader, 0, sizeof(reader)); reader.state = state; + reader.session = session; + reader.prompt = prompt; reader.fd = plan->entry[0].fd; reader.progress_session = session; + reader.logits = logits; reader.first_request_id = *request_id; reader.count = chunk_count; reader.total_tokens = total; @@ -3578,7 +4039,9 @@ static int dist_coordinator_prefill_prompt_pipelined( reader.progress_base = span_start; reader.progress_total = (uint32_t)prompt->len; reader.hc_values = hc_values; - reader.allow_hidden = + reader.reverse_apply_local_suffix = state->topology == DS4_DIST_TOPOLOGY_REVERSE; + reader.reset_first_chunk = reset_first_chunk; + reader.allow_hidden = reader.reverse_apply_local_suffix || (plan->entry[plan->count - 1u].flags & DS4_DIST_ROUTE_F_OUTPUT_LOGITS) == 0; pthread_mutex_init(&reader.progress_mu, NULL); pthread_cond_init(&reader.progress_cv, NULL); @@ -3634,6 +4097,7 @@ static int dist_coordinator_prefill_prompt_pipelined( flow_window); int rc = 0; + const bool reverse_topology = state->topology == DS4_DIST_TOPOLOGY_REVERSE; double local_eval_sec = 0.0; const double pipeline_t0 = dist_now_sec(); uint32_t pos = span_start; @@ -3659,37 +4123,40 @@ static int dist_coordinator_prefill_prompt_pipelined( rc = 1; break; } - if (pos == span_start && - reset_first_chunk && - ds4_session_layer_slice_reset(session, err, errlen) != 0) { - rc = 1; - break; + if (!reverse_topology) { + if (pos == span_start && + reset_first_chunk && + ds4_session_layer_slice_reset(session, err, errlen) != 0) { + rc = 1; + break; + } + const double local_t0 = dist_now_sec(); + rc = ds4_session_eval_layer_slice(session, + prompt->v + pos, + chunk, + pos, + state->local_start, + state->local_end, + NULL, + slot->hidden, + false, + NULL, + err, + errlen); + const double local_t1 = dist_now_sec(); + local_eval_sec += local_t1 - local_t0; + if (rc != 0) break; } - const double local_t0 = dist_now_sec(); - rc = ds4_session_eval_layer_slice(session, - prompt->v + pos, - chunk, - pos, - state->local_start, - state->local_end, - NULL, - slot->hidden, - false, - NULL, - err, - errlen); - const double local_t1 = dist_now_sec(); - local_eval_sec += local_t1 - local_t0; - if (rc != 0) break; slot->pos = pos; slot->n_tokens = chunk; - slot->hidden_bytes = hidden_bytes; + slot->hidden_bytes = reverse_topology ? 0u : hidden_bytes; slot->request_id = *request_id; slot->prefix_hash = next_prefix_hash; slot->result_hash = reader.expected_hashes[submitted_chunks]; slot->reset_session = reset_first_chunk && pos == span_start; - slot->ack_only = !getenv("DS4_DIST_DISABLE_PREFILL_ACK_ONLY") && + slot->ack_only = !reverse_topology && + !getenv("DS4_DIST_DISABLE_PREFILL_ACK_ONLY") && pos + chunk < span_end; rc = dist_prefill_sender_enqueue_slot(&sender, err, errlen); if (rc != 0) break; @@ -3722,12 +4189,13 @@ static int dist_coordinator_prefill_prompt_pipelined( if (rc == 0 && reader.rc == 0) { const double total_sec = pipeline_t1 - pipeline_t0; DIST_COORD_DEBUG(state, - "ds4: distributed coordinator: pipelined prefill done tokens=%u chunks=%u total=%.3fs %.2f t/s local=%.3fs send=%.3fs %.2f MiB/s\n", + "ds4: distributed coordinator: pipelined prefill done topology=%s tokens=%u chunks=%u total=%.3fs %.2f t/s local=%.3fs send=%.3fs %.2f MiB/s\n", + dist_topology_name(state->topology), total, chunk_count, total_sec, total_sec > 0.0 ? (double)total / total_sec : 0.0, - local_eval_sec, + local_eval_sec + reader.local_eval_sec, sender.send_sec, sender.send_sec > 0.0 ? ((double)sender.send_bytes / (1024.0 * 1024.0)) / sender.send_sec @@ -3751,6 +4219,10 @@ static int dist_coordinator_prefill_prompt_pipelined( free(reader.final_payload); return 1; } + if (reader.reverse_apply_local_suffix) { + free(reader.final_payload); + return 0; + } const uint32_t logits_bytes = (uint32_t)((uint64_t)ds4_engine_vocab_size(state->engine) * sizeof(float)); if (reader.final_kind == DS4_DIST_RESULT_LOGITS && @@ -4853,51 +5325,67 @@ static int dist_kv_write_layer_header( return 0; } -static uint32_t dist_kv_route_shard_count(const ds4_dist_session *d) { +static uint32_t dist_kv_route_owner_count(const ds4_dist_session *d) { return d ? 1u + d->plan.count : 0u; } -static void dist_kv_route_shard( - const ds4_dist_session *d, - uint32_t shard, - uint32_t *layer_start, - uint32_t *layer_end, - const ds4_dist_route_entry **entry) { - if (entry) *entry = NULL; - if (shard == 0) { - if (layer_start) *layer_start = d->state.local_start; - if (layer_end) *layer_end = d->state.local_end; - return; - } - const ds4_dist_route_entry *e = &d->plan.entry[shard - 1u]; - if (layer_start) *layer_start = e->layer_start; - if (layer_end) *layer_end = e->layer_end; - if (entry) *entry = e; +static int dist_kv_route_owner_cmp(const void *ap, const void *bp) { + const ds4_dist_kv_route_owner *a = ap; + const ds4_dist_kv_route_owner *b = bp; + if (a->layer_start < b->layer_start) return -1; + if (a->layer_start > b->layer_start) return 1; + if (a->layer_end < b->layer_end) return -1; + if (a->layer_end > b->layer_end) return 1; + if (a->is_local != b->is_local) return a->is_local ? -1 : 1; + return 0; } -static int dist_kv_route_validate( +static int dist_kv_route_build_owners( const ds4_dist_session *d, + ds4_dist_kv_route_owner *owners, + uint32_t owner_count, char *err, size_t errlen) { - if (!d || d->state.n_layers == 0 || - d->state.local_start != 0 || - d->state.local_end >= d->state.n_layers) { - if (errlen) snprintf(err, errlen, "distributed KV route does not start at layer 0"); + if (!d || !owners || d->state.n_layers == 0 || + d->state.local_start > d->state.local_end || + d->state.local_end >= d->state.n_layers || + owner_count != dist_kv_route_owner_count(d)) { + if (errlen) snprintf(err, errlen, "invalid distributed KV route"); return 1; } - uint32_t prev = d->state.local_end; + + owners[0].layer_start = d->state.local_start; + owners[0].layer_end = d->state.local_end; + owners[0].entry = NULL; + owners[0].is_local = true; for (uint32_t i = 0; i < d->plan.count; i++) { - const ds4_dist_route_entry *e = &d->plan.entry[i]; - if (prev == UINT32_MAX || - e->layer_start != prev + 1u || - e->layer_end < e->layer_start || - e->layer_end >= d->state.n_layers) { + owners[i + 1u].layer_start = d->plan.entry[i].layer_start; + owners[i + 1u].layer_end = d->plan.entry[i].layer_end; + owners[i + 1u].entry = &d->plan.entry[i]; + owners[i + 1u].is_local = false; + } + qsort(owners, owner_count, sizeof(owners[0]), dist_kv_route_owner_cmp); + + uint32_t prev_end = UINT32_MAX; + for (uint32_t i = 0; i < owner_count; i++) { + const ds4_dist_kv_route_owner *owner = &owners[i]; + if (owner->layer_end < owner->layer_start || + owner->layer_end >= d->state.n_layers) { + if (errlen) snprintf(err, errlen, "distributed KV route has invalid layer range"); + return 1; + } + if (i == 0) { + if (owner->layer_start != 0) { + if (errlen) snprintf(err, errlen, "distributed KV route does not cover layer 0"); + return 1; + } + } else if (owner->layer_start != prev_end + 1u) { if (errlen) snprintf(err, errlen, "distributed KV route is not contiguous"); return 1; } - prev = e->layer_end; + prev_end = owner->layer_end; } - if (prev + 1u != d->state.n_layers) { + if (prev_end + 1u != d->state.n_layers) { if (errlen) snprintf(err, errlen, "distributed KV route does not cover all layers"); return 1; } @@ -5087,10 +5575,20 @@ int ds4_dist_session_save_payload( return 1; } if (dist_session_ensure_route(d, err, errlen) != 0) return 1; - if (dist_kv_route_validate(d, err, errlen) != 0) return 1; + const uint32_t shard_count = dist_kv_route_owner_count(d); + ds4_dist_kv_route_owner *owners = calloc(shard_count, sizeof(owners[0])); + if (!owners) { + if (errlen) snprintf(err, errlen, "out of memory saving distributed KV route"); + return 1; + } + if (dist_kv_route_build_owners(d, owners, shard_count, err, errlen) != 0) { + free(owners); + return 1; + } const ds4_tokens *tokens = ds4_session_tokens(owner); if (!tokens || tokens->len < 0 || (uint64_t)tokens->len > UINT32_MAX) { + free(owners); if (errlen) snprintf(err, errlen, "distributed session has no valid token timeline"); return 1; } @@ -5099,20 +5597,22 @@ int ds4_dist_session_save_payload( const uint32_t vocab = (uint32_t)ds4_engine_vocab_size(d->state.engine); float *logits = malloc((size_t)vocab * sizeof(logits[0])); if (!logits) { + free(owners); if (errlen) snprintf(err, errlen, "out of memory saving distributed logits"); return 1; } if (ds4_session_copy_logits(owner, logits, (int)vocab) != (int)vocab) { + free(owners); free(logits); if (errlen) snprintf(err, errlen, "failed to copy distributed logits"); return 1; } - const uint32_t shard_count = dist_kv_route_shard_count(d); ds4_dist_kv_shard_file *shards = calloc(shard_count, sizeof(shards[0])); uint32_t *n_comp = calloc(d->state.n_layers, sizeof(n_comp[0])); uint32_t *n_index_comp = calloc(d->state.n_layers, sizeof(n_index_comp[0])); if (!shards || !n_comp || !n_index_comp) { + free(owners); free(logits); free(shards); free(n_comp); @@ -5126,12 +5626,12 @@ int ds4_dist_session_save_payload( bool layout_set = false; for (uint32_t shard = 0; shard < shard_count; shard++) { - uint32_t layer_start = 0, layer_end = 0; - const ds4_dist_route_entry *entry = NULL; - dist_kv_route_shard(d, shard, &layer_start, &layer_end, &entry); + const ds4_dist_kv_route_owner *owner_desc = &owners[shard]; + const uint32_t layer_start = owner_desc->layer_start; + const uint32_t layer_end = owner_desc->layer_end; shards[shard].fp = dist_tmpfile_or_err("distributed KV shard", err, errlen); if (!shards[shard].fp) goto cleanup; - if (shard == 0) { + if (owner_desc->is_local) { if (ds4_session_save_layer_payload(owner, shards[shard].fp, layer_start, layer_end, err, errlen) != 0) @@ -5140,7 +5640,7 @@ int ds4_dist_session_save_payload( "distributed local KV shard", err, errlen) != 0) goto cleanup; } else { - if (dist_save_remote_shard_to_file(d, entry, tokens, token_hash, + if (dist_save_remote_shard_to_file(d, owner_desc->entry, tokens, token_hash, shards[shard].fp, &shards[shard].bytes, err, errlen) != 0) @@ -5202,6 +5702,7 @@ int ds4_dist_session_save_payload( cleanup: dist_kv_shards_close(shards, shard_count); + free(owners); free(shards); free(n_comp); free(n_index_comp); @@ -5221,16 +5722,28 @@ int ds4_dist_session_load_payload( return 1; } if (dist_session_ensure_route(d, err, errlen) != 0) return 1; - if (dist_kv_route_validate(d, err, errlen) != 0) return 1; + const uint32_t shard_count = dist_kv_route_owner_count(d); + ds4_dist_kv_route_owner *owners = calloc(shard_count, sizeof(owners[0])); + if (!owners) { + if (errlen) snprintf(err, errlen, "out of memory loading distributed KV route"); + return 1; + } + if (dist_kv_route_build_owners(d, owners, shard_count, err, errlen) != 0) { + free(owners); + return 1; + } uint64_t remaining = payload_bytes; uint32_t h[DS4_SESSION_PAYLOAD_U32_FIELDS]; for (uint32_t i = 0; i < DS4_SESSION_PAYLOAD_U32_FIELDS; i++) { - if (dist_payload_read_u32(fp, &h[i], &remaining, err, errlen) != 0) + if (dist_payload_read_u32(fp, &h[i], &remaining, err, errlen) != 0) { + free(owners); return 1; + } } if (h[0] != DS4_SESSION_PAYLOAD_MAGIC || h[1] != DS4_SESSION_PAYLOAD_VERSION) { + free(owners); if (errlen) snprintf(err, errlen, "unsupported DS4 KV payload version"); return 1; } @@ -5252,6 +5765,7 @@ int ds4_dist_session_load_payload( layout.token_count >= (uint32_t)ds4_session_ctx(owner) || layout.vocab != (uint32_t)ds4_engine_vocab_size(d->state.engine) || !dist_kv_raw_live_valid(&layout)) { + free(owners); if (errlen) snprintf(err, errlen, "DS4 KV payload does not match current distributed runtime"); return 1; } @@ -5262,6 +5776,7 @@ int ds4_dist_session_load_payload( uint32_t *n_comp = calloc(layout.n_layers, sizeof(n_comp[0])); uint32_t *n_index_comp = calloc(layout.n_layers, sizeof(n_index_comp[0])); if ((layout.token_count && !tokens) || !logits || !n_comp || !n_index_comp) { + free(owners); free(tokens); free(logits); free(n_comp); @@ -5272,6 +5787,7 @@ int ds4_dist_session_load_payload( for (uint32_t i = 0; i < layout.token_count; i++) { uint32_t tok = 0; if (dist_payload_read_u32(fp, &tok, &remaining, err, errlen) != 0) { + free(owners); free(tokens); free(logits); free(n_comp); @@ -5280,6 +5796,7 @@ int ds4_dist_session_load_payload( } if (tok > (uint32_t)INT_MAX || tok >= (uint32_t)ds4_engine_vocab_size(d->state.engine)) { + free(owners); free(tokens); free(logits); free(n_comp); @@ -5295,6 +5812,7 @@ int ds4_dist_session_load_payload( if (dist_payload_read_bytes(fp, logits, (uint64_t)layout.vocab * sizeof(logits[0]), &remaining, err, errlen) != 0) { + free(owners); free(tokens); free(logits); free(n_comp); @@ -5319,36 +5837,34 @@ int ds4_dist_session_load_payload( } } - const uint32_t shard_count = dist_kv_route_shard_count(d); for (uint32_t shard = 0; shard < shard_count; shard++) { - uint32_t layer_start = 0, layer_end = 0; - const ds4_dist_route_entry *entry = NULL; + const ds4_dist_kv_route_owner *owner_desc = &owners[shard]; FILE *tmp = NULL; uint64_t shard_bytes = 0; - dist_kv_route_shard(d, shard, &layer_start, &layer_end, &entry); if (dist_prepare_shard_from_session_payload(d, fp, &remaining, &layout, n_comp, n_index_comp, - layer_start, - layer_end, + owner_desc->layer_start, + owner_desc->layer_end, &tmp, &shard_bytes, err, errlen) != 0) goto cleanup; - if (shard == 0) { + if (owner_desc->is_local) { if (ds4_session_load_layer_payload(owner, tmp, shard_bytes, tokens_arg, layout.token_count, - layer_start, layer_end, + owner_desc->layer_start, + owner_desc->layer_end, err, errlen) != 0) { fclose(tmp); goto cleanup; } } else { - if (dist_load_remote_shard_from_payload(d, entry, + if (dist_load_remote_shard_from_payload(d, owner_desc->entry, tokens_arg, layout.token_count, token_hash, tmp, shard_bytes, @@ -5370,6 +5886,7 @@ int ds4_dist_session_load_payload( rc = 0; cleanup: + free(owners); free(tokens); free(logits); free(n_comp); @@ -5405,6 +5922,10 @@ int ds4_dist_session_create( return 1; } if (dist_validate_options(opt, err, errlen) != 0) return 1; + const uint32_t n_layers = (uint32_t)ds4_engine_layer_count(engine); + if (dist_validate_layers_for_model(opt, n_layers, err, errlen) != 0) return 1; + ds4_dist_topology topology = DS4_DIST_TOPOLOGY_FORWARD; + if (dist_infer_coordinator_topology(opt, n_layers, &topology, err, errlen) != 0) return 1; int listen_fd = dist_open_listener(opt->listen_host, opt->listen_port, err, errlen); if (listen_fd < 0) return 1; @@ -5419,10 +5940,11 @@ int ds4_dist_session_create( d->listen_fd = listen_fd; d->state.engine = engine; d->state.model_id = (uint32_t)ds4_engine_model_id(engine); - d->state.n_layers = (uint32_t)ds4_engine_layer_count(engine); + d->state.n_layers = n_layers; d->state.local_start = opt->layers.start; d->state.local_end = dist_resolved_layer_end(opt, d->state.n_layers); d->state.ctx_size = ctx_size > 0 ? (uint32_t)ctx_size : 0u; + d->state.topology = topology; d->state.local_has_output = opt->layers.has_output; d->state.local_can_output_head = ds4_engine_has_output_head(engine); d->state.replay_check = opt->replay_check; @@ -5443,11 +5965,12 @@ int ds4_dist_session_create( if (opt->layers.has_output) snprintf(local_end, sizeof(local_end), "output"); else snprintf(local_end, sizeof(local_end), "%u", opt->layers.end); DIST_COORD_DEBUG(&d->state, - "ds4: distributed coordinator API: listening on %s:%d model_id=%u layers=%u local=%u:%s activation_bits=%u\n", + "ds4: distributed coordinator API: listening on %s:%d model_id=%u layers=%u topology=%s local=%u:%s activation_bits=%u\n", opt->listen_host, opt->listen_port, d->state.model_id, d->state.n_layers, + dist_topology_name(d->state.topology), opt->layers.start, local_end, d->state.activation_bits); @@ -5694,6 +6217,12 @@ int ds4_dist_session_eval( static int dist_run_coordinator(ds4_engine *engine, const ds4_dist_options *opt, const ds4_dist_generation_options *gen) { char err[256]; + const uint32_t n_layers = (uint32_t)ds4_engine_layer_count(engine); + ds4_dist_topology topology = DS4_DIST_TOPOLOGY_FORWARD; + if (dist_infer_coordinator_topology(opt, n_layers, &topology, err, sizeof(err)) != 0) { + fprintf(stderr, "ds4: distributed coordinator: %s\n", err); + return 1; + } int listen_fd = dist_open_listener(opt->listen_host, opt->listen_port, err, sizeof(err)); if (listen_fd < 0) { fprintf(stderr, "ds4: distributed coordinator: %s\n", err); @@ -5704,10 +6233,11 @@ static int dist_run_coordinator(ds4_engine *engine, const ds4_dist_options *opt, memset(&state, 0, sizeof(state)); state.engine = engine; state.model_id = (uint32_t)ds4_engine_model_id(engine); - state.n_layers = (uint32_t)ds4_engine_layer_count(engine); + state.n_layers = n_layers; state.local_start = opt->layers.start; state.local_end = dist_resolved_layer_end(opt, state.n_layers); state.ctx_size = gen && gen->ctx_size > 0 ? (uint32_t)gen->ctx_size : 0u; + state.topology = topology; state.local_has_output = opt->layers.has_output; state.local_can_output_head = ds4_engine_has_output_head(engine); state.replay_check = opt->replay_check; @@ -5722,11 +6252,12 @@ static int dist_run_coordinator(ds4_engine *engine, const ds4_dist_options *opt, if (opt->layers.has_output) snprintf(local_end, sizeof(local_end), "output"); else snprintf(local_end, sizeof(local_end), "%u", opt->layers.end); DIST_COORD_DEBUG(&state, - "ds4: distributed coordinator: listening on %s:%d model_id=%u layers=%u local=%u:%s activation_bits=%u\n", + "ds4: distributed coordinator: listening on %s:%d model_id=%u layers=%u topology=%s local=%u:%s activation_bits=%u\n", opt->listen_host, opt->listen_port, state.model_id, state.n_layers, + dist_topology_name(state.topology), opt->layers.start, local_end, state.activation_bits); @@ -7329,6 +7860,7 @@ static int dist_worker_process_work_payload( ds4_dist_route_entry current_route; ds4_dist_route_entry next_route; + ds4_dist_route_entry final_route; const bool has_route = work.route_count != 0; const bool has_next = has_route && work.route_index + 1u < work.route_count; if (has_route) { @@ -7364,6 +7896,20 @@ static int dist_worker_process_work_payload( free(tokens); return dist_worker_upstream_send_work_error(upstream, request_id, err); } + if (!dist_route_get_entry(route_blob, work.route_bytes, work.route_count, + work.route_count - 1u, &final_route, err, sizeof(err))) { + free(route_blob); + free(tokens); + return dist_worker_upstream_send_work_error(upstream, request_id, err); + } + if (!input_hc_present && + (final_route.flags & DS4_DIST_ROUTE_F_OUTPUT_LOGITS) != 0) { + free(route_blob); + free(tokens); + return dist_worker_upstream_send_work_error(upstream, + request_id, + "remote-first distributed route must return hidden-state upstream"); + } } if (has_next && output_logits) { free(route_blob); @@ -8381,13 +8927,61 @@ static int dist_validate_layers_for_model(const ds4_dist_options *opt, uint32_t if (errlen) snprintf(err, errlen, "layer range ends past final model layer %u", last); return 1; } - if (opt->role == DS4_DISTRIBUTED_COORDINATOR && opt->layers.start != 0) { - if (errlen) snprintf(err, errlen, "coordinator layer range must start at layer 0"); - return 1; + if (opt->role == DS4_DISTRIBUTED_COORDINATOR) { + return dist_infer_coordinator_topology(opt, n_layers, NULL, err, errlen); } return 0; } +int ds4_dist_test_validate_coordinator_layers( + const ds4_dist_options *opt, + uint32_t n_layers, + char *err, + size_t errlen) { + return dist_validate_layers_for_model(opt, n_layers, err, errlen); +} + +int ds4_dist_test_infer_coordinator_topology( + const ds4_dist_options *opt, + uint32_t n_layers, + int *topology_out, + char *err, + size_t errlen) { + ds4_dist_topology topology = DS4_DIST_TOPOLOGY_FORWARD; + int rc = dist_infer_coordinator_topology(opt, n_layers, &topology, err, errlen); + if (rc == 0 && topology_out) *topology_out = (int)topology; + return rc; +} + +bool ds4_dist_test_build_route_plan( + ds4_dist_coordinator_state *state, + ds4_dist_route_plan *plan, + char *err, + size_t errlen) { + if (!state || !plan) { + if (errlen) snprintf(err, errlen, "invalid distributed test route request"); + return false; + } + return dist_coordinator_build_route_plan(state, plan, NULL, err, errlen); +} + +void ds4_dist_test_route_plan_free(ds4_dist_route_plan *plan) { + dist_route_plan_free(plan); +} + +uint32_t ds4_dist_test_kv_route_owner_count(const ds4_dist_session *d) { + return dist_kv_route_owner_count(d); +} + +int ds4_dist_test_kv_route_build_owners( + const ds4_dist_session *d, + ds4_dist_kv_route_owner *owners, + uint32_t owner_count, + char *err, + size_t errlen) { + return dist_kv_route_build_owners(d, owners, owner_count, err, errlen); +} + int ds4_dist_run(ds4_engine *engine, const ds4_dist_options *opt, const ds4_dist_generation_options *gen) { if (!engine || !opt) { fprintf(stderr, "ds4: distributed runtime requires an open engine and options\n"); diff --git a/ds4_help.c b/ds4_help.c index d32e088cf..c61393129 100644 --- a/ds4_help.c +++ b/ds4_help.c @@ -214,10 +214,10 @@ static void print_steering(FILE *fp, const help_colors *c) { static void print_distributed(FILE *fp, const help_colors *c) { title(fp, c, "Distributed Inference"); fputc('\n', fp); - para(fp, c, "Distributed mode runs one logical session across several machines by assigning contiguous model layer ranges to workers. Workers own their layer slice and KV-cache shard; the coordinator owns the prompt, sampling loop, and client/API flow. Start workers first, then start the coordinator. The coordinator waits for a complete route and streams hidden states through the workers."); + para(fp, c, "Distributed mode runs one logical session across several machines by assigning contiguous model layer ranges to workers. The coordinator may own either a leading prefix (0:K) or a final suffix through output (K:output). Workers own their layer slice and KV-cache shard; the coordinator owns the prompt, sampling loop, and client/API flow. Start workers first, then start the coordinator. Reverse K:42 is unsupported."); fputc('\n', fp); opt(fp, c, "--role ROLE", "Distributed role: coordinator or worker."); - opt(fp, c, "--layers A:B", "Inclusive layer slice, e.g. 0:20 or 21:output."); + opt(fp, c, "--layers A:B", "Inclusive layer slice, e.g. 0:20, 21:42, or 21:output."); opt(fp, c, "--listen HOST PORT", "Coordinator listen address; workers may use it for their data listener."); opt(fp, c, "--coordinator HOST PORT", "Coordinator address for --role worker."); opt(fp, c, "--dist-prefill-chunk N", "Coordinator prefill pipeline chunk size. Default: session cap."); diff --git a/tests/ds4_distributed_test_internal.h b/tests/ds4_distributed_test_internal.h new file mode 100644 index 000000000..fc69c4097 --- /dev/null +++ b/tests/ds4_distributed_test_internal.h @@ -0,0 +1,145 @@ +#ifndef DS4_DISTRIBUTED_TEST_INTERNAL_H +#define DS4_DISTRIBUTED_TEST_INTERNAL_H + +/* Test-only view of selected distributed internals. + * + * Why this exists: + * - The distributed topology/route/KV-owner tests need to exercise the real + * planner and validator logic from ds4_distributed.c. + * - Those code paths operate on private runtime structs such as + * ds4_dist_coordinator_state, ds4_dist_worker_entry, ds4_dist_route_plan, + * and ds4_dist_session. + * - We intentionally do not expose those internals, or richer test-adapter + * APIs, through ds4_distributed.h because that would turn test scaffolding + * into public production surface area. + * + * This header therefore provides a narrow test-only bridge: + * - enough struct definitions for tests/ds4_test.c to assemble synthetic + * coordinator/worker/session state + * - thin declarations for internal logic entry points that remain implemented + * in production code + * + * It should only be included by unit tests. Production code should include + * ds4_distributed.h instead. + */ + +#include "../ds4_distributed.h" + +#include +#include +#include +#include +#include + +typedef struct ds4_dist_worker_entry { + int fd; + char peer_host[NI_MAXHOST]; + char peer_port[NI_MAXSERV]; + char model_name[128]; + uint32_t model_id; + uint32_t quant_bits; + uint32_t layer_start; + uint32_t layer_end; + uint32_t has_output; + uint32_t has_hidden; + uint32_t ctx_size; + uint32_t n_layers; + uint32_t listen_port; + struct ds4_dist_worker_entry *next; +} ds4_dist_worker_entry; + +typedef enum { + DS4_DIST_TOPOLOGY_FORWARD = 0, + DS4_DIST_TOPOLOGY_REVERSE = 1, +} ds4_dist_topology; + +typedef struct { + ds4_engine *engine; + uint32_t model_id; + uint32_t n_layers; + uint32_t local_start; + uint32_t local_end; + uint32_t ctx_size; + ds4_dist_topology topology; + bool local_has_output; + bool local_can_output_head; + bool replay_check; + bool debug; + bool use_control_for_work; + uint32_t prefill_chunk; + uint32_t prefill_window; + uint32_t activation_bits; + uint64_t generation; + pthread_mutex_t mu; + ds4_dist_worker_entry *workers; + bool shutting_down; +} ds4_dist_coordinator_state; + +typedef struct { + char host[NI_MAXHOST]; + uint32_t port; + uint32_t layer_start; + uint32_t layer_end; + uint32_t flags; + int fd; +} ds4_dist_route_entry; + +typedef struct { + ds4_dist_route_entry *entry; + uint32_t count; + void *blob; + uint32_t blob_bytes; +} ds4_dist_route_plan; + +typedef struct { + uint32_t layer_start; + uint32_t layer_end; + const ds4_dist_route_entry *entry; + bool is_local; +} ds4_dist_kv_route_owner; + +struct ds4_dist_session { + ds4_dist_coordinator_state state; + int listen_fd; + pthread_t accept_tid; + bool accept_started; + struct { + ds4_dist_coordinator_state *state; + int listen_fd; + } accept_ctx; + ds4_dist_route_plan plan; + bool plan_ready; + uint64_t plan_generation; + uint64_t session_id; + uint64_t request_id; + uint64_t snapshot_request_id; +}; + +#define DS4_DIST_ROUTE_F_OUTPUT_LOGITS 0x00000001u + +int ds4_dist_test_validate_coordinator_layers( + const ds4_dist_options *opt, + uint32_t n_layers, + char *err, + size_t errlen); +int ds4_dist_test_infer_coordinator_topology( + const ds4_dist_options *opt, + uint32_t n_layers, + int *topology_out, + char *err, + size_t errlen); +bool ds4_dist_test_build_route_plan( + ds4_dist_coordinator_state *state, + ds4_dist_route_plan *plan, + char *err, + size_t errlen); +void ds4_dist_test_route_plan_free(ds4_dist_route_plan *plan); +uint32_t ds4_dist_test_kv_route_owner_count(const ds4_dist_session *d); +int ds4_dist_test_kv_route_build_owners( + const ds4_dist_session *d, + ds4_dist_kv_route_owner *owners, + uint32_t owner_count, + char *err, + size_t errlen); + +#endif diff --git a/tests/ds4_test.c b/tests/ds4_test.c index ea1e52487..19f3eb78b 100644 --- a/tests/ds4_test.c +++ b/tests/ds4_test.c @@ -1,6 +1,7 @@ #define DS4_SERVER_TEST #define DS4_SERVER_TEST_NO_MAIN #include "../ds4_server.c" +#include "ds4_distributed_test_internal.h" #ifndef DS4_NO_GPU #include "../ds4_gpu.h" #include @@ -2176,6 +2177,200 @@ static void test_mtp_verify_depth(void) { } #endif +static ds4_dist_options test_dist_options( + ds4_distributed_role role, + uint32_t start, + uint32_t end, + bool has_output) { + ds4_dist_options opt; + memset(&opt, 0, sizeof(opt)); + opt.role = role; + opt.layers.start = start; + opt.layers.end = end; + opt.layers.has_output = has_output; + opt.layers.set = true; + return opt; +} + +static void test_dist_free_workers(ds4_dist_worker_entry *workers) { + while (workers) { + ds4_dist_worker_entry *next = workers->next; + free(workers); + workers = next; + } +} + +static ds4_dist_worker_entry *test_dist_worker( + const char *host, + uint32_t port, + uint32_t quant_bits, + uint32_t layer_start, + uint32_t layer_end, + bool has_output, + bool has_hidden, + uint32_t ctx_size) { + ds4_dist_worker_entry *w = calloc(1, sizeof(*w)); + TEST_ASSERT(w != NULL); + if (!w) return NULL; + w->fd = -1; + snprintf(w->peer_host, sizeof(w->peer_host), "%s", + host && host[0] ? host : "127.0.0.1"); + snprintf(w->peer_port, sizeof(w->peer_port), "%u", port); + w->model_id = 1; + w->quant_bits = quant_bits; + w->layer_start = layer_start; + w->layer_end = layer_end; + w->has_output = has_output; + w->has_hidden = has_hidden; + w->ctx_size = ctx_size; + w->n_layers = 43; + w->listen_port = port; + return w; +} + +static void test_distributed_topology_logic_group(void) { + char err[256]; + int topology = -1; + + ds4_dist_options forward = test_dist_options(DS4_DISTRIBUTED_COORDINATOR, 0, 19, false); + TEST_ASSERT(ds4_dist_test_validate_coordinator_layers(&forward, 43, err, sizeof(err)) == 0); + TEST_ASSERT(ds4_dist_test_infer_coordinator_topology(&forward, 43, &topology, err, sizeof(err)) == 0); + TEST_ASSERT(topology == 0); + + ds4_dist_options reverse = test_dist_options(DS4_DISTRIBUTED_COORDINATOR, 20, 0, true); + TEST_ASSERT(ds4_dist_test_validate_coordinator_layers(&reverse, 43, err, sizeof(err)) == 0); + TEST_ASSERT(ds4_dist_test_infer_coordinator_topology(&reverse, 43, &topology, err, sizeof(err)) == 0); + TEST_ASSERT(topology == 1); + + ds4_dist_options reverse_partial = test_dist_options(DS4_DISTRIBUTED_COORDINATOR, 20, 42, false); + TEST_ASSERT(ds4_dist_test_validate_coordinator_layers(&reverse_partial, 43, err, sizeof(err)) != 0); + + ds4_dist_options middle = test_dist_options(DS4_DISTRIBUTED_COORDINATOR, 5, 10, false); + TEST_ASSERT(ds4_dist_test_validate_coordinator_layers(&middle, 43, err, sizeof(err)) != 0); + + ds4_dist_coordinator_state state; + memset(&state, 0, sizeof(state)); + state.n_layers = 43; + state.local_start = 0; + state.local_end = 19; + state.topology = DS4_DIST_TOPOLOGY_FORWARD; + state.local_can_output_head = true; + state.use_control_for_work = false; + pthread_mutex_init(&state.mu, NULL); + state.workers = test_dist_worker("127.0.0.1", 20001, 2, 20, 42, true, true, 1024); + ds4_dist_route_plan route = {0}; + TEST_ASSERT(ds4_dist_test_build_route_plan(&state, &route, err, sizeof(err))); + TEST_ASSERT(route.count == 1); + TEST_ASSERT(route.entry[0].layer_start == 20); + TEST_ASSERT(route.entry[0].layer_end == 42); + TEST_ASSERT((route.entry[0].flags & DS4_DIST_ROUTE_F_OUTPUT_LOGITS) != 0); + ds4_dist_test_route_plan_free(&route); + test_dist_free_workers(state.workers); + pthread_mutex_destroy(&state.mu); + + memset(&state, 0, sizeof(state)); + state.n_layers = 43; + state.local_start = 20; + state.local_end = 42; + state.topology = DS4_DIST_TOPOLOGY_REVERSE; + state.local_has_output = true; + state.local_can_output_head = true; + state.use_control_for_work = false; + pthread_mutex_init(&state.mu, NULL); + state.workers = test_dist_worker("127.0.0.1", 20011, 2, 0, 19, false, true, 1024); + memset(&route, 0, sizeof(route)); + TEST_ASSERT(ds4_dist_test_build_route_plan(&state, &route, err, sizeof(err))); + TEST_ASSERT(route.count == 1); + TEST_ASSERT(route.entry[0].layer_start == 0); + TEST_ASSERT(route.entry[0].layer_end == 19); + TEST_ASSERT((route.entry[0].flags & DS4_DIST_ROUTE_F_OUTPUT_LOGITS) == 0); + + ds4_dist_session session; + memset(&session, 0, sizeof(session)); + session.state.n_layers = 43; + session.state.local_start = 20; + session.state.local_end = 42; + session.plan = route; + ds4_dist_kv_route_owner owners[4]; + uint32_t owner_count = ds4_dist_test_kv_route_owner_count(&session); + TEST_ASSERT(owner_count == 2); + TEST_ASSERT(ds4_dist_test_kv_route_build_owners(&session, + owners, + owner_count, + err, + sizeof(err)) == 0); + TEST_ASSERT(owner_count == 2); + TEST_ASSERT(!owners[0].is_local); + TEST_ASSERT(owners[0].layer_start == 0); + TEST_ASSERT(owners[0].layer_end == 19); + TEST_ASSERT(owners[1].is_local); + TEST_ASSERT(owners[1].layer_start == 20); + TEST_ASSERT(owners[1].layer_end == 42); + route.entry[0].layer_end = 18; + TEST_ASSERT(ds4_dist_test_kv_route_build_owners(&session, + owners, + owner_count, + err, + sizeof(err)) != 0); + route.entry[0].layer_end = 19; + ds4_dist_test_route_plan_free(&route); + test_dist_free_workers(state.workers); + pthread_mutex_destroy(&state.mu); + + memset(&state, 0, sizeof(state)); + state.n_layers = 43; + state.local_start = 20; + state.local_end = 42; + state.topology = DS4_DIST_TOPOLOGY_REVERSE; + state.local_has_output = true; + state.local_can_output_head = true; + state.use_control_for_work = false; + pthread_mutex_init(&state.mu, NULL); + state.workers = test_dist_worker("127.0.0.1", 20012, 2, 1, 19, false, true, 1024); + memset(&route, 0, sizeof(route)); + TEST_ASSERT(!ds4_dist_test_build_route_plan(&state, &route, err, sizeof(err))); + TEST_ASSERT(strstr(err, "missing layer 0") != NULL); + test_dist_free_workers(state.workers); + pthread_mutex_destroy(&state.mu); + + memset(&state, 0, sizeof(state)); + state.n_layers = 43; + state.local_start = 21; + state.local_end = 42; + state.topology = DS4_DIST_TOPOLOGY_REVERSE; + state.local_has_output = true; + state.local_can_output_head = true; + state.use_control_for_work = false; + pthread_mutex_init(&state.mu, NULL); + state.workers = test_dist_worker("127.0.0.1", 20013, 2, 0, 21, false, true, 1024); + memset(&route, 0, sizeof(route)); + err[0] = '\0'; + TEST_ASSERT(!ds4_dist_test_build_route_plan(&state, &route, err, sizeof(err))); + TEST_ASSERT(strstr(err, "overlap coordinator local range 21:output") != NULL); + TEST_ASSERT(strstr(err, "layers 21:21") != NULL); + test_dist_free_workers(state.workers); + pthread_mutex_destroy(&state.mu); + + memset(&state, 0, sizeof(state)); + state.n_layers = 43; + state.local_start = 21; + state.local_end = 42; + state.topology = DS4_DIST_TOPOLOGY_REVERSE; + state.local_has_output = true; + state.local_can_output_head = true; + state.use_control_for_work = false; + pthread_mutex_init(&state.mu, NULL); + state.workers = test_dist_worker("127.0.0.1", 20014, 2, 0, 10, false, true, 1024); + state.workers->next = test_dist_worker("127.0.0.1", 20015, 2, 10, 20, false, true, 1024); + memset(&route, 0, sizeof(route)); + err[0] = '\0'; + TEST_ASSERT(!ds4_dist_test_build_route_plan(&state, &route, err, sizeof(err))); + TEST_ASSERT(strstr(err, "overlap worker") != NULL); + TEST_ASSERT(strstr(err, "layers 10:10") != NULL); + test_dist_free_workers(state.workers); + pthread_mutex_destroy(&state.mu); +} + static void test_server_unit_group(void) { ds4_server_unit_tests_run(); } @@ -2203,6 +2398,7 @@ static const ds4_test_entry test_entries[] = { {"--streaming-decode-prefill-correctness", "streaming-decode-prefill-correctness", "streaming decode-style cold prefill drift and repeatability", test_streaming_decode_prefill_correctness}, {"--mtp-verify-depth", "mtp-verify-depth", "MTP speculative verify commits autoregressive-identical tokens at draft depth > 2", test_mtp_verify_depth}, #endif + {"--distributed-topology-logic", "distributed-topology-logic", "distributed coordinator topology, route-planning, and KV-owner logic", test_distributed_topology_logic_group}, {"--server", "server", "server parser/rendering/cache unit tests", test_server_unit_group}, };