Skip to content

Commit e7e3e6d

Browse files
Voxtral (#20970)
Signed-off-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
1 parent 4ffd963 commit e7e3e6d

File tree

14 files changed

+913
-47
lines changed

14 files changed

+913
-47
lines changed

examples/offline_inference/audio_language.py

Lines changed: 75 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
import os
1212
from dataclasses import asdict
13-
from typing import NamedTuple, Optional
13+
from typing import Any, NamedTuple, Optional
1414

1515
from huggingface_hub import snapshot_download
1616
from transformers import AutoTokenizer
@@ -30,7 +30,9 @@
3030

3131
class ModelRequestData(NamedTuple):
3232
engine_args: EngineArgs
33-
prompt: str
33+
prompt: Optional[str] = None
34+
prompt_token_ids: Optional[dict[str, list[int]]] = None
35+
multi_modal_data: Optional[dict[str, Any]] = None
3436
stop_token_ids: Optional[list[int]] = None
3537
lora_requests: Optional[list[LoRARequest]] = None
3638

@@ -40,6 +42,60 @@ class ModelRequestData(NamedTuple):
4042
# Unless specified, these settings have been tested to work on a single L4.
4143

4244

45+
# Voxtral
46+
def run_voxtral(question: str, audio_count: int) -> ModelRequestData:
47+
from mistral_common.audio import Audio
48+
from mistral_common.protocol.instruct.messages import (
49+
AudioChunk,
50+
RawAudio,
51+
TextChunk,
52+
UserMessage,
53+
)
54+
from mistral_common.protocol.instruct.request import ChatCompletionRequest
55+
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
56+
57+
model_name = "mistralai/Voxtral-Mini-3B-2507"
58+
tokenizer = MistralTokenizer.from_hf_hub(model_name)
59+
60+
engine_args = EngineArgs(
61+
model=model_name,
62+
max_model_len=8192,
63+
max_num_seqs=2,
64+
limit_mm_per_prompt={"audio": audio_count},
65+
config_format="mistral",
66+
load_format="mistral",
67+
tokenizer_mode="mistral",
68+
enforce_eager=True,
69+
enable_chunked_prefill=False,
70+
)
71+
72+
text_chunk = TextChunk(text=question)
73+
audios = [
74+
Audio.from_file(str(audio_assets[i].get_local_path()), strict=False)
75+
for i in range(audio_count)
76+
]
77+
audio_chunks = [
78+
AudioChunk(input_audio=RawAudio.from_audio(audio)) for audio in audios
79+
]
80+
81+
messages = [UserMessage(content=[*audio_chunks, text_chunk])]
82+
83+
req = ChatCompletionRequest(messages=messages, model=model_name)
84+
85+
tokens = tokenizer.encode_chat_completion(req)
86+
prompt_ids, audios = tokens.tokens, tokens.audios
87+
88+
audios_and_sr = [(au.audio_array, au.sampling_rate) for au in audios]
89+
90+
multi_modal_data = {"audio": audios_and_sr}
91+
92+
return ModelRequestData(
93+
engine_args=engine_args,
94+
prompt_token_ids=prompt_ids,
95+
multi_modal_data=multi_modal_data,
96+
)
97+
98+
4399
# Granite Speech
44100
def run_granite_speech(question: str, audio_count: int) -> ModelRequestData:
45101
# NOTE - the setting in this example are somehat different than what is
@@ -243,6 +299,7 @@ def run_whisper(question: str, audio_count: int) -> ModelRequestData:
243299

244300

245301
model_example_map = {
302+
"voxtral": run_voxtral,
246303
"granite_speech": run_granite_speech,
247304
"minicpmo": run_minicpmo,
248305
"phi4_mm": run_phi4mm,
@@ -311,16 +368,24 @@ def main(args):
311368
temperature=0.2, max_tokens=64, stop_token_ids=req_data.stop_token_ids
312369
)
313370

314-
mm_data = {}
315-
if audio_count > 0:
316-
mm_data = {
317-
"audio": [
318-
asset.audio_and_sample_rate for asset in audio_assets[:audio_count]
319-
]
320-
}
371+
mm_data = req_data.multi_modal_data
372+
if not mm_data:
373+
mm_data = {}
374+
if audio_count > 0:
375+
mm_data = {
376+
"audio": [
377+
asset.audio_and_sample_rate for asset in audio_assets[:audio_count]
378+
]
379+
}
321380

322381
assert args.num_prompts > 0
323-
inputs = {"prompt": req_data.prompt, "multi_modal_data": mm_data}
382+
inputs = {"multi_modal_data": mm_data}
383+
384+
if req_data.prompt:
385+
inputs["prompt"] = req_data.prompt
386+
else:
387+
inputs["prompt_token_ids"] = req_data.prompt_token_ids
388+
324389
if args.num_prompts > 1:
325390
# Batch inference
326391
inputs = [inputs] * args.num_prompts

requirements/common.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ pyzmq >= 25.0.0
3333
msgspec
3434
gguf >= 0.13.0
3535
importlib_metadata; python_version < '3.10'
36-
mistral_common[opencv] >= 1.6.2
36+
mistral_common[opencv] >= 1.8.0
3737
opencv-python-headless >= 4.11.0 # required for video IO
3838
pyyaml
3939
six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12

requirements/nightly_torch_test.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ jiwer # required for audio tests
2323
timm # required for internvl test
2424
transformers_stream_generator # required for qwen-vl test
2525
matplotlib # required for qwen-vl test
26-
mistral_common[opencv] >= 1.6.2 # required for pixtral test
26+
mistral_common[opencv] >= 1.8.0 # required for voxtral test
2727
num2words # required for smolvlm test
2828
opencv-python-headless >= 4.11.0 # required for video test
2929
datamodel_code_generator # required for minicpm3 test

requirements/test.in

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ torchvision==0.22.0
2828
transformers_stream_generator # required for qwen-vl test
2929
mamba_ssm # required for plamo2 test
3030
matplotlib # required for qwen-vl test
31-
mistral_common[opencv] >= 1.7.0 # required for pixtral test
31+
mistral_common[opencv] >= 1.8.0 # required for voxtral test
3232
num2words # required for smolvlm test
3333
opencv-python-headless >= 4.11.0 # required for video test
3434
datamodel_code_generator # required for minicpm3 test

requirements/test.txt

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,7 @@ mbstrdecoder==1.1.3
305305
# typepy
306306
mdurl==0.1.2
307307
# via markdown-it-py
308-
mistral-common==1.7.0
308+
mistral-common==1.8.0
309309
# via -r requirements/test.in
310310
more-itertools==10.5.0
311311
# via lm-eval
@@ -518,6 +518,8 @@ pyasn1-modules==0.4.2
518518
# via google-auth
519519
pybind11==2.13.6
520520
# via lm-eval
521+
pycountry==24.6.1
522+
# via pydantic-extra-types
521523
pycparser==2.22
522524
# via cffi
523525
pycryptodomex==3.22.0
@@ -528,9 +530,12 @@ pydantic==2.11.5
528530
# datamodel-code-generator
529531
# mistral-common
530532
# mteb
533+
# pydantic-extra-types
531534
# ray
532535
pydantic-core==2.33.2
533536
# via pydantic
537+
pydantic-extra-types==2.10.5
538+
# via mistral-common
534539
pygments==2.18.0
535540
# via rich
536541
pyparsing==3.2.0
@@ -835,6 +840,7 @@ typing-extensions==4.12.2
835840
# pqdm
836841
# pydantic
837842
# pydantic-core
843+
# pydantic-extra-types
838844
# torch
839845
# typer
840846
# typing-inspection

setup.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -692,7 +692,8 @@ def _read_requirements(filename: str) -> list[str]:
692692
"tensorizer": ["tensorizer==2.10.1"],
693693
"fastsafetensors": ["fastsafetensors >= 0.1.10"],
694694
"runai": ["runai-model-streamer", "runai-model-streamer-s3", "boto3"],
695-
"audio": ["librosa", "soundfile"], # Required for audio processing
695+
"audio": ["librosa", "soundfile",
696+
"mistral_common[audio]"], # Required for audio processing
696697
"video": [] # Kept for backwards compatibility
697698
},
698699
cmdclass=cmdclass,

tests/entrypoints/openai/test_transcription_validation.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,11 @@
1717

1818
from ...utils import RemoteOpenAIServer
1919

20+
MISTRAL_FORMAT_ARGS = [
21+
"--tokenizer_mode", "mistral", "--config_format", "mistral",
22+
"--load_format", "mistral"
23+
]
24+
2025

2126
@pytest.fixture
2227
def mary_had_lamb():
@@ -33,9 +38,18 @@ def winning_call():
3338

3439

3540
@pytest.mark.asyncio
36-
async def test_basic_audio(mary_had_lamb):
37-
model_name = "openai/whisper-large-v3-turbo"
41+
@pytest.mark.parametrize(
42+
"model_name",
43+
["openai/whisper-large-v3-turbo", "mistralai/Voxtral-Mini-3B-2507"])
44+
async def test_basic_audio(mary_had_lamb, model_name):
3845
server_args = ["--enforce-eager"]
46+
47+
if model_name.startswith("mistralai"):
48+
server_args += MISTRAL_FORMAT_ARGS
49+
50+
# TODO(PATRICK) - REMOVE AFTER RELEASE
51+
return # skip for now
52+
3953
# Based on https://github.com/openai/openai-cookbook/blob/main/examples/Whisper_prompting_guide.ipynb.
4054
with RemoteOpenAIServer(model_name, server_args) as remote_server:
4155
client = remote_server.get_async_client()
@@ -65,10 +79,13 @@ async def test_bad_requests(mary_had_lamb):
6579

6680

6781
@pytest.mark.asyncio
68-
async def test_long_audio_request(mary_had_lamb):
69-
model_name = "openai/whisper-large-v3-turbo"
82+
@pytest.mark.parametrize("model_name", ["openai/whisper-large-v3-turbo"])
83+
async def test_long_audio_request(mary_had_lamb, model_name):
7084
server_args = ["--enforce-eager"]
7185

86+
if model_name.startswith("openai"):
87+
return
88+
7289
mary_had_lamb.seek(0)
7390
audio, sr = librosa.load(mary_had_lamb)
7491
# Add small silence after each audio for repeatability in the split process
@@ -87,7 +104,8 @@ async def test_long_audio_request(mary_had_lamb):
87104
response_format="text",
88105
temperature=0.0)
89106
out = json.loads(transcription)['text']
90-
assert out.count("Mary had a little lamb") == 10
107+
counts = out.count("Mary had a little lamb")
108+
assert counts == 10, counts
91109

92110

93111
@pytest.mark.asyncio

tests/models/registry.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -440,6 +440,7 @@ def check_available_online(
440440
tokenizer="Isotr0py/Florence-2-tokenizer", # noqa: E501
441441
trust_remote_code=True), # noqa: E501
442442
"MllamaForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-3.2-11B-Vision-Instruct"), # noqa: E501
443+
"VoxtralForConditionalGeneration": _HfExamplesInfo("mistralai/Voxtral-Mini-3B-2507", is_available_online=False, tokenizer_mode="mistral"), # noqa: E501
443444
"WhisperForConditionalGeneration": _HfExamplesInfo("openai/whisper-large-v3"), # noqa: E501
444445

445446
# [Cross-encoder]
@@ -513,4 +514,4 @@ def find_hf_info(self, model_id: str) -> _HfExamplesInfo:
513514
raise ValueError(f"No example model defined for {model_id}")
514515

515516

516-
HF_EXAMPLE_MODELS = HfExampleModels(_EXAMPLE_MODELS)
517+
HF_EXAMPLE_MODELS = HfExampleModels(_EXAMPLE_MODELS)

vllm/entrypoints/openai/speech_to_text.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ async def _preprocess_speech_to_text(
112112
prompt = self.model_cls.get_generation_prompt(
113113
audio=chunk,
114114
stt_config=self.asr_config,
115+
model_config=self.model_config,
115116
language=lang,
116117
task_type=self.task_type,
117118
request_prompt=request.prompt)

vllm/model_executor/models/interfaces.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -722,7 +722,8 @@ class SupportsTranscription(Protocol):
722722

723723
@classmethod
724724
def get_generation_prompt(cls, audio: np.ndarray,
725-
stt_config: SpeechToTextConfig, language: str,
725+
stt_config: SpeechToTextConfig,
726+
model_config: ModelConfig, language: str,
726727
task_type: str,
727728
request_prompt: str) -> PromptType:
728729
"""Get the prompt for the ASR model.

0 commit comments

Comments
 (0)