@@ -57,7 +57,7 @@ def define_conversion_flags(model_name: str):
57
57
)
58
58
flags .DEFINE_string (
59
59
'output_name_prefix' ,
60
- f' { model_name } ' ,
60
+ model_name ,
61
61
'The prefix of the output tflite model name.' ,
62
62
)
63
63
flags .DEFINE_multi_integer (
@@ -91,6 +91,7 @@ def convert_to_tflite(
91
91
output_name_prefix : str ,
92
92
prefill_seq_len : Union [int , list [int ]],
93
93
pixel_values_size : torch .Size = None ,
94
+ pixel_seq_len : int = 0 ,
94
95
quantize : bool = True ,
95
96
config : cfg .ModelConfig = None ,
96
97
lora_ranks : Optional [list [int ]] = None ,
@@ -133,12 +134,18 @@ def convert_to_tflite(
133
134
use. If a list, the model will have multiple prefill signatures.
134
135
pixel_values_size (torch.Size, optional): The size of pixel values to pass
135
136
to the model. If None, the model is not expected to take pixel values.
137
+ pixel_seq_len (int, optional): The length of pixel tokens, or pixel
138
+ embeddings generated by the image encoder with pixel values. The actual
139
+ length of prefill_seq_len will be added by pixel_seq_len when pixel
140
+ values are passed.
136
141
quantize (bool, optional): Whether the model should be quanized. Defaults
137
142
to True.
138
143
config (cfg.ModelConfig, optional): The model config used to configure KV
139
144
cache. If None, it uses the config of the pytorch_model.
140
145
lora_ranks (list[int], optional): The ranks of the LORA layers. If None,
141
146
no LoRA signatures will be added.
147
+ export_config (ExportConfig, optional): The export configuration. If None,
148
+ it uses the default export configuration.
142
149
"""
143
150
# pylint: disable=protected-access
144
151
torch ._dynamo .config .cache_size_limit = 64
@@ -173,6 +180,7 @@ def convert_to_tflite(
173
180
output_file ,
174
181
prefill_seq_lens ,
175
182
pixel_values_size ,
183
+ pixel_seq_len ,
176
184
quantize ,
177
185
config ,
178
186
loras ,
@@ -185,6 +193,7 @@ def _export_helper(
185
193
output_file : str ,
186
194
prefill_seq_lens : list [int ],
187
195
pixel_values_size : torch .Size ,
196
+ pixel_seq_len : int ,
188
197
quantize : bool ,
189
198
config : cfg .ModelConfig ,
190
199
loras : list [None | lora_utils .LoRA ],
@@ -197,11 +206,18 @@ def _export_helper(
197
206
prefill_tokens_list .append (torch .full ((1 , seq_len ), 0 , dtype = torch .int ))
198
207
prefill_input_pos_list .append (torch .arange (0 , seq_len , dtype = torch .int ))
199
208
200
- prefill_pixel_values = (
201
- torch .full (pixel_values_size , 0 , dtype = torch .float32 )
202
- if pixel_values_size
203
- else None
204
- )
209
+ prefill_pixel_values = None
210
+ prefill_tokens_list_with_pixel = []
211
+ prefill_input_pos_list_with_pixel = []
212
+ if pixel_values_size is not None :
213
+ prefill_pixel_values = torch .full (pixel_values_size , 0 , dtype = torch .float32 )
214
+ for seq_len in prefill_seq_lens :
215
+ prefill_tokens_list_with_pixel .append (
216
+ torch .full ((1 , seq_len + pixel_seq_len ), 0 , dtype = torch .int )
217
+ )
218
+ prefill_input_pos_list_with_pixel .append (
219
+ torch .arange (0 , seq_len + pixel_seq_len , dtype = torch .int )
220
+ )
205
221
206
222
if export_config .prefill_mask is None :
207
223
prefill_masks = None
@@ -238,13 +254,11 @@ def _export_helper(
238
254
for lora in loras :
239
255
for i in range (len (prefill_seq_lens )):
240
256
prefill_seq_len = prefill_seq_lens [i ]
241
- prefill_tokens = prefill_tokens_list [i ]
242
- prefill_input_pos = prefill_input_pos_list [i ]
243
257
prefill_signature_name = f'prefill_{ prefill_seq_len } '
244
258
245
259
sample_kwargs = {
246
- 'tokens' : prefill_tokens ,
247
- 'input_pos' : prefill_input_pos ,
260
+ 'tokens' : prefill_tokens_list [ i ] ,
261
+ 'input_pos' : prefill_input_pos_list [ i ] ,
248
262
'kv_cache' : prefill_kv ,
249
263
}
250
264
if prefill_masks is not None :
@@ -261,13 +275,13 @@ def _export_helper(
261
275
)
262
276
263
277
if prefill_pixel_values is not None :
278
+ sample_kwargs ['tokens' ] = prefill_tokens_list_with_pixel [i ]
279
+ sample_kwargs ['input_pos' ] = prefill_input_pos_list_with_pixel [i ]
280
+ sample_kwargs ['pixel_values' ] = prefill_pixel_values
264
281
converter .add_signature (
265
282
prefill_signature_name + '_pixel' ,
266
283
mod ,
267
- sample_kwargs = {
268
- ** sample_kwargs ,
269
- 'pixel_values' : prefill_pixel_values ,
270
- },
284
+ sample_kwargs = sample_kwargs ,
271
285
)
272
286
273
287
sample_kwargs = {
0 commit comments