@@ -159,8 +159,15 @@ def _process_chunk(
159159 k = inv_scale * mx .fast .rms_norm (k , None , 1e-6 )
160160
161161 out , new_ssm_state = gated_delta_update (
162- q , k , v , a_chunk , b_chunk ,
163- self .A_log , self .dt_bias , ssm_state , ssm_mask ,
162+ q ,
163+ k ,
164+ v ,
165+ a_chunk ,
166+ b_chunk ,
167+ self .A_log ,
168+ self .dt_bias ,
169+ ssm_state ,
170+ ssm_mask ,
164171 use_kernel = not self .training ,
165172 )
166173 return out , new_conv_state , new_ssm_state
@@ -185,7 +192,9 @@ def __call__(
185192 conv_state = (
186193 cache [0 ]
187194 if cache is not None and cache [0 ] is not None
188- else mx .zeros ((B , self .conv_kernel_size - 1 , self .conv_dim ), dtype = inputs .dtype )
195+ else mx .zeros (
196+ (B , self .conv_kernel_size - 1 , self .conv_dim ), dtype = inputs .dtype
197+ )
189198 )
190199 ssm_state = cache [1 ] if cache else None
191200
@@ -198,18 +207,28 @@ def __call__(
198207 mask_c = mask [:, :n_confirmed ] if mask is not None else None
199208 mask_d = mask [:, n_confirmed :] if mask is not None else None
200209 out_c , conv_c , ssm_c = self ._process_chunk (
201- qkv [:, :n_confirmed ], a [:, :n_confirmed ], b [:, :n_confirmed ],
202- conv_state , ssm_state , mask_c ,
210+ qkv [:, :n_confirmed ],
211+ a [:, :n_confirmed ],
212+ b [:, :n_confirmed ],
213+ conv_state ,
214+ ssm_state ,
215+ mask_c ,
203216 )
204217 if cache is not None :
205218 cache .rollback_state = (conv_c , ssm_c )
206219 out_d , conv_f , ssm_f = self ._process_chunk (
207- qkv [:, n_confirmed :], a [:, n_confirmed :], b [:, n_confirmed :],
208- conv_c , ssm_c , mask_d ,
220+ qkv [:, n_confirmed :],
221+ a [:, n_confirmed :],
222+ b [:, n_confirmed :],
223+ conv_c ,
224+ ssm_c ,
225+ mask_d ,
209226 )
210227 out = mx .concatenate ([out_c , out_d ], axis = 1 )
211228 else :
212- out , conv_f , ssm_f = self ._process_chunk (qkv , a , b , conv_state , ssm_state , mask )
229+ out , conv_f , ssm_f = self ._process_chunk (
230+ qkv , a , b , conv_state , ssm_state , mask
231+ )
213232
214233 if cache is not None :
215234 cache [0 ] = conv_f
@@ -251,7 +270,9 @@ def __call__(
251270 n_confirmed : int = 0 ,
252271 ) -> mx .array :
253272 if self .is_linear :
254- r = self .linear_attn (self .input_layernorm (x ), mask , cache , n_confirmed = n_confirmed )
273+ r = self .linear_attn (
274+ self .input_layernorm (x ), mask , cache , n_confirmed = n_confirmed
275+ )
255276 else :
256277 r = self .self_attn (self .input_layernorm (x ), mask , cache )
257278 h = x + r
@@ -266,7 +287,9 @@ def __init__(self, args: TextModelArgs):
266287 super ().__init__ ()
267288 self .self_attn = Attention (args )
268289 self .input_layernorm = nn .RMSNorm (args .hidden_size , eps = args .rms_norm_eps )
269- self .post_attention_layernorm = nn .RMSNorm (args .hidden_size , eps = args .rms_norm_eps )
290+ self .post_attention_layernorm = nn .RMSNorm (
291+ args .hidden_size , eps = args .rms_norm_eps
292+ )
270293 if args .num_experts > 0 :
271294 self .mlp = SparseMoeBlock (args )
272295 else :
@@ -295,9 +318,7 @@ def __init__(self, args: TextModelArgs):
295318 self .pre_fc_norm_hidden = nn .RMSNorm (args .hidden_size , eps = args .rms_norm_eps )
296319 self .pre_fc_norm_embedding = nn .RMSNorm (args .hidden_size , eps = args .rms_norm_eps )
297320 self .fc = nn .Linear (args .hidden_size * 2 , args .hidden_size , bias = False )
298- self .layers = [
299- MTPDecoderLayer (args ) for _ in range (args .mtp_num_hidden_layers )
300- ]
321+ self .layers = [MTPDecoderLayer (args ) for _ in range (args .mtp_num_hidden_layers )]
301322 self .norm = nn .RMSNorm (args .hidden_size , eps = args .rms_norm_eps )
302323
303324 def __call__ (
@@ -355,7 +376,11 @@ def __call__(
355376
356377 for layer , c in zip (self .layers , cache ):
357378 mask = ssm_mask if layer .is_linear else fa_mask
358- kw = {"n_confirmed" : n_confirmed } if layer .is_linear and n_confirmed > 0 else {}
379+ kw = (
380+ {"n_confirmed" : n_confirmed }
381+ if layer .is_linear and n_confirmed > 0
382+ else {}
383+ )
359384 hidden_states = layer (hidden_states , mask = mask , cache = c , ** kw )
360385
361386 return hidden_states
@@ -380,7 +405,9 @@ def __call__(
380405 return_hidden : bool = False ,
381406 n_confirmed : int = 0 ,
382407 ) -> mx .array :
383- hidden = self .model (inputs , cache , input_embeddings = input_embeddings , n_confirmed = n_confirmed )
408+ hidden = self .model (
409+ inputs , cache , input_embeddings = input_embeddings , n_confirmed = n_confirmed
410+ )
384411 normed = self .model .norm (hidden )
385412 if self .args .tie_word_embeddings :
386413 out = self .model .embed_tokens .as_linear (normed )
0 commit comments