9
9
UserMessage )
10
10
from mistral_common .protocol .instruct .request import ChatCompletionRequest
11
11
from PIL import Image
12
- from transformers import PreTrainedTokenizer , PreTrainedTokenizerFast
13
12
14
13
from vllm .config import ModelConfig
15
14
from vllm .inputs import InputProcessingContext
16
15
from vllm .multimodal import MULTIMODAL_REGISTRY , MultiModalDataDict
17
16
from vllm .multimodal .inputs import MultiModalInputs
18
17
from vllm .multimodal .processing import BaseMultiModalProcessor , ProcessingCache
19
- from vllm .transformers_utils .tokenizer import (MistralTokenizer ,
20
- cached_tokenizer_from_config )
18
+ from vllm .transformers_utils .tokenizer import (AnyTokenizer , MistralTokenizer ,
19
+ cached_tokenizer_from_config ,
20
+ encode_tokens )
21
21
22
22
from ....multimodal .utils import random_audio , random_image , random_video
23
23
from ...registry import HF_EXAMPLE_MODELS
@@ -28,7 +28,6 @@ def _test_processing_correctness(
28
28
hit_rate : float ,
29
29
num_batches : int ,
30
30
simplify_rate : float ,
31
- ignore_mm_keys : Optional [set [str ]] = None ,
32
31
):
33
32
model_info = HF_EXAMPLE_MODELS .find_hf_info (model_id )
34
33
model_info .check_available_online (on_fail = "skip" )
@@ -99,10 +98,23 @@ def _test_processing_correctness(
99
98
}
100
99
101
100
mm_counts = {k : len (vs ) for k , vs in mm_data .items ()}
102
- prompt = dummy_inputs .get_dummy_processor_inputs (
103
- model_config .max_model_len ,
104
- mm_counts ,
105
- ).prompt_text
101
+
102
+ # Mistral chat outputs tokens directly, rather than text prompts
103
+ if isinstance (tokenizer , MistralTokenizer ):
104
+ images = mm_data .get ("image" , [])
105
+ request = ChatCompletionRequest (messages = [
106
+ UserMessage (content = [
107
+ TextChunk (text = "" ),
108
+ * (ImageChunk (image = image ) for image in images ),
109
+ ]),
110
+ ])
111
+ res = tokenizer .mistral .encode_chat_completion (request )
112
+ prompt = res .tokens
113
+ else :
114
+ prompt = dummy_inputs .get_dummy_processor_inputs (
115
+ model_config .max_model_len ,
116
+ mm_counts ,
117
+ ).prompt
106
118
107
119
# Drop unnecessary keys and test single -> multi conversion
108
120
if rng .rand () < simplify_rate :
@@ -112,124 +124,66 @@ def _test_processing_correctness(
112
124
elif len (mm_data [k ]) == 1 :
113
125
mm_data [k ] = mm_data [k ][0 ]
114
126
115
- if isinstance (tokenizer , MistralTokenizer ):
116
- _test_processing_correctness_mistral (
117
- model_config ,
118
- tokenizer ,
119
- prompt ,
120
- mm_data ,
121
- baseline_processor ,
122
- cached_processor ,
123
- batch_idx ,
124
- ignore_mm_keys = ignore_mm_keys ,
125
- )
126
- else :
127
- _test_processing_correctness_hf (
128
- model_config ,
129
- tokenizer ,
130
- prompt ,
131
- mm_data ,
132
- baseline_processor ,
133
- cached_processor ,
134
- batch_idx ,
135
- ignore_mm_keys = ignore_mm_keys ,
136
- )
137
-
138
-
139
- def _test_processing_correctness_hf (
127
+ _test_processing_correctness_one (
128
+ model_config ,
129
+ tokenizer ,
130
+ prompt ,
131
+ mm_data ,
132
+ baseline_processor ,
133
+ cached_processor ,
134
+ batch_idx ,
135
+ )
136
+
137
+
138
+ # For some multimodal models, tokenizer will always add bos_token
139
+ # at the beginning of prompt by default, causing hf_processor outputs
140
+ # incorrect token ids. So we need use `add_special_tokens=False` here
141
+ # to leave bos_token to be added by the processor.
142
+ _ADD_SPECIAL_TOKENS_OVERRIDES = {
143
+ "mllama" : False ,
144
+ "ovis" : False ,
145
+ "ultravox" : False ,
146
+ "whisper" : False ,
147
+ }
148
+
149
+ _IGNORE_MM_KEYS = {
150
+ # In Ultravox, the audio_features can be different depending on padding
151
+ # The slight difference should not be a problem though, since
152
+ # attention_mask lets us ignore the difference.
153
+ "ultravox" : {"audio_features" },
154
+ }
155
+
156
+
157
+ def _test_processing_correctness_one (
140
158
model_config : ModelConfig ,
141
- tokenizer : Union [ PreTrainedTokenizer , PreTrainedTokenizerFast ] ,
142
- prompt : str ,
159
+ tokenizer : AnyTokenizer ,
160
+ prompt : Union [ str , list [ int ]] ,
143
161
mm_data : MultiModalDataDict ,
144
162
baseline_processor : BaseMultiModalProcessor ,
145
163
cached_processor : BaseMultiModalProcessor ,
146
164
batch_idx : int ,
147
- ignore_mm_keys : Optional [set [str ]] = None ,
148
165
):
149
- if model_config .hf_config .model_type in ("mllama" , "ovis" , "ultravox" ,
150
- "whisper" ):
151
- # For some multimodal models, tokenizer will always add bos_token
152
- # at the beginning of prompt by default, causing hf_processor outputs
153
- # incorrect token ids. So we need use `add_special_tokens=False` here
154
- # to leave bos_token to be added by the processor.
155
- token_prompt = tokenizer .encode (prompt , add_special_tokens = False )
166
+ model_type = model_config .hf_config .model_type
167
+ ignore_mm_keys = _IGNORE_MM_KEYS .get (model_type , set [str ]())
168
+
169
+ if isinstance (prompt , str ):
170
+ text_prompt = prompt
171
+ token_prompt = encode_tokens (
172
+ tokenizer ,
173
+ prompt ,
174
+ add_special_tokens = _ADD_SPECIAL_TOKENS_OVERRIDES .get (model_type ),
175
+ )
156
176
else :
157
- token_prompt = tokenizer .encode (prompt )
158
-
159
- baseline_result = baseline_processor .apply (
160
- prompt ,
161
- mm_data = mm_data ,
162
- hf_processor_mm_kwargs = {},
163
- )
164
- cached_result = cached_processor .apply (
165
- prompt ,
166
- mm_data = mm_data ,
167
- hf_processor_mm_kwargs = {},
168
- )
169
-
170
- _assert_inputs_equal (
171
- baseline_result ,
172
- cached_result ,
173
- ignore_mm_keys = ignore_mm_keys ,
174
- msg = f"Failed ({ batch_idx = } , { prompt = } , { mm_data = } )" ,
175
- )
177
+ # Mistral does not support decode_tokens with skip_special_tokens=False
178
+ text_prompt = None
179
+ token_prompt = prompt
176
180
177
181
baseline_tokenized_result = baseline_processor .apply (
178
182
token_prompt ,
179
183
mm_data = mm_data ,
180
184
hf_processor_mm_kwargs = {},
181
185
)
182
186
183
- _assert_inputs_equal (
184
- baseline_result ,
185
- baseline_tokenized_result ,
186
- ignore_mm_keys = ignore_mm_keys ,
187
- msg = f"Failed ({ batch_idx = } , { prompt = } , { mm_data = } )" ,
188
- )
189
-
190
- cached_tokenized_result = cached_processor .apply (
191
- token_prompt ,
192
- mm_data = mm_data ,
193
- hf_processor_mm_kwargs = {},
194
- )
195
-
196
- _assert_inputs_equal (
197
- cached_result ,
198
- cached_tokenized_result ,
199
- ignore_mm_keys = ignore_mm_keys ,
200
- msg = f"Failed ({ batch_idx = } , { prompt = } , { mm_data = } )" ,
201
- )
202
-
203
-
204
- def _test_processing_correctness_mistral (
205
- model_config : ModelConfig ,
206
- tokenizer : MistralTokenizer ,
207
- prompt : str ,
208
- mm_data : MultiModalDataDict ,
209
- baseline_processor : BaseMultiModalProcessor ,
210
- cached_processor : BaseMultiModalProcessor ,
211
- batch_idx : int ,
212
- ignore_mm_keys : Optional [set [str ]] = None ,
213
- ):
214
- images = mm_data .get ("image" , [])
215
- if not isinstance (images , list ):
216
- images = [images ]
217
-
218
- request = ChatCompletionRequest (messages = [
219
- UserMessage (content = [
220
- TextChunk (text = prompt ),
221
- * (ImageChunk (image = image ) for image in images ),
222
- ]),
223
- ])
224
- res = tokenizer .mistral .encode_chat_completion (request )
225
- token_prompt = res .tokens
226
-
227
- # Mistral chat outputs tokens directly, rather than text prompts
228
- baseline_tokenized_result = baseline_processor .apply (
229
- token_prompt ,
230
- mm_data = mm_data ,
231
- hf_processor_mm_kwargs = {},
232
- )
233
187
cached_tokenized_result = cached_processor .apply (
234
188
token_prompt ,
235
189
mm_data = mm_data ,
@@ -240,9 +194,44 @@ def _test_processing_correctness_mistral(
240
194
baseline_tokenized_result ,
241
195
cached_tokenized_result ,
242
196
ignore_mm_keys = ignore_mm_keys ,
243
- msg = f"Failed ({ batch_idx = } , { prompt = } , { mm_data = } )" ,
197
+ msg = f"Failed ({ batch_idx = } , { token_prompt = } , { mm_data = } )" ,
244
198
)
245
199
200
+ if text_prompt is not None :
201
+ baseline_text_result = baseline_processor .apply (
202
+ text_prompt ,
203
+ mm_data = mm_data ,
204
+ hf_processor_mm_kwargs = {},
205
+ )
206
+ cached_text_result = cached_processor .apply (
207
+ text_prompt ,
208
+ mm_data = mm_data ,
209
+ hf_processor_mm_kwargs = {},
210
+ )
211
+
212
+ _assert_inputs_equal (
213
+ baseline_text_result ,
214
+ cached_text_result ,
215
+ ignore_mm_keys = ignore_mm_keys ,
216
+ msg = f"Failed ({ batch_idx = } , { text_prompt = } , { mm_data = } )" ,
217
+ )
218
+
219
+ _assert_inputs_equal (
220
+ baseline_text_result ,
221
+ baseline_tokenized_result ,
222
+ ignore_mm_keys = ignore_mm_keys ,
223
+ msg = f"Failed ({ batch_idx = } , { text_prompt = } , "
224
+ f"{ token_prompt = } , { mm_data = } )" ,
225
+ )
226
+
227
+ _assert_inputs_equal (
228
+ cached_text_result ,
229
+ cached_tokenized_result ,
230
+ ignore_mm_keys = ignore_mm_keys ,
231
+ msg = f"Failed ({ batch_idx = } , { text_prompt = } , "
232
+ f"{ token_prompt = } , { mm_data = } )" ,
233
+ )
234
+
246
235
247
236
# yapf: disable
248
237
@pytest .mark .parametrize ("model_id" , [
@@ -281,6 +270,7 @@ def _test_processing_correctness_mistral(
281
270
"AIDC-AI/Ovis2-1B" ,
282
271
"google/paligemma-3b-mix-224" ,
283
272
"google/paligemma2-3b-ft-docci-448" ,
273
+ "microsoft/Phi-3.5-vision-instruct" ,
284
274
"microsoft/Phi-4-multimodal-instruct" ,
285
275
"mistralai/Pixtral-12B-2409" ,
286
276
"mistral-community/pixtral-12b" ,
@@ -303,41 +293,6 @@ def test_processing_correctness(
303
293
num_batches : int ,
304
294
simplify_rate : float ,
305
295
):
306
- ignore_mm_keys = None
307
- if 'ultravox' in model_id :
308
- # In Ultravox, the audio_features can be different depending on padding
309
- # The slight difference should not be a problem though, since
310
- # attention_mask lets us ignore the difference.
311
- ignore_mm_keys = {"audio_features" }
312
-
313
- _test_processing_correctness (
314
- model_id ,
315
- hit_rate = hit_rate ,
316
- num_batches = num_batches ,
317
- simplify_rate = simplify_rate ,
318
- ignore_mm_keys = ignore_mm_keys ,
319
- )
320
-
321
-
322
- # yapf: disable
323
- @pytest .mark .parametrize ("model_id" , ["microsoft/Phi-3.5-vision-instruct" ])
324
- @pytest .mark .parametrize ("hit_rate" , [0.3 , 0.5 , 1.0 ])
325
- @pytest .mark .parametrize ("num_batches" , [32 ])
326
- @pytest .mark .parametrize ("simplify_rate" , [1.0 ])
327
- # yapf: enable
328
- def test_processing_correctness_phi3v (
329
- model_id : str ,
330
- hit_rate : float ,
331
- num_batches : int ,
332
- simplify_rate : float ,
333
- ):
334
- # HACK - this is an attempted workaround for the following bug
335
- # https://github.com/huggingface/transformers/issues/34307
336
- from transformers import AutoImageProcessor # noqa: F401
337
- from transformers import AutoProcessor # noqa: F401
338
-
339
- AutoImageProcessor .from_pretrained (model_id , trust_remote_code = True )
340
-
341
296
_test_processing_correctness (
342
297
model_id ,
343
298
hit_rate = hit_rate ,
@@ -356,16 +311,10 @@ def _assert_inputs_equal(
356
311
if ignore_mm_keys is None :
357
312
ignore_mm_keys = set ()
358
313
359
- if msg is None :
360
- assert "mm_kwargs" in a and "mm_kwargs" in b
361
- else :
362
- assert "mm_kwargs" in a and "mm_kwargs" in b , msg
314
+ assert "mm_kwargs" in a and "mm_kwargs" in b , msg
363
315
364
316
for key in ignore_mm_keys :
365
317
a ["mm_kwargs" ].pop (key , None )
366
318
b ["mm_kwargs" ].pop (key , None )
367
319
368
- if msg is None :
369
- assert a == b
370
- else :
371
- assert a == b , msg
320
+ assert a == b , msg
0 commit comments