Skip to content

Commit 57fd13a

Browse files
[Bugfix] Fix profiling dummy data for Pixtral (#18677)
Signed-off-by: DarkLight1337 <[email protected]>
1 parent 3a886bd commit 57fd13a

File tree

5 files changed

+153
-170
lines changed

5 files changed

+153
-170
lines changed

tests/models/multimodal/processing/test_common.py

Lines changed: 105 additions & 156 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,15 @@
99
UserMessage)
1010
from mistral_common.protocol.instruct.request import ChatCompletionRequest
1111
from PIL import Image
12-
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
1312

1413
from vllm.config import ModelConfig
1514
from vllm.inputs import InputProcessingContext
1615
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict
1716
from vllm.multimodal.inputs import MultiModalInputs
1817
from vllm.multimodal.processing import BaseMultiModalProcessor, ProcessingCache
19-
from vllm.transformers_utils.tokenizer import (MistralTokenizer,
20-
cached_tokenizer_from_config)
18+
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
19+
cached_tokenizer_from_config,
20+
encode_tokens)
2121

2222
from ....multimodal.utils import random_audio, random_image, random_video
2323
from ...registry import HF_EXAMPLE_MODELS
@@ -28,7 +28,6 @@ def _test_processing_correctness(
2828
hit_rate: float,
2929
num_batches: int,
3030
simplify_rate: float,
31-
ignore_mm_keys: Optional[set[str]] = None,
3231
):
3332
model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id)
3433
model_info.check_available_online(on_fail="skip")
@@ -99,10 +98,23 @@ def _test_processing_correctness(
9998
}
10099

101100
mm_counts = {k: len(vs) for k, vs in mm_data.items()}
102-
prompt = dummy_inputs.get_dummy_processor_inputs(
103-
model_config.max_model_len,
104-
mm_counts,
105-
).prompt_text
101+
102+
# Mistral chat outputs tokens directly, rather than text prompts
103+
if isinstance(tokenizer, MistralTokenizer):
104+
images = mm_data.get("image", [])
105+
request = ChatCompletionRequest(messages=[
106+
UserMessage(content=[
107+
TextChunk(text=""),
108+
*(ImageChunk(image=image) for image in images),
109+
]),
110+
])
111+
res = tokenizer.mistral.encode_chat_completion(request)
112+
prompt = res.tokens
113+
else:
114+
prompt = dummy_inputs.get_dummy_processor_inputs(
115+
model_config.max_model_len,
116+
mm_counts,
117+
).prompt
106118

107119
# Drop unnecessary keys and test single -> multi conversion
108120
if rng.rand() < simplify_rate:
@@ -112,124 +124,66 @@ def _test_processing_correctness(
112124
elif len(mm_data[k]) == 1:
113125
mm_data[k] = mm_data[k][0]
114126

115-
if isinstance(tokenizer, MistralTokenizer):
116-
_test_processing_correctness_mistral(
117-
model_config,
118-
tokenizer,
119-
prompt,
120-
mm_data,
121-
baseline_processor,
122-
cached_processor,
123-
batch_idx,
124-
ignore_mm_keys=ignore_mm_keys,
125-
)
126-
else:
127-
_test_processing_correctness_hf(
128-
model_config,
129-
tokenizer,
130-
prompt,
131-
mm_data,
132-
baseline_processor,
133-
cached_processor,
134-
batch_idx,
135-
ignore_mm_keys=ignore_mm_keys,
136-
)
137-
138-
139-
def _test_processing_correctness_hf(
127+
_test_processing_correctness_one(
128+
model_config,
129+
tokenizer,
130+
prompt,
131+
mm_data,
132+
baseline_processor,
133+
cached_processor,
134+
batch_idx,
135+
)
136+
137+
138+
# For some multimodal models, tokenizer will always add bos_token
139+
# at the beginning of prompt by default, causing hf_processor outputs
140+
# incorrect token ids. So we need use `add_special_tokens=False` here
141+
# to leave bos_token to be added by the processor.
142+
_ADD_SPECIAL_TOKENS_OVERRIDES = {
143+
"mllama": False,
144+
"ovis": False,
145+
"ultravox": False,
146+
"whisper": False,
147+
}
148+
149+
_IGNORE_MM_KEYS = {
150+
# In Ultravox, the audio_features can be different depending on padding
151+
# The slight difference should not be a problem though, since
152+
# attention_mask lets us ignore the difference.
153+
"ultravox": {"audio_features"},
154+
}
155+
156+
157+
def _test_processing_correctness_one(
140158
model_config: ModelConfig,
141-
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
142-
prompt: str,
159+
tokenizer: AnyTokenizer,
160+
prompt: Union[str, list[int]],
143161
mm_data: MultiModalDataDict,
144162
baseline_processor: BaseMultiModalProcessor,
145163
cached_processor: BaseMultiModalProcessor,
146164
batch_idx: int,
147-
ignore_mm_keys: Optional[set[str]] = None,
148165
):
149-
if model_config.hf_config.model_type in ("mllama", "ovis", "ultravox",
150-
"whisper"):
151-
# For some multimodal models, tokenizer will always add bos_token
152-
# at the beginning of prompt by default, causing hf_processor outputs
153-
# incorrect token ids. So we need use `add_special_tokens=False` here
154-
# to leave bos_token to be added by the processor.
155-
token_prompt = tokenizer.encode(prompt, add_special_tokens=False)
166+
model_type = model_config.hf_config.model_type
167+
ignore_mm_keys = _IGNORE_MM_KEYS.get(model_type, set[str]())
168+
169+
if isinstance(prompt, str):
170+
text_prompt = prompt
171+
token_prompt = encode_tokens(
172+
tokenizer,
173+
prompt,
174+
add_special_tokens=_ADD_SPECIAL_TOKENS_OVERRIDES.get(model_type),
175+
)
156176
else:
157-
token_prompt = tokenizer.encode(prompt)
158-
159-
baseline_result = baseline_processor.apply(
160-
prompt,
161-
mm_data=mm_data,
162-
hf_processor_mm_kwargs={},
163-
)
164-
cached_result = cached_processor.apply(
165-
prompt,
166-
mm_data=mm_data,
167-
hf_processor_mm_kwargs={},
168-
)
169-
170-
_assert_inputs_equal(
171-
baseline_result,
172-
cached_result,
173-
ignore_mm_keys=ignore_mm_keys,
174-
msg=f"Failed ({batch_idx=}, {prompt=}, {mm_data=})",
175-
)
177+
# Mistral does not support decode_tokens with skip_special_tokens=False
178+
text_prompt = None
179+
token_prompt = prompt
176180

177181
baseline_tokenized_result = baseline_processor.apply(
178182
token_prompt,
179183
mm_data=mm_data,
180184
hf_processor_mm_kwargs={},
181185
)
182186

183-
_assert_inputs_equal(
184-
baseline_result,
185-
baseline_tokenized_result,
186-
ignore_mm_keys=ignore_mm_keys,
187-
msg=f"Failed ({batch_idx=}, {prompt=}, {mm_data=})",
188-
)
189-
190-
cached_tokenized_result = cached_processor.apply(
191-
token_prompt,
192-
mm_data=mm_data,
193-
hf_processor_mm_kwargs={},
194-
)
195-
196-
_assert_inputs_equal(
197-
cached_result,
198-
cached_tokenized_result,
199-
ignore_mm_keys=ignore_mm_keys,
200-
msg=f"Failed ({batch_idx=}, {prompt=}, {mm_data=})",
201-
)
202-
203-
204-
def _test_processing_correctness_mistral(
205-
model_config: ModelConfig,
206-
tokenizer: MistralTokenizer,
207-
prompt: str,
208-
mm_data: MultiModalDataDict,
209-
baseline_processor: BaseMultiModalProcessor,
210-
cached_processor: BaseMultiModalProcessor,
211-
batch_idx: int,
212-
ignore_mm_keys: Optional[set[str]] = None,
213-
):
214-
images = mm_data.get("image", [])
215-
if not isinstance(images, list):
216-
images = [images]
217-
218-
request = ChatCompletionRequest(messages=[
219-
UserMessage(content=[
220-
TextChunk(text=prompt),
221-
*(ImageChunk(image=image) for image in images),
222-
]),
223-
])
224-
res = tokenizer.mistral.encode_chat_completion(request)
225-
token_prompt = res.tokens
226-
227-
# Mistral chat outputs tokens directly, rather than text prompts
228-
baseline_tokenized_result = baseline_processor.apply(
229-
token_prompt,
230-
mm_data=mm_data,
231-
hf_processor_mm_kwargs={},
232-
)
233187
cached_tokenized_result = cached_processor.apply(
234188
token_prompt,
235189
mm_data=mm_data,
@@ -240,9 +194,44 @@ def _test_processing_correctness_mistral(
240194
baseline_tokenized_result,
241195
cached_tokenized_result,
242196
ignore_mm_keys=ignore_mm_keys,
243-
msg=f"Failed ({batch_idx=}, {prompt=}, {mm_data=})",
197+
msg=f"Failed ({batch_idx=}, {token_prompt=}, {mm_data=})",
244198
)
245199

200+
if text_prompt is not None:
201+
baseline_text_result = baseline_processor.apply(
202+
text_prompt,
203+
mm_data=mm_data,
204+
hf_processor_mm_kwargs={},
205+
)
206+
cached_text_result = cached_processor.apply(
207+
text_prompt,
208+
mm_data=mm_data,
209+
hf_processor_mm_kwargs={},
210+
)
211+
212+
_assert_inputs_equal(
213+
baseline_text_result,
214+
cached_text_result,
215+
ignore_mm_keys=ignore_mm_keys,
216+
msg=f"Failed ({batch_idx=}, {text_prompt=}, {mm_data=})",
217+
)
218+
219+
_assert_inputs_equal(
220+
baseline_text_result,
221+
baseline_tokenized_result,
222+
ignore_mm_keys=ignore_mm_keys,
223+
msg=f"Failed ({batch_idx=}, {text_prompt=}, "
224+
f"{token_prompt=}, {mm_data=})",
225+
)
226+
227+
_assert_inputs_equal(
228+
cached_text_result,
229+
cached_tokenized_result,
230+
ignore_mm_keys=ignore_mm_keys,
231+
msg=f"Failed ({batch_idx=}, {text_prompt=}, "
232+
f"{token_prompt=}, {mm_data=})",
233+
)
234+
246235

247236
# yapf: disable
248237
@pytest.mark.parametrize("model_id", [
@@ -281,6 +270,7 @@ def _test_processing_correctness_mistral(
281270
"AIDC-AI/Ovis2-1B",
282271
"google/paligemma-3b-mix-224",
283272
"google/paligemma2-3b-ft-docci-448",
273+
"microsoft/Phi-3.5-vision-instruct",
284274
"microsoft/Phi-4-multimodal-instruct",
285275
"mistralai/Pixtral-12B-2409",
286276
"mistral-community/pixtral-12b",
@@ -303,41 +293,6 @@ def test_processing_correctness(
303293
num_batches: int,
304294
simplify_rate: float,
305295
):
306-
ignore_mm_keys = None
307-
if 'ultravox' in model_id:
308-
# In Ultravox, the audio_features can be different depending on padding
309-
# The slight difference should not be a problem though, since
310-
# attention_mask lets us ignore the difference.
311-
ignore_mm_keys = {"audio_features"}
312-
313-
_test_processing_correctness(
314-
model_id,
315-
hit_rate=hit_rate,
316-
num_batches=num_batches,
317-
simplify_rate=simplify_rate,
318-
ignore_mm_keys=ignore_mm_keys,
319-
)
320-
321-
322-
# yapf: disable
323-
@pytest.mark.parametrize("model_id", ["microsoft/Phi-3.5-vision-instruct"])
324-
@pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0])
325-
@pytest.mark.parametrize("num_batches", [32])
326-
@pytest.mark.parametrize("simplify_rate", [1.0])
327-
# yapf: enable
328-
def test_processing_correctness_phi3v(
329-
model_id: str,
330-
hit_rate: float,
331-
num_batches: int,
332-
simplify_rate: float,
333-
):
334-
# HACK - this is an attempted workaround for the following bug
335-
# https://github.com/huggingface/transformers/issues/34307
336-
from transformers import AutoImageProcessor # noqa: F401
337-
from transformers import AutoProcessor # noqa: F401
338-
339-
AutoImageProcessor.from_pretrained(model_id, trust_remote_code=True)
340-
341296
_test_processing_correctness(
342297
model_id,
343298
hit_rate=hit_rate,
@@ -356,16 +311,10 @@ def _assert_inputs_equal(
356311
if ignore_mm_keys is None:
357312
ignore_mm_keys = set()
358313

359-
if msg is None:
360-
assert "mm_kwargs" in a and "mm_kwargs" in b
361-
else:
362-
assert "mm_kwargs" in a and "mm_kwargs" in b, msg
314+
assert "mm_kwargs" in a and "mm_kwargs" in b, msg
363315

364316
for key in ignore_mm_keys:
365317
a["mm_kwargs"].pop(key, None)
366318
b["mm_kwargs"].pop(key, None)
367319

368-
if msg is None:
369-
assert a == b
370-
else:
371-
assert a == b, msg
320+
assert a == b, msg

tests/models/multimodal/processing/test_mllama.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def test_profiling(
4949
] * max_num_seqs
5050

5151
mm_kwargs = processor.apply(
52-
prompt=dummy_mm_data.prompt_text,
52+
prompt=dummy_mm_data.prompt,
5353
mm_data=dummy_mm_data.mm_data,
5454
hf_processor_mm_kwargs=dict(),
5555
)["mm_kwargs"]

tests/models/registry.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from packaging.version import Version
99
from transformers import __version__ as TRANSFORMERS_VERSION
1010

11+
from vllm.config import TokenizerMode
12+
1113

1214
@dataclass(frozen=True)
1315
class _HfExamplesInfo:
@@ -20,7 +22,7 @@ class _HfExamplesInfo:
2022
tokenizer: Optional[str] = None
2123
"""Set the tokenizer to load for this architecture."""
2224

23-
tokenizer_mode: str = "auto"
25+
tokenizer_mode: TokenizerMode = "auto"
2426
"""Set the tokenizer type for this architecture."""
2527

2628
speculative_model: Optional[str] = None
@@ -388,8 +390,7 @@ def check_available_online(
388390
"Phi4MMForCausalLM": _HfExamplesInfo("microsoft/Phi-4-multimodal-instruct",
389391
trust_remote_code=True),
390392
"PixtralForConditionalGeneration": _HfExamplesInfo("mistralai/Pixtral-12B-2409", # noqa: E501
391-
tokenizer_mode="mistral",
392-
v0_only=True),
393+
tokenizer_mode="mistral"),
393394
"QwenVLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen-VL",
394395
extras={"chat": "Qwen/Qwen-VL-Chat"}, # noqa: E501
395396
trust_remote_code=True,
@@ -400,7 +401,7 @@ def check_available_online(
400401
"Qwen2_5OmniModel": _HfExamplesInfo("Qwen/Qwen2.5-Omni-3B",
401402
min_transformers_version="4.52"),
402403
"Qwen2_5OmniForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2.5-Omni-7B-AWQ", # noqa: E501
403-
min_transformers_version="4.52"),
404+
min_transformers_version="4.52"),
404405
"SkyworkR1VChatModel": _HfExamplesInfo("Skywork/Skywork-R1V-38B"),
405406
"SmolVLMForConditionalGeneration": _HfExamplesInfo("HuggingFaceTB/SmolVLM2-2.2B-Instruct"), # noqa: E501
406407
"UltravoxModel": _HfExamplesInfo("fixie-ai/ultravox-v0_5-llama-3_2-1b", # noqa: E501

0 commit comments

Comments
 (0)