Skip to content

Add speculative decoding#173

Merged
davidkoski merged 4 commits intoml-explore:mainfrom
petrukha-ivan:speculative-decoding
Apr 3, 2026
Merged

Add speculative decoding#173
davidkoski merged 4 commits intoml-explore:mainfrom
petrukha-ivan:speculative-decoding

Conversation

@petrukha-ivan
Copy link
Copy Markdown
Contributor

Proposed changes

Add speculative decoding support, based on the approach in mlx-lm.

A new SpeculativeTokenIterator uses a smaller draft model to propose tokens that are verified in batch by the main model, yielding accepted tokens one at a time. Both generate() and generateTokens() get overloads accepting a draft model. This also extracts a TokenIteratorProtocol so the shared generation loop works with both iterator types.

Bug fix: trimPromptCache previously only trimmed the first cache layer. This PR fixes it to trim all layers, which is required for speculative decoding to correctly rewind rejected tokens.

Limitations

Speculative decoding requires trimmable KV caches to rewind rejected tokens. Models that use MambaCache are not supported and will throw an error. A possible solution would be to snapshot the cache state before drafting and restore it on rejection, but this adds memory overhead and complexity that felt out of scope for this PR.

Benchmarks

Benchmarked on M3 Max using a short text translation prompt (~150 tokens) generating ~130 tokens, with 2 draft tokens for speculative generation. Using more than 2 draft tokens makes results worse in most cases. The draft model generates further ahead and divergence from the main model lowers the acceptance rate, so the drafting overhead ends up making generation slower.

Main model Draft model Prompt tokens/s Generation tokens/s Memory peak
Qwen3-4B-4bit - 788 99 2564M
Qwen3-4B-4bit Qwen3-0.6B-4bit 708 128 (+29%) 2915M
Qwen3-14B-4bit - 258 32 8237M
Qwen3-14B-4bit Qwen3-0.6B-4bit 250 53 (+66%) 8351M
Qwen3-32B-4bit - 71 14 17916M
Qwen3-32B-4bit Qwen3-0.6B-4bit 73 25 (+79%) 18267M

The larger the gap between the main and draft model, the greater the benefit. The 32B model sees the biggest speedup (+79%) since its baseline generation is slow and there is more room for the draft model to help. Prompt processing speed stays largely unaffected, and the memory overhead from the 0.6B draft model is minimal in all cases (under 400MB).

Note: The benchmarks above use a translation task, which is relatively deterministic and tends to produce high acceptance rates. Results will vary depending on the draft model selection and the prompt. More creative tasks may see lower acceptance rates, and in some cases speculative decoding can even slow down generation.

Checklist

Put an x in the boxes that apply.

  • I have read the CONTRIBUTING document
  • I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
  • I have added tests that prove my fix is effective or that my feature works
  • I have updated the necessary documentation (if needed)

processor?.didSample(token: token)
y = .init(tokens: token)
mainState = result.state
asyncEval(y.tokens)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the python code doesn't do this. It isn't incorrect, but it might collide with the prefill of the draft (below).

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch! We do batch processing for this token along with the drafted tokens in the verify step, so I think there's no reason to force evaluation here. Removed.

parameters: GenerateParameters,
context: ModelContext,
draftModel: any LanguageModel,
draftCache: [KVCache]? = nil,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this need the main KVCache? or is there no need to keep that around?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, it's actually something I missed. Added parameters for the main cache 👍

@discardableResult
public func trimPromptCache(_ cache: [KVCache], numTokens: Int) -> Int {
guard canTrimPromptCache(cache), !cache.isEmpty else { return 0 }
cache.dropFirst().forEach { $0.trim(numTokens) }
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is curious -- perhaps we were not using cache trimming in the past but this looks important.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I think it went unnoticed because this function was never used internally, but since it's public it's definitely important.

Copy link
Copy Markdown
Collaborator

@davidkoski davidkoski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good. I filed #181 to track adding support in ChatSession.

@davidkoski davidkoski merged commit 8c9dd63 into ml-explore:main Apr 3, 2026
2 checks passed
jjang-ai pushed a commit to osaurus-ai/mlx-swift-lm that referenced this pull request Apr 3, 2026
jjang-ai pushed a commit to osaurus-ai/mlx-swift-lm that referenced this pull request Apr 3, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants