Skip to content

Commit 2619605

Browse files
protobird-gitcopybara-github
authored andcommitted
Update the converter to support multi-signature in multimodal models.
- Pass pixel_seq_len explicitly to specify how many tokens would be reserved for images PiperOrigin-RevId: 748706889
1 parent 8594417 commit 2619605

File tree

3 files changed

+36
-14
lines changed

3 files changed

+36
-14
lines changed

ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ def main(_):
4848
pixel_values_size=torch.Size(
4949
[1, config.channels, config.image_size, config.image_size]
5050
),
51+
pixel_seq_len=(config.image_size // config.patch_size) ** 2,
5152
quantize=flags.FLAGS.quantize,
5253
config=pytorch_model.config.decoder_config,
5354
export_config=ExportConfig(),

ai_edge_torch/generative/examples/qwen_vl/convert_to_tflite.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,9 @@ def main(_):
4343
)
4444

4545
grid_thw = pytorch_model.image_encoder.get_grid_thw()
46+
spatial_merge_size = (
47+
pytorch_model.config.image_encoder_config.spatial_merge_size
48+
)
4649
converter.convert_to_tflite(
4750
pytorch_model,
4851
output_path=flags.FLAGS.output_path,
@@ -51,6 +54,10 @@ def main(_):
5154
pixel_values_size=(
5255
pytorch_model.image_encoder.get_pixel_values_size(grid_thw)
5356
),
57+
pixel_seq_len=(
58+
(grid_thw[0][1] // spatial_merge_size)
59+
* (grid_thw[0][2] // spatial_merge_size)
60+
),
5461
quantize=flags.FLAGS.quantize,
5562
config=pytorch_model.config.decoder_config,
5663
export_config=ExportConfig(),

ai_edge_torch/generative/utilities/converter.py

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def define_conversion_flags(model_name: str):
5757
)
5858
flags.DEFINE_string(
5959
'output_name_prefix',
60-
f'{model_name}',
60+
model_name,
6161
'The prefix of the output tflite model name.',
6262
)
6363
flags.DEFINE_multi_integer(
@@ -91,6 +91,7 @@ def convert_to_tflite(
9191
output_name_prefix: str,
9292
prefill_seq_len: Union[int, list[int]],
9393
pixel_values_size: torch.Size = None,
94+
pixel_seq_len: int = 0,
9495
quantize: bool = True,
9596
config: cfg.ModelConfig = None,
9697
lora_ranks: Optional[list[int]] = None,
@@ -133,12 +134,18 @@ def convert_to_tflite(
133134
use. If a list, the model will have multiple prefill signatures.
134135
pixel_values_size (torch.Size, optional): The size of pixel values to pass
135136
to the model. If None, the model is not expected to take pixel values.
137+
pixel_seq_len (int, optional): The length of pixel tokens, or pixel
138+
embeddings generated by the image encoder with pixel values. The actual
139+
length of prefill_seq_len will be added by pixel_seq_len when pixel
140+
values are passed.
136141
quantize (bool, optional): Whether the model should be quanized. Defaults
137142
to True.
138143
config (cfg.ModelConfig, optional): The model config used to configure KV
139144
cache. If None, it uses the config of the pytorch_model.
140145
lora_ranks (list[int], optional): The ranks of the LORA layers. If None,
141146
no LoRA signatures will be added.
147+
export_config (ExportConfig, optional): The export configuration. If None,
148+
it uses the default export configuration.
142149
"""
143150
# pylint: disable=protected-access
144151
torch._dynamo.config.cache_size_limit = 64
@@ -173,6 +180,7 @@ def convert_to_tflite(
173180
output_file,
174181
prefill_seq_lens,
175182
pixel_values_size,
183+
pixel_seq_len,
176184
quantize,
177185
config,
178186
loras,
@@ -185,6 +193,7 @@ def _export_helper(
185193
output_file: str,
186194
prefill_seq_lens: list[int],
187195
pixel_values_size: torch.Size,
196+
pixel_seq_len: int,
188197
quantize: bool,
189198
config: cfg.ModelConfig,
190199
loras: list[None | lora_utils.LoRA],
@@ -197,11 +206,18 @@ def _export_helper(
197206
prefill_tokens_list.append(torch.full((1, seq_len), 0, dtype=torch.int))
198207
prefill_input_pos_list.append(torch.arange(0, seq_len, dtype=torch.int))
199208

200-
prefill_pixel_values = (
201-
torch.full(pixel_values_size, 0, dtype=torch.float32)
202-
if pixel_values_size
203-
else None
204-
)
209+
prefill_pixel_values = None
210+
prefill_tokens_list_with_pixel = []
211+
prefill_input_pos_list_with_pixel = []
212+
if pixel_values_size is not None:
213+
prefill_pixel_values = torch.full(pixel_values_size, 0, dtype=torch.float32)
214+
for seq_len in prefill_seq_lens:
215+
prefill_tokens_list_with_pixel.append(
216+
torch.full((1, seq_len + pixel_seq_len), 0, dtype=torch.int)
217+
)
218+
prefill_input_pos_list_with_pixel.append(
219+
torch.arange(0, seq_len + pixel_seq_len, dtype=torch.int)
220+
)
205221

206222
if export_config.prefill_mask is None:
207223
prefill_masks = None
@@ -238,13 +254,11 @@ def _export_helper(
238254
for lora in loras:
239255
for i in range(len(prefill_seq_lens)):
240256
prefill_seq_len = prefill_seq_lens[i]
241-
prefill_tokens = prefill_tokens_list[i]
242-
prefill_input_pos = prefill_input_pos_list[i]
243257
prefill_signature_name = f'prefill_{prefill_seq_len}'
244258

245259
sample_kwargs = {
246-
'tokens': prefill_tokens,
247-
'input_pos': prefill_input_pos,
260+
'tokens': prefill_tokens_list[i],
261+
'input_pos': prefill_input_pos_list[i],
248262
'kv_cache': prefill_kv,
249263
}
250264
if prefill_masks is not None:
@@ -261,13 +275,13 @@ def _export_helper(
261275
)
262276

263277
if prefill_pixel_values is not None:
278+
sample_kwargs['tokens'] = prefill_tokens_list_with_pixel[i]
279+
sample_kwargs['input_pos'] = prefill_input_pos_list_with_pixel[i]
280+
sample_kwargs['pixel_values'] = prefill_pixel_values
264281
converter.add_signature(
265282
prefill_signature_name + '_pixel',
266283
mod,
267-
sample_kwargs={
268-
**sample_kwargs,
269-
'pixel_values': prefill_pixel_values,
270-
},
284+
sample_kwargs=sample_kwargs,
271285
)
272286

273287
sample_kwargs = {

0 commit comments

Comments
 (0)