@@ -126,42 +126,118 @@ def _pad_to_max_length(
126
126
current_segments ,
127
127
pad_token_id ,
128
128
device ,
129
- padding = "right" ,
129
+ padding_side = "right" ,
130
+ padding = "longest" ,
130
131
bos_token_tensor = None ,
131
132
cut_off_length = None ,
133
+ return_token_timestamps = False ,
134
+ force_unique_generate_call = False ,
132
135
):
133
136
max_total_length = 0
134
137
sequences = []
135
- if padding not in ["right" , "left" ]:
136
- raise ValueError (f"`padding` must be either 'right' or 'left', not { padding } " )
138
+ token_timestamps_list = []
139
+
140
+ if padding_side not in ["right" , "left" ]:
141
+ raise ValueError (
142
+ f"`padding_side` must be either 'right' or 'left', not { padding_side } "
143
+ )
144
+
145
+ if padding not in ["longest" , "max_length" ]:
146
+ raise ValueError (
147
+ f"`padding` must be either 'longest' or 'max_length', not { padding } "
148
+ )
149
+ elif padding == "max_length" and cut_off_length is None :
150
+ raise ValueError (
151
+ "`cut_off_length` must be specified when `padding='max_length'`"
152
+ )
153
+
154
+ if force_unique_generate_call :
155
+ sequences_list = []
156
+ timestamps_list = []
157
+ for segments in current_segments :
158
+ result = segments [0 ]["result" ]
159
+ sequences_list .append (
160
+ result if isinstance (result , torch .Tensor ) else result ["sequences" ]
161
+ )
162
+ if return_token_timestamps :
163
+ timestamps_list .append (result ["token_timestamps" ])
164
+
165
+ sequences = torch .stack (sequences_list , dim = 0 )
166
+ if return_token_timestamps :
167
+ token_timestamps = torch .stack (timestamps_list , dim = 0 )
168
+ return sequences , token_timestamps
169
+ return sequences
137
170
138
171
for current_segment_list in current_segments :
139
172
if (
140
173
current_segment_list is not None
141
174
and len ([d ["tokens" ] for d in current_segment_list ]) > 0
142
175
):
143
176
sequence = torch .cat ([d ["tokens" ] for d in current_segment_list ], dim = - 1 )
177
+ if return_token_timestamps :
178
+ token_timestamps = torch .cat (
179
+ [
180
+ d ["result" ]["token_timestamps" ][d ["idxs" ][0 ] : d ["idxs" ][1 ]]
181
+ for d in current_segment_list
182
+ ],
183
+ dim = - 1 ,
184
+ )
144
185
145
186
if cut_off_length is not None :
146
187
sequence = sequence [- cut_off_length :]
188
+ if return_token_timestamps :
189
+ token_timestamps = token_timestamps [- cut_off_length :]
147
190
148
191
if bos_token_tensor is not None :
149
192
sequence = torch .cat ([bos_token_tensor , sequence ])
150
-
193
+ if return_token_timestamps :
194
+ token_timestamps = torch .cat (
195
+ [
196
+ torch .ones_like (bos_token_tensor , device = device ) * 0.0 ,
197
+ token_timestamps ,
198
+ ]
199
+ )
151
200
sequences .append (sequence )
201
+ if return_token_timestamps :
202
+ token_timestamps_list .append (token_timestamps )
152
203
max_total_length = max (max_total_length , len (sequences [- 1 ]))
153
204
elif bos_token_tensor is not None :
154
205
sequences .append (bos_token_tensor )
206
+ if return_token_timestamps :
207
+ token_timestamps_list .append (
208
+ torch .ones_like (bos_token_tensor , device = device ) * 0.0
209
+ )
155
210
else :
156
211
sequences .append (torch .tensor ([], device = device ))
212
+ if return_token_timestamps :
213
+ token_timestamps_list .append (torch .tensor ([], device = device ))
157
214
215
+ max_total_length = (
216
+ cut_off_length + 1 if padding == "max_length" else max_total_length
217
+ )
158
218
for i in range (len (current_segments )):
159
219
pad_length = max_total_length - len (sequences [i ])
160
- pad = (0 , pad_length ) if padding == "right" else (pad_length , 0 )
220
+ pad = (0 , pad_length ) if padding_side == "right" else (pad_length , 0 )
221
+
161
222
sequences [i ] = F .pad (sequences [i ], pad = pad , value = pad_token_id )
223
+ if return_token_timestamps :
224
+ token_timestamps_list [i ] = F .pad (
225
+ token_timestamps_list [i ],
226
+ pad = pad ,
227
+ value = (
228
+ token_timestamps_list [i ][- 1 ]
229
+ if len (token_timestamps_list [i ]) > 0
230
+ else 0.0
231
+ ),
232
+ )
162
233
163
234
sequences = torch .stack (sequences , dim = 0 )
164
- return sequences
235
+
236
+ if return_token_timestamps :
237
+ token_timestamps = torch .stack (token_timestamps_list , dim = 0 )
238
+ return sequences , token_timestamps
239
+ else :
240
+ return sequences
165
241
166
242
167
243
def whisper_generate (
@@ -186,9 +262,11 @@ def whisper_generate(
186
262
num_segment_frames : Optional [int ] = None ,
187
263
attention_mask : Optional [torch .Tensor ] = None ,
188
264
time_precision : float = 0.02 ,
265
+ time_precision_features : float = 0.01 ,
189
266
return_token_timestamps : Optional [bool ] = None ,
190
267
return_segments : bool = False ,
191
268
return_dict_in_generate : Optional [bool ] = None ,
269
+ force_unique_generate_call : Optional [bool ] = None ,
192
270
** kwargs ,
193
271
):
194
272
# 0. deprecate old inputs
@@ -270,11 +348,23 @@ def whisper_generate(
270
348
else input_features .device
271
349
)
272
350
begin_index = init_tokens .shape [1 ]
351
+ num_beams = kwargs .get (
352
+ "num_beams" ,
353
+ (
354
+ generation_config .num_beams
355
+ if hasattr (generation_config , "num_beams" )
356
+ and generation_config .num_beams is not None
357
+ else 1
358
+ ),
359
+ )
360
+ if "assistant_model" in kwargs :
361
+ # speculative decoding: the model should be able to return eos token
362
+ generation_config .begin_suppress_tokens = None
273
363
logits_processor = self ._retrieve_logit_processors (
274
364
generation_config = generation_config ,
275
365
logits_processor = logits_processor ,
276
366
begin_index = begin_index , # begin index is index of first generated decoder token
277
- num_beams = kwargs . get ( " num_beams" , 1 ) ,
367
+ num_beams = num_beams ,
278
368
device = device ,
279
369
)
280
370
@@ -321,7 +411,23 @@ def whisper_generate(
321
411
batch_size = cur_bsz ,
322
412
generation_config = generation_config ,
323
413
)
324
-
414
+ # 5bis speculative decoding: ensure the assistant model does only one call to generate
415
+ # and therefore returns decoder input token ids and eos token id
416
+ # we set a flag in the generation config to force the model to make only one call to generate
417
+ # and return the decoder input token ids and eos token id
418
+ if "assistant_model" in kwargs :
419
+ assistant_model = kwargs ["assistant_model" ]
420
+ assistant_model .generation_config .force_unique_generate_call = True
421
+
422
+ if force_unique_generate_call is None :
423
+ if hasattr (generation_config , "force_unique_generate_call" ):
424
+ force_unique_generate_call = generation_config .force_unique_generate_call
425
+ elif hasattr (self .generation_config , "force_unique_generate_call" ):
426
+ force_unique_generate_call = (
427
+ self .generation_config .force_unique_generate_call
428
+ )
429
+ else :
430
+ force_unique_generate_call = False
325
431
# 6 Transcribe audio until we reach the end of all input audios
326
432
while (seek < max_frames ).any ():
327
433
# 6.1 NOTE: When in longform transcription mode and batch size > 1 we need to dynamically
@@ -336,7 +442,11 @@ def whisper_generate(
336
442
cur_bsz = cur_bsz ,
337
443
batch_idx_map = batch_idx_map ,
338
444
)
339
- time_offset = seek * time_precision / input_stride
445
+ time_offset = (
446
+ seek .to (torch .float32 if device .type == "mps" else torch .float64 )
447
+ * time_precision
448
+ / input_stride
449
+ )
340
450
seek_num_frames = (max_frames - seek ).clamp (max = num_segment_frames )
341
451
342
452
# 6.2 cut out next 30s segment from input features
@@ -355,6 +465,9 @@ def whisper_generate(
355
465
transformers .generation .logits_process .SuppressTokensLogitsProcessor ,
356
466
"suppress_tokens" ,
357
467
)
468
+ extra_kwargs = {}
469
+ if version .parse (transformers .__version__ ) >= version .parse ("4.47.0" ):
470
+ extra_kwargs ["timestamp_begin" ] = timestamp_begin
358
471
359
472
decoder_input_ids , kwargs = self ._prepare_decoder_input_ids (
360
473
cur_bsz = cur_bsz ,
@@ -367,6 +480,7 @@ def whisper_generate(
367
480
config = self .config ,
368
481
device = init_tokens .device ,
369
482
suppress_tokens = suppress_tokens ,
483
+ ** extra_kwargs ,
370
484
kwargs = kwargs ,
371
485
)
372
486
@@ -419,7 +533,11 @@ def whisper_generate(
419
533
if should_skip [i ]:
420
534
seek [prev_i ] += seek_num_frames [prev_i ]
421
535
continue
422
-
536
+ extra_kwargs = {}
537
+ if version .parse (transformers .__version__ ) >= version .parse ("4.48.0" ):
538
+ extra_kwargs ["decoder_input_ids" ] = decoder_input_ids
539
+ if version .parse (transformers .__version__ ) >= version .parse ("4.47.0" ):
540
+ extra_kwargs ["time_precision_features" ] = time_precision_features
423
541
segments , segment_offset = self ._retrieve_segment (
424
542
seek_sequence = seek_sequence ,
425
543
seek_outputs = seek_outputs ,
@@ -431,14 +549,13 @@ def whisper_generate(
431
549
prev_idx = prev_i ,
432
550
idx = i ,
433
551
return_token_timestamps = return_token_timestamps ,
552
+ ** extra_kwargs ,
434
553
)
435
-
554
+ seek [ prev_i ] += segment_offset
436
555
current_segments [prev_i ] += segments
437
556
438
- if is_shortform :
439
- seek [prev_i ] += max_frames [i ]
440
- else :
441
- seek [prev_i ] += segment_offset
557
+ if force_unique_generate_call :
558
+ break
442
559
443
560
# 7. Once all segments are added to the list of all segments, called `current_segments`, we extract the predicted
444
561
# output tokens from the list of dicts. If we use batch size > 1, we make sure to pad the output
@@ -451,65 +568,69 @@ def whisper_generate(
451
568
else current_segments
452
569
)
453
570
454
- sequences = _pad_to_max_length (
455
- final_segments ,
456
- generation_config .pad_token_id ,
571
+ # if return_dict_in_generate=True and we forced a unique call to generate or return_timestamps=False,
572
+ # meaning we are sure only one call to generate has been made,
573
+ # -> we can return a ModelOutput
574
+ # otherwise, return_dict_in_generate is applied in the 'result' of each segment in final_segments
575
+ if (
576
+ return_dict_in_generate
577
+ and generation_config .return_dict_in_generate
578
+ and (force_unique_generate_call or not return_timestamps )
579
+ ):
580
+ # only one call to generate_with_fallback, we can return a ModelOutput
581
+ outputs = self ._stack_split_outputs (
582
+ seek_outputs , model_output_type , self .device , kwargs
583
+ )
584
+ if num_return_sequences > 1 :
585
+ if (
586
+ hasattr (outputs , "encoder_attentions" )
587
+ and outputs .encoder_attentions is not None
588
+ ):
589
+ outputs .encoder_attentions = tuple (
590
+ outputs .encoder_attentions [i ][::num_return_sequences ]
591
+ for i in range (len (outputs .encoder_attentions ))
592
+ )
593
+ if (
594
+ hasattr (outputs , "encoder_hidden_states" )
595
+ and outputs .encoder_hidden_states is not None
596
+ ):
597
+ outputs .encoder_hidden_states = tuple (
598
+ outputs .encoder_hidden_states [i ][::num_return_sequences ]
599
+ for i in range (len (outputs .encoder_hidden_states ))
600
+ )
601
+ return outputs
602
+
603
+ padded_outputs = _pad_to_max_length (
604
+ current_segments = final_segments ,
605
+ pad_token_id = generation_config .pad_token_id ,
457
606
device = self .device ,
458
- padding = "right" ,
607
+ padding_side = "right" ,
608
+ return_token_timestamps = return_token_timestamps ,
609
+ force_unique_generate_call = force_unique_generate_call ,
459
610
)
460
611
461
- # 8. If we return all segments, the predicted output sequences are put under `"sequences"`.
462
- if return_segments :
463
- return {"sequences" : sequences , "segments" : final_segments }
464
-
465
- if is_shortform :
466
- # add eos token:
467
- if (
468
- generation_config .max_new_tokens is None
469
- and generation_config .max_length is None
470
- ):
471
- eos_tokens = torch .full (
472
- (sequences .shape [0 ], 1 ), generation_config .eos_token_id
473
- )
474
- sequences = torch .cat ([sequences , eos_tokens ], dim = - 1 )
475
-
476
- if return_token_timestamps :
477
- outputs = {}
478
- outputs ["sequences" ] = sequences
479
- outputs ["token_timestamps" ] = torch .stack (
480
- [d ["token_timestamps" ] for d in seek_outputs ], dim = 0
481
- )
482
- elif hasattr (self .config , "token_latency" ) and self .config .token_latency :
483
- outputs = (sequences , seek_outputs [0 ])
484
- else :
485
- outputs = sequences
486
-
487
- if return_dict_in_generate and generation_config .return_dict_in_generate :
488
- dict_outputs = self ._stack_split_outputs (
489
- seek_outputs , model_output_type , sequences .device , kwargs
490
- )
491
-
492
- if num_return_sequences > 1 :
493
- if (
494
- hasattr (dict_outputs , "encoder_attentions" )
495
- and dict_outputs .encoder_attentions is not None
496
- ):
497
- dict_outputs .encoder_attentions = tuple (
498
- dict_outputs .encoder_attentions [i ][::num_return_sequences ]
499
- for i in range (len (dict_outputs .encoder_attentions ))
500
- )
501
- if (
502
- hasattr (dict_outputs , "encoder_hidden_states" )
503
- and dict_outputs .encoder_hidden_states is not None
504
- ):
505
- dict_outputs .encoder_hidden_states = tuple (
506
- dict_outputs .encoder_hidden_states [i ][::num_return_sequences ]
507
- for i in range (len (dict_outputs .encoder_hidden_states ))
508
- )
509
- if return_token_timestamps :
510
- dict_outputs ["token_timestamps" ] = outputs ["token_timestamps" ]
511
- return dict_outputs
612
+ if return_dict_in_generate and generation_config .return_dict_in_generate :
613
+ return_segments = True
614
+ elif not return_segments and not return_token_timestamps :
615
+ if hasattr (self .config , "token_latency" ) and self .config .token_latency :
616
+ return (padded_outputs , seek_outputs [0 ])
617
+ return padded_outputs
618
+
619
+ if return_token_timestamps :
620
+ sequences , token_timestamps = padded_outputs
621
+ outputs = {
622
+ "sequences" : sequences ,
623
+ "token_timestamps" : token_timestamps ,
624
+ }
625
+ elif hasattr (self .config , "token_latency" ) and self .config .token_latency :
626
+ outputs = (sequences , seek_outputs [0 ])
627
+ else :
628
+ sequences = padded_outputs
629
+ outputs = {
630
+ "sequences" : sequences ,
631
+ }
512
632
513
- return outputs
633
+ if return_segments :
634
+ outputs ["segments" ] = final_segments
514
635
515
- return sequences
636
+ return outputs
0 commit comments