Skip to content

Commit 5048a88

Browse files
authored
[whisper] fix compatibility issues (#3647)
1 parent 9c5158d commit 5048a88

File tree

1 file changed

+193
-72
lines changed
  • intel_extension_for_pytorch/transformers/generation

1 file changed

+193
-72
lines changed

intel_extension_for_pytorch/transformers/generation/utils.py

Lines changed: 193 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -126,42 +126,118 @@ def _pad_to_max_length(
126126
current_segments,
127127
pad_token_id,
128128
device,
129-
padding="right",
129+
padding_side="right",
130+
padding="longest",
130131
bos_token_tensor=None,
131132
cut_off_length=None,
133+
return_token_timestamps=False,
134+
force_unique_generate_call=False,
132135
):
133136
max_total_length = 0
134137
sequences = []
135-
if padding not in ["right", "left"]:
136-
raise ValueError(f"`padding` must be either 'right' or 'left', not {padding}")
138+
token_timestamps_list = []
139+
140+
if padding_side not in ["right", "left"]:
141+
raise ValueError(
142+
f"`padding_side` must be either 'right' or 'left', not {padding_side}"
143+
)
144+
145+
if padding not in ["longest", "max_length"]:
146+
raise ValueError(
147+
f"`padding` must be either 'longest' or 'max_length', not {padding}"
148+
)
149+
elif padding == "max_length" and cut_off_length is None:
150+
raise ValueError(
151+
"`cut_off_length` must be specified when `padding='max_length'`"
152+
)
153+
154+
if force_unique_generate_call:
155+
sequences_list = []
156+
timestamps_list = []
157+
for segments in current_segments:
158+
result = segments[0]["result"]
159+
sequences_list.append(
160+
result if isinstance(result, torch.Tensor) else result["sequences"]
161+
)
162+
if return_token_timestamps:
163+
timestamps_list.append(result["token_timestamps"])
164+
165+
sequences = torch.stack(sequences_list, dim=0)
166+
if return_token_timestamps:
167+
token_timestamps = torch.stack(timestamps_list, dim=0)
168+
return sequences, token_timestamps
169+
return sequences
137170

138171
for current_segment_list in current_segments:
139172
if (
140173
current_segment_list is not None
141174
and len([d["tokens"] for d in current_segment_list]) > 0
142175
):
143176
sequence = torch.cat([d["tokens"] for d in current_segment_list], dim=-1)
177+
if return_token_timestamps:
178+
token_timestamps = torch.cat(
179+
[
180+
d["result"]["token_timestamps"][d["idxs"][0] : d["idxs"][1]]
181+
for d in current_segment_list
182+
],
183+
dim=-1,
184+
)
144185

145186
if cut_off_length is not None:
146187
sequence = sequence[-cut_off_length:]
188+
if return_token_timestamps:
189+
token_timestamps = token_timestamps[-cut_off_length:]
147190

148191
if bos_token_tensor is not None:
149192
sequence = torch.cat([bos_token_tensor, sequence])
150-
193+
if return_token_timestamps:
194+
token_timestamps = torch.cat(
195+
[
196+
torch.ones_like(bos_token_tensor, device=device) * 0.0,
197+
token_timestamps,
198+
]
199+
)
151200
sequences.append(sequence)
201+
if return_token_timestamps:
202+
token_timestamps_list.append(token_timestamps)
152203
max_total_length = max(max_total_length, len(sequences[-1]))
153204
elif bos_token_tensor is not None:
154205
sequences.append(bos_token_tensor)
206+
if return_token_timestamps:
207+
token_timestamps_list.append(
208+
torch.ones_like(bos_token_tensor, device=device) * 0.0
209+
)
155210
else:
156211
sequences.append(torch.tensor([], device=device))
212+
if return_token_timestamps:
213+
token_timestamps_list.append(torch.tensor([], device=device))
157214

215+
max_total_length = (
216+
cut_off_length + 1 if padding == "max_length" else max_total_length
217+
)
158218
for i in range(len(current_segments)):
159219
pad_length = max_total_length - len(sequences[i])
160-
pad = (0, pad_length) if padding == "right" else (pad_length, 0)
220+
pad = (0, pad_length) if padding_side == "right" else (pad_length, 0)
221+
161222
sequences[i] = F.pad(sequences[i], pad=pad, value=pad_token_id)
223+
if return_token_timestamps:
224+
token_timestamps_list[i] = F.pad(
225+
token_timestamps_list[i],
226+
pad=pad,
227+
value=(
228+
token_timestamps_list[i][-1]
229+
if len(token_timestamps_list[i]) > 0
230+
else 0.0
231+
),
232+
)
162233

163234
sequences = torch.stack(sequences, dim=0)
164-
return sequences
235+
236+
if return_token_timestamps:
237+
token_timestamps = torch.stack(token_timestamps_list, dim=0)
238+
return sequences, token_timestamps
239+
else:
240+
return sequences
165241

166242

167243
def whisper_generate(
@@ -186,9 +262,11 @@ def whisper_generate(
186262
num_segment_frames: Optional[int] = None,
187263
attention_mask: Optional[torch.Tensor] = None,
188264
time_precision: float = 0.02,
265+
time_precision_features: float = 0.01,
189266
return_token_timestamps: Optional[bool] = None,
190267
return_segments: bool = False,
191268
return_dict_in_generate: Optional[bool] = None,
269+
force_unique_generate_call: Optional[bool] = None,
192270
**kwargs,
193271
):
194272
# 0. deprecate old inputs
@@ -270,11 +348,23 @@ def whisper_generate(
270348
else input_features.device
271349
)
272350
begin_index = init_tokens.shape[1]
351+
num_beams = kwargs.get(
352+
"num_beams",
353+
(
354+
generation_config.num_beams
355+
if hasattr(generation_config, "num_beams")
356+
and generation_config.num_beams is not None
357+
else 1
358+
),
359+
)
360+
if "assistant_model" in kwargs:
361+
# speculative decoding: the model should be able to return eos token
362+
generation_config.begin_suppress_tokens = None
273363
logits_processor = self._retrieve_logit_processors(
274364
generation_config=generation_config,
275365
logits_processor=logits_processor,
276366
begin_index=begin_index, # begin index is index of first generated decoder token
277-
num_beams=kwargs.get("num_beams", 1),
367+
num_beams=num_beams,
278368
device=device,
279369
)
280370

@@ -321,7 +411,23 @@ def whisper_generate(
321411
batch_size=cur_bsz,
322412
generation_config=generation_config,
323413
)
324-
414+
# 5bis speculative decoding: ensure the assistant model does only one call to generate
415+
# and therefore returns decoder input token ids and eos token id
416+
# we set a flag in the generation config to force the model to make only one call to generate
417+
# and return the decoder input token ids and eos token id
418+
if "assistant_model" in kwargs:
419+
assistant_model = kwargs["assistant_model"]
420+
assistant_model.generation_config.force_unique_generate_call = True
421+
422+
if force_unique_generate_call is None:
423+
if hasattr(generation_config, "force_unique_generate_call"):
424+
force_unique_generate_call = generation_config.force_unique_generate_call
425+
elif hasattr(self.generation_config, "force_unique_generate_call"):
426+
force_unique_generate_call = (
427+
self.generation_config.force_unique_generate_call
428+
)
429+
else:
430+
force_unique_generate_call = False
325431
# 6 Transcribe audio until we reach the end of all input audios
326432
while (seek < max_frames).any():
327433
# 6.1 NOTE: When in longform transcription mode and batch size > 1 we need to dynamically
@@ -336,7 +442,11 @@ def whisper_generate(
336442
cur_bsz=cur_bsz,
337443
batch_idx_map=batch_idx_map,
338444
)
339-
time_offset = seek * time_precision / input_stride
445+
time_offset = (
446+
seek.to(torch.float32 if device.type == "mps" else torch.float64)
447+
* time_precision
448+
/ input_stride
449+
)
340450
seek_num_frames = (max_frames - seek).clamp(max=num_segment_frames)
341451

342452
# 6.2 cut out next 30s segment from input features
@@ -355,6 +465,9 @@ def whisper_generate(
355465
transformers.generation.logits_process.SuppressTokensLogitsProcessor,
356466
"suppress_tokens",
357467
)
468+
extra_kwargs = {}
469+
if version.parse(transformers.__version__) >= version.parse("4.47.0"):
470+
extra_kwargs["timestamp_begin"] = timestamp_begin
358471

359472
decoder_input_ids, kwargs = self._prepare_decoder_input_ids(
360473
cur_bsz=cur_bsz,
@@ -367,6 +480,7 @@ def whisper_generate(
367480
config=self.config,
368481
device=init_tokens.device,
369482
suppress_tokens=suppress_tokens,
483+
**extra_kwargs,
370484
kwargs=kwargs,
371485
)
372486

@@ -419,7 +533,11 @@ def whisper_generate(
419533
if should_skip[i]:
420534
seek[prev_i] += seek_num_frames[prev_i]
421535
continue
422-
536+
extra_kwargs = {}
537+
if version.parse(transformers.__version__) >= version.parse("4.48.0"):
538+
extra_kwargs["decoder_input_ids"] = decoder_input_ids
539+
if version.parse(transformers.__version__) >= version.parse("4.47.0"):
540+
extra_kwargs["time_precision_features"] = time_precision_features
423541
segments, segment_offset = self._retrieve_segment(
424542
seek_sequence=seek_sequence,
425543
seek_outputs=seek_outputs,
@@ -431,14 +549,13 @@ def whisper_generate(
431549
prev_idx=prev_i,
432550
idx=i,
433551
return_token_timestamps=return_token_timestamps,
552+
**extra_kwargs,
434553
)
435-
554+
seek[prev_i] += segment_offset
436555
current_segments[prev_i] += segments
437556

438-
if is_shortform:
439-
seek[prev_i] += max_frames[i]
440-
else:
441-
seek[prev_i] += segment_offset
557+
if force_unique_generate_call:
558+
break
442559

443560
# 7. Once all segments are added to the list of all segments, called `current_segments`, we extract the predicted
444561
# output tokens from the list of dicts. If we use batch size > 1, we make sure to pad the output
@@ -451,65 +568,69 @@ def whisper_generate(
451568
else current_segments
452569
)
453570

454-
sequences = _pad_to_max_length(
455-
final_segments,
456-
generation_config.pad_token_id,
571+
# if return_dict_in_generate=True and we forced a unique call to generate or return_timestamps=False,
572+
# meaning we are sure only one call to generate has been made,
573+
# -> we can return a ModelOutput
574+
# otherwise, return_dict_in_generate is applied in the 'result' of each segment in final_segments
575+
if (
576+
return_dict_in_generate
577+
and generation_config.return_dict_in_generate
578+
and (force_unique_generate_call or not return_timestamps)
579+
):
580+
# only one call to generate_with_fallback, we can return a ModelOutput
581+
outputs = self._stack_split_outputs(
582+
seek_outputs, model_output_type, self.device, kwargs
583+
)
584+
if num_return_sequences > 1:
585+
if (
586+
hasattr(outputs, "encoder_attentions")
587+
and outputs.encoder_attentions is not None
588+
):
589+
outputs.encoder_attentions = tuple(
590+
outputs.encoder_attentions[i][::num_return_sequences]
591+
for i in range(len(outputs.encoder_attentions))
592+
)
593+
if (
594+
hasattr(outputs, "encoder_hidden_states")
595+
and outputs.encoder_hidden_states is not None
596+
):
597+
outputs.encoder_hidden_states = tuple(
598+
outputs.encoder_hidden_states[i][::num_return_sequences]
599+
for i in range(len(outputs.encoder_hidden_states))
600+
)
601+
return outputs
602+
603+
padded_outputs = _pad_to_max_length(
604+
current_segments=final_segments,
605+
pad_token_id=generation_config.pad_token_id,
457606
device=self.device,
458-
padding="right",
607+
padding_side="right",
608+
return_token_timestamps=return_token_timestamps,
609+
force_unique_generate_call=force_unique_generate_call,
459610
)
460611

461-
# 8. If we return all segments, the predicted output sequences are put under `"sequences"`.
462-
if return_segments:
463-
return {"sequences": sequences, "segments": final_segments}
464-
465-
if is_shortform:
466-
# add eos token:
467-
if (
468-
generation_config.max_new_tokens is None
469-
and generation_config.max_length is None
470-
):
471-
eos_tokens = torch.full(
472-
(sequences.shape[0], 1), generation_config.eos_token_id
473-
)
474-
sequences = torch.cat([sequences, eos_tokens], dim=-1)
475-
476-
if return_token_timestamps:
477-
outputs = {}
478-
outputs["sequences"] = sequences
479-
outputs["token_timestamps"] = torch.stack(
480-
[d["token_timestamps"] for d in seek_outputs], dim=0
481-
)
482-
elif hasattr(self.config, "token_latency") and self.config.token_latency:
483-
outputs = (sequences, seek_outputs[0])
484-
else:
485-
outputs = sequences
486-
487-
if return_dict_in_generate and generation_config.return_dict_in_generate:
488-
dict_outputs = self._stack_split_outputs(
489-
seek_outputs, model_output_type, sequences.device, kwargs
490-
)
491-
492-
if num_return_sequences > 1:
493-
if (
494-
hasattr(dict_outputs, "encoder_attentions")
495-
and dict_outputs.encoder_attentions is not None
496-
):
497-
dict_outputs.encoder_attentions = tuple(
498-
dict_outputs.encoder_attentions[i][::num_return_sequences]
499-
for i in range(len(dict_outputs.encoder_attentions))
500-
)
501-
if (
502-
hasattr(dict_outputs, "encoder_hidden_states")
503-
and dict_outputs.encoder_hidden_states is not None
504-
):
505-
dict_outputs.encoder_hidden_states = tuple(
506-
dict_outputs.encoder_hidden_states[i][::num_return_sequences]
507-
for i in range(len(dict_outputs.encoder_hidden_states))
508-
)
509-
if return_token_timestamps:
510-
dict_outputs["token_timestamps"] = outputs["token_timestamps"]
511-
return dict_outputs
612+
if return_dict_in_generate and generation_config.return_dict_in_generate:
613+
return_segments = True
614+
elif not return_segments and not return_token_timestamps:
615+
if hasattr(self.config, "token_latency") and self.config.token_latency:
616+
return (padded_outputs, seek_outputs[0])
617+
return padded_outputs
618+
619+
if return_token_timestamps:
620+
sequences, token_timestamps = padded_outputs
621+
outputs = {
622+
"sequences": sequences,
623+
"token_timestamps": token_timestamps,
624+
}
625+
elif hasattr(self.config, "token_latency") and self.config.token_latency:
626+
outputs = (sequences, seek_outputs[0])
627+
else:
628+
sequences = padded_outputs
629+
outputs = {
630+
"sequences": sequences,
631+
}
512632

513-
return outputs
633+
if return_segments:
634+
outputs["segments"] = final_segments
514635

515-
return sequences
636+
return outputs

0 commit comments

Comments
 (0)