You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This PR adds DeepSeek V4 Flash and Pro implementations, tested for both the single request and batch request path in Exo.
It uses the new Compressor and Indexer, reducing memory consumption drastically and improving prefill speeds compared to the open PRs. A few kernels that might be helpful for improving decode performance, too.
This is basically just throwing in one more implementation into the ring; I can clean up this PR further if it'll help or @machiabeli , @austinbv or @Blaizzy can cherry-pick changes :)
To test this branch, please just use the correctly formatted prompt (e.g. uv run mlx_lm.generate --model ~/.exo/models/mlx-community--DeepSeek-V4-Flash --prompt '<|begin▁of▁sentence|>You are a helpful AI assistant. Respond directly and concisely. <|User|>Hi!<|Assistant|><think>' --max-tokens 512 --temp 1.0 )
Three sanitize gaps when loading mlx-community 4-bit on multi-node TP
Tested this branch against mlx-community/deepseek-ai-DeepSeek-V4-Flash-4bit running 2-rank tensor-parallel across 2× Mac Studio M4 Max (128 GB / node). Hit three weight-naming mismatches that all manifest as silent load_weights(strict=False) drops, then either OOM-by-random-init or KeyError / quantized_matmul shape errors at first forward. All fixes are isolated to Model.sanitize().
1. Pre-stacked routed experts (silent drop → OOM)
The 4-bit quant ships routed experts already stacked as a single tensor under layers.{i}.ffn.experts.{w1,w2,w3}.{weight,scales,biases} rather than per-expert unstacked (experts.{0..255}.w1.weight). The stacking loop only fires for the unstacked layout, so the keys pass straight through and load_weights(strict=False) silently drops them. The SwitchLinear is then left at its mx.random.uniform(shape=(256, 2048, 4096)) fp32 init — ~8.6 GB per projection × 3 × 43 layers ≈ 1.1 TB, which gets materialized by per-layer mx.eval during sharding and OOMs on 128 GB nodes.
forold, newinw_remap.items():
# shared_experts first so its substring isn't partially rewrittennk=nk.replace(f".shared_experts.{old}.", f".shared_experts.{new}.")
# NEW:nk=nk.replace(f".ffn.experts.{old}.", f".ffn.switch_mlp.{new}.")
2. Per-group wo_a (KeyError: 0)
The 4-bit checkpoint stores wo_a as o_groups separate sub-tensors (wo_a.{0..7}.{weight,scales,biases}), but V4Attention.__init__ builds a single nn.Linear of shape (o_groups * o_lora_rank, in). tree_unflatten interprets the numeric .0./.1./… suffixes as list indices, so wo_a arrives at Module.update as a list[dict] where the model expects a leaf with weight/scales/biases, and crashes with KeyError: 0.
The top-level remap only renames embed.weight → model.embed_tokens.weight and head.weight → lm_head.weight. The 4-bit quant also ships embed.{scales,biases} and head.{scales,biases}. Without renaming those, nn.quantize's class_predicate (f"{p}.scales" in weights) misses and embed_tokens / lm_head stay as plain nn.Embedding / nn.Linear while the .weight slot ends up holding the int4-packed tensor — shape/dtype mismatch at first forward.
Fork carrying all three fixes: adurham/mlx-lm commits 223604e, 6ee9898, 15de79d. Generates coherent tokens at temperature=1.0, top_p=1.0 on 2-rank TP after these patches.
Side note: _grouped_output_projection reaches into wo_a.weight directly, which bypasses any shard_linear wrapper — sharding heads via wq_b all-to-sharded then makes the manual reshape see half the per-group input dim and crashes inside mx.quantized_matmul. That's downstream of this PR (a host-runtime concern), but if you'd accept it, calling self.wo_a(...) instead of accessing .weight directly would let multi-node TP shard heads cleanly. Happy to open a separate PR with these as one diff if useful.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
This PR adds DeepSeek V4 Flash and Pro implementations, tested for both the single request and batch request path in Exo.
It uses the new Compressor and Indexer, reducing memory consumption drastically and improving prefill speeds compared to the open PRs. A few kernels that might be helpful for improving decode performance, too.
This is basically just throwing in one more implementation into the ring; I can clean up this PR further if it'll help or @machiabeli , @austinbv or @Blaizzy can cherry-pick changes :)
To test this branch, please just use the correctly formatted prompt (e.g.
uv run mlx_lm.generate --model ~/.exo/models/mlx-community--DeepSeek-V4-Flash --prompt '<|begin▁of▁sentence|>You are a helpful AI assistant. Respond directly and concisely. <|User|>Hi!<|Assistant|><think>' --max-tokens 512 --temp 1.0)