Skip to content

fix: prevent KV cache corruption on SWA/ISWA models + hot-path perf#2180

Open
avion23 wants to merge 2 commits intoabetlen:mainfrom
avion23:fix/perf-and-iswa
Open

fix: prevent KV cache corruption on SWA/ISWA models + hot-path perf#2180
avion23 wants to merge 2 commits intoabetlen:mainfrom
avion23:fix/perf-and-iswa

Conversation

@avion23
Copy link
Copy Markdown

@avion23 avion23 commented Apr 12, 2026

Problem

Gemma-4 and any model with interleaved sliding window attention (ISWA) crashes on the second call to create_chat_completion:

llm = Llama(model_path="gemma4-q4.gguf", n_gpu_layers=-1)

llm.create_chat_completion(messages=[{"role": "user", "content": "What is 2+2?"}])
# → OK

llm.create_chat_completion(messages=[{"role": "user", "content": "Write a hello world"}])
# → RuntimeError: error during generation: [end of text]

Root cause

ISWA KV caches store position tracking in global maps (g_iswa_pos_max / g_iswa_pos_min) that are cleared by llama_memory_clear() but not by llama_memory_seq_rm(). The generate() method detects a prefix match between consecutive prompts (shared BOS token), calls kv_cache_seq_rm() to remove the divergent tail, sees it return True, and skips the full reset. Stale position maps cause batch allocator inconsistency → llama_decode returns -1.

Additionally, reset() was a no-op on KV cache state (only reset n_tokens), so calling llm.reset() between prompts had no effect.

A separate tokens[:-1] off-by-one in prefix matching silently broke prefix detection for all models.

Fix

  1. reset() calls llama_memory_clear() unconditionally — proper KV cache state reset
  2. generate() sets longest_prefix = 0 for recurrent/SWA models when the prefix doesn't cover all cached tokens, falling through to full reset
  3. Fixed tokens[:-1]tokens in prefix matching
  4. eval() simplified: removed dead code paths, direct 2D logits assignment

Performance

Location Before After
set_batch() Python loop: 512 tokens × 5 assignments 5 numpy bulk writes
_create_completion detokenization O(n²) per-token detokenize(completion_tokens) O(1) incremental token_to_piece()
_create_completion logit_bias np.copy(scores) every token in-place modification
_create_completion top-k logprobs sorted() on 128K vocab: O(V log V) np.argpartition: O(V)
eval() logits flattened copy + reshape direct 2D view assignment
token_to_piece() create_string_buffer(32) with trailing garbage returns exact byte length

Benchmarks (TinyLlama-1.1B, M1 Pro, 10-50 runs)

Test Result
Generation throughput ~930 tokens/sec
set_batch (200 token prompt) ~61ms
logit_bias processing ~54ms
top-k argpartition vs sorted (128K vocab) 32x faster (888ms vs 28.3s for 1000 runs)

Testing

  • ISWA: Gemma-4-E2B-it (Q4_K_M, Q8_0) — multi-turn 6 rounds ✅
  • Ministral-3: 3B — multi-turn 3+ rounds ✅; 8B, 14B — multi-turn 3+ rounds ✅ (requires n_ctx=2048)
  • Other: Mistral-7B, Llama-3-8B, Phi-4, Gemma-2-9B, Qwen3-VL-8B, Mistral-Small-24B, Nomos-1, TinyLlama-1.1B ✅ (13 models total)

Functional verification: ISWA crash fix, reset(), prefix matching, logit_bias, set_batch, eval(), token_to_piece, argpartition correctness — all pass.

@avion23 avion23 force-pushed the fix/perf-and-iswa branch 2 times, most recently from 939fa72 to 9609c82 Compare April 12, 2026 15:58
@avion23 avion23 changed the title perf: vectorize hot-path operations + fix SWA/ISWA KV cache corruption (Gemma-4) fix: prevent KV cache corruption on SWA/ISWA models + hot-path perf Apr 12, 2026
@avion23 avion23 marked this pull request as ready for review April 12, 2026 16:04
Ralf Waldukat added 2 commits April 13, 2026 12:53
… cache corruption

- set_batch(): numpy bulk writes replace per-token Python loop
- _create_completion: incremental token_to_piece() accumulation
  replaces O(n²) re-detokenization per generated token
- _create_completion: in-place logit_bias instead of full vocab copy
- _create_completion: np.argpartition for top-k logprobs (O(V) vs O(V log V))
- reset(): call llama_memory_clear() for proper KV cache state reset
- generate(): bypass prefix-match for recurrent/SWA models
- generate(): fix tokens[:-1] off-by-one in prefix matching
- eval(): remove unconditional kv_cache_seq_rm, simplify logits assignment
- token_to_piece(): return correct byte length via actual write count
@avion23 avion23 force-pushed the fix/perf-and-iswa branch from c9bbd6d to 3538232 Compare April 13, 2026 05:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant