Skip to content

Commit 0d8b18c

Browse files
paulyu12paulyu
and
paulyu
committed
[Platform][Worker][ModelRunner] Add LoRA & Multi-LoRA support (vllm-project#521)
### What this PR does / why we need it? According to this RFC [[RFC]: Join the MultiLora and MultiLora Dynammic Serving feature develop vllm-project#396](vllm-project#396) and this [vLLM Ascend Roadmap Q2 2025 vllm-project#448](vllm-project#448), we pull request relavant code to support (1) Multi-LoRA and (2) Multi-LoRA Dynamic Serving. LoRA reference is here: [LoRA reference](https://docs.vllm.ai/en/latest/features/lora.html) ### Does this PR introduce _any_ user-facing change? Following openai HTTP apis will be supported: /v1/load_lora_adapter /v1/unload_lora_adapter ### How was this patch tested? git clone https://github.com/vllm-project/vllm.git cd vllm/examples/offline_inference/ && python3 multilora_inference.py --------- Signed-off-by: paulyu <paulyu0307@gmail.com> Co-authored-by: paulyu <paulyu0307@gmail.com>
1 parent 15314cc commit 0d8b18c

File tree

4 files changed

+484
-14
lines changed

4 files changed

+484
-14
lines changed
Lines changed: 346 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,346 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
from typing import Callable, Optional, Tuple, Union
4+
5+
import torch
6+
from vllm.lora.ops.torch_ops import (bgmv_expand, bgmv_expand_slice,
7+
bgmv_shrink, sgmv_expand,
8+
sgmv_expand_slice, sgmv_shrink)
9+
from vllm.lora.punica_wrapper.punica_base import PunicaWrapperBase
10+
11+
12+
# The platforms that are compatible with the PyTorch-native implementation can
13+
# inherit this class
14+
class PunicaWrapperNPU(PunicaWrapperBase):
15+
"""
16+
PunicaWrapperNPU is designed to manage and provide metadata for the punica
17+
kernel. The main function is to maintain the state information for
18+
Multi-LoRA, and to provide the interface for the pytorch punica ops.
19+
"""
20+
21+
def __init__(self, max_num_batched_tokens: int, max_batches: int,
22+
device: Union[torch.device, str], **kwargs):
23+
PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches,
24+
device)
25+
26+
def _shrink_prefill(
27+
self,
28+
y: torch.Tensor,
29+
x: torch.Tensor,
30+
w_t_all: torch.Tensor,
31+
scale: float,
32+
):
33+
#No LoRA request, so return directly
34+
if self.no_lora:
35+
return
36+
sgmv_shrink(
37+
x,
38+
w_t_all,
39+
y,
40+
*self.prefill_metadata,
41+
scale,
42+
)
43+
44+
def _shrink_decode(
45+
self,
46+
y: torch.Tensor,
47+
x: torch.Tensor,
48+
w_t_all: torch.Tensor,
49+
scale: float,
50+
):
51+
bgmv_shrink(x, w_t_all, y, self.token_lora_indices, scale)
52+
53+
def _expand_prefill(
54+
self,
55+
y: torch.Tensor,
56+
x: torch.Tensor,
57+
w_t_all: torch.Tensor,
58+
add_inputs: bool,
59+
):
60+
#No LoRA request, so return directly
61+
if self.no_lora:
62+
return
63+
sgmv_expand(
64+
x,
65+
w_t_all,
66+
y,
67+
*self.prefill_metadata,
68+
add_inputs,
69+
)
70+
71+
def _expand_decode(
72+
self,
73+
y: torch.Tensor,
74+
x: torch.Tensor,
75+
w_t_all: torch.Tensor,
76+
add_inputs: bool,
77+
):
78+
bgmv_expand(x, w_t_all, y, self.token_lora_indices, add_inputs)
79+
80+
def _expand_slice_prefill(
81+
self,
82+
y: torch.Tensor,
83+
x: torch.Tensor,
84+
w_t_all: torch.Tensor,
85+
y_offset: int,
86+
y_slice_size: int,
87+
add_inputs: bool,
88+
):
89+
#No LoRA request, so return directly
90+
if self.no_lora:
91+
return
92+
sgmv_expand_slice(
93+
x,
94+
w_t_all,
95+
y,
96+
*self.prefill_metadata,
97+
y_offset,
98+
y_slice_size,
99+
add_inputs,
100+
)
101+
102+
def _expand_slice_decode(
103+
self,
104+
y: torch.Tensor,
105+
x: torch.Tensor,
106+
w_t_all: torch.Tensor,
107+
y_offset: int,
108+
y_slice_size: int,
109+
add_inputs: bool,
110+
):
111+
bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, y_offset,
112+
y_slice_size, add_inputs)
113+
114+
def _apply_expand(
115+
self,
116+
y: torch.Tensor,
117+
x: torch.Tensor,
118+
w_t_all: torch.Tensor,
119+
y_offset: int,
120+
y_slice_size: int,
121+
add_inputs: bool = True,
122+
):
123+
"""
124+
Perform the ` y[:,y_offset:y_offset+y_slice_size]+=x@w_t_all`
125+
computation, which is suitable for the
126+
GEMM of lora'b.
127+
"""
128+
129+
expand_slice_fun: Callable = (self._expand_slice_prefill
130+
if self.is_prefill else
131+
self._expand_slice_decode)
132+
expand_slice_fun(y, x, w_t_all, y_offset, y_slice_size, add_inputs)
133+
134+
def _apply_shrink(self, y: torch.Tensor, x: torch.Tensor,
135+
w_t_all: torch.Tensor, scale: float):
136+
"""
137+
Perform the ` y+=x@w_t_all` computation, which is suitable for the
138+
GEMM of lora'a.
139+
When `is_prefill is` true, it indicates that it is currently the
140+
prefill stage, and the `_shrink_prefill` function should be called.
141+
Otherwise, it is the decode stage, and the _shrink_decode function
142+
should be called.
143+
"""
144+
y_org = y
145+
y = y.view(-1, y.shape[-1])
146+
shrink_fun: Callable = (self._shrink_prefill
147+
if self.is_prefill else self._shrink_decode)
148+
shrink_fun(y, x, w_t_all, scale)
149+
y = y.view_as(y_org)
150+
151+
def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor],
152+
x: torch.Tensor, lora_a_stacked: Tuple[torch.Tensor, ...],
153+
scale: float, **kwargs):
154+
"""
155+
Performs GEMM for multiple slices of lora_a.
156+
When `is_prefill is` true, it indicates that it is currently the
157+
prefill stage, and the `_shrink_prefill` function should be called.
158+
Otherwise, it is the decode stage, and the _shrink_decode function
159+
should be called.
160+
161+
Semantics:
162+
for i in range(len(lora_a_stacked)):
163+
y[i] += (x @ lora_a_stacked[i]) * scale
164+
165+
Args:
166+
y (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Output tensors
167+
x (torch.Tensor): Input tensor
168+
lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weights
169+
scale (float): Scaling factor for the operation
170+
"""
171+
172+
x = x.view(-1, x.shape[-1])
173+
# TODO fuse these kernels
174+
for slice_idx in range(len(lora_a_stacked)):
175+
self._apply_shrink(y[slice_idx], x, lora_a_stacked[slice_idx],
176+
scale)
177+
178+
def add_expand(self,
179+
y: torch.Tensor,
180+
x: Union[Tuple[torch.Tensor, ...], torch.Tensor],
181+
lora_b_stacked: Tuple[torch.Tensor, ...],
182+
lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]],
183+
output_slices: Tuple[int, ...],
184+
offset_start: int = 0,
185+
add_inputs=True,
186+
**kwargs) -> None:
187+
"""
188+
Performs GEMM and bias addition for multiple slices of lora_b.
189+
190+
Semantics:
191+
for i in range(len(lora_b_stacked)):
192+
slice = output_slices[i]
193+
y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] +
194+
lora_bias_stacked[i]
195+
offset += slice
196+
197+
Args:
198+
y (torch.Tensor): Output tensor.
199+
x (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Input tensors
200+
lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight
201+
lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]):
202+
bias's weight
203+
output_slices (Tuple[int, ...]): Every slice's size
204+
add_inputs (bool): Defaults to True.
205+
"""
206+
y_org = y
207+
y = y.view(-1, y.shape[-1])
208+
offset_left = offset_start
209+
if lora_bias_stacked is not None:
210+
self._apply_bias(self.token_lora_indices, y, output_slices,
211+
lora_bias_stacked)
212+
for slice_idx in range(len(lora_b_stacked)):
213+
self._apply_expand(
214+
y,
215+
x[slice_idx],
216+
lora_b_stacked[slice_idx],
217+
offset_left,
218+
output_slices[slice_idx],
219+
add_inputs=add_inputs,
220+
)
221+
offset_left += output_slices[slice_idx]
222+
y = y.view_as(y_org)
223+
224+
def add_lora_embedding(self,
225+
y: torch.Tensor,
226+
x: torch.Tensor,
227+
lora_b_stacked: torch.Tensor,
228+
add_inputs: bool = True,
229+
**kwargs) -> None:
230+
"""
231+
Applies lora specifically for VocabParallelEmbeddingWithLoRA.
232+
233+
Semantics:
234+
y += x @ lora_b_stacked
235+
236+
Args:
237+
y (torch.Tensor): Output tensor.
238+
x (torch.Tensor): Input tensor.
239+
lora_b_stacked (torch.Tensor): lora_b's weights.
240+
add_inputs (bool): Default to True.
241+
"""
242+
243+
# Embedding layer only need expand op
244+
expand_fun: Callable = (self._expand_prefill
245+
if self.is_prefill else self._expand_decode)
246+
expand_fun(y, x, lora_b_stacked, add_inputs)
247+
248+
def add_lora_linear(self,
249+
y: torch.Tensor,
250+
x: torch.Tensor,
251+
lora_a_stacked: Tuple[torch.Tensor, ...],
252+
lora_b_stacked: Tuple[torch.Tensor, ...],
253+
lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]],
254+
scale: float,
255+
output_slices: Tuple[int, ...],
256+
*,
257+
buffer: Optional[Tuple[torch.Tensor, ...]] = None,
258+
**kwargs) -> None:
259+
"""
260+
Applicable to linear-related lora.
261+
262+
Semantics:
263+
for i in range(len(lora_a_stacked)):
264+
y[i] += (
265+
x[i].unsqueeze(0)
266+
@ lora_a_stacked[indices[i], layer_idx, :, :]
267+
@ lora_b_stacked[indices[i], layer_idx, :, :]
268+
* scale
269+
).squeeze(0)+lora_bias_stacked[i]
270+
271+
Args:
272+
y (torch.Tensor): Output tensor. Will be changed in-place.
273+
x (torch.Tensor): Input tensor
274+
lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weight.
275+
lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight.
276+
lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): lora's bias.
277+
scale (float): Scaling factor.
278+
output_slices (Tuple[int, ...]): Every slice's size.
279+
buffer (Optional[Tuple[torch.Tensor, ...]]): Defaults to None.
280+
"""
281+
282+
assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices)
283+
if lora_bias_stacked is not None:
284+
assert len(lora_bias_stacked) == len(output_slices)
285+
y = self._apply_bias(self.token_lora_indices, y, output_slices,
286+
lora_bias_stacked)
287+
288+
if buffer is None:
289+
r = lora_b_stacked[0].size(-1)
290+
# We set the buffer to be float32 by default, consistent with the
291+
# triton op
292+
buffer = tuple(
293+
torch.zeros(
294+
(x.size(0), r), dtype=torch.float32, device=x.device)
295+
for _ in range(len(output_slices)))
296+
self.add_shrink(buffer, x, lora_a_stacked, scale, **kwargs)
297+
self.add_expand(y,
298+
buffer,
299+
lora_b_stacked,
300+
None,
301+
output_slices,
302+
add_inputs=True,
303+
**kwargs)
304+
305+
def add_lora_logits(self,
306+
y: torch.Tensor,
307+
x: torch.Tensor,
308+
lora_a_stacked: torch.Tensor,
309+
lora_b_stacked: torch.Tensor,
310+
scale,
311+
*,
312+
buffer: Optional[torch.Tensor] = None,
313+
**kwargs) -> None:
314+
"""
315+
Applies lora specifically for LogitsProcessorWithLoRA.
316+
317+
Semantics:
318+
buffer = (x @ lora_a_stacked) * scale
319+
y += buffer @ lora_b_stacked
320+
321+
Args:
322+
y (torch.Tensor): Output tensor.
323+
x (torch.Tensor): Input tensor.
324+
lora_a_stacked (torch.Tensor): lora_a's weights.
325+
lora_b_stacked (torch.Tensor):lora_b's weights.
326+
scale (float): Scaling factor.
327+
buffer (Optional[torch.Tensor]):Default to None.
328+
"""
329+
y_org = y
330+
y = y.view(-1, y.shape[-1])
331+
x = x.view(-1, x.shape[-1])
332+
r = lora_b_stacked.size(-1)
333+
if buffer is None:
334+
# We set the buffer to be float32 by default, consistent with the
335+
# triton op
336+
buffer = torch.zeros((x.size(0), r),
337+
dtype=torch.float32,
338+
device=x.device)
339+
# LogitsProcessorWithLoRA always using bgmv.
340+
bgmv_shrink(x, lora_a_stacked, buffer, self.sampler_indices, scale)
341+
bgmv_expand(buffer,
342+
lora_b_stacked,
343+
y,
344+
self.sampler_indices,
345+
add_inputs=True)
346+
y = y.view_as(y_org)

vllm_ascend/platform.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,10 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
171171
return "vllm_ascend.attention.attention.AscendMLAAttentionBackend"
172172
return "vllm_ascend.attention.attention.AscendAttentionBackend"
173173

174+
@classmethod
175+
def get_punica_wrapper(cls) -> str:
176+
return "vllm_ascend.lora.punica_wrapper.punica_npu.PunicaWrapperNPU"
177+
174178
@classmethod
175179
def get_current_memory_usage(cls,
176180
device: Optional[torch.types.Device] = None

0 commit comments

Comments
 (0)