Skip to content

Commit cfff7ca

Browse files
authored
[Whisper] Pipeline: handle long form generation (#35750)
* handle long form generation * add warning * correct incorrect in place token change * update test to catch edge case * make style * update warning * add doc
1 parent 02ecdcf commit cfff7ca

File tree

4 files changed

+64
-17
lines changed

4 files changed

+64
-17
lines changed

src/transformers/models/whisper/generation_whisper.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,18 @@ def _pad_to_max_length(
136136
cut_off_length=None,
137137
return_token_timestamps=False,
138138
force_unique_generate_call=False,
139+
skip_ending_double_timestamps=False,
140+
timestamp_begin=None,
139141
):
142+
"""
143+
skip_ending_double_timestamps: when the segement ended with two timestamp tokens, whether to ignore the last timestamp token
144+
see https://github.com/huggingface/transformers/pull/35750
145+
146+
_pad_to_max_length is used in different contexts:
147+
1. At the end of generation: we need to keep both ending timestamp tokens in the segment (see https://github.com/huggingface/transformers/pull/34537).
148+
2. In the middle of generation, e.g. when condition_on_prev_tokens=True and we want to use the last generated tokens as decoder_input_ids:
149+
we must skip one of the double ending timestamp tokens (see https://github.com/huggingface/transformers/pull/35750).
150+
"""
140151
max_total_length = 0
141152
sequences = []
142153
token_timestamps_list = []
@@ -166,7 +177,17 @@ def _pad_to_max_length(
166177

167178
for current_segment_list in current_segments:
168179
if current_segment_list is not None and len([d["tokens"] for d in current_segment_list]) > 0:
169-
sequence = torch.cat([d["tokens"] for d in current_segment_list], dim=-1)
180+
sequences_list = []
181+
for d in current_segment_list:
182+
if skip_ending_double_timestamps and len(d["tokens"]) > 2 and d["tokens"][-2] >= timestamp_begin:
183+
# the segment finishes with two timestamp tokens
184+
# we need to ignore the last timestamp token
185+
# see https://github.com/huggingface/transformers/pull/34537
186+
sequences_list.append(d["tokens"][:-1])
187+
else:
188+
sequences_list.append(d["tokens"])
189+
sequence = torch.cat(sequences_list, dim=-1)
190+
170191
if return_token_timestamps:
171192
token_timestamps = torch.cat(
172193
[d["result"]["token_timestamps"][d["idxs"][0] : d["idxs"][1]] for d in current_segment_list],
@@ -1809,14 +1830,6 @@ def _prepare_decoder_input_ids(
18091830
# according to https://github.com/openai/whisper/blob/e58f28804528831904c3b6f2c0e473f346223433/whisper/decoding.py#L609
18101831
active_segments = [current_segments[i] if do_condition_on_prev_tokens[i] else None for i in batch_idx_map]
18111832

1812-
for segments in active_segments:
1813-
for seg in segments:
1814-
if len(seg["tokens"]) > 2 and seg["tokens"][-2] >= timestamp_begin:
1815-
# the segment finishes with two timestamp tokens
1816-
# we need to ignore the last timestamp token
1817-
# see https://github.com/huggingface/transformers/pull/34537
1818-
seg["tokens"] = seg["tokens"][:-1]
1819-
18201833
if prompt_ids is not None and generation_config.prompt_condition_type == "all-segments":
18211834
prev_ids = prompt_ids
18221835
else:
@@ -1833,6 +1846,8 @@ def _prepare_decoder_input_ids(
18331846
padding=padding,
18341847
bos_token_tensor=prev_ids,
18351848
cut_off_length=cut_off_length,
1849+
skip_ending_double_timestamps=True,
1850+
timestamp_begin=timestamp_begin,
18361851
)
18371852
decoder_input_ids = torch.cat([prev_tokens, decoder_input_ids], dim=-1)
18381853

src/transformers/models/whisper/tokenization_whisper.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -910,7 +910,7 @@ def _convert_to_list(token_ids):
910910
return token_ids
911911

912912

913-
def _decode_asr(tokenizer, model_outputs, *, return_timestamps, return_language, time_precision):
913+
def _decode_asr(tokenizer, model_outputs, *, return_timestamps, return_language, time_precision, segment_size=1500):
914914
"""
915915
Internal method meant to only be used by asr pipeline. Handles all the little quirks specific to whisper to handle
916916
the various options not allowed in other seq2seq models
@@ -962,6 +962,12 @@ def new_chunk():
962962
last_timestamp = None
963963
first_timestamp = timestamp_begin
964964

965+
# long form generation: we need to handle the case where the call to generate returns concatenated segments,
966+
# with underlying multiple calls to generate
967+
cur_max_timestamp = 0.0
968+
prev_segments_len = 0.0
969+
penultimate_timestamp = 0.0
970+
965971
if "stride" in output:
966972
chunk_len, stride_left, stride_right = output["stride"]
967973
# Offset the timings to account for the other `model_outputs`.
@@ -1024,7 +1030,24 @@ def new_chunk():
10241030
pass
10251031
elif token >= timestamp_begin:
10261032
# 3/ Timestamp token
1027-
time = (token - timestamp_begin) * time_precision + time_offset
1033+
1034+
timestamp = float((token - timestamp_begin) * time_precision)
1035+
if timestamp < cur_max_timestamp:
1036+
# next segment has started
1037+
last_was_single_ending = i >= 2 and not (
1038+
token_ids[i - 1] >= timestamp_begin and token_ids[i - 2] >= timestamp_begin
1039+
)
1040+
if last_was_single_ending:
1041+
prev_segments_len += time_precision * segment_size
1042+
else:
1043+
cur_max_timestamp = penultimate_timestamp
1044+
prev_segments_len += penultimate_timestamp
1045+
1046+
penultimate_timestamp = cur_max_timestamp
1047+
cur_max_timestamp = timestamp
1048+
1049+
time = (token - timestamp_begin) * time_precision + time_offset + prev_segments_len
1050+
10281051
time = round(time, 2)
10291052
if last_timestamp and token >= last_timestamp:
10301053
# Whisper outputted a timestamp token, but it falls within

src/transformers/pipelines/automatic_speech_recognition.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -283,13 +283,20 @@ def _sanitize_parameters(
283283
# No parameters on this pipeline right now
284284
preprocess_params = {}
285285
if chunk_length_s is not None:
286-
if self.type == "seq2seq" and not ignore_warning:
287-
logger.warning(
286+
if self.type in ["seq2seq", "seq2seq_whisper"] and not ignore_warning:
287+
type_warning = (
288288
"Using `chunk_length_s` is very experimental with seq2seq models. The results will not necessarily"
289289
" be entirely accurate and will have caveats. More information:"
290290
" https://github.com/huggingface/transformers/pull/20104. Ignore this warning with pipeline(...,"
291-
" ignore_warning=True)"
291+
" ignore_warning=True)."
292292
)
293+
if self.type == "seq2seq_whisper":
294+
type_warning += (
295+
" To use Whisper for long-form transcription, use rather the model's `generate` method directly "
296+
"as the model relies on it's own chunking mechanism (cf. Whisper original paper, section 3.8. "
297+
"Long-form Transcription)."
298+
)
299+
logger.warning(type_warning)
293300
preprocess_params["chunk_length_s"] = chunk_length_s
294301
if stride_length_s is not None:
295302
preprocess_params["stride_length_s"] = stride_length_s

tests/models/whisper/test_modeling_whisper.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2031,11 +2031,13 @@ def test_large_timestamp_generation(self):
20312031
).input_features
20322032
input_features = input_features.to(torch_device)
20332033

2034-
generated_ids = model.generate(input_features, max_length=448, return_timestamps=True).to("cpu")
2034+
generated_ids = model.generate(
2035+
input_features, max_length=448, return_timestamps=True, condition_on_prev_tokens=True
2036+
).to("cpu")
20352037

20362038
# fmt: off
20372039
EXPECTED_OUTPUT = torch.tensor([
2038-
50365, 2221, 13, 2326, 388, 391, 307, 264, 50244, 295, 264, 2808, 5359, 11, 293, 321, 366, 5404, 281, 2928, 702, 14943, 13, 50629, 50682, 6966, 307, 2221, 13, 2326, 388, 391, 311, 9060, 1570, 1880, 813, 702, 1871, 13, 50870, 50911, 634, 5112, 505, 300, 412, 341, 42729, 3196, 295, 264, 1064, 11, 365, 5272, 293, 12904, 9256, 450, 10539, 949, 505, 11, 51245, 51287, 1034, 4680, 10117, 490, 3936, 293, 1080, 3542, 5160, 881, 26336, 281, 264, 1575, 13, 51494, 51523, 634, 575, 12525, 22618, 1968, 6144, 35617, 1456, 397, 266, 311, 589, 307, 534, 10281, 934, 439, 11, 51799, 51815, 50365, 293, 393, 4411, 50430
2040+
[50365, 2221, 13, 2326, 388, 391, 307, 264, 50244, 295, 264, 2808, 5359, 11, 293, 321, 366, 5404, 281, 2928, 702, 14943, 13, 50629, 50682, 6966, 307, 2221, 13, 2326, 388, 391, 311, 9060, 1570, 1880, 813, 702, 1871, 13, 50870, 50911, 634, 5112, 505, 300, 412, 341, 42729, 3196, 295, 264, 1064, 11, 365, 5272, 293, 12904, 9256, 450, 10539, 949, 505, 11, 51245, 51287, 1034, 4680, 10117, 490, 3936, 293, 1080, 3542, 5160, 881, 26336, 281, 264, 1575, 13, 51494, 51523, 634, 575, 12525, 22618, 1968, 6144, 35617, 1456, 397, 266, 311, 589, 307, 534, 10281, 934, 439, 11, 51799, 51815, 50365, 293, 393, 4411, 50431]
20392041
])
20402042
# fmt: on
20412043
torch.testing.assert_close(generated_ids[0], EXPECTED_OUTPUT)
@@ -2078,7 +2080,7 @@ def test_large_timestamp_generation(self):
20782080
},
20792081
{
20802082
"text": (" and can discover"),
2081-
"timestamp": (28.68, 29.98),
2083+
"timestamp": (28.68, 30.0),
20822084
},
20832085
],
20842086
}

0 commit comments

Comments
 (0)