Skip to content

Native MTP speculative decoding (Qwen3.5/3.6 reference implementation)#990

Open
AirRunner wants to merge 29 commits into
ml-explore:mainfrom
AirRunner:feat/mtp-native
Open

Native MTP speculative decoding (Qwen3.5/3.6 reference implementation)#990
AirRunner wants to merge 29 commits into
ml-explore:mainfrom
AirRunner:feat/mtp-native

Conversation

@AirRunner

@AirRunner AirRunner commented Mar 13, 2026

Copy link
Copy Markdown

Summary

Qwen3.5 checkpoints ship with a built-in Multi-Token Prediction head (mtp_num_hidden_layers: 1 in config) that predicts token t+2 from the backbone hidden state at t and the embedding of token t+1. This PR adds support for using it as a native speculative decoding mechanism. No separate draft model needed, at minimal extra compute (1 extra transformer layer).

Changes

  • mlx_lm/generate.py: MTP generation loop with draft/verify and probabilistic acceptance, --mtp CLI flag
  • mlx_lm/models/cache.py: rollback_state slot for conv/SSM snapshot on draft rejection
  • mlx_lm/sample_utils.py: p_draw parameter added to apply_xtc to share the XTC draw across draft and verify
  • mlx_lm/models/qwen3_5.py: MTP head module, self.norm moved to TextModel to expose pre-norm hidden states for MTP, n_confirmed parameter for SSM rollback, sanitize: norm +1 shift now triggered only on raw HF checkpoints (unsanitized conv1d), not on presence of MTP weights, and MoE gate weights at 8-bit group=64
  • mlx_lm/models/qwen3_5_moe.py: MTP checkpoint sanitization for MoE variants + handling both Qwen3.5 and Qwen3.6 (fused gate_up_proj)
  • mlx_lm/server.py: --mtp flag, dynamic MTP/batch switching + fix xtc_special_tokens construction
  • tests/test_mtp.py: 10 unit tests

How it works

Each backbone forward pass returns both logits and pre-norm hidden states. The MTP head fuses pre_fc_norm_hidden(h_t) and pre_fc_norm_embedding(embed(t+1)) via a linear projection, runs one full-attention transformer layer, and produces draft logits through the shared lm_head.

The generation loop verifies drafts by feeding [confirmed_tok, draft_tok] to the backbone with n_confirmed=1. This causes GatedDeltaNet to snapshot its conv/SSM state after the confirmed token. On acceptance, both tokens are emitted. On rejection, the SSM state is rolled back to the snapshot and KV caches are trimmed.

Results (Qwen3.6-27B 4-bit, M4 Pro)

Pooled tok/s, 3 runs × 8 prompts, conditions interleaved. See bench_mtp.py.

Acceptance is reported as A/V (drafts accepted / drafts proposed), the standard speculative decoding metric. The benchmark script also reports A/(V+A) (~46% for Qwen3.6-27B at temp=0.6), which equals A/V ÷ (1 + A/V) and is used in some implementations.

Condition Tok/s Speedup Accept (A/V)
Baseline 15.7 1.00x
MTP temp=0 24.6 1.57x 88.3%
MTP temp=0.6 24.0 1.53x 84.8%
MTP temp=1.0 23.5 1.50x 80.5%

Identity check: greedy MTP output == standard generate_step output.

Usage

mlx_lm.generate --model <path> --mtp
mlx_lm.server   --model <path> --mtp

Checkpoint conversion

This requires a checkpoint converted with MTP weights (the default sanitize() previously stripped them). Re-convert from HF with this branch to preserve mtp.* weights.

Note on M1/M2: M1 and M2 lack native BF16 GPU support (MTLDataType.bfloat requires Apple8+). If you choose not to quantise mtp.fc on M1/M2, you need to add the flag the flag --dtype float16 to the convert command. Without it, MTP may drastically slow down on M1 despite positive acceptance rates.

BatchGenerator

Dynamic MTP/batch switching: the server now auto-switches based on self.requests.empty(): MTP for solo requests and BatchGenerator for concurrent ones. Is a best-effort queue check the right approach, or is there a preferred pattern in the server architecture?

Addressed in feat/mtp-batched where GenerationBatch supports MTP natively for B > 1

Future work

DRY refactor + SamplerConfig

A follow-up PR independent of MTP would address:

  1. Code duplication across the now three generator functions:
  • _prefill logic: 3 variants across generate_step, speculative_generate_step, and mtp_generate_step
  • _process_and_sample: almost same pattern in speculative_generate_step and mtp_generate_step
  • quantize_cache_fn = functools.partial(...): same pattern in all three
  1. SamplerConfig: currently mtp_generate_step cannot accept a pre-built sampler= callable and produce correct acceptance logprobs simultaneously. A sampler today returns only a token, but for MTP the acceptance criterion also needs the log-probability distribution the token was drawn from. The fix is a richer sampler interface that returns (token, lp_distribution), allowing both generate_step and mtp_generate_step to share the same interface without passing a dozen individual parameters.

Beyond DRY, SamplerConfig unlocks a potential performance gain: sparse residual sampling.
On rejection at temp > 0, the current implementation samples from max(p_target - p_draft, 0) / Z over the full vocabulary (151K-token for Qwen3.5, 580 µs/call). With top_k > 0, the sampler already computes a top-k partition over the vocabulary, so exposing those indices lets the rejection path work on a K-token support instead.
Without a SamplerConfig, re-running argpartition specifically for the rejection path is slower or equal to the full-vocab path.

Batched MTP

This PR brings MTP for the solo request path only.
However, per-sequence selective rollback (restore SSM state + trim KV only for rejected sequences) is already implemented in AirRunner/mlx-lm · feat/mtp-batched, left out of this PR to keep the diff reviewable.

Test plan

  • Unit tests (10/10 passing) — module existence, cache creation, shapes, pre-norm hidden states, quant predicate, generation identity, end-to-end
  • Manual validation on Qwen3.5-27B, Qwen3.5-0.8B and Qwen3.5-35B-A3B (all 4-bit)

Relates to #872 — cc @janhilgard


Update - probabilistic acceptance and MoE benchmarks

Integrated probabilistic draft acceptance with two cases:

  1. Greedy (temp == 0): exact-match acceptance, mathematically correct for deterministic argmax sampling
  2. Stochastic (temp > 0): min(1, p_target / p_draft): recovers greedy acceptance level at any temperature

Benchmarks on M4 Pro, with 8 diverse prompts:

A reproducible benchmark script is available: bench_mtp.py

Qwen3.5-27B 4-bit

Tok/s Speedup Acceptance (A/V)
No MTP 15.3 1.00x
MTP, temp=0 24.0 1.57x 85.2%
MTP, temp=0.6, exact match 22.7 1.49x 75.4%
MTP, temp=0.6, probabilistic 22.9 1.51x 85.2%

Qwen3.5-35B-A3B 4-bit

Tok/s Speedup Acceptance (A/V)
No MTP 85.3 1.00x
MTP, temp=0 87.9 1.04x 85.2%
MTP, temp=0.6, exact match 84.5 0.98x 78.6%
MTP, temp=0.6, probabilistic 86.5 1.03x 85.2%

On M4 Pro MoE speedup is marginal regardless of acceptance rate. MTP benefit scales with baseline decode time, so at 85 tok/s (3B active params) the MTP overhead is proportionally too large to yield meaningful speedup. With probabilistic acceptance, acceptance rates are consistent with the dense model (~85%).

Bandwidth model

The cross-hardware speedup variation is explained by speedup = (1+p) / (β+δ), where p is the per-round acceptance probability, β = T_verify_backbone / T_baseline, and δ = T_mtp_head / T_baseline. Full derivation and per-component bandwidth estimates in this comment.

For reference:

  • @Thump604's MoE results (M2 Ultra, 8-bit, temp=0, exact match): 35B-A3B 1.11x, 122B-A10B 1.09x.
  • @sammcj results (M5 Max, 4-bit, temp=0): 9B +11.3%, 27B +35.5%, 122B +12.4%.
  • @Anionex results (M5 Pro, 4-bit, temp=0): 27B +31.4%, 79.5% acceptance (A/V).

@vlbosch

vlbosch commented Mar 15, 2026

Copy link
Copy Markdown

Great work! Would this also be possible for models like GLM5? As in, does each model require its own implementation of MTP, or can we reuse your mtp_generate_step-funtion for other models? Thanks for your work so far!

@AirRunner

Copy link
Copy Markdown
Author

Great work! Would this also be possible for models like GLM5? As in, does each model require its own implementation of MTP, or can we reuse your mtp_generate_step-funtion for other models? Thanks for your work so far!

Thanks!

Yes mtp_generate_step() is fully reusable, but each model still needs its own model-side interface.

The Qwen3.5-specific part is MTPDecoderLayer, mtp_forward (produce draft logits), make_mtp_cache and the backbone's __call__ (with n_confirmed for SSM state rollback on hybrid models).

So the speculative-decoding logic lives in one place, and adding a new model is just a matter of exposing the right interface.

For GLM5 specifically, it would certainly be feasible yeah. But I don't think there is even a glm5.py currently.

@Thump604

Copy link
Copy Markdown

Great work on this! We've been using it on M2 Ultra (128GB) with all three Qwen3.5 sizes and it works well.

MoE fix needed

The PR works out of the box for the dense 27B, but MoE models (35B-A3B, 122B-A10B) fail conversion with "768 parameters not in model". The MTP layer's expert weights use unfused per-expert format (mtp.layers.{l}.mlp.experts.{i}.gate_proj.weight) unlike the backbone which uses pre-fused gate_up_proj. The existing sanitize() in qwen3_5_moe.py only handles backbone expert stacking.

Fix (add to qwen3_5_moe.py sanitize(), after the backbone expert stacking loop):

# Stack per-expert MTP weights into switch_mlp format.
mtp_num = getattr(self.language_model.args, "mtp_num_hidden_layers", 0)
num_experts = self.language_model.args.num_experts
for l in range(mtp_num):
    prefix = f"language_model.mtp.layers.{l}.mlp"
    test_key = f"{prefix}.experts.0.gate_proj.weight"
    if test_key in new_weights:
        for n in ["gate_proj", "up_proj", "down_proj"]:
            to_join = [
                new_weights.pop(f"{prefix}.experts.{e}.{n}.weight")
                for e in range(num_experts)
            ]
            new_weights[f"{prefix}.switch_mlp.{n}.weight"] = mx.stack(to_join)

Also needs import mlx.core as mx at the top of the file.

Full fix on our fork: Thump604/mlx-lm@04a4383

Benchmark results (M2 Ultra, greedy)

Model Baseline MTP Speedup
27B-8bit (dense) 20.6 tok/s 27.1 tok/s 1.32x
35B-A3B-8bit (MoE) 74.4 tok/s 82.3 tok/s 1.11x
122B-A10B-5bit (MoE) 43.0 tok/s 46.7 tok/s 1.09x

Pre-converted models with MTP weights: Thump604/Qwen3.5-27B-MLX-8bit, 35B, 122B

@AirRunner

Copy link
Copy Markdown
Author

@Thump604 Thanks for the report and the fix! I've integrated it in AirRunner/mlx-lm@8d06796 with a credit.

Also, what acceptance rates did you get with MoE? I'm curious if it's somehow correlated to the speedup.

@Thump604

Copy link
Copy Markdown

Thanks for the quick integration!

Here are the acceptance rates derived from our benchmarks (M2 Ultra 128GB, greedy/temp=0):

Model Baseline tok/s MTP tok/s Speedup Implied Accept Rate
27B dense 8-bit 20.6 27.1 1.32x ~32%
35B-A3B MoE 8-bit 74.4 82.3 1.11x ~11%
122B-A10B MoE 5-bit 43.0 46.7 1.09x ~9%

At temp=0.6 (production sampling), 122B drops to 1.05x (~5% acceptance).

So yes — it does correlate with architecture. MoE acceptance rates are significantly lower than dense. My hypothesis: the MTP layer contains a full 256-expert MoE routing step (same expert count as the backbone), but with only a single layer of context depth it struggles to predict the correct expert routing. The dense 27B's MTP layer is a standard transformer layer — much simpler prediction task, much higher acceptance.

The fp16 27B was actually 0.61x (slower) — bandwidth-saturated, the MTP overhead exceeds the savings. 8-bit quantization is the sweet spot where MTP helps most.

@Thump604

Copy link
Copy Markdown

Hey @AirRunner — thanks for integrating the MoE sanitize fix! The PR has merge conflicts with main now though. Would you be able to rebase? Happy to help if needed.

Also, any thoughts on tagging a maintainer for review? This has been open since March 13 with zero maintainer engagement. The implementation is solid (8 tests, code review feedback addressed, MoE fix integrated), just needs someone to look at it.

@AirRunner

AirRunner commented Mar 21, 2026

Copy link
Copy Markdown
Author

Hey @Goekdeniz-Guelmez, would you be able to take a look when you get a chance?

Quick summary: 8 unit tests, code review feedback from @janhilgard and @Thump604, rebased on main.
Results: 1.52x token generation on Qwen3.5-27B dense on M4 Pro, validated independently on M2 Ultra across three Qwen3.5 sizes (MoE and dense).

@layer4down

Copy link
Copy Markdown

Subject: Successfully running Qwen3.5-27B locally with workaround

Transparency Note: This comment was drafted with the assistance of an AI assistant to help document the troubleshooting process. All technical details and findings are from actual testing.


Thanks for this PR! I was able to get Qwen3.5-27B working locally with MLX, but encountered an issue that might help others.

The Bug I Was Addressing

When trying to use the model with a client that passes short model IDs, I encountered:

401 Client Error. (Request ID: Root=1-69bfb0a8...)
Repository Not Found for url: https://huggingface.co/api/models/qwen3_5-27b_4bit/revision/main.
Please make sure you specified the correct `repo_id` and `repo_type`.
User Access Token "Claude-flow-ro" is expired

The error message was misleading - it suggested an expired token, but the real issue was a config/weight mismatch described below.

Issue Encountered

The model failed to load with:

ValueError: Missing 15 parameters: 
language_model.mtp.fc.weight,
language_model.mtp.layers.0.input_layernorm.weight,
...

Root Cause

The model's config.json (from mlx-community/Qwen3.5-27B-4bit on HuggingFace) has:

{
  "text_config": {
    "mtp_num_hidden_layers": 1
  }
}

However, the actual .safetensors weights do not contain any MTP parameters. The PR code correctly expects MTP weights when mtp_num_hidden_layers > 0, but this particular model's config claims MTP support that isn't present in the weights.

Workaround

Set mtp_num_hidden_layers to 0 in the model's config:

cat config.json | jq '.text_config.mtp_num_hidden_layers = 0' > config_fixed.json
mv config_fixed.json config.json

Other Configuration Notes

For anyone trying this setup:

  • Context length: Model supports 98K+ context; works with --max-tokens 98304
  • KV cache quantization: Works with MLX_KV_CACHE_QUANT=true environment variable
  • Model path as ID: The server uses the full local path as the model ID in API calls. For example:
    // Request to /v1/chat/completions
    {
      "model": "/path/to/local/models/mlx-community/Qwen3.5-27B-4bit",
      "messages": [...]
    }
    Short names like "Qwen3.5-27B" will trigger a HuggingFace lookup (and fail if the repo doesn't exist or auth is expired).

Suggestion

It might be helpful to add a check/warning when:

  1. mtp_num_hidden_layers > 0 in config
  2. But MTP weights are missing from the loaded model

This would help users identify config/weight mismatches more quickly and avoid confusing auth error messages.

@AirRunner

Copy link
Copy Markdown
Author

@layer4down thanks for the write-up!

You're right, mlx-community/Qwen3.5-27B-4bit was quantized without the MTP head weights, the mtp_num_hidden_layers: 1 in the config is inherited from the original Qwen3.5 config but the MTP parameters were not included when quantizing.

To actually use MTP acceleration, the model needs to be re-quantized including the MTP layers using this branch.

As you suggested I just pushed a fix that raises a clear ValueError instead of the cryptic "Missing N parameters" crash.

@Thump604

Copy link
Copy Markdown

@angeloskath -- this PR has been open 11 days with no maintainer review. AirRunner rebased on 2026-03-21, all conflicts resolved, 8 unit tests passing.

We've been running this in production on M2 Ultra 128GB since day one. Qwen3.5-122B-A10B-VLM-MTP-5bit, 24/7 inference serving coding agents. MTP acceptance rates:

  • 27B dense 8-bit: 1.32x (32% acceptance, best fit)
  • 35B MoE 8-bit: 1.11x (11% acceptance)
  • 122B MoE 5-bit: 1.09x (9% acceptance)

MoE acceptance rates are lower because a single MTP layer can't predict expert routing well. Still a net win for the latency-sensitive use case.

The MoE sanitize fix (commit 8d06796) is essential for Qwen3.5 MoE models -- without it, 768 MTP parameters are silently missing. We've also published pre-converted VLM+MTP models on HuggingFace that depend on this code path.

Would be great to get this reviewed and merged so the community models work out of the box.

@cresseelia

cresseelia commented Mar 29, 2026

Copy link
Copy Markdown

Can we at the reviewer again? it's an important update for qwen3.5

@Thump604

Copy link
Copy Markdown

@angeloskath @awni — this PR has been open 17 days with no maintainer review or feedback. Multiple community members have asked for review (AirRunner, ourselves, cresseelia).

Is there a concern with the approach, scope, or implementation that's blocking review? We're happy to help address any issues — split the PR, rework the API surface, add tests, whatever is needed.

We're running this in production on 122B and have validated it across three Qwen3.5 model sizes. The community is actively hitting the config/weight mismatch that AirRunner already fixed in this branch (layer4down's report above). Without this merged, users have to manually patch config.json to use MTP on Qwen3.5 models.

If the PR needs changes or a different direction, we'd rather know than wait. Let us know how we can help move this forward.

@Goekdeniz-Guelmez

Copy link
Copy Markdown
Contributor

as you can see 6 files have been changes/added alongside 700 lines of added code. This is a PR that has big changes int he codebase itself. Reviewing and (correctly) implementing it will take time. 17 days not long enough. My full weight fine-tuning PR took multiple weeks to be merged. Just keep it open, update it and please be patient. Adding completely new features will take long.

@janhilgard

Copy link
Copy Markdown

@Goekdeniz-Guelmez — fair point, thanks for the perspective. We appreciate you taking the time to look at it.

To make the review easier, we can split this into two smaller PRs:

PR 1 — Model architecture (~260 lines): MTPModule, MTPDecoderLayer, SSM rollback support in GatedDeltaNet, MoE weight stacking, cache rollback field. Pure model-side changes, reviewable independently.

PR 2 — Generation + tests (~420 lines): mtp_generate_step() function, --mtp CLI flag, 8 unit tests. Depends on PR 1 but much easier to review once the model interface is established.

Would splitting it this way help with the review process? Happy to do the work if so.

@AirRunner — would you be open to splitting the PR this way?

@AirRunner

AirRunner commented Apr 1, 2026

Copy link
Copy Markdown
Author

@Goekdeniz-Guelmez — fair point, thanks for the perspective. We appreciate you taking the time to look at it.

To make the review easier, we can split this into two smaller PRs:

PR 1 — Model architecture (~260 lines): MTPModule, MTPDecoderLayer, SSM rollback support in GatedDeltaNet, MoE weight stacking, cache rollback field. Pure model-side changes, reviewable independently.

PR 2 — Generation + tests (~420 lines): mtp_generate_step() function, --mtp CLI flag, 8 unit tests. Depends on PR 1 but much easier to review once the model interface is established.

Would splitting it this way help with the review process? Happy to do the work if so.

@AirRunner — would you be open to splitting the PR this way?

@janhilgard I'm not sure splitting would actually help the review here actually?

The PRs you suggest wouldn't be reviewable in isolation, because the architecture changes only make sense in the context of how mtp_generate_step uses them. Also the changes in generate would be dead code until the other PR lands, so one would need to review both PRs together anyways.

(Also, 183 of the 683 added lines are just unit tests).

That said, I'm open to whatever helps, happy to reorganize if it does :).

@Thump604

Thump604 commented Apr 1, 2026

Copy link
Copy Markdown

@angeloskath @awni — this PR has been open 20+ days with no maintainer review. It is the foundation for MTP speculative decoding on Qwen3.5 models, which several of us are using in production. My PR #1085 (probabilistic acceptance, 2.3x throughput on 122B) builds directly on top of it.

AirRunner's implementation is solid: 8 tests, 80.6% acceptance on M4 Pro. Is there a concern about scope or approach blocking review?

@gyzerok

gyzerok commented Apr 1, 2026

Copy link
Copy Markdown

@Thump604 can you stop pinging people? The more annoying you are the less likely anyone is going to respond.

@janhilgard

Copy link
Copy Markdown

Great work — I've been running MTP on Qwen3.5 MoE models in production (M3 Ultra, 256 GB) and wanted to share findings that might explain the low MoE acceptance rates.

BF16 MTP weights are critical for MoE acceptance

Your quant_predicate excludes only mtp.fc:

if path.endswith("mtp.fc"):
    return False

But the MTP transformer layer (attention, MLP, norms) still gets quantized. We found that quantized MTP weights give near-0% acceptance on MoE models — the quantization error compounds through the expert routing prediction.

Fix: exclude ALL MTP weights from quantization:

if "mtp." in path:
    return False

Our MoE results with BF16 MTP weights

Model Quantization MTP weights Acceptance Speedup
35B-A3B 4-bit BF16 79-85% 1.18x
122B-A10B 4-bit BF16 77-78% 1.12x
35B-A3B 4-bit dequantized 4→BF16 ~0%

vs your MoE benchmarks (quantized MTP weights):

Model MTP weights Implied acceptance Speedup
35B-A3B 8-bit quantized ~11% 1.11x
122B-A10B 5-bit quantized ~5% 1.09x

The difference is stark: BF16 MTP weights → 79-85% acceptance, quantized → 5-11%.

Batch auto-skip

Your PR sets is_batchable = False when MTP is active. In our vllm-mlx integration (#245 on waybarrios/vllm-mlx) we auto-skip MTP when batch_size > 1:

if len(active_batch) > 1:
    # Skip MTP, fall back to standard generation
    return _orig_step(input_tokens, cache)

This gives the best of both worlds:

  • 1 request: MTP active → 86 tok/s (1.18x)
  • 8 requests: MTP skipped → 307 tok/s (full batching throughput)

Instead of disabling batching entirely, you could dynamically switch.

Weight extraction

We extract BF16 MTP weights from the original HF model (not the quantized MLX model) with a dedicated script. See vllm-mlx PR #245 for the add_mtp_weights_qwen35.py script that:

  • Downloads only MTP-containing shards (not entire model)
  • Stacks per-expert weights into SwitchLinear format
  • Applies RMSNorm +1.0 shift
  • Outputs native BF16

Happy to collaborate on getting BF16 MTP weights into the standard conversion pipeline.

@Thump604

Thump604 commented Apr 2, 2026

Copy link
Copy Markdown

I tested your BF16 MTP finding on our models. Sharing the data since it tells a different story on 5-bit and 8-bit backbones.

I extracted fresh BF16 MTP weights from the original HF models (not dequantized from quantized), applied the RMSNorm +1.0 shift, stacked MoE experts into SwitchLinear format, and re-quantized to match the backbone (5-bit gs=64 for 122B, 4-bit gs=64 for 4B). This matches the process you describe in your extraction script.

Results (probabilistic acceptance, temp=0.6):

Model Backbone Original quantized MTP BF16-source re-quantized MTP
4B dense 4-bit gs=64 44.9%, 91.8 tok/s 43.8%, 86.5 tok/s
122B MoE 5-bit gs=64 47.3%, 21.5 tok/s 47.3%, 21.0 tok/s

No measurable difference. Re-quantizing MTP from the BF16 source produces the same acceptance as the original quantized weights on these models.

I also tested with fully unquantized BF16 MTP (no re-quantization, just raw BF16 + norm shift). This gave 0% acceptance across all models. The BF16 MTP forward pass produces a different logit distribution than the quantized backbone expects. Once I re-quantize to match the backbone, the acceptance rate converges to the same ~46%.

Your 79-85% acceptance at 4-bit is significantly higher than what I see. A few questions:

  • Are you running the MTP layer entirely in BF16 (unquantized), or does your script quantize it to match the backbone?
  • Which mlx-lm generate path are you using? Our probabilistic acceptance is from PR feat: probabilistic MTP acceptance (speculative sampling) #1085 (min(1, p_target/p_draft)). Exact match at temp=0.6 gives ~5%.
  • Are your numbers from greedy (temp=0) or sampled (temp=0.6)?

Our acceptance ceiling appears to be ~47% with probabilistic sampling regardless of how the MTP weights are prepared, as long as they match the backbone's quantization. If you are getting 79-85%, there may be a difference in the generation loop or sampling strategy that accounts for the gap.

AirRunner and others added 28 commits June 5, 2026 06:05
Extend GatedDeltaNet.__call__ with an n_confirmed parameter that splits the T=2 verification pass into two sub-calls.
After processing the confirmed token, the intermediate conv/ssm state is snapshotted into ArraysCache.rollback_state.
On rejection, SSM layers restore this snapshot while attention layers trim their KV cache by 1 as before.

Acceptance rate ~64% average / ~85% on 100-token run.
- Yield token.item() instead of raw mx.array to match generate_step convention (fixes detokenizer crash via stream_generate)
- Create MTP cache when prompt_cache lacks MTP entries (server creates backbone-only caches via make_prompt_cache)
- Disable batch generation for MTP models (draft/verify loop requires single-sequence processing)

Note: batch-aware MTP would need per-sequence accept/reject and SSM rollback within BatchGenerator
…t_predicate)

- Return pre-norm hidden states from Qwen3_5TextModel: apply norm in TextModel before lm_head only (avoiding double normalization (model.norm + pre_fc_norm_hidden).
- Exclude mtp.fc from quantization via quant_predicate (the fusion projection (2H→H) stays in bf16 for accuracy).

27B results after reconversion: 80.6% acceptance, 23.3 tok/s on M4 Pro (1.52x).
Replace auto-detection of MTP head with explicit --mtp flag, consistent with existing --draft-model for speculative decoding.

MTP is now opt-in. Without the flag, models with MTP weights use standard generation and batch serving remains fully functional.
8 tests using a tiny synthetic Qwen3.5 model (4 layers, hidden=64) with mtp_num_hidden_layers=1 and hybrid SSM+attention layers.
- MTP module instantiation and cache creation
- return_hidden shape and pre-norm verification
- mtp_forward output shape
- quant_predicate excludes mtp.fc
- Token identity: mtp_generate_step == generate_step (greedy)
- End-to-end mtp_generate_step completion
Instead of silently falling back to standard generation, emit a warning so the user knows their --mtp flag had no effect.
MTP layers in MoE models (35B-A3B, 122B-A10B) ship unfused per-expert weights (mtp.layers.{l}.mlp.experts.{i}.gate_proj.weight) whereas the backbone uses pre-fused switch_mlp format. Conversion was failing with ~768 parameters not in model.

Add a stacking loop in qwen3_5_moe.py sanitize() after the backbone expert loop, mirroring the same pattern for MTP prefixes.

Co-authored-by: Thump604 <thump604@users.noreply.github.com>
When mtp_num_hidden_layers > 0 but the model weights contain no MTP parameters, the previous error was a cryptic 'Missing N parameters'.
Now raises a ValueError with an actionable message.
With sampler=None (greedy decoding): keep exact-match acceptance, this is the mathematically correct criterion for a deterministic point-mass
distribution.

For stochastic samplers (temp > 0), accept the draft token with probability min(1, p_target / p_draft), computed from the log-probability distributions already returned by _process_and_sample. No extra forward passes needed.

This recovers the greedy acceptance rate (~46%) at any temperature, vs ~43% with exact-match at temp=0.6 on Qwen3.5-27B 4-bit.

Suggested by @janhilgard; implementation reference in ml-explore#1085 by @Thump604.
The _prefill loop in mtp_generate_step previously stopped when y.size <= prefill_step_size (512), leaving up to 512 tokens for _step_backbone(..., return_hidden=True). Since return_hidden=True keeps the full hidden state [1, N, d_model] live, N > 1 caused unnecessary memory pressure on longer prompts.

The loop now stops at exactly 1 token (matching generate_step's design), ensuring the hidden state is always [1, 1, d_model]. Default prefill_step_size raised from 512 to 2048 accordingly.
Three bugs caused logits processors to receive stale token context in mtp_generate_step, breaking any processor that reads tokens[-1]:

1. _step_backbone used a fixed y_ctx slice on every loop iteration for n_predict=2, adding y[0] twice instead of y[0] then y[1]. The bonus token was therefore sampled with the wrong context.

2. _step_mtp passed prev_tokens directly to _process_and_sample without including main_tok, so tokens[-1] was the input token of the preceding backbone step rather than the just-sampled token. Fixed with a local tokens_for_proc that appends main_tok without mutating prev_tokens (mutating would double-count main_tok when _step_backbone adds y[0] at the next verify pass).

3. On draft rejection, prev_tokens retained the rejected draft token added by _step_backbone at i=1, corrupting context for the subsequent _step_mtp call.

All three changes are gated on 'if logits_processors:' and have no effect on the no-processor path (benchmarks unchanged).

Add two regression tests:
- test_mtp_generate_identity_with_logits_processor: verifies that mtp_generate_step and generate_step produce identical greedy output under a context-sensitive stateless processor (covers bugs 1 and 3).
- test_mtp_processor_prev_tokens_correct_at_draft_step: a forcing processor deterministically sets T0=4 and verifies the MTP head receives T0 as tokens[-1], not the preceding prompt token (bug 2).
Replace nonlocal mutation with explicit parameter/return value threading. _step_backbone now takes prev_tokens as an argument and returns it as a fourth value; _step_mtp takes prev_tokens as a third argument. The main loop unpacks and passes prev_tokens explicitly at every call site.

Also name the three hidden-state slice indices (hidden_at_main, hidden_at_confirmed, hidden_at_draft) so their roles are self-evident.
…s_processors dimensionality

Pass input_embeddings through _prefill for VLM prefill compatibility. Wrap/unwrap logits as 2D in _process_and_sample so logits_processors receive the expected [1, vocab] shape.
…e sanitize

Extract _unfuse_experts and _stack_per_expert helpers to eliminate duplicated logic between backbone and MTP conversion.

Detect MTP format once before the loop (Qwen3.6 fused gate_up_proj vs Qwen3.5 per-expert) instead of re-checking per iteration.
On rejection, _rollback_draft trimmed mtp_cache by 1, but the 2-token backbone verification pass never writes to mtp_cache. The trim was removing the valid KV entry from the previous _step_mtp call, causing accumulated context drift after repeated rejections (lower acceptance rate over long generations).
On rejection, emit a token sampled from max(p_target - p_draft, 0) / Z instead of the backbone argmax. This guarantees the output marginal equals the target distribution exactly (Leviathan et al. 2022; Chen et al. 2023).
- Remove z.item() sync: z stays in the MLX graph and is evaluated once alongside categorical(), reducing Metal round-trips from 2 to 1.
- Replace if z > 1e-8 guard with mx.where(z > 0, residual, p_target): when the residual mass is zero, sample from p_target instead of keeping verify_pred (argmax). Matches Leviathan et al. 2022 §2.3.
Replace sampler= callable with explicit sampling params (temp, top_p, top_k, min_p, xtc_*) so mtp_generate_step can compute temperature-adjusted lp_accept for correct probabilistic acceptance at temp > 0.

- Extract make_sampler_chain from make_sampler (DRY); mtp uses it directly to build the filter chain without a pre-assembled sampler.
- Compute lp_accept from the filtered+scaled distribution so it matches the distribution the token was drawn from.
- Share the XTC boolean draw across draft and verify steps via xtc_cell, so both steps apply the same XTC mask.
- Draw acceptance coin as mx.random.uniform(), evaluated in parallel with the verify forward pass (amortized Metal dispatch, consistent with mx.random.seed()).
- Fix _xtc_special_tokens: use tokenizer.eos_token_ids (plural) and concatenate properly instead of mixing int and list.
- Update tests: remove sampler= from MTP tests, add top_k variant, extract _collect_rejection_tokens/_assert_residual_varies helpers.
Previously _prefill only populated the backbone cache, leaving the MTP KVCache cold at the start of decode. The MTP head was trained with full prefix context, so starting from an empty cache is misaligned with training.

Now each prefill chunk passes return_hidden=True and immediately calls mtp_forward(hidden, y[1:n+1], mtp_cache). The hidden tensor is transient: consumed within the same iteration before mx.clear_cache().
generate_step calls mx.clear_cache() every 256 tokens to bound the Metal allocator's free list.

Introduce _CACHE_CLEAR_INTERVAL = 256 shared by both generate_step and mtp_generate_step to add the equivalent cache-clearing logic to the MTP decode loop.  The block-based counter (ntoks // _CACHE_CLEAR_INTERVAL) handles MTP iterations that could emit multiple tokens at once, where a '% interval == 0' check could skip a boundary.
Declare u = mx.random.uniform() immediately before its first use (mx.eval) rather than before the unrelated _step_backbone call.
Empirical benchmarks (Qwen3.6-27B 4-bit, M4 Pro, temp=0/0.6/1.0) show no measurable impact on MTP acceptance rate when mtp.fc is quantized to 4-bit: acceptance delta is within noise (−0.2 to +0.3 pp), speedup delta within noise (−0.003 to +0.026x).

Additionally, keeping mtp.fc in BF16 penalizes M1 users where BF16 has no native GPU support.
The test was checking mtp.fc exclusion, which was removed in c47c1cb after empirical benchmarks.
The accepted draft token was never processed by the MTP head, causing the cache to drift behind the backbone cache by one entry per accept. After k accepts the MTP head operates on k tokens of missing context. Empirically the impact was negligible though (backbone hidden dominates MTP head conditioning).

Fix: extend _step_mtp with an optional cache_commit=(hidden, tok) parameter. When set, the alignment position and the draft position are processed in a single 2-token batched mtp_forward, committing the accepted token to mtp_cache at no extra forward-pass cost.
machiabeli pushed a commit to machiabeli/mlx-lm-1 that referenced this pull request Jun 8, 2026
Implement native MTP support following the HF reference architecture
and ml-explore/mlx-lm PR ml-explore#990 patterns:

Model (deepseek_v4.py):
- MTPBlock wrapping DeepseekV4Block with e_proj, h_proj, enorm, hnorm,
  norm, and per-block HyperHead
- return_hidden support in Model.__call__ for exposing raw 4D hidden state
- mtp_forward() and make_mtp_cache() on Model
- Weight sanitization: keep and remap MTP weights, stack MTP experts

Generation (generate.py):
- mtp_generate_step() speculative decoding loop with draft/verify cycle
- Greedy exact-match and probabilistic acceptance modes
- --mtp CLI flag with graceful fallback warning

Server (server.py):
- --mtp CLI flag and stream_generate integration
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.