1111
1212from .activations import swiglu
1313from .base import BaseModelArgs , create_attention_mask , scaled_dot_product_attention
14+ from .mla import MultiLinear
1415from .pipeline import PipelineMixin
1516from .rope_utils import initialize_rope
1617from .switch_layers import SwitchGLU
@@ -85,11 +86,11 @@ def __init__(self, config: ModelArgs):
8586 bias = config .attention_bias ,
8687 )
8788 self .kv_a_layernorm = nn .RMSNorm (self .kv_lora_rank , eps = 1e-6 )
88- self .kv_b_proj = nn . Linear (
89- self .kv_lora_rank ,
90- self . num_heads
91- * ( self .q_head_dim - self . qk_rope_head_dim + self . v_head_dim ),
92- bias = False ,
89+ self .embed_q = MultiLinear (
90+ self .qk_nope_head_dim , self . kv_lora_rank , self . num_heads
91+ )
92+ self .unembed_out = MultiLinear (
93+ self . kv_lora_rank , self . v_head_dim , self . num_heads
9394 )
9495
9596 self .o_proj = nn .Linear (
@@ -132,29 +133,38 @@ def __call__(
132133 compressed_kv = self .kv_a_proj_with_mqa (x )
133134 compressed_kv , k_pe = mx .split (compressed_kv , [self .kv_lora_rank ], axis = - 1 )
134135 k_pe = k_pe .reshape (B , L , 1 , self .qk_rope_head_dim ).transpose (0 , 2 , 1 , 3 )
135- kv = self .kv_b_proj (self .kv_a_layernorm (compressed_kv ))
136- kv = kv .reshape (B , L , self .num_heads , - 1 ).transpose (0 , 2 , 1 , 3 )
136+ kv_latent = self .kv_a_layernorm (compressed_kv )
137+
138+ offset = cache .offset if cache is not None else 0
139+ q_pe = self .rope (q_pe , offset )
140+ k_pe = self .rope (k_pe , offset )
137141
138- k_nope , values = mx .split ( kv , [ self . qk_nope_head_dim ], axis = - 1 )
142+ kv_latent = mx .expand_dims ( kv_latent , axis = 1 )
139143
140144 if cache is not None :
141- q_pe = self .rope (q_pe , cache .offset )
142- k_pe = self .rope (k_pe , cache .offset )
143- k_pe = mx .repeat (k_pe , self .num_heads , axis = 1 )
144- keys , values = cache .update_and_fetch (
145- mx .concatenate ([k_nope , k_pe ], axis = - 1 ), values
145+ kv_latent , k_pe = cache .update_and_fetch (kv_latent , k_pe )
146+
147+ pe_scores = (q_pe * self .scale ) @ k_pe .swapaxes (- 1 , - 2 )
148+ if mask is not None :
149+ pe_scores = mx .where (
150+ mask ,
151+ pe_scores ,
152+ mx .array (mx .finfo (pe_scores .dtype ).min , pe_scores .dtype ),
146153 )
147- else :
148- q_pe = self .rope (q_pe )
149- k_pe = self .rope (k_pe )
150- k_pe = mx .repeat (k_pe , self .num_heads , axis = 1 )
151- keys = mx .concatenate ([k_nope , k_pe ], axis = - 1 )
152154
153- queries = mx .concatenate ([q_nope , q_pe ], axis = - 1 )
155+ if L == 1 :
156+ q_nope = self .embed_q (q_nope )
157+ k = v = kv_latent
158+ else :
159+ k = self .embed_q (kv_latent , transpose = False )
160+ v = self .unembed_out (kv_latent )
154161
155162 output = scaled_dot_product_attention (
156- queries , keys , values , cache = cache , scale = self .scale , mask = mask
163+ q_nope , k , v , cache = cache , scale = self .scale , mask = pe_scores
157164 )
165+ if L == 1 :
166+ output = self .unembed_out (output )
167+
158168 output = output .transpose (0 , 2 , 1 , 3 ).reshape (B , L , - 1 )
159169 return self .o_proj (output )
160170
@@ -329,7 +339,7 @@ def __call__(
329339
330340 if cache is None :
331341 cache = [None ] * len (self .pipeline_layers )
332- mask = create_attention_mask (h , cache [0 ])
342+ mask = create_attention_mask (h , cache [0 ], return_array = True )
333343
334344 # Receive from the previous process in the pipeline
335345 if pipeline_rank < pipeline_size - 1 :
@@ -423,6 +433,42 @@ def dequant(weight, scale_inv):
423433 for e in range (self .args .n_routed_experts )
424434 ]
425435 weights [f"{ prefix } .mlp.switch_mlp.{ m } .{ k } " ] = mx .stack (to_join )
436+ prefix = f"model.layers.{ l } .self_attn"
437+ if f"{ prefix } .kv_b_proj.weight" in weights :
438+ layer = self .model .layers [l ].self_attn .embed_q
439+ quantized = f"{ prefix } .kv_b_proj.scales" in weights
440+ v = weights .pop (f"{ prefix } .kv_b_proj.weight" )
441+ head_dim = self .args .qk_nope_head_dim + self .args .v_head_dim
442+
443+ if quantized :
444+ dims = self .args .kv_lora_rank
445+ scales = weights .pop (f"{ prefix } .kv_b_proj.scales" )
446+ biases = weights .pop (f"{ prefix } .kv_b_proj.biases" )
447+ # Try to infer bits and group size
448+ bits = (v .shape [- 1 ] * 32 ) // dims
449+ group_size = dims // scales .shape [- 1 ]
450+ v = mx .dequantize (
451+ v , scales , biases , bits = bits , group_size = group_size
452+ )
453+ num_heads = self .args .num_attention_heads
454+ v = v .reshape (num_heads , head_dim , - 1 )
455+ wk = mx .contiguous (
456+ v [:, : self .args .qk_nope_head_dim , :].swapaxes (- 1 , - 2 )
457+ )
458+ wv = mx .contiguous (v [:, self .args .qk_nope_head_dim :, :])
459+ if quantized :
460+ wk , wk_scales , wk_biases = mx .quantize (
461+ wk , bits = bits , group_size = group_size
462+ )
463+ wv , wv_scales , wv_biases = mx .quantize (
464+ wv , bits = bits , group_size = group_size
465+ )
466+ weights [f"{ prefix } .embed_q.scales" ] = wk_scales
467+ weights [f"{ prefix } .unembed_out.scales" ] = wv_scales
468+ weights [f"{ prefix } .embed_q.biases" ] = wk_biases
469+ weights [f"{ prefix } .unembed_out.biases" ] = wv_biases
470+ weights [f"{ prefix } .embed_q.weight" ] = wk
471+ weights [f"{ prefix } .unembed_out.weight" ] = wv
426472
427473 # Remove multi-token prediction layer and any unused precomputed rotary freqs
428474 return {
@@ -434,6 +480,7 @@ def dequant(weight, scale_inv):
434480 def shard (self , group : Optional [mx .distributed .Group ] = None ):
435481 group = group or mx .distributed .init ()
436482 N = group .size ()
483+ rank = group .rank ()
437484 for layer in self .model .layers :
438485 # Shard the self attention
439486 if layer .self_attn .q_lora_rank is None :
@@ -444,13 +491,20 @@ def shard(self, group: Optional[mx.distributed.Group] = None):
444491 layer .self_attn .q_b_proj = shard_linear (
445492 layer .self_attn .q_b_proj , "all-to-sharded" , group = group
446493 )
447- layer .self_attn .kv_b_proj = shard_linear (
448- layer .self_attn .kv_b_proj , "all-to-sharded" , group = group
449- )
494+ layer .self_attn .num_heads //= N
495+ num_heads = layer .self_attn .num_heads
496+ sh = rank * num_heads
497+ eh = sh + num_heads
498+
499+ def shard_heads (w ):
500+ return w [sh :eh ]
501+
502+ layer .self_attn .embed_q .apply (shard_heads )
503+ layer .self_attn .unembed_out .apply (shard_heads )
504+
450505 layer .self_attn .o_proj = shard_linear (
451506 layer .self_attn .o_proj , "sharded-to-all" , group = group
452507 )
453- layer .self_attn .num_heads //= N
454508
455509 # Shard the MLP
456510 if isinstance (layer .mlp , DeepseekV3MLP ):
0 commit comments