Skip to content

Commit e795d72

Browse files
NickLuccheywang96
andauthored
[Frontend] Add /v1/audio/translations OpenAI API endpoint (#19615)
Signed-off-by: Roger Wang <ywang@roblox.com> Signed-off-by: NickLucche <nlucches@redhat.com> Co-authored-by: Roger Wang <ywang@roblox.com>
1 parent 8359f4c commit e795d72

File tree

10 files changed

+1127
-461
lines changed

10 files changed

+1127
-461
lines changed

docs/serving/openai_compatible_server.md

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@ We currently support the following OpenAI APIs:
5757
- Only applicable to [embedding models](../models/pooling_models.md) (`--task embed`).
5858
- [Transcriptions API][transcriptions-api] (`/v1/audio/transcriptions`)
5959
- Only applicable to Automatic Speech Recognition (ASR) models (OpenAI Whisper) (`--task generate`).
60+
- [Translation API][translations-api] (`/v1/audio/translations`)
61+
- Only applicable to Automatic Speech Recognition (ASR) models (OpenAI Whisper) (`--task generate`).
6062

6163
In addition, we have the following custom APIs:
6264

@@ -374,6 +376,34 @@ The following extra parameters are supported:
374376
```python
375377
--8<-- "vllm/entrypoints/openai/protocol.py:transcription-extra-params"
376378
```
379+
380+
[](){ #translations-api }
381+
382+
### Translations API
383+
384+
Our Translation API is compatible with [OpenAI's Translations API](https://platform.openai.com/docs/api-reference/audio/createTranslation);
385+
you can use the [official OpenAI Python client](https://github.com/openai/openai-python) to interact with it.
386+
Whisper models can translate audio from one of the 55 non-English supported languages into English.
387+
Please mind that the popular `openai/whisper-large-v3-turbo` model does not support translating.
388+
389+
!!! note
390+
To use the Translation API, please install with extra audio dependencies using `pip install vllm[audio]`.
391+
392+
Code example: <gh-file:examples/online_serving/openai_translation_client.py>
393+
394+
#### Extra Parameters
395+
396+
The following [sampling parameters][sampling-params] are supported.
397+
398+
```python
399+
--8<-- "vllm/entrypoints/openai/protocol.py:translation-sampling-params"
400+
```
401+
402+
The following extra parameters are supported:
403+
404+
```python
405+
--8<-- "vllm/entrypoints/openai/protocol.py:translation-extra-params"
406+
```
377407

378408
[](){ #tokenizer-api }
379409

examples/online_serving/openai_transcription_client.py

Lines changed: 20 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -26,23 +26,12 @@
2626

2727
from vllm.assets.audio import AudioAsset
2828

29-
mary_had_lamb = AudioAsset("mary_had_lamb").get_local_path()
30-
winning_call = AudioAsset("winning_call").get_local_path()
3129

32-
# Modify OpenAI's API key and API base to use vLLM's API server.
33-
openai_api_key = "EMPTY"
34-
openai_api_base = "http://localhost:8000/v1"
35-
client = OpenAI(
36-
api_key=openai_api_key,
37-
base_url=openai_api_base,
38-
)
39-
40-
41-
def sync_openai():
30+
def sync_openai(audio_path: str, client: OpenAI):
4231
"""
4332
Perform synchronous transcription using OpenAI-compatible API.
4433
"""
45-
with open(str(mary_had_lamb), "rb") as f:
34+
with open(audio_path, "rb") as f:
4635
transcription = client.audio.transcriptions.create(
4736
file=f,
4837
model="openai/whisper-large-v3",
@@ -58,8 +47,7 @@ def sync_openai():
5847
print("transcription result:", transcription.text)
5948

6049

61-
# OpenAI Transcription API client does not support streaming.
62-
async def stream_openai_response():
50+
async def stream_openai_response(audio_path: str, base_url: str, api_key: str):
6351
"""
6452
Perform streaming transcription using vLLM's raw HTTP streaming API.
6553
"""
@@ -68,11 +56,12 @@ async def stream_openai_response():
6856
"stream": True,
6957
"model": "openai/whisper-large-v3",
7058
}
71-
url = openai_api_base + "/audio/transcriptions"
72-
headers = {"Authorization": f"Bearer {openai_api_key}"}
59+
url = base_url + "/audio/transcriptions"
60+
headers = {"Authorization": f"Bearer {api_key}"}
7361
print("transcription result:", end=" ")
62+
# OpenAI Transcription API client does not support streaming.
7463
async with httpx.AsyncClient() as client:
75-
with open(str(winning_call), "rb") as f:
64+
with open(audio_path, "rb") as f:
7665
async with client.stream(
7766
"POST", url, files={"file": f}, data=data, headers=headers
7867
) as response:
@@ -93,10 +82,20 @@ async def stream_openai_response():
9382

9483

9584
def main():
96-
sync_openai()
97-
85+
mary_had_lamb = str(AudioAsset("mary_had_lamb").get_local_path())
86+
winning_call = str(AudioAsset("winning_call").get_local_path())
87+
88+
# Modify OpenAI's API key and API base to use vLLM's API server.
89+
openai_api_key = "EMPTY"
90+
openai_api_base = "http://localhost:8000/v1"
91+
client = OpenAI(
92+
api_key=openai_api_key,
93+
base_url=openai_api_base,
94+
)
95+
96+
sync_openai(mary_had_lamb, client)
9897
# Run the asynchronous function
99-
asyncio.run(stream_openai_response())
98+
asyncio.run(stream_openai_response(winning_call, openai_api_base, openai_api_key))
10099

101100

102101
if __name__ == "__main__":
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
import asyncio
4+
import json
5+
6+
import httpx
7+
from openai import OpenAI
8+
9+
from vllm.assets.audio import AudioAsset
10+
11+
12+
def sync_openai(audio_path: str, client: OpenAI):
13+
with open(audio_path, "rb") as f:
14+
translation = client.audio.translations.create(
15+
file=f,
16+
model="openai/whisper-large-v3",
17+
response_format="json",
18+
temperature=0.0,
19+
# Additional params not provided by OpenAI API.
20+
extra_body=dict(
21+
language="it",
22+
seed=4419,
23+
repetition_penalty=1.3,
24+
),
25+
)
26+
print("translation result:", translation.text)
27+
28+
29+
async def stream_openai_response(audio_path: str, base_url: str, api_key: str):
30+
data = {
31+
"language": "it",
32+
"stream": True,
33+
"model": "openai/whisper-large-v3",
34+
}
35+
url = base_url + "/audio/translations"
36+
headers = {"Authorization": f"Bearer {api_key}"}
37+
print("translation result:", end=" ")
38+
# OpenAI translation API client does not support streaming.
39+
async with httpx.AsyncClient() as client:
40+
with open(audio_path, "rb") as f:
41+
async with client.stream(
42+
"POST", url, files={"file": f}, data=data, headers=headers
43+
) as response:
44+
async for line in response.aiter_lines():
45+
# Each line is a JSON object prefixed with 'data: '
46+
if line:
47+
if line.startswith("data: "):
48+
line = line[len("data: ") :]
49+
# Last chunk, stream ends
50+
if line.strip() == "[DONE]":
51+
break
52+
# Parse the JSON response
53+
chunk = json.loads(line)
54+
# Extract and print the content
55+
content = chunk["choices"][0].get("delta", {}).get("content")
56+
print(content, end="")
57+
58+
59+
def main():
60+
foscolo = str(AudioAsset("azacinto_foscolo").get_local_path())
61+
62+
# Modify OpenAI's API key and API base to use vLLM's API server.
63+
openai_api_key = "EMPTY"
64+
openai_api_base = "http://localhost:8000/v1"
65+
client = OpenAI(
66+
api_key=openai_api_key,
67+
base_url=openai_api_base,
68+
)
69+
sync_openai(foscolo, client)
70+
# Run the asynchronous function
71+
asyncio.run(stream_openai_response(foscolo, openai_api_base, openai_api_key))
72+
73+
74+
if __name__ == "__main__":
75+
main()

tests/entrypoints/openai/test_transcription_validation.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,8 @@ async def test_long_audio_request(mary_had_lamb):
8282

8383
mary_had_lamb.seek(0)
8484
audio, sr = librosa.load(mary_had_lamb)
85+
# Add small silence after each audio for repeatability in the split process
86+
audio = np.pad(audio, (0, 1600))
8587
repeated_audio = np.tile(audio, 10)
8688
# Repeated audio to buffer
8789
buffer = io.BytesIO()
Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
import io
5+
# imports for guided decoding tests
6+
import json
7+
from unittest.mock import patch
8+
9+
import librosa
10+
import numpy as np
11+
import pytest
12+
import soundfile as sf
13+
from openai._base_client import AsyncAPIClient
14+
15+
from vllm.assets.audio import AudioAsset
16+
17+
from ...utils import RemoteOpenAIServer
18+
19+
20+
@pytest.fixture
21+
def foscolo():
22+
# Test translation it->en
23+
path = AudioAsset('azacinto_foscolo').get_local_path()
24+
with open(str(path), "rb") as f:
25+
yield f
26+
27+
28+
# NOTE: (NickLucche) the large-v3-turbo model was not trained on translation!
29+
@pytest.mark.asyncio
30+
async def test_basic_audio(foscolo):
31+
model_name = "openai/whisper-small"
32+
server_args = ["--enforce-eager"]
33+
with RemoteOpenAIServer(model_name, server_args) as remote_server:
34+
client = remote_server.get_async_client()
35+
translation = await client.audio.translations.create(
36+
model=model_name,
37+
file=foscolo,
38+
response_format="text",
39+
# TODO remove once language detection is implemented
40+
extra_body=dict(language="it"),
41+
temperature=0.0)
42+
out = json.loads(translation)['text'].strip()
43+
assert "Nor will I ever touch the sacred" in out
44+
45+
46+
@pytest.mark.asyncio
47+
async def test_audio_prompt(foscolo):
48+
model_name = "openai/whisper-small"
49+
server_args = ["--enforce-eager"]
50+
# Condition whisper on starting text
51+
prompt = "Nor have I ever"
52+
with RemoteOpenAIServer(model_name, server_args) as remote_server:
53+
client = remote_server.get_async_client()
54+
transcription = await client.audio.translations.create(
55+
model=model_name,
56+
file=foscolo,
57+
prompt=prompt,
58+
extra_body=dict(language="it"),
59+
response_format="text",
60+
temperature=0.0)
61+
out = json.loads(transcription)['text']
62+
assert "Nor will I ever touch the sacred" not in out
63+
assert prompt not in out
64+
65+
66+
@pytest.mark.asyncio
67+
async def test_non_asr_model(foscolo):
68+
# text to text model
69+
model_name = "JackFram/llama-68m"
70+
server_args = ["--enforce-eager"]
71+
with RemoteOpenAIServer(model_name, server_args) as remote_server:
72+
client = remote_server.get_async_client()
73+
res = await client.audio.translations.create(model=model_name,
74+
file=foscolo,
75+
temperature=0.0)
76+
assert res.code == 400 and not res.text
77+
assert res.message == "The model does not support Translations API"
78+
79+
80+
@pytest.mark.asyncio
81+
async def test_streaming_response(foscolo):
82+
model_name = "openai/whisper-small"
83+
server_args = ["--enforce-eager"]
84+
translation = ""
85+
with RemoteOpenAIServer(model_name, server_args) as remote_server:
86+
client = remote_server.get_async_client()
87+
res_no_stream = await client.audio.translations.create(
88+
model=model_name,
89+
file=foscolo,
90+
response_format="json",
91+
extra_body=dict(language="it"),
92+
temperature=0.0)
93+
# Unfortunately this only works when the openai client is patched
94+
# to use streaming mode, not exposed in the translation api.
95+
original_post = AsyncAPIClient.post
96+
97+
async def post_with_stream(*args, **kwargs):
98+
kwargs['stream'] = True
99+
return await original_post(*args, **kwargs)
100+
101+
with patch.object(AsyncAPIClient, "post", new=post_with_stream):
102+
client = remote_server.get_async_client()
103+
res = await client.audio.translations.create(model=model_name,
104+
file=foscolo,
105+
temperature=0.0,
106+
extra_body=dict(
107+
stream=True,
108+
language="it"))
109+
# Reconstruct from chunks and validate
110+
async for chunk in res:
111+
# just a chunk
112+
text = chunk.choices[0]['delta']['content']
113+
translation += text
114+
115+
assert translation == res_no_stream.text
116+
117+
118+
@pytest.mark.asyncio
119+
async def test_stream_options(foscolo):
120+
model_name = "openai/whisper-small"
121+
server_args = ["--enforce-eager"]
122+
with RemoteOpenAIServer(model_name, server_args) as remote_server:
123+
original_post = AsyncAPIClient.post
124+
125+
async def post_with_stream(*args, **kwargs):
126+
kwargs['stream'] = True
127+
return await original_post(*args, **kwargs)
128+
129+
with patch.object(AsyncAPIClient, "post", new=post_with_stream):
130+
client = remote_server.get_async_client()
131+
res = await client.audio.translations.create(
132+
model=model_name,
133+
file=foscolo,
134+
temperature=0.0,
135+
extra_body=dict(language="it",
136+
stream=True,
137+
stream_include_usage=True,
138+
stream_continuous_usage_stats=True))
139+
final = False
140+
continuous = True
141+
async for chunk in res:
142+
if not len(chunk.choices):
143+
# final usage sent
144+
final = True
145+
else:
146+
continuous = continuous and hasattr(chunk, 'usage')
147+
assert final and continuous
148+
149+
150+
@pytest.mark.asyncio
151+
async def test_long_audio_request(foscolo):
152+
model_name = "openai/whisper-small"
153+
server_args = ["--enforce-eager"]
154+
155+
foscolo.seek(0)
156+
audio, sr = librosa.load(foscolo)
157+
repeated_audio = np.tile(audio, 2)
158+
# Repeated audio to buffer
159+
buffer = io.BytesIO()
160+
sf.write(buffer, repeated_audio, sr, format='WAV')
161+
buffer.seek(0)
162+
with RemoteOpenAIServer(model_name, server_args) as remote_server:
163+
client = remote_server.get_async_client()
164+
translation = await client.audio.translations.create(
165+
model=model_name,
166+
file=buffer,
167+
extra_body=dict(language="it"),
168+
response_format="text",
169+
temperature=0.0)
170+
out = json.loads(translation)['text'].strip().lower()
171+
# TODO investigate higher model uncertainty in for longer translations.
172+
assert out.count("nor will i ever") == 2

0 commit comments

Comments
 (0)