Skip to content

Commit 860bd12

Browse files
nguyenhoangthuan99yangw-dev
authored andcommitted
[Frontend] add chunking audio for > 30s audio (vllm-project#19597)
Signed-off-by: nguyenhoangthuan99 <thuanhppro12@gmail.com> Signed-off-by: Yang Wang <elainewy@meta.com>
1 parent 6d5cbb4 commit 860bd12

File tree

2 files changed

+168
-95
lines changed

2 files changed

+168
-95
lines changed

tests/entrypoints/openai/test_transcription_validation.py

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -74,19 +74,29 @@ async def test_bad_requests(mary_had_lamb):
7474
language="hh",
7575
temperature=0.0)
7676

77-
# Expect audio too long: repeat the timeseries
78-
mary_had_lamb.seek(0)
79-
audio, sr = librosa.load(mary_had_lamb)
80-
repeated_audio = np.tile(audio, 10)
81-
# Repeated audio to buffer
82-
buffer = io.BytesIO()
83-
sf.write(buffer, repeated_audio, sr, format='WAV')
84-
buffer.seek(0)
85-
with pytest.raises(openai.BadRequestError):
86-
await client.audio.transcriptions.create(model=model_name,
87-
file=buffer,
88-
language="en",
89-
temperature=0.0)
77+
78+
@pytest.mark.asyncio
79+
async def test_long_audio_request(mary_had_lamb):
80+
model_name = "openai/whisper-large-v3-turbo"
81+
server_args = ["--enforce-eager"]
82+
83+
mary_had_lamb.seek(0)
84+
audio, sr = librosa.load(mary_had_lamb)
85+
repeated_audio = np.tile(audio, 10)
86+
# Repeated audio to buffer
87+
buffer = io.BytesIO()
88+
sf.write(buffer, repeated_audio, sr, format='WAV')
89+
buffer.seek(0)
90+
with RemoteOpenAIServer(model_name, server_args) as remote_server:
91+
client = remote_server.get_async_client()
92+
transcription = await client.audio.transcriptions.create(
93+
model=model_name,
94+
file=buffer,
95+
language="en",
96+
response_format="text",
97+
temperature=0.0)
98+
out = json.loads(transcription)['text']
99+
assert out.count("Mary had a little lamb") == 10
90100

91101

92102
@pytest.mark.asyncio

vllm/entrypoints/openai/serving_transcription.py

Lines changed: 145 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,13 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
import asyncio
44
import io
5+
import math
56
import time
67
from collections.abc import AsyncGenerator
78
from math import ceil
89
from typing import Final, Optional, Union, cast
910

11+
import numpy as np
1012
from fastapi import Request
1113

1214
from vllm.config import ModelConfig
@@ -143,6 +145,8 @@
143145
# As per https://platform.openai.com/docs/guides/speech-to-text#overview.
144146
# TODO configurable
145147
MAX_AUDIO_CLIP_FILESIZE_MB = 25
148+
OVERLAP_CHUNK_SECOND = 1
149+
MIN_ENERGY_WINDOW_SIZE = 1600 # 1600 ~ 100ms for 16000 Hz audio
146150

147151

148152
class OpenAIServingTranscription(OpenAIServing):
@@ -178,7 +182,7 @@ async def _preprocess_transcription(
178182
self,
179183
request: TranscriptionRequest,
180184
audio_data: bytes,
181-
) -> tuple[PromptType, float]:
185+
) -> tuple[list[PromptType], float]:
182186
# Validate request
183187
# TODO language should be optional and can be guessed.
184188
# For now we default to en. See
@@ -206,22 +210,22 @@ async def _preprocess_transcription(
206210
y, sr = librosa.load(bytes_)
207211

208212
duration = librosa.get_duration(y=y, sr=sr)
209-
if duration > self.max_audio_clip_s:
210-
raise ValueError(
211-
f"Maximum clip duration ({self.max_audio_clip_s}s) "
212-
"exceeded.")
213-
214-
prompt = {
215-
"encoder_prompt": {
216-
"prompt": "",
217-
"multi_modal_data": {
218-
"audio": (y, sr),
213+
chunks = [y] if duration < 30 else self._split_audio(y, sr)
214+
prompts = []
215+
for i, chunk in enumerate(chunks):
216+
prompt = {
217+
"encoder_prompt": {
218+
"prompt": "",
219+
"multi_modal_data": {
220+
"audio": (chunk, sr),
221+
},
219222
},
220-
},
221-
"decoder_prompt":
222-
f"<|startoftranscript|>{lang_token}<|transcribe|><|notimestamps|>{request.prompt}"
223-
}
224-
return cast(PromptType, prompt), duration
223+
"decoder_prompt":
224+
f"<|startoftranscript|>{lang_token}<|transcribe|><|notimestamps|>{request.prompt}"
225+
if i == 0 else ""
226+
}
227+
prompts.append(cast(PromptType, prompt))
228+
return prompts, duration
225229

226230
# TODO (varun) : Make verbose response work !
227231
async def create_transcription(
@@ -268,7 +272,7 @@ async def create_transcription(
268272
"Currently do not support PromptAdapter for Transcription."
269273
)
270274

271-
prompt, duration_s = await self._preprocess_transcription(
275+
prompts, duration_s = await self._preprocess_transcription(
272276
request=request,
273277
audio_data=audio_data,
274278
)
@@ -277,7 +281,8 @@ async def create_transcription(
277281
logger.exception("Error in preprocessing prompt inputs")
278282
return self.create_error_response(str(e))
279283

280-
result_generator: Optional[AsyncGenerator[RequestOutput, None]] = None
284+
list_result_generator: Optional[list[AsyncGenerator[RequestOutput,
285+
None]]] = None
281286
try:
282287
# Unlike most decoder-only models, whisper generation length is not
283288
# constrained by the size of the input audio, which is mapped to a
@@ -288,32 +293,36 @@ async def create_transcription(
288293

289294
self._log_inputs(
290295
request_id,
291-
prompt['decoder_prompt'], # type: ignore
296+
prompts[0]['decoder_prompt'], # type: ignore
292297
params=sampling_params,
293298
lora_request=None,
294299
prompt_adapter_request=None)
295300

296-
result_generator = self.engine_client.generate(
297-
prompt,
298-
sampling_params,
299-
request_id,
300-
)
301+
list_result_generator = [
302+
self.engine_client.generate(
303+
prompt,
304+
sampling_params,
305+
request_id,
306+
) for prompt in prompts
307+
]
301308
except ValueError as e:
302309
# TODO: Use a vllm-specific Validation Error
303310
return self.create_error_response(str(e))
304311

305312
if request.stream:
306313
return self.transcription_stream_generator(request,
307-
result_generator,
314+
list_result_generator,
308315
request_id,
309316
request_metadata,
310317
duration_s)
311318
# Non-streaming response.
312319
try:
313-
assert result_generator is not None
314-
async for op in result_generator:
315-
result = op
316-
return TranscriptionResponse(text=result.outputs[0].text)
320+
assert list_result_generator is not None
321+
text = ""
322+
for result_generator in list_result_generator:
323+
async for op in result_generator:
324+
text += op.outputs[0].text
325+
return TranscriptionResponse(text=text)
317326
except asyncio.CancelledError:
318327
return self.create_error_response("Client disconnected")
319328
except ValueError as e:
@@ -322,7 +331,7 @@ async def create_transcription(
322331

323332
async def transcription_stream_generator(
324333
self, request: TranscriptionRequest,
325-
result_generator: AsyncGenerator[RequestOutput, None],
334+
list_result_generator: list[AsyncGenerator[RequestOutput, None]],
326335
request_id: str, request_metadata: RequestResponseMetadata,
327336
audio_duration_s: float) -> AsyncGenerator[str, None]:
328337
created_time = int(time.time())
@@ -335,60 +344,65 @@ async def transcription_stream_generator(
335344
include_usage = request.stream_include_usage \
336345
if request.stream_include_usage else False
337346
include_continuous_usage = request.stream_continuous_usage_stats\
338-
if include_usage and request.stream_continuous_usage_stats\
339-
else False
347+
if include_usage and request.stream_continuous_usage_stats\
348+
else False
340349

341350
try:
342-
async for res in result_generator:
343-
# On first result.
344-
if res.prompt_token_ids is not None:
345-
# Do not account the 4-tokens `<|startoftranscript|>..`
346-
# Could be negative when language token is not specified.
347-
num_prompt_tokens = max(len(res.prompt_token_ids) - 4, 0)
348-
# NOTE(NickLucche) user can't pass encoder prompts directly
349-
# at least not to Whisper. One indicator of the encoder
350-
# amount of processing is the log-mel spectogram length.
351-
num_prompt_tokens += ceil(audio_duration_s *
352-
self.model_sr / self.hop_length)
353-
354-
# We need to do it here, because if there are exceptions in
355-
# the result_generator, it needs to be sent as the FIRST
356-
# response (by the try...catch).
357-
358-
# Just one output (n=1) supported.
359-
assert len(res.outputs) == 1
360-
output = res.outputs[0]
361-
362-
delta_message = DeltaMessage(content=output.text)
363-
completion_tokens += len(output.token_ids)
364-
365-
if output.finish_reason is None:
366-
# Still generating, send delta update.
367-
choice_data = TranscriptionResponseStreamChoice(
368-
delta=delta_message)
369-
else:
370-
# Model is finished generating.
371-
choice_data = TranscriptionResponseStreamChoice(
372-
delta=delta_message,
373-
finish_reason=output.finish_reason,
374-
stop_reason=output.stop_reason)
375-
376-
chunk = TranscriptionStreamResponse(id=request_id,
377-
object=chunk_object_type,
378-
created=created_time,
379-
choices=[choice_data],
380-
model=model_name)
381-
382-
# handle usage stats if requested & if continuous
383-
if include_continuous_usage:
384-
chunk.usage = UsageInfo(
385-
prompt_tokens=num_prompt_tokens,
386-
completion_tokens=completion_tokens,
387-
total_tokens=num_prompt_tokens + completion_tokens,
388-
)
389-
390-
data = chunk.model_dump_json(exclude_unset=True)
391-
yield f"data: {data}\n\n"
351+
for result_generator in list_result_generator:
352+
async for res in result_generator:
353+
# On first result.
354+
if res.prompt_token_ids is not None:
355+
# Do not account the 4-tokens `<|startoftranscript|>..`
356+
# Could be negative when language token
357+
# is not specified.
358+
num_prompt_tokens = max(
359+
len(res.prompt_token_ids) - 4, 0)
360+
# NOTE(NickLucche) user can't pass encoder
361+
# prompts directly at least not to Whisper.
362+
# One indicator of the encoder amount of processing
363+
# is the log-mel spectogram length.
364+
num_prompt_tokens += ceil(
365+
audio_duration_s * self.model_sr / self.hop_length)
366+
367+
# We need to do it here, because if there are exceptions in
368+
# the result_generator, it needs to be sent as the FIRST
369+
# response (by the try...catch).
370+
371+
# Just one output (n=1) supported.
372+
assert len(res.outputs) == 1
373+
output = res.outputs[0]
374+
375+
delta_message = DeltaMessage(content=output.text)
376+
completion_tokens += len(output.token_ids)
377+
378+
if output.finish_reason is None:
379+
# Still generating, send delta update.
380+
choice_data = TranscriptionResponseStreamChoice(
381+
delta=delta_message)
382+
else:
383+
# Model is finished generating.
384+
choice_data = TranscriptionResponseStreamChoice(
385+
delta=delta_message,
386+
finish_reason=output.finish_reason,
387+
stop_reason=output.stop_reason)
388+
389+
chunk = TranscriptionStreamResponse(
390+
id=request_id,
391+
object=chunk_object_type,
392+
created=created_time,
393+
choices=[choice_data],
394+
model=model_name)
395+
396+
# handle usage stats if requested & if continuous
397+
if include_continuous_usage:
398+
chunk.usage = UsageInfo(
399+
prompt_tokens=num_prompt_tokens,
400+
completion_tokens=completion_tokens,
401+
total_tokens=num_prompt_tokens + completion_tokens,
402+
)
403+
404+
data = chunk.model_dump_json(exclude_unset=True)
405+
yield f"data: {data}\n\n"
392406

393407
# Once the final token is handled, if stream_options.include_usage
394408
# is sent, send the usage.
@@ -422,3 +436,52 @@ async def transcription_stream_generator(
422436
yield f"data: {data}\n\n"
423437
# Send the final done message after all response.n are finished
424438
yield "data: [DONE]\n\n"
439+
440+
def _split_audio(self, audio_data: np.ndarray,
441+
sample_rate: int) -> list[np.ndarray]:
442+
chunk_size = sample_rate * self.max_audio_clip_s
443+
overlap_size = sample_rate * OVERLAP_CHUNK_SECOND
444+
chunks = []
445+
i = 0
446+
while i < audio_data.shape[-1]:
447+
if i + chunk_size >= audio_data.shape[-1]:
448+
# handle last chunk
449+
chunks.append(audio_data[..., i:])
450+
break
451+
452+
# Find the best split point in the overlap region
453+
search_start = i + chunk_size - overlap_size
454+
search_end = min(i + chunk_size, audio_data.shape[-1])
455+
split_point = self._find_split_point(audio_data, search_start,
456+
search_end)
457+
458+
# Extract chunk up to the split point
459+
chunks.append(audio_data[..., i:split_point])
460+
i = split_point
461+
return chunks
462+
463+
def _find_split_point(self, wav: np.ndarray, start_idx: int,
464+
end_idx: int) -> int:
465+
"""Find the best point to split audio by
466+
looking for silence or low amplitude.
467+
Args:
468+
wav: Audio tensor [1, T]
469+
start_idx: Start index of search region
470+
end_idx: End index of search region
471+
Returns:
472+
Index of best splitting point
473+
"""
474+
segment = wav[start_idx:end_idx]
475+
476+
# Calculate RMS energy in small windows
477+
min_energy = math.inf
478+
quietest_idx = 0
479+
for i in range(0,
480+
len(segment) - MIN_ENERGY_WINDOW_SIZE,
481+
MIN_ENERGY_WINDOW_SIZE):
482+
window = segment[i:i + MIN_ENERGY_WINDOW_SIZE]
483+
energy = (window**2).mean()**0.5
484+
if energy < min_energy:
485+
quietest_idx = i + start_idx
486+
min_energy = energy
487+
return quietest_idx

0 commit comments

Comments
 (0)