Skip to content

Add MTP support for Step 3.5 Flash#901

Open
janhilgard wants to merge 1 commit intoml-explore:mainfrom
janhilgard:feat/step3p5-model
Open

Add MTP support for Step 3.5 Flash#901
janhilgard wants to merge 1 commit intoml-explore:mainfrom
janhilgard:feat/step3p5-model

Conversation

@janhilgard
Copy link

Summary

  • Add MTP (Multi-Token Prediction) speculative decoding support to the existing Step 3.5 Flash model
  • 3 MTP prediction layers with per-layer shared_head (dense MLP, not MoE)
  • Backward compatible — MTP only instantiated when num_nextn_predict_layers > 0

Changes

Component Description
Step3p5SharedHead Per-layer prediction head: norm + output projection
Step3p5MTPLayer hnorm + enorm → eh_proj → sliding attention + dense MLP → shared_head
Step3p5MTP Container for multiple prediction layers
Model.__call__ Added return_hidden parameter for MTP integration
Model.mtp_forward Single-step MTP prediction (hidden + next token → logits)
Model.make_mtp_cache KV cache factory for MTP layers
sanitize Handle MTP weight loading (instead of filtering)
quant_predicate Exclude MTP norm layers from quantization

Also fixes two pre-existing ruff E741 lint warnings (llayer in generator expressions).

MTP Architecture

hidden_states ──→ hnorm ──→ ┐
                             ├──→ concat [e, h] ──→ eh_proj ──→ attn + MLP ──→ shared_head ──→ logits
next_token_ids ──→ embed ──→ enorm ──→ ┘

Each MTP layer uses sliding attention with dense SwiGLU MLP (not MoE), matching the original StepFun architecture.

Test plan

  • Basic inference without MTP (backward compatibility)
  • return_hidden=True returns prenorm hidden states
  • mtp_forward() produces correct-shaped logits
  • make_mtp_cache() returns KVCache list matching MTP layer count
  • ruff format and ruff check pass cleanly
  • End-to-end test with mlx-community/Step-3.5-Flash-4bit (111 GB, requires Apple Silicon)

🤖 Generated with Claude Code

Add MTP speculative decoding support to the existing Step 3.5 Flash model:
- Step3p5SharedHead: per-layer prediction head (norm + output projection)
- Step3p5MTPLayer: hnorm + enorm → eh_proj → attention + dense MLP → shared_head
- Step3p5MTP: container for multiple prediction layers
- Model: return_hidden, mtp_forward, make_mtp_cache
- sanitize: handle MTP weight loading
- quant_predicate: exclude MTP norm layers from quantization

Tested with mlx-community/Step-3.5-Flash-4bit (196B MoE, 3 MTP layers).
Backward compatible — MTP is only instantiated when num_nextn_predict_layers > 0.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
janhilgard added a commit to janhilgard/vllm-mlx that referenced this pull request Feb 16, 2026
Step 3.5 Flash is a 196B MoE model (288 experts, top-8 routing, ~11B
active params) with 3 MTP prediction layers. The MLX community 4-bit
conversion strips MTP weights and lacks MTP-aware modeling code.

This adds:
- scripts/add_mtp_weights_step3p5.py: Downloads BF16 MTP shards from
  the original model, extracts layers 45-47, remaps to mtp.layers.*,
  quantizes to 4-bit, and installs the MTP modeling file
- scripts/modeling_step3p5_mtp.py: Full MLX-native model implementation
  with MTP support (Step3p5MTP, Step3p5MTPLayer, Step3p5SharedHead)
- Reasoning parser alias "step3p5" (reuses deepseek_r1 <think> parser)
- Documentation updates in README.md and docs/reference/models.md

Note: The custom modeling file is a workaround until
ml-explore/mlx-lm#901 is merged upstream.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant