diff --git a/tools/server/server-task.cpp b/tools/server/server-task.cpp index 33de2e4d9ca..a2e0b326aa7 100644 --- a/tools/server/server-task.cpp +++ b/tools/server/server-task.cpp @@ -2090,8 +2090,20 @@ bool server_prompt_cache::load(server_prompt & prompt, const server_tokens & tok if (it_best != states.end()) { SRV_INF(" - found better prompt with f_keep = %.3f, sim = %.3f\n", f_keep_best, sim_best); + // Restore the cached state into the live context WITHOUT consuming the + // cache entry. The previous behaviour moved the entry out of `states` + // and freed its serialized bytes on every match. That is wrong for an + // alternating multi-conversation workload: when the incoming prompt + // only partially matches the entry (a different conversation that + // shares a long prefix), update_slots() may still decide it cannot + // reuse the restored KV and reset to a full re-prefill — but by then + // the entry has already been erased, so when the original conversation + // comes back its state is gone and it pays a full re-prefill too. + // Keeping the entry lets both conversations hit the cache on their + // respective turns. Entries are still bounded by update()'s size/token + // eviction; this only removes the on-match deletion. { - auto & data = it_best->data.main; + const auto & data = it_best->data.main; const size_t size = data.size(); const size_t n = llama_state_seq_set_data_ext(ctx_tgt, data.data(), size, id_slot, 0); @@ -2100,13 +2112,10 @@ bool server_prompt_cache::load(server_prompt & prompt, const server_tokens & tok return false; } - - data.clear(); - data.shrink_to_fit(); } { - auto & data = it_best->data.drft; + const auto & data = it_best->data.drft; if (!data.empty()) { GGML_ASSERT(ctx_dft); @@ -2118,15 +2127,17 @@ bool server_prompt_cache::load(server_prompt & prompt, const server_tokens & tok return false; } - - data.clear(); - data.shrink_to_fit(); } } - prompt = std::move(*it_best); - - states.erase(it_best); + // Copy tokens + checkpoints into the slot's prompt, leaving the cached + // entry intact. Do NOT copy data.main / data.drft into prompt.data: the + // KV bytes are now live in ctx_tgt/ctx_dft, and prompt_save() asserts + // prompt.data.size() == 0 before it serializes a fresh save from the + // live context. server_tokens has a deleted copy ctor, hence clone(). + prompt.tokens = it_best->tokens.clone(); + prompt.data = {}; + prompt.checkpoints = it_best->checkpoints; } return true; diff --git a/tools/server/tests/unit/test_prompt_cache_reuse.py b/tools/server/tests/unit/test_prompt_cache_reuse.py new file mode 100644 index 00000000000..d3890fc628c --- /dev/null +++ b/tools/server/tests/unit/test_prompt_cache_reuse.py @@ -0,0 +1,131 @@ +import os +import tempfile +import pytest +from utils import * + +server = ServerPreset.tinyllama2() + + +class LogReader: + def __init__(self, path): + self.path = path + self.pos = 0 + + def drain(self): + with open(self.path, errors="ignore") as f: + f.seek(self.pos) + content = f.read() + self.pos = f.tell() + return content + + +@pytest.fixture(autouse=True) +def create_server(): + global server + server = ServerPreset.tinyllama2() + # single slot: a conversation that is not currently loaded lives ONLY + # in the prompt cache, which is what exposes the consume-on-load bug. + server.n_slots = 1 + server.n_predict = 4 + server.temperature = 0.0 + server.server_slots = True + server.cache_ram = 100 + # force every slot reuse through the prompt cache (save + load) instead + # of the in-place LCP-similarity reuse, so load() is actually exercised. + server.slot_prompt_similarity = 0 + # isolate the cache save/load to get_available_slot (the idle-slot + # clearing path would add another save and muddy the scenario). + server.no_cache_idle_slots = True + server.debug = True + fd, server.log_path = tempfile.mkstemp(suffix='.log') + os.close(fd) + yield + + +# A and C share a long common prefix (so a request for C matches A's cached +# entry, f_keep >= 0.25), then diverge. B is unrelated to both. +COMMON_AC = ( + "Once upon a time in a quiet village by the sea there lived an old " + "fisherman who every morning rowed his small wooden boat out past the " + "harbour wall to cast his nets beneath the pale light of the rising sun." +) +CONV_A = COMMON_AC + ( + " On this particular day he caught a silver fish that spoke to him and " + "promised three wishes in exchange for its freedom and a safe return home." +) +CONV_C = COMMON_AC + ( + " But the storm clouds gathered quickly that afternoon and the waves grew " + "tall and angry as the wind tore the sails and scattered the frightened gulls." +) +CONV_B = ( + "In a bustling city far inland a young clockmaker tinkered late into the " + "night with brass gears and tiny springs trying to build a machine that " + "could measure not the hours but the quiet weight of a person's memories." +) +# A continuation of A (strict superset). Used for A's return so that +# n_past < task tokens and we avoid the identical-prompt path that, on +# SWA / hybrid / recurrent models, cannot partially remove the final +# token and would reset regardless of caching. +CONV_A_CONT = CONV_A + ( + " The fisherman closed his eyes and made his first wish very carefully." +) + + +def _total_prompt_tokens(res): + t = res.body["timings"] + return t["prompt_n"] + t["cache_n"] + + +# A prompt-cache entry must survive being matched by a DIFFERENT conversation. +# Regression test for load() consuming (erasing) the matched entry: with one +# slot and three conversations, conversation A lives only in the cache while +# conversation C — which shares A's prefix — is loaded. C's load must not +# destroy A's entry, otherwise A pays a full re-prefill when it returns. +def test_cache_entry_survives_cross_conversation_load(): + global server + server.start() + log = LogReader(server.log_path) + + # 1) Conversation A, cold. Capture its full token length. + res_a1 = server.make_request("POST", "/completion", data={ + "prompt": CONV_A, + "cache_prompt": True, + }) + assert res_a1.status_code == 200 + assert res_a1.body["timings"]["cache_n"] == 0 # nothing cached yet + n_tokens_a = _total_prompt_tokens(res_a1) + + # 2) Conversation B (unrelated). Selecting the slot saves A into the + # cache; B does not match A, so A is parked in the cache untouched. + res_b = server.make_request("POST", "/completion", data={ + "prompt": CONV_B, + "cache_prompt": True, + }) + assert res_b.status_code == 200 + assert "updating prompt cache" in log.drain() + + # 3) Conversation C, which shares A's long prefix. Selecting the slot + # saves B; loading C matches A's cached entry (f_keep >= 0.25). The + # buggy behaviour erased A here; the fix keeps it. + res_c = server.make_request("POST", "/completion", data={ + "prompt": CONV_C, + "cache_prompt": True, + }) + assert res_c.status_code == 200 + assert res_c.body["timings"]["cache_n"] > 0 # C reused A's shared prefix + + # 4) Conversation A returns (as a strict superset, so n_past < task + # tokens). It was only in the cache. With the fix its entry survived + # step 3, so all of A is reused and only the new continuation is + # processed. Without the fix A's entry was consumed in step 3 and + # only the prefix A shares with C can be reused. + res_a2 = server.make_request("POST", "/completion", data={ + "prompt": CONV_A_CONT, + "cache_prompt": True, + }) + assert res_a2.status_code == 200 + cache_n_a2 = res_a2.body["timings"]["cache_n"] + + # The full original A prompt is reused from cache. Under the bug this + # would only be the A/C shared prefix, which is well below n_tokens_a. + assert cache_n_a2 >= n_tokens_a - 2 diff --git a/tools/server/tests/utils.py b/tools/server/tests/utils.py index c5dba1c139f..2c4a2416a43 100644 --- a/tools/server/tests/utils.py +++ b/tools/server/tests/utils.py @@ -106,6 +106,7 @@ class ServerProcess: sleep_idle_seconds: int | None = None cache_ram: int | None = None no_cache_idle_slots: bool = False + slot_prompt_similarity: float | None = None log_path: str | None = None webui_mcp_proxy: bool = False backend_sampling: bool = False @@ -251,6 +252,8 @@ def start(self, timeout_seconds: int = DEFAULT_HTTP_TIMEOUT) -> None: server_args.extend(["--cache-ram", self.cache_ram]) if self.no_cache_idle_slots: server_args.append("--no-cache-idle-slots") + if self.slot_prompt_similarity is not None: + server_args.extend(["--slot-prompt-similarity", self.slot_prompt_similarity]) if self.webui_mcp_proxy: server_args.append("--webui-mcp-proxy") if self.backend_sampling: