Skip to content

Commit fb12ea6

Browse files
committed
style: apply black and isort formatting
1 parent 29e1ee2 commit fb12ea6

3 files changed

Lines changed: 57 additions & 22 deletions

File tree

mlx_lm/generate.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33
import argparse
44
import contextlib
55
import functools
6-
import warnings
76
import json
87
import sys
98
import time
9+
import warnings
1010
from dataclasses import dataclass
1111
from functools import partial
1212
from typing import (
@@ -710,7 +710,9 @@ def _process_and_sample(tokens, logits):
710710
def _step_backbone(y, n_predict=1, n_confirmed=0):
711711
"""Run the backbone on ``y`` and return (tokens, logprobs, hidden)."""
712712
with mx.stream(generation_stream):
713-
logits, hidden = model(y[None], cache=model_cache, return_hidden=True, n_confirmed=n_confirmed)
713+
logits, hidden = model(
714+
y[None], cache=model_cache, return_hidden=True, n_confirmed=n_confirmed
715+
)
714716
logits = logits[:, -n_predict:, :]
715717
quantize_cache_fn(model_cache)
716718
nonlocal prev_tokens
@@ -778,11 +780,13 @@ def _prefill(y):
778780
y_with_draft = mx.concatenate(
779781
[y, mx.array([draft_tok.item()], mx.uint32)]
780782
)
781-
toks, lps, hidden = _step_backbone(y_with_draft, n_predict=2, n_confirmed=1)
783+
toks, lps, hidden = _step_backbone(
784+
y_with_draft, n_predict=2, n_confirmed=1
785+
)
782786
mx.eval(toks, draft_tok)
783787

784-
verify_pred = toks[0] # backbone prediction after y → verify draft
785-
bonus_tok = toks[1] # backbone prediction after draft_tok
788+
verify_pred = toks[0] # backbone prediction after y → verify draft
789+
bonus_tok = toks[1] # backbone prediction after draft_tok
786790
verify_lp = lps[0]
787791
bonus_lp = lps[1]
788792

@@ -812,7 +816,10 @@ def _prefill(y):
812816
# by GatedDeltaNet after the confirmed token.
813817
# Attention layers (KVCache): trim the draft-token entry.
814818
for c in model_cache:
815-
if hasattr(c, "rollback_state") and c.rollback_state is not None:
819+
if (
820+
hasattr(c, "rollback_state")
821+
and c.rollback_state is not None
822+
):
816823
conv_snap, ssm_snap = c.rollback_state
817824
c[0] = conv_snap
818825
c[1] = ssm_snap

mlx_lm/models/qwen3_5.py

Lines changed: 42 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -159,8 +159,15 @@ def _process_chunk(
159159
k = inv_scale * mx.fast.rms_norm(k, None, 1e-6)
160160

161161
out, new_ssm_state = gated_delta_update(
162-
q, k, v, a_chunk, b_chunk,
163-
self.A_log, self.dt_bias, ssm_state, ssm_mask,
162+
q,
163+
k,
164+
v,
165+
a_chunk,
166+
b_chunk,
167+
self.A_log,
168+
self.dt_bias,
169+
ssm_state,
170+
ssm_mask,
164171
use_kernel=not self.training,
165172
)
166173
return out, new_conv_state, new_ssm_state
@@ -185,7 +192,9 @@ def __call__(
185192
conv_state = (
186193
cache[0]
187194
if cache is not None and cache[0] is not None
188-
else mx.zeros((B, self.conv_kernel_size - 1, self.conv_dim), dtype=inputs.dtype)
195+
else mx.zeros(
196+
(B, self.conv_kernel_size - 1, self.conv_dim), dtype=inputs.dtype
197+
)
189198
)
190199
ssm_state = cache[1] if cache else None
191200

@@ -198,18 +207,28 @@ def __call__(
198207
mask_c = mask[:, :n_confirmed] if mask is not None else None
199208
mask_d = mask[:, n_confirmed:] if mask is not None else None
200209
out_c, conv_c, ssm_c = self._process_chunk(
201-
qkv[:, :n_confirmed], a[:, :n_confirmed], b[:, :n_confirmed],
202-
conv_state, ssm_state, mask_c,
210+
qkv[:, :n_confirmed],
211+
a[:, :n_confirmed],
212+
b[:, :n_confirmed],
213+
conv_state,
214+
ssm_state,
215+
mask_c,
203216
)
204217
if cache is not None:
205218
cache.rollback_state = (conv_c, ssm_c)
206219
out_d, conv_f, ssm_f = self._process_chunk(
207-
qkv[:, n_confirmed:], a[:, n_confirmed:], b[:, n_confirmed:],
208-
conv_c, ssm_c, mask_d,
220+
qkv[:, n_confirmed:],
221+
a[:, n_confirmed:],
222+
b[:, n_confirmed:],
223+
conv_c,
224+
ssm_c,
225+
mask_d,
209226
)
210227
out = mx.concatenate([out_c, out_d], axis=1)
211228
else:
212-
out, conv_f, ssm_f = self._process_chunk(qkv, a, b, conv_state, ssm_state, mask)
229+
out, conv_f, ssm_f = self._process_chunk(
230+
qkv, a, b, conv_state, ssm_state, mask
231+
)
213232

214233
if cache is not None:
215234
cache[0] = conv_f
@@ -251,7 +270,9 @@ def __call__(
251270
n_confirmed: int = 0,
252271
) -> mx.array:
253272
if self.is_linear:
254-
r = self.linear_attn(self.input_layernorm(x), mask, cache, n_confirmed=n_confirmed)
273+
r = self.linear_attn(
274+
self.input_layernorm(x), mask, cache, n_confirmed=n_confirmed
275+
)
255276
else:
256277
r = self.self_attn(self.input_layernorm(x), mask, cache)
257278
h = x + r
@@ -266,7 +287,9 @@ def __init__(self, args: TextModelArgs):
266287
super().__init__()
267288
self.self_attn = Attention(args)
268289
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
269-
self.post_attention_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
290+
self.post_attention_layernorm = nn.RMSNorm(
291+
args.hidden_size, eps=args.rms_norm_eps
292+
)
270293
if args.num_experts > 0:
271294
self.mlp = SparseMoeBlock(args)
272295
else:
@@ -295,9 +318,7 @@ def __init__(self, args: TextModelArgs):
295318
self.pre_fc_norm_hidden = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
296319
self.pre_fc_norm_embedding = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
297320
self.fc = nn.Linear(args.hidden_size * 2, args.hidden_size, bias=False)
298-
self.layers = [
299-
MTPDecoderLayer(args) for _ in range(args.mtp_num_hidden_layers)
300-
]
321+
self.layers = [MTPDecoderLayer(args) for _ in range(args.mtp_num_hidden_layers)]
301322
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
302323

303324
def __call__(
@@ -355,7 +376,11 @@ def __call__(
355376

356377
for layer, c in zip(self.layers, cache):
357378
mask = ssm_mask if layer.is_linear else fa_mask
358-
kw = {"n_confirmed": n_confirmed} if layer.is_linear and n_confirmed > 0 else {}
379+
kw = (
380+
{"n_confirmed": n_confirmed}
381+
if layer.is_linear and n_confirmed > 0
382+
else {}
383+
)
359384
hidden_states = layer(hidden_states, mask=mask, cache=c, **kw)
360385

361386
return hidden_states
@@ -380,7 +405,9 @@ def __call__(
380405
return_hidden: bool = False,
381406
n_confirmed: int = 0,
382407
) -> mx.array:
383-
hidden = self.model(inputs, cache, input_embeddings=input_embeddings, n_confirmed=n_confirmed)
408+
hidden = self.model(
409+
inputs, cache, input_embeddings=input_embeddings, n_confirmed=n_confirmed
410+
)
384411
normed = self.model.norm(hidden)
385412
if self.args.tie_word_embeddings:
386413
out = self.model.embed_tokens.as_linear(normed)

tests/test_mtp.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22
import unittest
33

44
import mlx.core as mx
5-
from mlx_lm.models.cache import make_prompt_cache
5+
66
from mlx_lm.generate import generate_step, mtp_generate_step
7+
from mlx_lm.models.cache import make_prompt_cache
78

89

910
def _make_qwen3_5_mtp_model():

0 commit comments

Comments
 (0)