Skip to content

Commit f18526f

Browse files
author
Awni Hannun
authored
DSV3 MLA (#839)
* mla * try to speed up prefill * update dsv32 as well
1 parent 25a4c83 commit f18526f

5 files changed

Lines changed: 291 additions & 135 deletions

File tree

mlx_lm/models/deepseek_v3.py

Lines changed: 79 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from .activations import swiglu
1313
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
14+
from .mla import MultiLinear
1415
from .pipeline import PipelineMixin
1516
from .rope_utils import initialize_rope
1617
from .switch_layers import SwitchGLU
@@ -85,11 +86,11 @@ def __init__(self, config: ModelArgs):
8586
bias=config.attention_bias,
8687
)
8788
self.kv_a_layernorm = nn.RMSNorm(self.kv_lora_rank, eps=1e-6)
88-
self.kv_b_proj = nn.Linear(
89-
self.kv_lora_rank,
90-
self.num_heads
91-
* (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim),
92-
bias=False,
89+
self.embed_q = MultiLinear(
90+
self.qk_nope_head_dim, self.kv_lora_rank, self.num_heads
91+
)
92+
self.unembed_out = MultiLinear(
93+
self.kv_lora_rank, self.v_head_dim, self.num_heads
9394
)
9495

9596
self.o_proj = nn.Linear(
@@ -132,29 +133,38 @@ def __call__(
132133
compressed_kv = self.kv_a_proj_with_mqa(x)
133134
compressed_kv, k_pe = mx.split(compressed_kv, [self.kv_lora_rank], axis=-1)
134135
k_pe = k_pe.reshape(B, L, 1, self.qk_rope_head_dim).transpose(0, 2, 1, 3)
135-
kv = self.kv_b_proj(self.kv_a_layernorm(compressed_kv))
136-
kv = kv.reshape(B, L, self.num_heads, -1).transpose(0, 2, 1, 3)
136+
kv_latent = self.kv_a_layernorm(compressed_kv)
137+
138+
offset = cache.offset if cache is not None else 0
139+
q_pe = self.rope(q_pe, offset)
140+
k_pe = self.rope(k_pe, offset)
137141

138-
k_nope, values = mx.split(kv, [self.qk_nope_head_dim], axis=-1)
142+
kv_latent = mx.expand_dims(kv_latent, axis=1)
139143

140144
if cache is not None:
141-
q_pe = self.rope(q_pe, cache.offset)
142-
k_pe = self.rope(k_pe, cache.offset)
143-
k_pe = mx.repeat(k_pe, self.num_heads, axis=1)
144-
keys, values = cache.update_and_fetch(
145-
mx.concatenate([k_nope, k_pe], axis=-1), values
145+
kv_latent, k_pe = cache.update_and_fetch(kv_latent, k_pe)
146+
147+
pe_scores = (q_pe * self.scale) @ k_pe.swapaxes(-1, -2)
148+
if mask is not None:
149+
pe_scores = mx.where(
150+
mask,
151+
pe_scores,
152+
mx.array(mx.finfo(pe_scores.dtype).min, pe_scores.dtype),
146153
)
147-
else:
148-
q_pe = self.rope(q_pe)
149-
k_pe = self.rope(k_pe)
150-
k_pe = mx.repeat(k_pe, self.num_heads, axis=1)
151-
keys = mx.concatenate([k_nope, k_pe], axis=-1)
152154

153-
queries = mx.concatenate([q_nope, q_pe], axis=-1)
155+
if L == 1:
156+
q_nope = self.embed_q(q_nope)
157+
k = v = kv_latent
158+
else:
159+
k = self.embed_q(kv_latent, transpose=False)
160+
v = self.unembed_out(kv_latent)
154161

155162
output = scaled_dot_product_attention(
156-
queries, keys, values, cache=cache, scale=self.scale, mask=mask
163+
q_nope, k, v, cache=cache, scale=self.scale, mask=pe_scores
157164
)
165+
if L == 1:
166+
output = self.unembed_out(output)
167+
158168
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
159169
return self.o_proj(output)
160170

@@ -329,7 +339,7 @@ def __call__(
329339

330340
if cache is None:
331341
cache = [None] * len(self.pipeline_layers)
332-
mask = create_attention_mask(h, cache[0])
342+
mask = create_attention_mask(h, cache[0], return_array=True)
333343

334344
# Receive from the previous process in the pipeline
335345
if pipeline_rank < pipeline_size - 1:
@@ -423,6 +433,42 @@ def dequant(weight, scale_inv):
423433
for e in range(self.args.n_routed_experts)
424434
]
425435
weights[f"{prefix}.mlp.switch_mlp.{m}.{k}"] = mx.stack(to_join)
436+
prefix = f"model.layers.{l}.self_attn"
437+
if f"{prefix}.kv_b_proj.weight" in weights:
438+
layer = self.model.layers[l].self_attn.embed_q
439+
quantized = f"{prefix}.kv_b_proj.scales" in weights
440+
v = weights.pop(f"{prefix}.kv_b_proj.weight")
441+
head_dim = self.args.qk_nope_head_dim + self.args.v_head_dim
442+
443+
if quantized:
444+
dims = self.args.kv_lora_rank
445+
scales = weights.pop(f"{prefix}.kv_b_proj.scales")
446+
biases = weights.pop(f"{prefix}.kv_b_proj.biases")
447+
# Try to infer bits and group size
448+
bits = (v.shape[-1] * 32) // dims
449+
group_size = dims // scales.shape[-1]
450+
v = mx.dequantize(
451+
v, scales, biases, bits=bits, group_size=group_size
452+
)
453+
num_heads = self.args.num_attention_heads
454+
v = v.reshape(num_heads, head_dim, -1)
455+
wk = mx.contiguous(
456+
v[:, : self.args.qk_nope_head_dim, :].swapaxes(-1, -2)
457+
)
458+
wv = mx.contiguous(v[:, self.args.qk_nope_head_dim :, :])
459+
if quantized:
460+
wk, wk_scales, wk_biases = mx.quantize(
461+
wk, bits=bits, group_size=group_size
462+
)
463+
wv, wv_scales, wv_biases = mx.quantize(
464+
wv, bits=bits, group_size=group_size
465+
)
466+
weights[f"{prefix}.embed_q.scales"] = wk_scales
467+
weights[f"{prefix}.unembed_out.scales"] = wv_scales
468+
weights[f"{prefix}.embed_q.biases"] = wk_biases
469+
weights[f"{prefix}.unembed_out.biases"] = wv_biases
470+
weights[f"{prefix}.embed_q.weight"] = wk
471+
weights[f"{prefix}.unembed_out.weight"] = wv
426472

427473
# Remove multi-token prediction layer and any unused precomputed rotary freqs
428474
return {
@@ -434,6 +480,7 @@ def dequant(weight, scale_inv):
434480
def shard(self, group: Optional[mx.distributed.Group] = None):
435481
group = group or mx.distributed.init()
436482
N = group.size()
483+
rank = group.rank()
437484
for layer in self.model.layers:
438485
# Shard the self attention
439486
if layer.self_attn.q_lora_rank is None:
@@ -444,13 +491,20 @@ def shard(self, group: Optional[mx.distributed.Group] = None):
444491
layer.self_attn.q_b_proj = shard_linear(
445492
layer.self_attn.q_b_proj, "all-to-sharded", group=group
446493
)
447-
layer.self_attn.kv_b_proj = shard_linear(
448-
layer.self_attn.kv_b_proj, "all-to-sharded", group=group
449-
)
494+
layer.self_attn.num_heads //= N
495+
num_heads = layer.self_attn.num_heads
496+
sh = rank * num_heads
497+
eh = sh + num_heads
498+
499+
def shard_heads(w):
500+
return w[sh:eh]
501+
502+
layer.self_attn.embed_q.apply(shard_heads)
503+
layer.self_attn.unembed_out.apply(shard_heads)
504+
450505
layer.self_attn.o_proj = shard_linear(
451506
layer.self_attn.o_proj, "sharded-to-all", group=group
452507
)
453-
layer.self_attn.num_heads //= N
454508

455509
# Shard the MLP
456510
if isinstance(layer.mlp, DeepseekV3MLP):

mlx_lm/models/deepseek_v32.py

Lines changed: 78 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from .activations import swiglu
1212
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
1313
from .cache import CacheList, KVCache
14+
from .mla import MultiLinear
1415
from .rope_utils import initialize_rope
1516
from .switch_layers import SwitchGLU
1617

@@ -147,11 +148,11 @@ def __init__(self, config: ModelArgs):
147148
bias=config.attention_bias,
148149
)
149150
self.kv_a_layernorm = nn.RMSNorm(self.kv_lora_rank, eps=1e-6)
150-
self.kv_b_proj = nn.Linear(
151-
self.kv_lora_rank,
152-
self.num_heads
153-
* (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim),
154-
bias=False,
151+
self.embed_q = MultiLinear(
152+
self.qk_nope_head_dim, self.kv_lora_rank, self.num_heads
153+
)
154+
self.unembed_out = MultiLinear(
155+
self.kv_lora_rank, self.v_head_dim, self.num_heads
155156
)
156157

157158
self.o_proj = nn.Linear(
@@ -193,26 +194,19 @@ def __call__(
193194
compressed_kv = self.kv_a_proj_with_mqa(x)
194195
compressed_kv, k_pe = mx.split(compressed_kv, [self.kv_lora_rank], axis=-1)
195196
k_pe = k_pe.reshape(B, L, 1, self.qk_rope_head_dim).transpose(0, 2, 1, 3)
196-
kv = self.kv_b_proj(self.kv_a_layernorm(compressed_kv))
197-
kv = kv.reshape(B, L, self.num_heads, -1).transpose(0, 2, 1, 3)
197+
kv_latent = self.kv_a_layernorm(compressed_kv)
198+
199+
offset = cache[0].offset if cache is not None else 0
200+
q_pe = self.rope(q_pe, offset)
201+
k_pe = self.rope(k_pe, offset)
198202

199-
k_nope, values = mx.split(kv, [self.qk_nope_head_dim], axis=-1)
203+
kv_latent = mx.expand_dims(kv_latent, axis=1)
200204

201205
if cache is not None:
202-
q_pe = self.rope(q_pe, cache[0].offset)
203-
k_pe = self.rope(k_pe, cache[0].offset)
204-
k_pe = mx.repeat(k_pe, self.num_heads, axis=1)
205-
keys, values = cache[0].update_and_fetch(
206-
mx.concatenate([k_nope, k_pe], axis=-1), values
207-
)
206+
kv_latent, k_pe = cache[0].update_and_fetch(kv_latent, k_pe)
208207
else:
209208
cache = [None] * 2
210-
q_pe = self.rope(q_pe)
211-
k_pe = self.rope(k_pe)
212-
k_pe = mx.repeat(k_pe, self.num_heads, axis=1)
213-
keys = mx.concatenate([k_nope, k_pe], axis=-1)
214209

215-
queries = mx.concatenate([q_nope, q_pe], axis=-1)
216210
topk_indices = self.indexer(x, qr, mask, cache=cache[1])
217211
if topk_indices is not None:
218212
shape = list(topk_indices.shape)
@@ -229,9 +223,27 @@ def __call__(
229223
if cache is not None and cache[0] is not None:
230224
cache[0].keys = mx.depends(cache[0].keys, (cache[1].keys, cache[1].values))
231225

226+
pe_scores = (q_pe * self.scale) @ k_pe.swapaxes(-1, -2)
227+
if mask is not None:
228+
pe_scores = mx.where(
229+
mask,
230+
pe_scores,
231+
mx.array(mx.finfo(pe_scores.dtype).min, pe_scores.dtype),
232+
)
233+
234+
if L == 1:
235+
q_nope = self.embed_q(q_nope)
236+
k = v = kv_latent
237+
else:
238+
k = self.embed_q(kv_latent, transpose=False)
239+
v = self.unembed_out(kv_latent)
240+
232241
output = scaled_dot_product_attention(
233-
queries, keys, values, cache=cache[0], scale=self.scale, mask=mask
242+
q_nope, k, v, cache=cache, scale=self.scale, mask=pe_scores
234243
)
244+
if L == 1:
245+
output = self.unembed_out(output)
246+
235247
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
236248
return self.o_proj(output)
237249

@@ -509,6 +521,41 @@ def dequant(weight, scale_inv):
509521
for e in range(self.args.n_routed_experts)
510522
]
511523
weights[f"{prefix}.mlp.switch_mlp.{m}.{k}"] = mx.stack(to_join)
524+
if f"{prefix}.kv_b_proj.weight" in weights:
525+
layer = self.model.layers[l].self_attn.embed_q
526+
quantized = f"{prefix}.kv_b_proj.scales" in weights
527+
v = weights.pop(f"{prefix}.kv_b_proj.weight")
528+
head_dim = self.args.qk_nope_head_dim + self.args.v_head_dim
529+
530+
if quantized:
531+
dims = self.args.kv_lora_rank
532+
scales = weights.pop(f"{prefix}.kv_b_proj.scales")
533+
biases = weights.pop(f"{prefix}.kv_b_proj.biases")
534+
# Try to infer bits and group size
535+
bits = (v.shape[-1] * 32) // dims
536+
group_size = dims // scales.shape[-1]
537+
v = mx.dequantize(
538+
v, scales, biases, bits=bits, group_size=group_size
539+
)
540+
num_heads = self.args.num_attention_heads
541+
v = v.reshape(num_heads, head_dim, -1)
542+
wk = mx.contiguous(
543+
v[:, : self.args.qk_nope_head_dim, :].swapaxes(-1, -2)
544+
)
545+
wv = mx.contiguous(v[:, self.args.qk_nope_head_dim :, :])
546+
if quantized:
547+
wk, wk_scales, wk_biases = mx.quantize(
548+
wk, bits=bits, group_size=group_size
549+
)
550+
wv, wv_scales, wv_biases = mx.quantize(
551+
wv, bits=bits, group_size=group_size
552+
)
553+
weights[f"{prefix}.embed_q.scales"] = wk_scales
554+
weights[f"{prefix}.unembed_out.scales"] = wv_scales
555+
weights[f"{prefix}.embed_q.biases"] = wk_biases
556+
weights[f"{prefix}.unembed_out.biases"] = wv_biases
557+
weights[f"{prefix}.embed_q.weight"] = wk
558+
weights[f"{prefix}.unembed_out.weight"] = wv
512559

513560
# Remove multi-token prediction layer and any unused precomputed rotary freqs
514561
return {
@@ -520,17 +567,25 @@ def dequant(weight, scale_inv):
520567
def shard(self, group: Optional[mx.distributed.Group] = None):
521568
group = group or mx.distributed.init()
522569
N = group.size()
570+
rank = group.rank()
523571
for layer in self.model.layers:
524572
layer.self_attn.q_b_proj = shard_linear(
525573
layer.self_attn.q_b_proj, "all-to-sharded", group=group
526574
)
527-
layer.self_attn.kv_b_proj = shard_linear(
528-
layer.self_attn.kv_b_proj, "all-to-sharded", group=group
529-
)
575+
530576
layer.self_attn.o_proj = shard_linear(
531577
layer.self_attn.o_proj, "sharded-to-all", group=group
532578
)
533579
layer.self_attn.num_heads //= N
580+
num_heads = layer.self_attn.num_heads
581+
sh = rank * num_heads
582+
eh = sh + num_heads
583+
584+
def shard_heads(w):
585+
return w[sh:eh]
586+
587+
layer.self_attn.embed_q.apply(shard_heads)
588+
layer.self_attn.unembed_out.apply(shard_heads)
534589

535590
# Shard the MLP
536591
if isinstance(layer.mlp, DeepseekV32MLP):

0 commit comments

Comments
 (0)