Skip to content

feat(nemotron_h): add Multi-Token Prediction (MTP) module#1161

Open
Thump604 wants to merge 1 commit intoml-explore:mainfrom
Thump604:feat/nemotron-mtp-model
Open

feat(nemotron_h): add Multi-Token Prediction (MTP) module#1161
Thump604 wants to merge 1 commit intoml-explore:mainfrom
Thump604:feat/nemotron-mtp-model

Conversation

@Thump604
Copy link
Copy Markdown

Summary

Nemotron-3-Super-120B ships MTP prediction heads (1,040 weight keys covering attention + 512-expert MoE) in its HuggingFace checkpoint, but neither HF transformers nor mlx-lm currently use them — sanitize() explicitly strips mtp.* keys.

This PR adds native MTP support to the Nemotron-H model definition:

  • NemotronHMTPModule: dual-norm embedding/hidden fusion via eh_proj, followed by attention + MoE layers matching mtp_hybrid_override_pattern (*E = 1 attention + 1 MoE layer)
  • NemotronHMTPBlock: supports attention (*), MoE (E), and MLP (-) block types
  • Model interface: mtp_forward(), make_mtp_cache(), and return_hidden parameter on __call__ — the standard MTP model contract already used by Qwen3.5 models
  • Weight remapping: sanitize() remaps HF mtp.layers.0.* → flat mtp.* keys and stacks 512-expert weight shards
  • Removes the mtp.* weight stripping so MTP weights are loaded when present

Test results

38% MTP acceptance rate on Nemotron-3-Super-120B-A12B-5bit with extracted FP16 MTP weights (5.5 GB from BF16 shards 49-50). Coherent generation confirmed.

Scope

This is model-only — it adds the MTP architecture and weight loading to nemotron_h.py. The generate-level mtp_generate_step() integration (which calls mtp_forward() during decoding) is a separate concern for a follow-up PR.

The model-level interface (mtp_forward, make_mtp_cache, return_hidden) matches the existing Qwen3.5 MTP contract, so existing MTP generate infrastructure can use it without modification.

MTP weight availability

The MTP weights are present in the original NVIDIA checkpoint but need to be extracted separately since they span the last 2 of 50 BF16 safetensors shards. A conversion script for extracting and stacking the MTP weights is available separately.

Nemotron-3-Super-120B ships MTP prediction heads (1,040 keys covering
attention + 512-expert MoE) but neither HF transformers nor mlx-lm
currently use them — both explicitly strip `mtp.*` weights during load.

This commit adds native MTP support to the Nemotron-H model:

- NemotronHMTPModule: dual-norm embedding/hidden fusion via eh_proj,
  followed by attention + MoE layers matching the mtp_hybrid_override_pattern
- NemotronHMTPBlock: supports attention (*), MoE (E), and MLP (-) types
- Model gains mtp_forward(), make_mtp_cache(), and return_hidden on __call__
- sanitize() remaps HF mtp.layers.0.* keys and stacks 512-expert weights
- Weight stripping of mtp.* keys removed

Tested: 38% MTP acceptance rate on Nemotron-3-Super-120B-A12B-5bit
with extracted FP16 MTP weights. Coherent generation confirmed.

The generate-level mtp_generate_step() integration is a separate concern
and will follow in a subsequent PR.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant