Skip to content

Implement DSV4#1195

Open
rltakashige wants to merge 1 commit into
ml-explore:mainfrom
rltakashige:deepseek-v4
Open

Implement DSV4#1195
rltakashige wants to merge 1 commit into
ml-explore:mainfrom
rltakashige:deepseek-v4

Conversation

@rltakashige

@rltakashige rltakashige commented Apr 25, 2026

Copy link
Copy Markdown
Contributor

This PR adds DeepSeek V4 Flash and Pro implementations, tested for both the single request and batch request path in Exo.

It uses the new Compressor and Indexer, reducing memory consumption drastically and improving prefill speeds compared to the open PRs. A few kernels that might be helpful for improving decode performance, too.

This is basically just throwing in one more implementation into the ring; I can clean up this PR further if it'll help or @machiabeli , @austinbv or @Blaizzy can cherry-pick changes :)

To test this branch, please just use the correctly formatted prompt (e.g. uv run mlx_lm.generate --model ~/.exo/models/mlx-community--DeepSeek-V4-Flash --prompt '<|begin▁of▁sentence|>You are a helpful AI assistant. Respond directly and concisely. <|User|>Hi!<|Assistant|><think>' --max-tokens 512 --temp 1.0 )

@rltakashige

Copy link
Copy Markdown
Contributor Author
image

@adurham

adurham commented Apr 26, 2026

Copy link
Copy Markdown
Contributor

Three sanitize gaps when loading mlx-community 4-bit on multi-node TP

Tested this branch against mlx-community/deepseek-ai-DeepSeek-V4-Flash-4bit running 2-rank tensor-parallel across 2× Mac Studio M4 Max (128 GB / node). Hit three weight-naming mismatches that all manifest as silent load_weights(strict=False) drops, then either OOM-by-random-init or KeyError / quantized_matmul shape errors at first forward. All fixes are isolated to Model.sanitize().

1. Pre-stacked routed experts (silent drop → OOM)

The 4-bit quant ships routed experts already stacked as a single tensor under layers.{i}.ffn.experts.{w1,w2,w3}.{weight,scales,biases} rather than per-expert unstacked (experts.{0..255}.w1.weight). The stacking loop only fires for the unstacked layout, so the keys pass straight through and load_weights(strict=False) silently drops them. The SwitchLinear is then left at its mx.random.uniform(shape=(256, 2048, 4096)) fp32 init — ~8.6 GB per projection × 3 × 43 layers ≈ 1.1 TB, which gets materialized by per-layer mx.eval during sharding and OOMs on 128 GB nodes.

for old, new in w_remap.items():
    # shared_experts first so its substring isn't partially rewritten
    nk = nk.replace(f".shared_experts.{old}.", f".shared_experts.{new}.")
    # NEW:
    nk = nk.replace(f".ffn.experts.{old}.", f".ffn.switch_mlp.{new}.")

2. Per-group wo_a (KeyError: 0)

The 4-bit checkpoint stores wo_a as o_groups separate sub-tensors (wo_a.{0..7}.{weight,scales,biases}), but V4Attention.__init__ builds a single nn.Linear of shape (o_groups * o_lora_rank, in). tree_unflatten interprets the numeric .0./.1./… suffixes as list indices, so wo_a arrives at Module.update as a list[dict] where the model expects a leaf with weight/scales/biases, and crashes with KeyError: 0.

o_groups = self.args.o_groups
for layer_idx in range(n_layers):
    prefix = f"model.layers.{layer_idx}.attn.wo_a"
    for suffix in ("weight", "scales", "biases"):
        key0 = f"{prefix}.0.{suffix}"
        if key0 in weights:
            parts = [weights.pop(f"{prefix}.{g}.{suffix}")
                     for g in range(o_groups)]
            weights[f"{prefix}.{suffix}"] = mx.concatenate(parts, axis=0)

3. Quantized embed / head .scales and .biases

The top-level remap only renames embed.weightmodel.embed_tokens.weight and head.weightlm_head.weight. The 4-bit quant also ships embed.{scales,biases} and head.{scales,biases}. Without renaming those, nn.quantize's class_predicate (f"{p}.scales" in weights) misses and embed_tokens / lm_head stay as plain nn.Embedding / nn.Linear while the .weight slot ends up holding the int4-packed tensor — shape/dtype mismatch at first forward.

top_remap = {
    "embed.weight": "model.embed_tokens.weight",
    "embed.scales": "model.embed_tokens.scales",   # NEW
    "embed.biases": "model.embed_tokens.biases",   # NEW
    "norm.weight": "model.norm.weight",
    "head.weight": "lm_head.weight",
    "head.scales": "lm_head.scales",               # NEW
    "head.biases": "lm_head.biases",               # NEW
    ...
}

Reference

Fork carrying all three fixes: adurham/mlx-lm commits 223604e, 6ee9898, 15de79d. Generates coherent tokens at temperature=1.0, top_p=1.0 on 2-rank TP after these patches.

Side note: _grouped_output_projection reaches into wo_a.weight directly, which bypasses any shard_linear wrapper — sharding heads via wq_b all-to-sharded then makes the manual reshape see half the per-group input dim and crashes inside mx.quantized_matmul. That's downstream of this PR (a host-runtime concern), but if you'd accept it, calling self.wo_a(...) instead of accessing .weight directly would let multi-node TP shard heads cleanly. Happy to open a separate PR with these as one diff if useful.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants