Skip to content

Commit 4e2daf5

Browse files
[Doc] Add qwen2-audio eager mode tutorial (#1371)
### What this PR does / why we need it? Add qwen2-audio eager mode tutorial. Signed-off-by: shen-shanshan <467638484@qq.com>
1 parent 1025344 commit 4e2daf5

File tree

3 files changed

+151
-67
lines changed

3 files changed

+151
-67
lines changed

docs/source/tutorials/index.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
:maxdepth: 1
66
single_npu
77
single_npu_multimodal
8+
single_npu_audio
89
multi_npu
910
multi_npu_quantization
1011
single_node_300i
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
# Single NPU (Qwen2-Audio 7B)
2+
3+
## Run vllm-ascend on Single NPU
4+
5+
### Offline Inference on Single NPU
6+
7+
Run docker container:
8+
9+
```{code-block} bash
10+
:substitutions:
11+
# Update the vllm-ascend image
12+
export IMAGE=quay.io/ascend/vllm-ascend:|vllm_ascend_version|
13+
docker run --rm \
14+
--name vllm-ascend \
15+
--device /dev/davinci0 \
16+
--device /dev/davinci_manager \
17+
--device /dev/devmm_svm \
18+
--device /dev/hisi_hdc \
19+
-v /usr/local/dcmi:/usr/local/dcmi \
20+
-v /usr/local/bin/npu-smi:/usr/local/bin/npu-smi \
21+
-v /usr/local/Ascend/driver/lib64/:/usr/local/Ascend/driver/lib64/ \
22+
-v /usr/local/Ascend/driver/version.info:/usr/local/Ascend/driver/version.info \
23+
-v /etc/ascend_install.info:/etc/ascend_install.info \
24+
-v /root/.cache:/root/.cache \
25+
-p 8000:8000 \
26+
-it $IMAGE bash
27+
```
28+
29+
Setup environment variables:
30+
31+
```bash
32+
# Use vllm v1 engine
33+
export VLLM_USE_V1=1
34+
35+
# Load model from ModelScope to speed up download
36+
export VLLM_USE_MODELSCOPE=True
37+
38+
# Set `max_split_size_mb` to reduce memory fragmentation and avoid out of memory
39+
export PYTORCH_NPU_ALLOC_CONF=max_split_size_mb:256
40+
```
41+
42+
:::{note}
43+
`max_split_size_mb` prevents the native allocator from splitting blocks larger than this size (in MB). This can reduce fragmentation and may allow some borderline workloads to complete without running out of memory. You can find more details [<u>here</u>](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/800alpha003/apiref/envref/envref_07_0061.html).
44+
:::
45+
46+
Install packages required for audio processing:
47+
48+
```bash
49+
pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
50+
pip install librosa soundfile
51+
```
52+
53+
Run the following script to execute offline inference on a single NPU:
54+
55+
```python
56+
from vllm import LLM, SamplingParams
57+
from vllm.assets.audio import AudioAsset
58+
from vllm.utils import FlexibleArgumentParser
59+
60+
audio_assets = [AudioAsset("mary_had_lamb"), AudioAsset("winning_call")]
61+
question_per_audio_count = {
62+
1: "What is recited in the audio?",
63+
2: "What sport and what nursery rhyme are referenced?"
64+
}
65+
66+
67+
def prepare_inputs(audio_count: int):
68+
audio_in_prompt = "".join([
69+
f"Audio {idx+1}: <|audio_bos|><|AUDIO|><|audio_eos|>\n"
70+
for idx in range(audio_count)
71+
])
72+
question = question_per_audio_count[audio_count]
73+
prompt = ("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
74+
"<|im_start|>user\n"
75+
f"{audio_in_prompt}{question}<|im_end|>\n"
76+
"<|im_start|>assistant\n")
77+
78+
mm_data = {
79+
"audio":
80+
[asset.audio_and_sample_rate for asset in audio_assets[:audio_count]]
81+
}
82+
83+
# Merge text prompt and audio data into inputs
84+
inputs = {"prompt": prompt, "multi_modal_data": mm_data}
85+
return inputs
86+
87+
88+
def main(audio_count: int):
89+
# NOTE: The default `max_num_seqs` and `max_model_len` may result in OOM on
90+
# lower-end GPUs.
91+
# Unless specified, these settings have been tested to work on a single L4.
92+
# `limit_mm_per_prompt`: the max num items for each modality per prompt.
93+
llm = LLM(model="Qwen/Qwen2-Audio-7B-Instruct",
94+
max_model_len=4096,
95+
max_num_seqs=5,
96+
limit_mm_per_prompt={"audio": audio_count},
97+
enforce_eager=True)
98+
99+
inputs = prepare_inputs(audio_count)
100+
101+
sampling_params = SamplingParams(temperature=0.2,
102+
max_tokens=64,
103+
stop_token_ids=None)
104+
105+
outputs = llm.generate(inputs, sampling_params=sampling_params)
106+
107+
for o in outputs:
108+
generated_text = o.outputs[0].text
109+
print(generated_text)
110+
111+
112+
if __name__ == "__main__":
113+
audio_count = 2
114+
main(audio_count)
115+
```
116+
117+
If you run this script successfully, you can see the info shown below:
118+
119+
```bash
120+
The sport referenced is baseball, and the nursery rhyme is 'Mary Had a Little Lamb'.
121+
```
122+
123+
### Online Serving on Single NPU
124+
125+
Currently, vllm's OpenAI-compatible server doesn't support audio inputs, find more details [<u>here</u>](https://github.com/vllm-project/vllm/issues/19977).

examples/offline_inference_audio_language.py

Lines changed: 25 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -26,74 +26,51 @@
2626

2727
from vllm import LLM, SamplingParams
2828
from vllm.assets.audio import AudioAsset
29-
from vllm.utils import FlexibleArgumentParser
3029

3130
audio_assets = [AudioAsset("mary_had_lamb"), AudioAsset("winning_call")]
3231
question_per_audio_count = {
33-
0: "What is 1+1?",
3432
1: "What is recited in the audio?",
3533
2: "What sport and what nursery rhyme are referenced?"
3634
}
3735

38-
# NOTE: The default `max_num_seqs` and `max_model_len` may result in OOM on
39-
# lower-end GPUs.
40-
# Unless specified, these settings have been tested to work on a single L4.
41-
42-
43-
# Qwen2-Audio
44-
def run_qwen2_audio(question: str, audio_count: int):
45-
model_name = "Qwen/Qwen2-Audio-7B-Instruct"
46-
47-
llm = LLM(model=model_name,
48-
max_model_len=4096,
49-
max_num_seqs=5,
50-
limit_mm_per_prompt={"audio": audio_count})
5136

37+
def prepare_inputs(audio_count: int):
5238
audio_in_prompt = "".join([
53-
f"Audio {idx+1}: "
54-
f"<|audio_bos|><|AUDIO|><|audio_eos|>\n" for idx in range(audio_count)
39+
f"Audio {idx+1}: <|audio_bos|><|AUDIO|><|audio_eos|>\n"
40+
for idx in range(audio_count)
5541
])
56-
42+
question = question_per_audio_count[audio_count]
5743
prompt = ("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
5844
"<|im_start|>user\n"
5945
f"{audio_in_prompt}{question}<|im_end|>\n"
6046
"<|im_start|>assistant\n")
61-
stop_token_ids = None
62-
return llm, prompt, stop_token_ids
6347

48+
mm_data = {
49+
"audio":
50+
[asset.audio_and_sample_rate for asset in audio_assets[:audio_count]]
51+
}
6452

65-
model_example_map = {"qwen2_audio": run_qwen2_audio}
53+
# Merge text prompt and audio data into inputs
54+
inputs = {"prompt": prompt, "multi_modal_data": mm_data}
55+
return inputs
6656

6757

68-
def main(args):
69-
model = args.model_type
70-
if model not in model_example_map:
71-
raise ValueError(f"Model type {model} is not supported.")
58+
def main(audio_count: int):
59+
# NOTE: The default `max_num_seqs` and `max_model_len` may result in OOM on
60+
# lower-end GPUs.
61+
# Unless specified, these settings have been tested to work on a single L4.
62+
# `limit_mm_per_prompt`: the max num items for each modality per prompt.
63+
llm = LLM(model="Qwen/Qwen2-Audio-7B-Instruct",
64+
max_model_len=4096,
65+
max_num_seqs=5,
66+
limit_mm_per_prompt={"audio": audio_count},
67+
enforce_eager=True)
7268

73-
audio_count = args.num_audios
74-
llm, prompt, stop_token_ids = model_example_map[model](
75-
question_per_audio_count[audio_count], audio_count)
69+
inputs = prepare_inputs(audio_count)
7670

77-
# We set temperature to 0.2 so that outputs can be different
78-
# even when all prompts are identical when running batch inference.
7971
sampling_params = SamplingParams(temperature=0.2,
8072
max_tokens=64,
81-
stop_token_ids=stop_token_ids)
82-
83-
mm_data = {}
84-
if audio_count > 0:
85-
mm_data = {
86-
"audio": [
87-
asset.audio_and_sample_rate
88-
for asset in audio_assets[:audio_count]
89-
]
90-
}
91-
92-
assert args.num_prompts > 0
93-
inputs = {"prompt": prompt, "multi_modal_data": mm_data}
94-
if args.num_prompts > 1:
95-
# Batch inference
96-
inputs = [inputs] * args.num_prompts # type: ignore
73+
stop_token_ids=None)
9774

9875
outputs = llm.generate(inputs, sampling_params=sampling_params)
9976

@@ -103,24 +80,5 @@ def main(args):
10380

10481

10582
if __name__ == "__main__":
106-
parser = FlexibleArgumentParser(
107-
description='Demo on using vLLM for offline inference with '
108-
'audio language models')
109-
parser.add_argument('--model-type',
110-
'-m',
111-
type=str,
112-
default="qwen2_audio",
113-
choices=model_example_map.keys(),
114-
help='Huggingface "model_type".')
115-
parser.add_argument('--num-prompts',
116-
type=int,
117-
default=1,
118-
help='Number of prompts to run.')
119-
parser.add_argument("--num-audios",
120-
type=int,
121-
default=1,
122-
choices=[0, 1, 2],
123-
help="Number of audio items per prompt.")
124-
125-
args = parser.parse_args()
126-
main(args)
83+
audio_count = 2
84+
main(audio_count)

0 commit comments

Comments
 (0)