Conversation
Libraries/MLXLMCommon/Evaluate.swift
Outdated
| processor?.didSample(token: token) | ||
| y = .init(tokens: token) | ||
| mainState = result.state | ||
| asyncEval(y.tokens) |
There was a problem hiding this comment.
I think the python code doesn't do this. It isn't incorrect, but it might collide with the prefill of the draft (below).
There was a problem hiding this comment.
Good catch! We do batch processing for this token along with the drafted tokens in the verify step, so I think there's no reason to force evaluation here. Removed.
| parameters: GenerateParameters, | ||
| context: ModelContext, | ||
| draftModel: any LanguageModel, | ||
| draftCache: [KVCache]? = nil, |
There was a problem hiding this comment.
Does this need the main KVCache? or is there no need to keep that around?
There was a problem hiding this comment.
Oh, it's actually something I missed. Added parameters for the main cache 👍
| @discardableResult | ||
| public func trimPromptCache(_ cache: [KVCache], numTokens: Int) -> Int { | ||
| guard canTrimPromptCache(cache), !cache.isEmpty else { return 0 } | ||
| cache.dropFirst().forEach { $0.trim(numTokens) } |
There was a problem hiding this comment.
This is curious -- perhaps we were not using cache trimming in the past but this looks important.
There was a problem hiding this comment.
Yeah, I think it went unnoticed because this function was never used internally, but since it's public it's definitely important.
8507565 to
9596574
Compare
davidkoski
left a comment
There was a problem hiding this comment.
This looks good. I filed #181 to track adding support in ChatSession.
* Add speculative decoding
* Add speculative decoding
Proposed changes
Add speculative decoding support, based on the approach in mlx-lm.
A new
SpeculativeTokenIteratoruses a smaller draft model to propose tokens that are verified in batch by the main model, yielding accepted tokens one at a time. Bothgenerate()andgenerateTokens()get overloads accepting a draft model. This also extracts aTokenIteratorProtocolso the shared generation loop works with both iterator types.Bug fix:
trimPromptCachepreviously only trimmed the first cache layer. This PR fixes it to trim all layers, which is required for speculative decoding to correctly rewind rejected tokens.Limitations
Speculative decoding requires trimmable KV caches to rewind rejected tokens. Models that use
MambaCacheare not supported and will throw an error. A possible solution would be to snapshot the cache state before drafting and restore it on rejection, but this adds memory overhead and complexity that felt out of scope for this PR.Benchmarks
Benchmarked on M3 Max using a short text translation prompt (~150 tokens) generating ~130 tokens, with 2 draft tokens for speculative generation. Using more than 2 draft tokens makes results worse in most cases. The draft model generates further ahead and divergence from the main model lowers the acceptance rate, so the drafting overhead ends up making generation slower.
The larger the gap between the main and draft model, the greater the benefit. The 32B model sees the biggest speedup (+79%) since its baseline generation is slow and there is more room for the draft model to help. Prompt processing speed stays largely unaffected, and the memory overhead from the 0.6B draft model is minimal in all cases (under 400MB).
Checklist
Put an
xin the boxes that apply.pre-commit run --all-filesto format my code / installed pre-commit prior to committing changes