Fix hybrid cache checkpoints for short conversations#929
Draft
omui-ai wants to merge 6 commits intoml-explore:mainfrom
Draft
Fix hybrid cache checkpoints for short conversations#929omui-ai wants to merge 6 commits intoml-explore:mainfrom
omui-ai wants to merge 6 commits intoml-explore:mainfrom
Conversation
Modify server tokenization logic to distinguish between the prompt used for generation and the key used for cache lookup. - Add `_tokenize_for_cache_key` which applies chat templates without the generation suffix. This ensures KV cache lookups match on message content only, fixing prompt chaining issues where the suffix incorrectly altered the cache key. - Update batch and non-batch generation flows to fetch cache using the clean key, then append the generation suffix to the remaining tokens. - Fix `ArraysCache.nbytes` to check for `None` entries before summing bytes, preventing potential errors during size calculation.
Previously, extracting a cache entry removed it from the LRU, preventing multiple requests from reusing the same cached prefix. Update `LRUPromptCache._extract` to accept a `keep_original` flag. When enabled for shorter prefix matches, the method returns a deep copy of the cache without deleting the original entry. This ensures the cached prompt remains available for subsequent requests, supporting hybrid model prompt chaining. Add `test_cache_persistence` to verify that cached prefixes persist and are reused across multiple requests.
Update LRUPromptCache to store cache entries at message boundaries (e.g., after system or user messages) in addition to the full prompt sequence. This allows the cache to be shared when conversations branch, improving efficiency for multi-turn dialogs. - Modify `insert_cache` to accept optional `boundary_positions` list. - Add `_insert_boundary_cache` helper to store references to shared cache objects at specific token indices. - Add `_find_cache_boundaries` in `ResponseGenerator` to detect message delimiters like `<|im_end|>` across different tokenizers.
Modify the LRU cache strategy to return references instead of copies, reducing memory overhead for hybrid models and prompt chaining. - Remove `deepcopy` in `_extract` to allow cache objects to be mutated in place. - Update `fetch_nearest_cache` to return the matched token position, enabling cache migration. - Extend `insert_cache` with `old_position` to move cache entries rather than duplicating them. - Dynamically update `nbytes` when overwriting existing cache entries. - Add debug logging for cache operations.
Adds support for creating periodic snapshots of the prompt cache to facilitate branching conversation histories. - Introduced `is_snapshot` attribute to `CacheEntry` to distinguish mutable cache entries from immutable snapshots. - Added `checkpoint_interval` (default 8192) to `__init__` to specify snapshot frequency. - Implemented `_find_checkpoint_positions` to place snapshots at logical message boundaries near the interval. - Modified lookup logic to extract copies from snapshots (preventing shared state mutation) while preserving in-place updates for linear extensions.
Three issues found while testing on Qwen3.5-35B-A3B (hybrid DeltaNet + attention) served on M1 Max 64GB via multi-turn chat: 1. _find_checkpoint_positions: the checkpoint_interval loop (range(8192, tokens_len, 8192)) produces no checkpoints when tokens_len < 8192 — which is most early conversation turns. Without a checkpoint, the next turn can't reuse the cache for hybrid models (can_trim_prompt_cache returns False, so the "longer" match path is skipped). Fix: always checkpoint at the last message boundary regardless of length. 2. _compute_message_boundaries: range(1, len(messages)) is empty for single-message conversations, so the first turn never produces a boundary. Fix: range(1, len(messages) + 1). 3. _create_checkpoint_snapshot: can_trim_prompt_cache() is all-or-nothing — returns False if any layer is non-trimmable. For hybrid models this means KVCache layers don't get trimmed in snapshots either, leaving stale KV entries from the generation prompt. Fix: fall back to per-layer trim so KVCache layers get cleaned while ArraysCache layers stay as-is. Tested empirically: before these fixes, multi-turn cache hit rate was 0% across all turns. After: 0% → 37% → 57% scaling upward as the conversation grows. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Context
Builds on #923 by @dexloom — found three issues while testing that PR's checkpoint approach on real hardware (Qwen3.5-35B-A3B, M1 Max 64GB, multi-turn chat).
What's wrong
With #923 as-is, multi-turn cache hit rate is 0% across all turns for conversations under 8k tokens. Three issues combine to cause this:
1. No checkpoints for short conversations
_find_checkpoint_positionsiteratesrange(checkpoint_interval, tokens_len, checkpoint_interval). Whentokens_len < 8192(i.e. most early conversation turns), the range is empty — no checkpoints are created. Without a checkpoint, the next turn can't reuse the cache for hybrid models becausecan_trim_prompt_cachereturnsFalseand the "longer" match path infetch_nearest_cacheis skipped.Fix: Always checkpoint at the last message boundary, then add periodic checkpoints for long contexts.
2. No boundaries for single-message conversations
_compute_message_boundariesusesrange(1, len(messages))which is empty when there's only one message. The first turn of every conversation never gets a boundary, so no checkpoint can be created even if issue 1 were fixed.Fix:
range(1, len(messages) + 1).3. All-or-nothing trim in snapshots
_create_checkpoint_snapshotusescan_trim_prompt_cache(snapshot)which returnsFalseif any layer is non-trimmable. For hybrid models (KVCache + ArraysCache), this means KVCache layers don't get trimmed in snapshots either, leaving stale KV entries from the generation prompt.Fix: Fall back to per-layer trim — KVCache layers get cleaned while ArraysCache layers stay as-is.
Test results
Qwen3.5-35B-A3B (30 DeltaNet + 10 attention layers), 8-bit MLX, served via OpenAI-compatible API:
Savings scale upward as the conversation grows — only new messages need processing each turn.
🤖 Generated with Claude Code