Skip to content

Commit 41060c6

Browse files
[Core] Add Support for Default Modality Specific LoRAs [generate / chat completions] (#19126)
Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
1 parent 3de2ed7 commit 41060c6

File tree

9 files changed

+482
-5
lines changed

9 files changed

+482
-5
lines changed

docs/features/lora.md

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,3 +272,80 @@ The new format of `--lora-modules` is mainly to support the display of parent mo
272272
]
273273
}
274274
```
275+
276+
## Default LoRA Models For Multimodal Models
277+
278+
Some models, e.g., [Granite Speech](https://huggingface.co/ibm-granite/granite-speech-3.3-8b) and [Phi-4-multimodal-instruct](https://huggingface.co/microsoft/Phi-4-multimodal-instruct) multimodal, contain LoRA adapter(s) that are expected to always be applied when a given modality is present. This can be a bit tedious to manage with the above approaches, as it requires the user to send the `LoRARequest` (offline) or to filter requests between the base model and LoRA model (server) depending on the content of the request's multimodal data.
279+
280+
To this end, we allow registration of default multimodal LoRAs to handle this automatically, where users can map each modality to a LoRA adapter to automatically apply it when the corresponding inputs are present. Note that currently, we only allow one LoRA per prompt; if several modalities are provided, each of which are registered to a given modality, none of them will be applied.
281+
282+
Example usage for offline inference:
283+
284+
```python
285+
from transformers import AutoTokenizer
286+
from vllm import LLM, SamplingParams
287+
from vllm.assets.audio import AudioAsset
288+
289+
model_id = "ibm-granite/granite-speech-3.3-2b"
290+
tokenizer = AutoTokenizer.from_pretrained(model_id)
291+
292+
def get_prompt(question: str, has_audio: bool):
293+
"""Build the input prompt to send to vLLM."""
294+
if has_audio:
295+
question = f"<|audio|>{question}"
296+
chat = [
297+
{
298+
"role": "user",
299+
"content": question
300+
}
301+
]
302+
return tokenizer.apply_chat_template(chat, tokenize=False)
303+
304+
305+
model = LLM(
306+
model=model_id,
307+
enable_lora=True,
308+
max_lora_rank=64,
309+
max_model_len=2048,
310+
limit_mm_per_prompt={"audio": 1},
311+
# Will always pass a `LoRARequest` with the `model_id`
312+
# whenever audio is contained in the request data.
313+
default_mm_loras = {"audio": model_id},
314+
enforce_eager=True,
315+
)
316+
317+
question = "can you transcribe the speech into a written format?"
318+
prompt_with_audio = get_prompt(
319+
question=question,
320+
has_audio=True,
321+
)
322+
audio = AudioAsset("mary_had_lamb").audio_and_sample_rate
323+
324+
inputs = {
325+
"prompt": prompt_with_audio,
326+
"multi_modal_data": {
327+
"audio": audio,
328+
}
329+
}
330+
331+
332+
outputs = model.generate(
333+
inputs,
334+
sampling_params=SamplingParams(
335+
temperature=0.2,
336+
max_tokens=64,
337+
),
338+
)
339+
```
340+
341+
You can also pass a json dictionary of `--default-mm-loras` mapping modalities to LoRA model IDs. For example, when starting the server:
342+
343+
```bash
344+
vllm serve ibm-granite/granite-speech-3.3-2b \
345+
--max-model-len 2048 \
346+
--enable-lora \
347+
--default-mm-loras '{"audio":"ibm-granite/granite-speech-3.3-2b"}' \
348+
--max-lora-rank 64
349+
```
350+
351+
Note: Default multimodal LoRAs are currently only available for `.generate` and chat completions.
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
import os
5+
6+
import openai # use the official client for correctness check
7+
import pytest
8+
import pytest_asyncio
9+
from huggingface_hub import snapshot_download
10+
11+
from ...conftest import AudioTestAssets
12+
from ...utils import RemoteOpenAIServer
13+
14+
# NOTE - the tests in this module are currently analogous to test_chat, but are
15+
# separated to avoid OOM killing due to module-scoped servers, since we
16+
# need a multimodal model for these tests.
17+
18+
# Contains a modality specific lora alongside the base model
19+
MULTIMODAL_MODEL_NAME = snapshot_download(
20+
"microsoft/Phi-4-multimodal-instruct")
21+
AUDIO_LORA_PATH = os.path.join(MULTIMODAL_MODEL_NAME, "speech-lora")
22+
23+
ACTIVE_MM_LORA_RESPONSE = "Spoken text: The first words I spoke in the original chronograph, a little piece of practical poetry. Mary had a little lamb, it slept with quite a snow, and everywhere that Mary went, the lamb was sure to go." # noqa: E501
24+
25+
26+
@pytest.fixture(scope="module")
27+
def monkeypatch_module():
28+
from _pytest.monkeypatch import MonkeyPatch
29+
mpatch = MonkeyPatch()
30+
yield mpatch
31+
mpatch.undo()
32+
33+
34+
@pytest.fixture(scope="module", params=[False, True])
35+
def multimodal_server(request, monkeypatch_module): # noqa: F811
36+
37+
use_v1 = request.param
38+
monkeypatch_module.setenv('VLLM_USE_V1', '1' if use_v1 else '0')
39+
40+
args = [
41+
# use half precision for speed and memory savings in CI environment
42+
"--dtype",
43+
"half",
44+
"--max-model-len",
45+
"12800",
46+
"--enforce-eager",
47+
# lora config below
48+
"--enable-lora",
49+
"--lora-modules",
50+
f"speech={AUDIO_LORA_PATH}",
51+
"--max-lora-rank",
52+
"320",
53+
"--max-num-seqs",
54+
"2",
55+
"--trust-remote-code",
56+
"--gpu-memory-utilization",
57+
"0.8",
58+
"--default-mm-loras",
59+
f"{{\"audio\": \"{AUDIO_LORA_PATH}\"}}",
60+
]
61+
62+
with RemoteOpenAIServer(MULTIMODAL_MODEL_NAME, args) as remote_server:
63+
yield remote_server
64+
65+
66+
@pytest_asyncio.fixture
67+
async def multi_modal_client(multimodal_server):
68+
async with multimodal_server.get_async_client() as async_client:
69+
yield async_client
70+
71+
72+
@pytest.mark.asyncio
73+
@pytest.mark.parametrize(
74+
# base model with default lora should give the same response as lora model
75+
"model_name",
76+
[MULTIMODAL_MODEL_NAME, "speech"],
77+
)
78+
async def test_default_mm_lora_chat_completions(
79+
model_name: str,
80+
multi_modal_client: openai.AsyncOpenAI,
81+
audio_assets: AudioTestAssets,
82+
):
83+
messages = [{
84+
"role":
85+
"user",
86+
"content": [{
87+
"type": "text",
88+
"text": "Can you transcribe this audio?",
89+
}, {
90+
"type": "audio_url",
91+
"audio_url": {
92+
"url": audio_assets[0].url
93+
},
94+
}]
95+
}]
96+
97+
chat_completion = await multi_modal_client.chat.completions.create(
98+
model=model_name,
99+
messages=messages,
100+
max_completion_tokens=128,
101+
temperature=0.0)
102+
103+
assert len(chat_completion.choices) > 0
104+
105+
message = chat_completion.choices[0].message
106+
assert message.content is not None and len(message.content) >= 0
107+
assert message.content == ACTIVE_MM_LORA_RESPONSE

tests/lora/test_default_mm_loras.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
"""
4+
Tests for applying default registered multimodal loras.
5+
"""
6+
7+
import os
8+
9+
from huggingface_hub import snapshot_download
10+
11+
from vllm.lora.request import LoRARequest
12+
13+
from ..conftest import AudioTestAssets, VllmRunner
14+
15+
MODEL_PATH = snapshot_download("microsoft/Phi-4-multimodal-instruct")
16+
AUDIO_LORA_PATH = os.path.join(MODEL_PATH, "speech-lora")
17+
IMAGE_LORA_PATH = os.path.join(MODEL_PATH, "vision-lora")
18+
19+
AUDIO_PROMPT = "<|user|><|audio_1|>Can you transcribe this audio?<|end|><|assistant|>" # noqa: E501
20+
21+
# Responses are greedy decoded; we just check the end of
22+
# the generated text. If the lora is inactive, this model
23+
# generates commentary on the transcription.
24+
RESPONSE_SUFFIX_WITH_LORA = "Spoken text: The first words I spoke in the original chronograph, a little piece of practical poetry. Mary had a little lamb, it slept with quite a snow, and everywhere that Mary went, the lamb was sure to go." # noqa: E501
25+
RESPONSE_SUFFIX_WITHOUT_LORA = "Certainly! Here is the transcription of the audio you provided:\n\nThe first words I spoke in the original phonograph record: A little piece of practical poetry. Mary had a little lamb; its fleece was white as snow, and everywhere that Mary went, the lamb was sure to go." # noqa: E501
26+
27+
VLLM_RUNNER_BASE_KWARGS = {
28+
"model_name": MODEL_PATH,
29+
"dtype": "half",
30+
"enable_lora": "True",
31+
"max_num_seqs": 2,
32+
"max_lora_rank": 320,
33+
"max_model_len": 12800,
34+
"gpu_memory_utilization": 0.8,
35+
"limit_mm_per_prompt": {
36+
"audio": 1
37+
},
38+
"enforce_eager": True,
39+
}
40+
41+
42+
def run_test(vllm_runner, audio_assets, lora_request, expected_suffix,
43+
**kwargs):
44+
inputs = [([AUDIO_PROMPT], [audio_assets[0].audio_and_sample_rate[0]])]
45+
46+
# Apply any additional kwargs as overrides to the base kwargs
47+
vllm_runner_kwargs = {**VLLM_RUNNER_BASE_KWARGS, **kwargs}
48+
49+
with vllm_runner(**vllm_runner_kwargs) as vllm_model:
50+
vllm_outputs_with_default_lora = [
51+
vllm_model.generate_greedy(
52+
prompts,
53+
max_tokens=128,
54+
audios=audios,
55+
lora_request=lora_request,
56+
) for prompts, audios in inputs
57+
]
58+
59+
assert vllm_outputs_with_default_lora[-1][-1][-1].endswith(
60+
expected_suffix)
61+
62+
63+
def test_active_default_mm_lora(
64+
vllm_runner: type[VllmRunner],
65+
audio_assets: AudioTestAssets,
66+
):
67+
"""Ensure that we can use the default audio lora."""
68+
run_test(
69+
vllm_runner,
70+
audio_assets,
71+
lora_request=None,
72+
default_mm_loras={"audio": AUDIO_LORA_PATH},
73+
expected_suffix=RESPONSE_SUFFIX_WITH_LORA,
74+
)
75+
76+
77+
def test_inactive_default_mm_lora(
78+
vllm_runner: type[VllmRunner],
79+
audio_assets: AudioTestAssets,
80+
):
81+
"""Ensure that modalities are filtered properly."""
82+
# Default image lora won't be active since we only pass audio
83+
run_test(
84+
vllm_runner,
85+
audio_assets,
86+
lora_request=None,
87+
default_mm_loras={"image": IMAGE_LORA_PATH},
88+
expected_suffix=RESPONSE_SUFFIX_WITHOUT_LORA,
89+
)
90+
91+
92+
def test_default_mm_lora_succeeds_with_redundant_lora_request(
93+
vllm_runner: type[VllmRunner],
94+
audio_assets: AudioTestAssets,
95+
):
96+
"""Ensure that redundantly providing the lora works."""
97+
run_test(
98+
vllm_runner,
99+
audio_assets,
100+
lora_request=LoRARequest("audio", 1, AUDIO_LORA_PATH),
101+
default_mm_loras={"audio": AUDIO_LORA_PATH},
102+
expected_suffix=RESPONSE_SUFFIX_WITH_LORA,
103+
)
104+
105+
106+
def test_default_mm_lora_fails_with_overridden_lora_request(
107+
vllm_runner: type[VllmRunner],
108+
audio_assets: AudioTestAssets,
109+
):
110+
"""Ensure that if the lora_request conflicts with default_mm_loras,
111+
we use the lora_request."""
112+
run_test(
113+
vllm_runner,
114+
audio_assets,
115+
lora_request=LoRARequest("speech", 2, AUDIO_LORA_PATH),
116+
default_mm_loras={"audio": IMAGE_LORA_PATH},
117+
expected_suffix=RESPONSE_SUFFIX_WITH_LORA,
118+
)

vllm/config.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from vllm import version
3434
from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass
3535
from vllm.logger import init_logger
36+
from vllm.model_executor.layers.quantization import QuantizationMethods
3637
from vllm.platforms import current_platform
3738
from vllm.transformers_utils.config import (
3839
ConfigFormat, get_config, get_hf_image_processor_config,
@@ -2989,6 +2990,16 @@ class LoRAConfig:
29892990
trained with those scaling factors to be used at the same time. If not
29902991
specified, only adapters trained with the base model scaling factor are
29912992
allowed."""
2993+
default_mm_loras: Optional[dict[str, str]] = None
2994+
"""Dictionary mapping specific modalities to LoRA model paths; this field
2995+
is only applicable to multimodal models and should be leveraged when a
2996+
model always expects a LoRA to be active when a given modality is present.
2997+
Note that currently, if a request provides multiple additional
2998+
modalities, each of which have their own LoRA, we do NOT apply
2999+
default_mm_loras because we currently only support one lora adapter
3000+
per prompt. When run in offline mode, the lora IDs for n modalities
3001+
will be automatically assigned to 1-n with the names of the modalities
3002+
in alphabetic order."""
29923003
bias_enabled: bool = False
29933004
"""Enable bias for LoRA adapters."""
29943005

vllm/engine/arg_utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -395,6 +395,8 @@ class EngineArgs:
395395
enable_lora_bias: bool = LoRAConfig.bias_enabled
396396
max_loras: int = LoRAConfig.max_loras
397397
max_lora_rank: int = LoRAConfig.max_lora_rank
398+
default_mm_loras: Optional[Dict[str, str]] = \
399+
LoRAConfig.default_mm_loras
398400
fully_sharded_loras: bool = LoRAConfig.fully_sharded_loras
399401
max_cpu_loras: Optional[int] = LoRAConfig.max_cpu_loras
400402
lora_dtype: Optional[Union[str, torch.dtype]] = LoRAConfig.lora_dtype
@@ -807,6 +809,8 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
807809
**lora_kwargs["max_cpu_loras"])
808810
lora_group.add_argument("--fully-sharded-loras",
809811
**lora_kwargs["fully_sharded_loras"])
812+
lora_group.add_argument("--default-mm-loras",
813+
**lora_kwargs["default_mm_loras"])
810814

811815
# PromptAdapter related configs
812816
prompt_adapter_kwargs = get_kwargs(PromptAdapterConfig)
@@ -1284,10 +1288,16 @@ def create_engine_config(
12841288
disable_hybrid_kv_cache_manager,
12851289
)
12861290

1291+
if not model_config.is_multimodal_model and self.default_mm_loras:
1292+
raise ValueError(
1293+
"Default modality-specific LoRA(s) were provided for a "
1294+
"non multimodal model")
1295+
12871296
lora_config = LoRAConfig(
12881297
bias_enabled=self.enable_lora_bias,
12891298
max_lora_rank=self.max_lora_rank,
12901299
max_loras=self.max_loras,
1300+
default_mm_loras=self.default_mm_loras,
12911301
fully_sharded_loras=self.fully_sharded_loras,
12921302
lora_extra_vocab_size=self.lora_extra_vocab_size,
12931303
long_lora_scaling_factors=self.long_lora_scaling_factors,

0 commit comments

Comments
 (0)