Skip to content

Commit 219fea7

Browse files
committed
fix qwen3.5 moe reader
1 parent 7d6ae4f commit 219fea7

3 files changed

Lines changed: 104 additions & 81 deletions

File tree

lmdeploy/turbomind/deploy/source_model/qwen.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22
import json
33
import os.path as osp
4+
import re
45

56
import torch
67

78
from ..config import RopeParam
9+
from ..loader import create_loader
810
from .base import INPUT_MODELS
911
from .llama import LlamaModel, LlamaReader
1012

@@ -383,13 +385,58 @@ def model_info(self):
383385

384386

385387
class Qwen3_5MoeReader(Qwen3_5ReaderMixin, Qwen3MoeReader):
386-
pass
388+
389+
def _unpacked_moe_expert(self, e: int, i: int, kind: str):
390+
prefix = f'{self.attn_layer_prefix}.{i}.mlp.experts'
391+
gate_up = self.params.get(f'{prefix}.gate_up_proj.{kind}')
392+
down = self.params.get(f'{prefix}.down_proj.{kind}')
393+
if gate_up is None or down is None:
394+
return None
395+
396+
# Packed Qwen3.5 MoE checkpoints store all experts in the first
397+
# dimension. Slice one expert before transform so quantized policies
398+
# still see a 2D tensor.
399+
gate_up = self.transform(gate_up[e], kind)
400+
down = self.transform(down[e], kind)
401+
gate, up = gate_up.chunk(2, dim=0)
402+
return (gate, down, up)
403+
404+
def moe_ffn_expert(self, e=None, i=None, kind=None):
405+
if not kind:
406+
return self.filter(r'experts', i)
407+
unpacked = self._unpacked_moe_expert(e, i, kind)
408+
if unpacked is not None:
409+
return unpacked
410+
411+
return super().moe_ffn_expert(e, i, kind)
387412

388413

389414
@INPUT_MODELS.register_module(name='qwen3_5-moe')
390415
class Qwen3_5MoeModel(Qwen3MoeModel):
391416
Reader = Qwen3_5MoeReader
392417

418+
@staticmethod
419+
def map_packed_qwen35_experts(name: str):
420+
"""Map packed expert names to weight names, i.e.,
421+
"mlp.experts.gate_up_proj" -> "mlp.experts.gate_up_proj.weight" so that
422+
class Weight in parameter.py can classify them."""
423+
s = re.sub(r'(mlp\.experts\.(?:gate_up|down)_proj)$', r'\1.weight', name)
424+
return s
425+
426+
def readers(self):
427+
pattern = getattr(self.Reader, 'attn_layer_pattern', self.Reader.attn_layer_patten)
428+
loader = create_loader(self.model_path, pattern, [])
429+
430+
has_packed_gate_up = any('mlp.experts.gate_up_proj' in k for k in loader.index.keys())
431+
has_packed_down = any('mlp.experts.down_proj' in k for k in loader.index.keys())
432+
if has_packed_gate_up and has_packed_down:
433+
loader.mappings = [self.map_packed_qwen35_experts]
434+
435+
for i, param in loader.items():
436+
reader = self.Reader(param, {}, False, self.model_config, policy=self.policy, fp8_quant=self.fp8_quant)
437+
yield i, reader
438+
torch.cuda.empty_cache()
439+
393440
def model_info(self):
394441
if 'text_config' in self.model_config:
395442
self.model_config = self.model_config['text_config']

lmdeploy/utils.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from contextlib import contextmanager
99
from dataclasses import dataclass
1010
from logging import Logger, LogRecord
11-
from typing import List, Optional, Tuple, Union
1211

1312
import torch
1413
from transformers import PretrainedConfig
@@ -26,7 +25,7 @@ class _ASNI_COLOR:
2625

2726
# copy from: https://github.com/termcolor/termcolor
2827
@functools.cache
29-
def can_colorize(*, no_color: Optional[bool] = None, force_color: Optional[bool] = None) -> bool:
28+
def can_colorize(*, no_color: bool | None = None, force_color: bool | None = None) -> bool:
3029
"""Check env vars and for tty/dumb terminal."""
3130
import io
3231
if no_color is not None and no_color:
@@ -110,8 +109,8 @@ def filter(self, record: LogRecord) -> bool:
110109
' - %(message)s'
111110

112111

113-
def get_logger(name: Optional[str] = None,
114-
log_file: Optional[str] = None,
112+
def get_logger(name: str | None = None,
113+
log_file: str | None = None,
115114
log_level: int = logging.INFO,
116115
file_mode: str = 'a',
117116
log_formatter: str = _FORMAT) -> Logger:
@@ -178,7 +177,7 @@ def get_logger(name: Optional[str] = None,
178177
return logger
179178

180179

181-
def filter_suffix(response: str, suffixes: Optional[List[str]] = None) -> str:
180+
def filter_suffix(response: str, suffixes: list[str] | None = None) -> str:
182181
"""Filter response with suffixes.
183182
184183
Args:
@@ -197,12 +196,12 @@ def filter_suffix(response: str, suffixes: Optional[List[str]] = None) -> str:
197196

198197

199198
# TODO remove stop_word_offsets stuff and make it clean
200-
def _stop_words(stop_words: List[Union[int, str]], tokenizer: object):
199+
def _stop_words(stop_words: list[int | str], tokenizer: object):
201200
"""Return list of stop-words to numpy.ndarray."""
202201
import numpy as np
203202
if stop_words is None:
204203
return None
205-
assert isinstance(stop_words, List) and \
204+
assert isinstance(stop_words, list) and \
206205
all(isinstance(elem, (str, int)) for elem in stop_words), \
207206
f'stop_words must be a list but got {type(stop_words)}'
208207
stop_indexes = []
@@ -211,7 +210,7 @@ def _stop_words(stop_words: List[Union[int, str]], tokenizer: object):
211210
stop_indexes += tokenizer.indexes_containing_token(stop_word)
212211
elif isinstance(stop_word, int):
213212
stop_indexes.append(stop_word)
214-
assert isinstance(stop_indexes, List) and all(isinstance(elem, int) for elem in stop_indexes), 'invalid stop_words'
213+
assert isinstance(stop_indexes, list) and all(isinstance(elem, int) for elem in stop_indexes), 'invalid stop_words'
215214
# each id in stop_indexes represents a stop word
216215
# refer to https://github.com/fauxpilot/fauxpilot/discussions/165 for
217216
# detailed explanation about fastertransformer's stop_indexes
@@ -297,7 +296,7 @@ async def __tmp():
297296
# modified from https://github.com/vllm-project/vllm/blob/0650e5935b0f6af35fb2acf71769982c47b804d7/vllm/config.py#L1082-L1150 # noqa
298297
def _get_and_verify_max_len(
299298
hf_config: PretrainedConfig,
300-
max_model_len: Optional[int],
299+
max_model_len: int | None,
301300
) -> int:
302301
"""Get and verify the model's maximum length."""
303302

@@ -326,7 +325,11 @@ def _get_and_verify_max_len(
326325
]
327326
max_len_key = None
328327
for key in possible_keys:
329-
max_len = getattr(hf_config, key, None)
328+
max_len = None
329+
if hasattr(hf_config, key):
330+
max_len = getattr(hf_config, key)
331+
elif key in hf_config:
332+
max_len = hf_config[key]
330333
if max_len is not None:
331334
max_len_key = key if max_len < derived_max_model_len \
332335
else max_len_key
@@ -503,9 +506,9 @@ class FlattenedTensorBucket:
503506

504507
def __init__(
505508
self,
506-
named_tensors: List[Tuple[str, torch.Tensor]] = None,
509+
named_tensors: list[tuple[str, torch.Tensor]] | None = None,
507510
flattened_tensor: torch.Tensor = None,
508-
metadata: List[FlattenedTensorMetadata] = None,
511+
metadata: list[FlattenedTensorMetadata] | None = None,
509512
):
510513
"""Initialize a tensor bucket from a list of named tensors or from pre-
511514
flattened data.
@@ -548,11 +551,11 @@ def get_flattened_tensor(self) -> torch.Tensor:
548551
"""Get the flattened tensor containing multiple tensors."""
549552
return self.flattened_tensor
550553

551-
def get_metadata(self) -> List[FlattenedTensorMetadata]:
554+
def get_metadata(self) -> list[FlattenedTensorMetadata]:
552555
"""Get all metadatas for all tensors in the bucket."""
553556
return self.metadata
554557

555-
def reconstruct_tensors(self) -> List[Tuple[str, torch.Tensor]]:
558+
def reconstruct_tensors(self) -> list[tuple[str, torch.Tensor]]:
556559
"""Reconstruct original tensors."""
557560
# preallocate the result list
558561
reconstructed = [None] * len(self.metadata)

tests/test_lmdeploy/test_qwen3coder_parser.py

Lines changed: 39 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,10 @@
77
import shortuuid
88

99
from lmdeploy.serve.openai.api_server import VariableInterface
10-
from lmdeploy.serve.openai.protocol import (
11-
ChatCompletionRequest, ChatCompletionResponse,
12-
ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
13-
ChatCompletionStreamResponse, ChatMessage, DeltaMessage, DeltaToolCall,
14-
UsageInfo)
15-
from lmdeploy.serve.openai.tool_parser.qwen3coder_parser import (
16-
Qwen3CoderToolParser)
10+
from lmdeploy.serve.openai.protocol import (ChatCompletionRequest, ChatCompletionResponse, ChatCompletionResponseChoice,
11+
ChatCompletionResponseStreamChoice, ChatCompletionStreamResponse,
12+
ChatMessage, DeltaMessage, DeltaToolCall, UsageInfo)
13+
from lmdeploy.serve.openai.tool_parser.qwen3coder_parser import Qwen3CoderToolParser
1714

1815
TestExpects = collections.namedtuple('TestExpects', 'func_name kwargs')
1916

@@ -57,30 +54,26 @@ def encode(self, text: str) -> List[int]:
5754

5855

5956
def _chat_completion_v1(
60-
request: ChatCompletionRequest, text_sequence: List[str]
61-
) -> Union[ChatCompletionResponse, Generator[ChatCompletionStreamResponse,
62-
None, None]]:
57+
request: ChatCompletionRequest,
58+
text_sequence: List[str]) -> Union[ChatCompletionResponse, Generator[ChatCompletionStreamResponse, None, None]]:
6359
request_id = f'chat-{shortuuid.random()}'
6460
created_time = int(time.time())
6561
model_name = request.model
6662
if request.stream:
6763

68-
def completion_stream_generator(
69-
) -> Generator[ChatCompletionStreamResponse, None, None]:
64+
def completion_stream_generator() -> Generator[ChatCompletionStreamResponse, None, None]:
7065
previous_text = ''
7166
current_text = ''
7267
finish_reason = 'stop'
73-
has_parser = (VariableInterface.tool_parser is not None
74-
or VariableInterface.reasoning_parser is not None)
68+
has_parser = (VariableInterface.tool_parser is not None or VariableInterface.reasoning_parser is not None)
7569
for text in text_sequence:
7670
logprobs, usage = None, None
7771
delta_message = DeltaMessage(role='assistant', content=text)
7872
if has_parser:
7973
current_text = current_text + text
8074
has_tool = VariableInterface.tool_parser is not None
8175
if request.tool_choice != 'none' and has_tool:
82-
tool_delta = VariableInterface.tool_parser\
83-
.extract_tool_calls_streaming(
76+
tool_delta = VariableInterface.tool_parser.extract_tool_calls_streaming(
8477
previous_text=previous_text,
8578
current_text=current_text,
8679
delta_text=delta_message.content,
@@ -93,25 +86,22 @@ def completion_stream_generator(
9386
delta_message.content = tool_delta.content or ''
9487
if VariableInterface.reasoning_parser is not None:
9588
parser = VariableInterface.reasoning_parser
96-
reasoning_delta = parser.extract_reasoning_content_streaming(
97-
previous_text=previous_text,
98-
current_text=current_text,
99-
delta_text=delta_message.content,
100-
previous_token_ids=[],
101-
current_token_ids=[],
102-
delta_token_ids=[])
89+
reasoning_delta = parser.extract_reasoning_content_streaming(previous_text=previous_text,
90+
current_text=current_text,
91+
delta_text=delta_message.content,
92+
previous_token_ids=[],
93+
current_token_ids=[],
94+
delta_token_ids=[])
10395
if reasoning_delta is not None:
104-
delta_message.reasoning_content = (
105-
reasoning_delta.reasoning_content)
96+
delta_message.reasoning_content = (reasoning_delta.reasoning_content)
10697
delta_message.content = reasoning_delta.content or ''
10798
if has_parser:
10899
previous_text = current_text
109100

110-
choice_data = ChatCompletionResponseStreamChoice(
111-
index=0,
112-
delta=delta_message,
113-
finish_reason=finish_reason,
114-
logprobs=logprobs)
101+
choice_data = ChatCompletionResponseStreamChoice(index=0,
102+
delta=delta_message,
103+
finish_reason=finish_reason,
104+
logprobs=logprobs)
115105
response = ChatCompletionStreamResponse(
116106
id=request_id,
117107
created=created_time,
@@ -129,25 +119,20 @@ def completion_stream_generator(
129119
finish_reason = 'stop'
130120
has_tool = VariableInterface.tool_parser is not None
131121
if request.tool_choice != 'none' and has_tool:
132-
tool_call_info = VariableInterface.tool_parser.extract_tool_calls(
133-
text, request=request)
122+
tool_call_info = VariableInterface.tool_parser.extract_tool_calls(text, request=request)
134123
text, tool_calls = tool_call_info.content, tool_call_info.tool_calls
135124
if isinstance(tool_calls, List) and len(tool_calls):
136125
if finish_reason == 'stop':
137126
finish_reason = 'tool_calls'
138127

139128
if VariableInterface.reasoning_parser is not None:
140129
parser = VariableInterface.reasoning_parser
141-
reasoning_content, text = parser.extract_reasoning_content(
142-
text, request)
130+
reasoning_content, text = parser.extract_reasoning_content(text, request)
143131

144132
choices = []
145133
choice_data = ChatCompletionResponseChoice(
146134
index=0,
147-
message=ChatMessage(role='assistant',
148-
content=text,
149-
tool_calls=tool_calls,
150-
reasoning_content=reasoning_content),
135+
message=ChatMessage(role='assistant', content=text, tool_calls=tool_calls, reasoning_content=reasoning_content),
151136
finish_reason=finish_reason,
152137
)
153138
choices.append(choice_data)
@@ -161,9 +146,7 @@ def completion_stream_generator(
161146
)
162147

163148

164-
def _stream_parse(
165-
request: ChatCompletionRequest,
166-
text_sequence: List[str]) -> Tuple[str, str, List[DeltaToolCall]]:
149+
def _stream_parse(request: ChatCompletionRequest, text_sequence: List[str]) -> Tuple[str, str, List[DeltaToolCall]]:
167150
content = ''
168151
reasoning_content = ''
169152
tool_calls = {}
@@ -184,19 +167,16 @@ def _stream_parse(
184167
if c.function.name:
185168
existing_call.function.name = c.function.name
186169
if c.function.arguments:
187-
existing_call.function.arguments = (
188-
existing_call.function.arguments or '')
170+
existing_call.function.arguments = (existing_call.function.arguments or '')
189171
existing_call.function.arguments += c.function.arguments
190-
return content, reasoning_content, list(
191-
sorted(tool_calls.values(), key=lambda x: x.index))
172+
return content, reasoning_content, list(sorted(tool_calls.values(), key=lambda x: x.index))
192173

193174

194175
@pytest.mark.parametrize(('text_sequence', 'expects'), [
195-
(DELTA_TEXT_SEQUENCE,
196-
[TestExpects('get_weather', {
197-
'location': '北京',
198-
'unit': 'celsius'
199-
})]),
176+
(DELTA_TEXT_SEQUENCE, [TestExpects('get_weather', {
177+
'location': '北京',
178+
'unit': 'celsius'
179+
})]),
200180
(DELTA_TEXT_SEQUENCE_MULTIPLE_CALLS, [
201181
TestExpects('get_weather', {
202182
'location': '北京',
@@ -209,11 +189,8 @@ def test_parser_stream(text_sequence: List[str], expects: List[TestExpects]):
209189
tokenizer = DummyTokenizer()
210190
VariableInterface.tool_parser = Qwen3CoderToolParser(tokenizer=tokenizer)
211191
VariableInterface.reasoning_parser = None
212-
request = ChatCompletionRequest(model='qwen3coder',
213-
messages=[],
214-
stream=True)
215-
content, reasoning_content, tool_calls = _stream_parse(
216-
request, text_sequence)
192+
request = ChatCompletionRequest(model='qwen3coder', messages=[], stream=True)
193+
content, reasoning_content, tool_calls = _stream_parse(request, text_sequence)
217194
assert len(tool_calls) == len(expects)
218195
for parsed_call, expected_call in zip(tool_calls, expects):
219196
assert parsed_call.function.name == expected_call.func_name
@@ -223,11 +200,10 @@ def test_parser_stream(text_sequence: List[str], expects: List[TestExpects]):
223200

224201

225202
@pytest.mark.parametrize(('text_sequence', 'expects'), [
226-
(DELTA_TEXT_SEQUENCE,
227-
[TestExpects('get_weather', {
228-
'location': '北京',
229-
'unit': 'celsius'
230-
})]),
203+
(DELTA_TEXT_SEQUENCE, [TestExpects('get_weather', {
204+
'location': '北京',
205+
'unit': 'celsius'
206+
})]),
231207
(DELTA_TEXT_SEQUENCE_MULTIPLE_CALLS, [
232208
TestExpects('get_weather', {
233209
'location': '北京',
@@ -236,14 +212,12 @@ def test_parser_stream(text_sequence: List[str], expects: List[TestExpects]):
236212
TestExpects('get_weather', {'location': '上海'})
237213
]),
238214
])
239-
def test_parser_nonstream(text_sequence: List[str],
240-
expects: List[TestExpects]):
215+
def test_parser_nonstream(text_sequence: List[str], expects: List[TestExpects]):
241216
tokenizer = DummyTokenizer()
242217
VariableInterface.tool_parser = Qwen3CoderToolParser(tokenizer=tokenizer)
243218
VariableInterface.reasoning_parser = None
244219
resp: ChatCompletionResponse = _chat_completion_v1(
245-
ChatCompletionRequest(model='qwen3coder', messages=[], stream=False),
246-
text_sequence)
220+
ChatCompletionRequest(model='qwen3coder', messages=[], stream=False), text_sequence)
247221

248222
assert len(resp.choices) == 1
249223
first_message = resp.choices[0].message
@@ -273,8 +247,7 @@ def test_no_think_nonstream():
273247
VariableInterface.tool_parser = Qwen3CoderToolParser(tokenizer=tokenizer)
274248
VariableInterface.reasoning_parser = None
275249
resp: ChatCompletionResponse = _chat_completion_v1(
276-
ChatCompletionRequest(model='qwen3coder', messages=[], stream=False),
277-
text_sequence)
250+
ChatCompletionRequest(model='qwen3coder', messages=[], stream=False), text_sequence)
278251

279252
assert len(resp.choices) == 1
280253
first_message = resp.choices[0].message

0 commit comments

Comments
 (0)