2
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
3
import asyncio
4
4
import io
5
+ import math
5
6
import time
6
7
from collections .abc import AsyncGenerator
7
8
from math import ceil
8
9
from typing import Final , Optional , Union , cast
9
10
11
+ import numpy as np
10
12
from fastapi import Request
11
13
12
14
from vllm .config import ModelConfig
143
145
# As per https://platform.openai.com/docs/guides/speech-to-text#overview.
144
146
# TODO configurable
145
147
MAX_AUDIO_CLIP_FILESIZE_MB = 25
148
+ OVERLAP_CHUNK_SECOND = 1
149
+ MIN_ENERGY_WINDOW_SIZE = 1600 # 1600 ~ 100ms for 16000 Hz audio
146
150
147
151
148
152
class OpenAIServingTranscription (OpenAIServing ):
@@ -178,7 +182,7 @@ async def _preprocess_transcription(
178
182
self ,
179
183
request : TranscriptionRequest ,
180
184
audio_data : bytes ,
181
- ) -> tuple [PromptType , float ]:
185
+ ) -> tuple [list [ PromptType ] , float ]:
182
186
# Validate request
183
187
# TODO language should be optional and can be guessed.
184
188
# For now we default to en. See
@@ -206,22 +210,22 @@ async def _preprocess_transcription(
206
210
y , sr = librosa .load (bytes_ )
207
211
208
212
duration = librosa .get_duration (y = y , sr = sr )
209
- if duration > self .max_audio_clip_s :
210
- raise ValueError (
211
- f"Maximum clip duration ({ self .max_audio_clip_s } s) "
212
- "exceeded." )
213
-
214
- prompt = {
215
- "encoder_prompt" : {
216
- "prompt" : "" ,
217
- "multi_modal_data" : {
218
- "audio" : (y , sr ),
213
+ chunks = [y ] if duration < 30 else self ._split_audio (y , sr )
214
+ prompts = []
215
+ for i , chunk in enumerate (chunks ):
216
+ prompt = {
217
+ "encoder_prompt" : {
218
+ "prompt" : "" ,
219
+ "multi_modal_data" : {
220
+ "audio" : (chunk , sr ),
221
+ },
219
222
},
220
- },
221
- "decoder_prompt" :
222
- f"<|startoftranscript|>{ lang_token } <|transcribe|><|notimestamps|>{ request .prompt } "
223
- }
224
- return cast (PromptType , prompt ), duration
223
+ "decoder_prompt" :
224
+ f"<|startoftranscript|>{ lang_token } <|transcribe|><|notimestamps|>{ request .prompt } "
225
+ if i == 0 else ""
226
+ }
227
+ prompts .append (cast (PromptType , prompt ))
228
+ return prompts , duration
225
229
226
230
# TODO (varun) : Make verbose response work !
227
231
async def create_transcription (
@@ -268,7 +272,7 @@ async def create_transcription(
268
272
"Currently do not support PromptAdapter for Transcription."
269
273
)
270
274
271
- prompt , duration_s = await self ._preprocess_transcription (
275
+ prompts , duration_s = await self ._preprocess_transcription (
272
276
request = request ,
273
277
audio_data = audio_data ,
274
278
)
@@ -277,7 +281,8 @@ async def create_transcription(
277
281
logger .exception ("Error in preprocessing prompt inputs" )
278
282
return self .create_error_response (str (e ))
279
283
280
- result_generator : Optional [AsyncGenerator [RequestOutput , None ]] = None
284
+ list_result_generator : Optional [list [AsyncGenerator [RequestOutput ,
285
+ None ]]] = None
281
286
try :
282
287
# Unlike most decoder-only models, whisper generation length is not
283
288
# constrained by the size of the input audio, which is mapped to a
@@ -288,32 +293,36 @@ async def create_transcription(
288
293
289
294
self ._log_inputs (
290
295
request_id ,
291
- prompt ['decoder_prompt' ], # type: ignore
296
+ prompts [ 0 ] ['decoder_prompt' ], # type: ignore
292
297
params = sampling_params ,
293
298
lora_request = None ,
294
299
prompt_adapter_request = None )
295
300
296
- result_generator = self .engine_client .generate (
297
- prompt ,
298
- sampling_params ,
299
- request_id ,
300
- )
301
+ list_result_generator = [
302
+ self .engine_client .generate (
303
+ prompt ,
304
+ sampling_params ,
305
+ request_id ,
306
+ ) for prompt in prompts
307
+ ]
301
308
except ValueError as e :
302
309
# TODO: Use a vllm-specific Validation Error
303
310
return self .create_error_response (str (e ))
304
311
305
312
if request .stream :
306
313
return self .transcription_stream_generator (request ,
307
- result_generator ,
314
+ list_result_generator ,
308
315
request_id ,
309
316
request_metadata ,
310
317
duration_s )
311
318
# Non-streaming response.
312
319
try :
313
- assert result_generator is not None
314
- async for op in result_generator :
315
- result = op
316
- return TranscriptionResponse (text = result .outputs [0 ].text )
320
+ assert list_result_generator is not None
321
+ text = ""
322
+ for result_generator in list_result_generator :
323
+ async for op in result_generator :
324
+ text += op .outputs [0 ].text
325
+ return TranscriptionResponse (text = text )
317
326
except asyncio .CancelledError :
318
327
return self .create_error_response ("Client disconnected" )
319
328
except ValueError as e :
@@ -322,7 +331,7 @@ async def create_transcription(
322
331
323
332
async def transcription_stream_generator (
324
333
self , request : TranscriptionRequest ,
325
- result_generator : AsyncGenerator [RequestOutput , None ],
334
+ list_result_generator : list [ AsyncGenerator [RequestOutput , None ] ],
326
335
request_id : str , request_metadata : RequestResponseMetadata ,
327
336
audio_duration_s : float ) -> AsyncGenerator [str , None ]:
328
337
created_time = int (time .time ())
@@ -335,60 +344,65 @@ async def transcription_stream_generator(
335
344
include_usage = request .stream_include_usage \
336
345
if request .stream_include_usage else False
337
346
include_continuous_usage = request .stream_continuous_usage_stats \
338
- if include_usage and request .stream_continuous_usage_stats \
339
- else False
347
+ if include_usage and request .stream_continuous_usage_stats \
348
+ else False
340
349
341
350
try :
342
- async for res in result_generator :
343
- # On first result.
344
- if res .prompt_token_ids is not None :
345
- # Do not account the 4-tokens `<|startoftranscript|>..`
346
- # Could be negative when language token is not specified.
347
- num_prompt_tokens = max (len (res .prompt_token_ids ) - 4 , 0 )
348
- # NOTE(NickLucche) user can't pass encoder prompts directly
349
- # at least not to Whisper. One indicator of the encoder
350
- # amount of processing is the log-mel spectogram length.
351
- num_prompt_tokens += ceil (audio_duration_s *
352
- self .model_sr / self .hop_length )
353
-
354
- # We need to do it here, because if there are exceptions in
355
- # the result_generator, it needs to be sent as the FIRST
356
- # response (by the try...catch).
357
-
358
- # Just one output (n=1) supported.
359
- assert len (res .outputs ) == 1
360
- output = res .outputs [0 ]
361
-
362
- delta_message = DeltaMessage (content = output .text )
363
- completion_tokens += len (output .token_ids )
364
-
365
- if output .finish_reason is None :
366
- # Still generating, send delta update.
367
- choice_data = TranscriptionResponseStreamChoice (
368
- delta = delta_message )
369
- else :
370
- # Model is finished generating.
371
- choice_data = TranscriptionResponseStreamChoice (
372
- delta = delta_message ,
373
- finish_reason = output .finish_reason ,
374
- stop_reason = output .stop_reason )
375
-
376
- chunk = TranscriptionStreamResponse (id = request_id ,
377
- object = chunk_object_type ,
378
- created = created_time ,
379
- choices = [choice_data ],
380
- model = model_name )
381
-
382
- # handle usage stats if requested & if continuous
383
- if include_continuous_usage :
384
- chunk .usage = UsageInfo (
385
- prompt_tokens = num_prompt_tokens ,
386
- completion_tokens = completion_tokens ,
387
- total_tokens = num_prompt_tokens + completion_tokens ,
388
- )
389
-
390
- data = chunk .model_dump_json (exclude_unset = True )
391
- yield f"data: { data } \n \n "
351
+ for result_generator in list_result_generator :
352
+ async for res in result_generator :
353
+ # On first result.
354
+ if res .prompt_token_ids is not None :
355
+ # Do not account the 4-tokens `<|startoftranscript|>..`
356
+ # Could be negative when language token
357
+ # is not specified.
358
+ num_prompt_tokens = max (
359
+ len (res .prompt_token_ids ) - 4 , 0 )
360
+ # NOTE(NickLucche) user can't pass encoder
361
+ # prompts directly at least not to Whisper.
362
+ # One indicator of the encoder amount of processing
363
+ # is the log-mel spectogram length.
364
+ num_prompt_tokens += ceil (
365
+ audio_duration_s * self .model_sr / self .hop_length )
366
+
367
+ # We need to do it here, because if there are exceptions in
368
+ # the result_generator, it needs to be sent as the FIRST
369
+ # response (by the try...catch).
370
+
371
+ # Just one output (n=1) supported.
372
+ assert len (res .outputs ) == 1
373
+ output = res .outputs [0 ]
374
+
375
+ delta_message = DeltaMessage (content = output .text )
376
+ completion_tokens += len (output .token_ids )
377
+
378
+ if output .finish_reason is None :
379
+ # Still generating, send delta update.
380
+ choice_data = TranscriptionResponseStreamChoice (
381
+ delta = delta_message )
382
+ else :
383
+ # Model is finished generating.
384
+ choice_data = TranscriptionResponseStreamChoice (
385
+ delta = delta_message ,
386
+ finish_reason = output .finish_reason ,
387
+ stop_reason = output .stop_reason )
388
+
389
+ chunk = TranscriptionStreamResponse (
390
+ id = request_id ,
391
+ object = chunk_object_type ,
392
+ created = created_time ,
393
+ choices = [choice_data ],
394
+ model = model_name )
395
+
396
+ # handle usage stats if requested & if continuous
397
+ if include_continuous_usage :
398
+ chunk .usage = UsageInfo (
399
+ prompt_tokens = num_prompt_tokens ,
400
+ completion_tokens = completion_tokens ,
401
+ total_tokens = num_prompt_tokens + completion_tokens ,
402
+ )
403
+
404
+ data = chunk .model_dump_json (exclude_unset = True )
405
+ yield f"data: { data } \n \n "
392
406
393
407
# Once the final token is handled, if stream_options.include_usage
394
408
# is sent, send the usage.
@@ -422,3 +436,52 @@ async def transcription_stream_generator(
422
436
yield f"data: { data } \n \n "
423
437
# Send the final done message after all response.n are finished
424
438
yield "data: [DONE]\n \n "
439
+
440
+ def _split_audio (self , audio_data : np .ndarray ,
441
+ sample_rate : int ) -> list [np .ndarray ]:
442
+ chunk_size = sample_rate * self .max_audio_clip_s
443
+ overlap_size = sample_rate * OVERLAP_CHUNK_SECOND
444
+ chunks = []
445
+ i = 0
446
+ while i < audio_data .shape [- 1 ]:
447
+ if i + chunk_size >= audio_data .shape [- 1 ]:
448
+ # handle last chunk
449
+ chunks .append (audio_data [..., i :])
450
+ break
451
+
452
+ # Find the best split point in the overlap region
453
+ search_start = i + chunk_size - overlap_size
454
+ search_end = min (i + chunk_size , audio_data .shape [- 1 ])
455
+ split_point = self ._find_split_point (audio_data , search_start ,
456
+ search_end )
457
+
458
+ # Extract chunk up to the split point
459
+ chunks .append (audio_data [..., i :split_point ])
460
+ i = split_point
461
+ return chunks
462
+
463
+ def _find_split_point (self , wav : np .ndarray , start_idx : int ,
464
+ end_idx : int ) -> int :
465
+ """Find the best point to split audio by
466
+ looking for silence or low amplitude.
467
+ Args:
468
+ wav: Audio tensor [1, T]
469
+ start_idx: Start index of search region
470
+ end_idx: End index of search region
471
+ Returns:
472
+ Index of best splitting point
473
+ """
474
+ segment = wav [start_idx :end_idx ]
475
+
476
+ # Calculate RMS energy in small windows
477
+ min_energy = math .inf
478
+ quietest_idx = 0
479
+ for i in range (0 ,
480
+ len (segment ) - MIN_ENERGY_WINDOW_SIZE ,
481
+ MIN_ENERGY_WINDOW_SIZE ):
482
+ window = segment [i :i + MIN_ENERGY_WINDOW_SIZE ]
483
+ energy = (window ** 2 ).mean ()** 0.5
484
+ if energy < min_energy :
485
+ quietest_idx = i + start_idx
486
+ min_energy = energy
487
+ return quietest_idx
0 commit comments