Skip to content

Commit 777cfa7

Browse files
author
paulyu
committed
[Platform][Worker][ModelRunner] Add LoRA & Multi-LoRA support
Signed-off-by: paulyu <paulyu0307@gmail.com>
1 parent 5fa70b6 commit 777cfa7

File tree

4 files changed

+488
-14
lines changed

4 files changed

+488
-14
lines changed
Lines changed: 348 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,348 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
from typing import Callable, Optional, Tuple, Union
4+
5+
import torch
6+
7+
from vllm.lora.ops.torch_ops import (bgmv_expand, bgmv_expand_slice,
8+
bgmv_shrink, sgmv_expand,
9+
sgmv_expand_slice, sgmv_shrink)
10+
11+
from vllm.lora.punica_wrapper.punica_base import PunicaWrapperBase
12+
13+
14+
# The platforms that are compatible with the PyTorch-native implementation can
15+
# inherit this class
16+
class PunicaWrapperNPU(PunicaWrapperBase):
17+
"""
18+
PunicaWrapperNPU is designed to manage and provide metadata for the punica
19+
kernel. The main function is to maintain the state information for
20+
Multi-LoRA, and to provide the interface for the pytorch punica ops.
21+
"""
22+
23+
def __init__(self, max_num_batched_tokens: int, max_batches: int,
24+
device: Union[torch.device, str], **kwargs):
25+
PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches,
26+
device)
27+
28+
def _shrink_prefill(
29+
self,
30+
y: torch.Tensor,
31+
x: torch.Tensor,
32+
w_t_all: torch.Tensor,
33+
scale: float,
34+
):
35+
#No LoRA request, so return directly
36+
if self.no_lora:
37+
return
38+
sgmv_shrink(
39+
x,
40+
w_t_all,
41+
y,
42+
*self.prefill_metadata,
43+
scale,
44+
)
45+
46+
def _shrink_decode(
47+
self,
48+
y: torch.Tensor,
49+
x: torch.Tensor,
50+
w_t_all: torch.Tensor,
51+
scale: float,
52+
):
53+
bgmv_shrink(x, w_t_all, y, self.token_lora_indices, scale)
54+
55+
def _expand_prefill(
56+
self,
57+
y: torch.Tensor,
58+
x: torch.Tensor,
59+
w_t_all: torch.Tensor,
60+
add_inputs: bool,
61+
):
62+
#No LoRA request, so return directly
63+
if self.no_lora:
64+
return
65+
sgmv_expand(
66+
x,
67+
w_t_all,
68+
y,
69+
*self.prefill_metadata,
70+
add_inputs,
71+
)
72+
73+
def _expand_decode(
74+
self,
75+
y: torch.Tensor,
76+
x: torch.Tensor,
77+
w_t_all: torch.Tensor,
78+
add_inputs: bool,
79+
):
80+
bgmv_expand(x, w_t_all, y, self.token_lora_indices, add_inputs)
81+
82+
def _expand_slice_prefill(
83+
self,
84+
y: torch.Tensor,
85+
x: torch.Tensor,
86+
w_t_all: torch.Tensor,
87+
y_offset: int,
88+
y_slice_size: int,
89+
add_inputs: bool,
90+
):
91+
#No LoRA request, so return directly
92+
if self.no_lora:
93+
return
94+
sgmv_expand_slice(
95+
x,
96+
w_t_all,
97+
y,
98+
*self.prefill_metadata,
99+
y_offset,
100+
y_slice_size,
101+
add_inputs,
102+
)
103+
104+
def _expand_slice_decode(
105+
self,
106+
y: torch.Tensor,
107+
x: torch.Tensor,
108+
w_t_all: torch.Tensor,
109+
y_offset: int,
110+
y_slice_size: int,
111+
add_inputs: bool,
112+
):
113+
bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, y_offset,
114+
y_slice_size, add_inputs)
115+
116+
def _apply_expand(
117+
self,
118+
y: torch.Tensor,
119+
x: torch.Tensor,
120+
w_t_all: torch.Tensor,
121+
y_offset: int,
122+
y_slice_size: int,
123+
add_inputs: bool = True,
124+
):
125+
"""
126+
Perform the ` y[:,y_offset:y_offset+y_slice_size]+=x@w_t_all`
127+
computation, which is suitable for the
128+
GEMM of lora'b.
129+
"""
130+
131+
expand_slice_fun: Callable = (self._expand_slice_prefill
132+
if self.is_prefill else
133+
self._expand_slice_decode)
134+
expand_slice_fun(y, x, w_t_all, y_offset, y_slice_size, add_inputs)
135+
136+
def _apply_shrink(self, y: torch.Tensor, x: torch.Tensor,
137+
w_t_all: torch.Tensor, scale: float):
138+
"""
139+
Perform the ` y+=x@w_t_all` computation, which is suitable for the
140+
GEMM of lora'a.
141+
When `is_prefill is` true, it indicates that it is currently the
142+
prefill stage, and the `_shrink_prefill` function should be called.
143+
Otherwise, it is the decode stage, and the _shrink_decode function
144+
should be called.
145+
"""
146+
y_org = y
147+
y = y.view(-1, y.shape[-1])
148+
shrink_fun: Callable = (self._shrink_prefill
149+
if self.is_prefill else self._shrink_decode)
150+
shrink_fun(y, x, w_t_all, scale)
151+
y = y.view_as(y_org)
152+
153+
def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor],
154+
x: torch.Tensor, lora_a_stacked: Tuple[torch.Tensor, ...],
155+
scale: float, **kwargs):
156+
"""
157+
Performs GEMM for multiple slices of lora_a.
158+
When `is_prefill is` true, it indicates that it is currently the
159+
prefill stage, and the `_shrink_prefill` function should be called.
160+
Otherwise, it is the decode stage, and the _shrink_decode function
161+
should be called.
162+
163+
Semantics:
164+
for i in range(len(lora_a_stacked)):
165+
y[i] += (x @ lora_a_stacked[i]) * scale
166+
167+
Args:
168+
y (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Output tensors
169+
x (torch.Tensor): Input tensor
170+
lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weights
171+
scale (float): Scaling factor for the operation
172+
"""
173+
174+
x = x.view(-1, x.shape[-1])
175+
# TODO fuse these kernels
176+
for slice_idx in range(len(lora_a_stacked)):
177+
self._apply_shrink(y[slice_idx], x, lora_a_stacked[slice_idx],
178+
scale)
179+
180+
def add_expand(self,
181+
y: torch.Tensor,
182+
x: Union[Tuple[torch.Tensor, ...], torch.Tensor],
183+
lora_b_stacked: Tuple[torch.Tensor, ...],
184+
lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]],
185+
output_slices: Tuple[int, ...],
186+
offset_start: int = 0,
187+
add_inputs=True,
188+
**kwargs) -> None:
189+
"""
190+
Performs GEMM and bias addition for multiple slices of lora_b.
191+
192+
Semantics:
193+
for i in range(len(lora_b_stacked)):
194+
slice = output_slices[i]
195+
y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] +
196+
lora_bias_stacked[i]
197+
offset += slice
198+
199+
Args:
200+
y (torch.Tensor): Output tensor.
201+
x (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Input tensors
202+
lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight
203+
lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]):
204+
bias's weight
205+
output_slices (Tuple[int, ...]): Every slice's size
206+
add_inputs (bool): Defaults to True.
207+
"""
208+
y_org = y
209+
y = y.view(-1, y.shape[-1])
210+
offset_left = offset_start
211+
if lora_bias_stacked is not None:
212+
self._apply_bias(self.token_lora_indices, y, output_slices,
213+
lora_bias_stacked)
214+
for slice_idx in range(len(lora_b_stacked)):
215+
self._apply_expand(
216+
y,
217+
x[slice_idx],
218+
lora_b_stacked[slice_idx],
219+
offset_left,
220+
output_slices[slice_idx],
221+
add_inputs=add_inputs,
222+
)
223+
offset_left += output_slices[slice_idx]
224+
y = y.view_as(y_org)
225+
226+
def add_lora_embedding(self,
227+
y: torch.Tensor,
228+
x: torch.Tensor,
229+
lora_b_stacked: torch.Tensor,
230+
add_inputs: bool = True,
231+
**kwargs) -> None:
232+
"""
233+
Applies lora specifically for VocabParallelEmbeddingWithLoRA.
234+
235+
Semantics:
236+
y += x @ lora_b_stacked
237+
238+
Args:
239+
y (torch.Tensor): Output tensor.
240+
x (torch.Tensor): Input tensor.
241+
lora_b_stacked (torch.Tensor): lora_b's weights.
242+
add_inputs (bool): Default to True.
243+
"""
244+
245+
# Embedding layer only need expand op
246+
expand_fun: Callable = (self._expand_prefill
247+
if self.is_prefill else self._expand_decode)
248+
expand_fun(y, x, lora_b_stacked, add_inputs)
249+
250+
def add_lora_linear(self,
251+
y: torch.Tensor,
252+
x: torch.Tensor,
253+
lora_a_stacked: Tuple[torch.Tensor, ...],
254+
lora_b_stacked: Tuple[torch.Tensor, ...],
255+
lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]],
256+
scale: float,
257+
output_slices: Tuple[int, ...],
258+
*,
259+
buffer: Optional[Tuple[torch.Tensor, ...]] = None,
260+
**kwargs) -> None:
261+
"""
262+
Applicable to linear-related lora.
263+
264+
Semantics:
265+
for i in range(len(lora_a_stacked)):
266+
y[i] += (
267+
x[i].unsqueeze(0)
268+
@ lora_a_stacked[indices[i], layer_idx, :, :]
269+
@ lora_b_stacked[indices[i], layer_idx, :, :]
270+
* scale
271+
).squeeze(0)+lora_bias_stacked[i]
272+
273+
Args:
274+
y (torch.Tensor): Output tensor. Will be changed in-place.
275+
x (torch.Tensor): Input tensor
276+
lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weight.
277+
lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight.
278+
lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): lora's bias.
279+
scale (float): Scaling factor.
280+
output_slices (Tuple[int, ...]): Every slice's size.
281+
buffer (Optional[Tuple[torch.Tensor, ...]]): Defaults to None.
282+
"""
283+
284+
assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices)
285+
if lora_bias_stacked is not None:
286+
assert len(lora_bias_stacked) == len(output_slices)
287+
y = self._apply_bias(self.token_lora_indices, y, output_slices,
288+
lora_bias_stacked)
289+
290+
if buffer is None:
291+
r = lora_b_stacked[0].size(-1)
292+
# We set the buffer to be float32 by default, consistent with the
293+
# triton op
294+
buffer = tuple(
295+
torch.zeros(
296+
(x.size(0), r), dtype=torch.float32, device=x.device)
297+
for _ in range(len(output_slices)))
298+
self.add_shrink(buffer, x, lora_a_stacked, scale, **kwargs)
299+
self.add_expand(y,
300+
buffer,
301+
lora_b_stacked,
302+
None,
303+
output_slices,
304+
add_inputs=True,
305+
**kwargs)
306+
307+
def add_lora_logits(self,
308+
y: torch.Tensor,
309+
x: torch.Tensor,
310+
lora_a_stacked: torch.Tensor,
311+
lora_b_stacked: torch.Tensor,
312+
scale,
313+
*,
314+
buffer: Optional[torch.Tensor] = None,
315+
**kwargs) -> None:
316+
"""
317+
Applies lora specifically for LogitsProcessorWithLoRA.
318+
319+
Semantics:
320+
buffer = (x @ lora_a_stacked) * scale
321+
y += buffer @ lora_b_stacked
322+
323+
Args:
324+
y (torch.Tensor): Output tensor.
325+
x (torch.Tensor): Input tensor.
326+
lora_a_stacked (torch.Tensor): lora_a's weights.
327+
lora_b_stacked (torch.Tensor):lora_b's weights.
328+
scale (float): Scaling factor.
329+
buffer (Optional[torch.Tensor]):Default to None.
330+
"""
331+
y_org = y
332+
y = y.view(-1, y.shape[-1])
333+
x = x.view(-1, x.shape[-1])
334+
r = lora_b_stacked.size(-1)
335+
if buffer is None:
336+
# We set the buffer to be float32 by default, consistent with the
337+
# triton op
338+
buffer = torch.zeros((x.size(0), r),
339+
dtype=torch.float32,
340+
device=x.device)
341+
# LogitsProcessorWithLoRA always using bgmv.
342+
bgmv_shrink(x, lora_a_stacked, buffer, self.sampler_indices, scale)
343+
bgmv_expand(buffer,
344+
lora_b_stacked,
345+
y,
346+
self.sampler_indices,
347+
add_inputs=True)
348+
y = y.view_as(y_org)

vllm_ascend/platform.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,10 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
143143
return "vllm_ascend.attention.attention.AscendMLAAttentionBackend"
144144
return "vllm_ascend.attention.attention.AscendAttentionBackend"
145145

146+
@classmethod
147+
def get_punica_wrapper(cls) -> str:
148+
return "vllm_ascend.lora.punica_wrapper.punica_npu.PunicaWrapperNPU"
149+
146150
@classmethod
147151
def get_current_memory_usage(cls,
148152
device: Optional[torch.types.Device] = None

0 commit comments

Comments
 (0)