Skip to content

Commit 5c188cf

Browse files
varun-sundar-rabindranathVarun Sundar Rabindranath
authored andcommitted
[misc] LoRA - Skip LoRA kernels when not required (vllm-project#15152)
Signed-off-by: Varun Sundar Rabindranath <varun@neuralmagic.com> Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
1 parent cae98da commit 5c188cf

File tree

4 files changed

+113
-33
lines changed

4 files changed

+113
-33
lines changed

vllm/lora/ops/triton_ops/lora_expand.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ def _lora_expand(
136136
num_tokens_per_lora: torch.Tensor, # shape [max-loras + 1]
137137
lora_token_start_loc: torch.Tensor, # shape [max-loras + 2]
138138
lora_ids: torch.Tensor, # shape [max-loras + 1]
139+
no_lora_flag_cpu: torch.Tensor, # shape [1]
139140
offset_start: int = 0,
140141
add_inputs: bool = False,
141142
) -> None:
@@ -157,11 +158,19 @@ def _lora_expand(
157158
identifies the the region in token_indices_sorted_by_lora_ids that
158159
LoRA lora_ids[i] should process.
159160
lora_ids (torch.Tensor): LoRA ids to process.
161+
no_lora_flag_cpu (torch.Tensor): A CPU tensor of size 1, that indicates
162+
if there are any requests that require LoRA.
160163
offset_start (int, optional): Offset start for output_tensor.
161164
Defaults to 0.
162165
add_inputs (bool, optional): Whether to add the input tensor to the
163166
output tensor. Defaults to False.
164167
"""
168+
169+
assert no_lora_flag_cpu.numel() == 1
170+
if no_lora_flag_cpu.item():
171+
# None of the inputs require LoRA.
172+
return
173+
165174
assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32]
166175
for weight in lora_b_weights:
167176
assert weight.dtype in [torch.float16, torch.bfloat16]
@@ -170,6 +179,8 @@ def _lora_expand(
170179
assert output_tensor.is_contiguous()
171180

172181
# metadata sanity check.
182+
M = inputs.size(1)
183+
assert token_lora_mapping.size(0) == M
173184
assert token_lora_mapping.size(0) == token_indices_sorted_by_lora_ids.size(
174185
0)
175186
assert lora_ids.size(0) == num_tokens_per_lora.size(0)
@@ -181,7 +192,6 @@ def _lora_expand(
181192
inputs.device)
182193

183194
K = lora_b_weights[0].shape[-1] # K= rank
184-
M = inputs.size(1)
185195
ADD_INPUTS = add_inputs
186196
MAX_LORAS = lora_ids.size(0)
187197
CAST_TYPE = False
@@ -263,6 +273,7 @@ def _lora_expand_fake(
263273
num_tokens_per_lora: torch.Tensor,
264274
lora_token_start_loc: torch.Tensor,
265275
lora_ids: torch.Tensor,
276+
no_lora_flag_cpu: torch.Tensor,
266277
offset_start: int = 0,
267278
add_inputs: bool = False,
268279
) -> None:

vllm/lora/ops/triton_ops/lora_kernel_metadata.py

Lines changed: 36 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,17 @@ class LoRAKernelMeta:
1717
num_tokens_per_lora: torch.Tensor
1818
lora_token_start_loc: torch.Tensor
1919

20+
# The V1 architecture uses the traced torch.compile graphs to execute
21+
# a forward pass. Things to note about this process,
22+
# 1. The tracing infers all python scalar datatype objects into a constant
23+
# value.
24+
# 2. The tracing cannot handle dynamic control flow. (dynamic control flow
25+
# is an experimental feature in pytorch)
26+
# 3. The internals of torch.ops functions are not traced.
27+
# We disguise the "no_lora" flag as a cpu tensor and leverage point number 3
28+
# to early exit from inside the lora_expand / lora_shrink torch operation.
29+
no_lora_flag_cpu: torch.Tensor
30+
2031
@staticmethod
2132
def make(max_loras: int, max_num_tokens: int,
2233
device: Union[torch.device, str]) -> "LoRAKernelMeta":
@@ -47,17 +58,24 @@ def make(max_loras: int, max_num_tokens: int,
4758
lora_token_start_loc = torch.zeros(max_loras + 2,
4859
dtype=torch.int32,
4960
device=device)
61+
62+
no_lora_flag_cpu = torch.tensor([False],
63+
dtype=torch.bool,
64+
device='cpu')
65+
5066
return LoRAKernelMeta(
5167
token_lora_mapping=token_lora_mapping,
5268
token_indices_sorted_by_lora_ids=token_indices_sorted_by_lora_ids,
5369
active_lora_ids=active_lora_ids,
5470
num_tokens_per_lora=num_tokens_per_lora,
55-
lora_token_start_loc=lora_token_start_loc)
71+
lora_token_start_loc=lora_token_start_loc,
72+
no_lora_flag_cpu=no_lora_flag_cpu)
5673

5774
def _reset(self):
5875
self.active_lora_ids.fill_(-1)
5976
self.num_tokens_per_lora.fill_(0)
6077
self.lora_token_start_loc.fill_(0)
78+
self.no_lora_flag_cpu.fill_(False)
6179

6280
def prepare_tensors(self, token_lora_mapping: torch.Tensor) -> None:
6381
"""
@@ -70,6 +88,14 @@ def prepare_tensors(self, token_lora_mapping: torch.Tensor) -> None:
7088

7189
self._reset()
7290

91+
# Check and record no-lora case.
92+
no_lora = torch.all(token_lora_mapping == -1)
93+
self.no_lora_flag_cpu[0] = no_lora
94+
95+
if no_lora:
96+
# Early exit. LoRA kernels will not be run.
97+
return
98+
7399
num_tokens = token_lora_mapping.size(0)
74100

75101
# copy token lora mapping
@@ -100,7 +126,7 @@ def prepare_tensors(self, token_lora_mapping: torch.Tensor) -> None:
100126
def meta_args(
101127
self, token_nums: int
102128
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
103-
torch.Tensor]:
129+
torch.Tensor, torch.Tensor]:
104130
"""
105131
This function returns the kernel metadata required for the current
106132
forward pass execution of the kernel. The function returns all the
@@ -111,7 +137,11 @@ def meta_args(
111137
token_nums (int): Number of input tokens in the current forward
112138
pass.
113139
"""
114-
return (self.token_lora_mapping[:token_nums],
115-
self.token_indices_sorted_by_lora_ids[:token_nums],
116-
self.num_tokens_per_lora, self.lora_token_start_loc,
117-
self.active_lora_ids)
140+
return (
141+
self.token_lora_mapping[:token_nums],
142+
self.token_indices_sorted_by_lora_ids[:token_nums],
143+
self.num_tokens_per_lora,
144+
self.lora_token_start_loc,
145+
self.active_lora_ids,
146+
self.no_lora_flag_cpu,
147+
)

vllm/lora/ops/triton_ops/lora_shrink.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ def _lora_shrink(
106106
num_tokens_per_lora: torch.Tensor, # shape [max-loras + 1]
107107
lora_token_start_loc: torch.Tensor, # shape [max-loras + 2]
108108
lora_ids: torch.Tensor, # shape [max-loras + 1]
109+
no_lora_flag_cpu: torch.Tensor, # shape [1]
109110
scaling: float,
110111
) -> None:
111112
"""
@@ -126,8 +127,16 @@ def _lora_shrink(
126127
identifies the region in token_indices_sorted_by_lora_ids that
127128
LoRA lora_ids[i] should process.
128129
lora_ids (torch.Tensor): LoRA ids to process.
130+
no_lora_flag_cpu (torch.Tensor): A CPU tensor of size 1, that indicates
131+
if there are any requests that require LoRA.
129132
scaling (float): Scaling factor.
130133
"""
134+
135+
assert no_lora_flag_cpu.numel() == 1
136+
if no_lora_flag_cpu.item():
137+
# None of the inputs require LoRA.
138+
return
139+
131140
assert inputs.dtype == lora_a_weights[0].dtype
132141
assert inputs.dtype in [torch.float16, torch.bfloat16]
133142
for weight in lora_a_weights:
@@ -138,6 +147,8 @@ def _lora_shrink(
138147
assert output_tensor.is_contiguous()
139148

140149
# metadata sanity check
150+
M = inputs.size(0)
151+
assert token_lora_mapping.size(0) == M
141152
assert token_lora_mapping.size(0) == token_indices_sorted_by_lora_ids.size(
142153
0)
143154
assert lora_ids.size(0) == num_tokens_per_lora.size(0)
@@ -146,7 +157,6 @@ def _lora_shrink(
146157
(lora_ptr_tensor, lora_strides_d0, lora_strides_d1,
147158
lora_strides_d2) = _get_lora_a_ptr(lora_a_weights, inputs.device)
148159
N, K = lora_a_weights[0].shape[-2:] # K=hidden_size,N=rank
149-
M = inputs.size(0)
150160
NUM_SLICES = len(lora_a_weights)
151161
MAX_LORAS = lora_ids.size(0)
152162

@@ -218,6 +228,7 @@ def _lora_shrink_fake(
218228
num_tokens_per_lora: torch.Tensor,
219229
lora_token_start_loc: torch.Tensor,
220230
lora_ids: torch.Tensor,
231+
no_lora_flag_cpu: torch.Tensor,
221232
scaling: float,
222233
) -> None:
223234
return

vllm/worker/model_runner.py

Lines changed: 53 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1243,6 +1243,29 @@ def profile_run(self) -> None:
12431243
max_num_seqs = self.scheduler_config.max_num_seqs
12441244
self._dummy_run(max_num_batched_tokens, max_num_seqs)
12451245

1246+
def _add_dummy_loras(self, num_loras: int) -> list[LoRARequest]:
1247+
assert num_loras > 0
1248+
assert self.lora_manager is not None
1249+
1250+
dummy_lora_requests: list[LoRARequest] = []
1251+
with self.lora_manager.dummy_lora_cache():
1252+
for idx in range(num_loras):
1253+
lora_id = idx + 1
1254+
dummy_lora_request = LoRARequest(
1255+
lora_name=f"warmup_{lora_id}",
1256+
lora_int_id=lora_id,
1257+
lora_path="/not/a/real/path",
1258+
)
1259+
self.lora_manager.add_dummy_lora(dummy_lora_request,
1260+
rank=LORA_WARMUP_RANK)
1261+
dummy_lora_requests.append(dummy_lora_request)
1262+
return dummy_lora_requests
1263+
1264+
def _remove_dummy_loras(self):
1265+
# Remove dummy loras.
1266+
assert self.lora_manager is not None
1267+
self.remove_all_loras()
1268+
12461269
def _dummy_run(self,
12471270
max_num_batched_tokens: int,
12481271
max_num_seqs: int = 1) -> None:
@@ -1252,28 +1275,20 @@ def _dummy_run(self,
12521275
SamplingParams(top_p=0.99, top_k=self.vocab_size - 1)
12531276

12541277
# This represents the maximum number of different requests
1255-
# that will have unique loras, an therefore the max amount of memory
1256-
# consumption create dummy lora request copies from the lora request
1257-
# passed in, which contains a lora from the lora warmup path.
1278+
# that will have unique loras, and therefore the max amount of
1279+
# memory consumption. Create dummy lora request copies from the
1280+
# lora request passed in, which contains a lora from the lora
1281+
# warmup path.
12581282
dummy_lora_requests: List[LoRARequest] = []
12591283
dummy_lora_requests_per_seq: List[LoRARequest] = []
12601284
if self.lora_config:
1261-
assert self.lora_manager is not None
1262-
with self.lora_manager.dummy_lora_cache():
1263-
for idx in range(self.lora_config.max_loras):
1264-
lora_id = idx + 1
1265-
dummy_lora_request = LoRARequest(
1266-
lora_name=f"warmup_{lora_id}",
1267-
lora_int_id=lora_id,
1268-
lora_path="/not/a/real/path",
1269-
)
1270-
self.lora_manager.add_dummy_lora(dummy_lora_request,
1271-
rank=LORA_WARMUP_RANK)
1272-
dummy_lora_requests.append(dummy_lora_request)
1273-
dummy_lora_requests_per_seq = [
1274-
dummy_lora_requests[idx % len(dummy_lora_requests)]
1275-
for idx in range(max_num_seqs)
1276-
]
1285+
dummy_lora_requests = self._add_dummy_loras(
1286+
self.lora_config.max_loras)
1287+
assert len(dummy_lora_requests) == self.lora_config.max_loras
1288+
dummy_lora_requests_per_seq = [
1289+
dummy_lora_requests[idx % len(dummy_lora_requests)]
1290+
for idx in range(max_num_seqs)
1291+
]
12771292

12781293
# Profile memory usage with max_num_sequences sequences and the
12791294
# total number of tokens equal to max_num_batched_tokens.
@@ -1355,9 +1370,8 @@ def _dummy_run(self,
13551370
self.execute_model(model_input, kv_caches, intermediate_tensors)
13561371
torch.cuda.synchronize()
13571372
if self.lora_config:
1358-
# Remove dummy loras.
1359-
assert self.lora_manager is not None
1360-
self.remove_all_loras()
1373+
self._remove_dummy_loras()
1374+
13611375
return
13621376

13631377
def remove_all_loras(self):
@@ -1480,6 +1494,16 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None:
14801494
dtype=self.model_config.dtype,
14811495
device=self.device)
14821496

1497+
dummy_lora_id: Optional[int] = None
1498+
dummy_lora_request: LoRARequest = []
1499+
if self.lora_config:
1500+
# The goal is to capture the LoRA kernels in cuda graphs.
1501+
# for this purpose, as single dummy lora is sufficient.
1502+
dummy_lora_requests = self._add_dummy_loras(num_loras=1)
1503+
assert len(dummy_lora_requests) == 1
1504+
dummy_lora_request = dummy_lora_requests[0]
1505+
dummy_lora_id = dummy_lora_request.lora_int_id
1506+
14831507
with self.attn_state.graph_capture(max_batch_size), graph_capture(
14841508
self.device) as graph_capture_context:
14851509
# NOTE: Capturing the largest batch size first may help reduce the
@@ -1504,10 +1528,11 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None:
15041528
attn_metadata.enable_kv_scales_calculation = False
15051529
if self.lora_config:
15061530
lora_mapping = LoRAMapping(
1507-
**dict(index_mapping=[0] * batch_size,
1508-
prompt_mapping=[0] * batch_size,
1531+
**dict(index_mapping=[dummy_lora_id] * batch_size,
1532+
prompt_mapping=[dummy_lora_id] * batch_size,
15091533
is_prefill=False))
1510-
self.set_active_loras(set(), lora_mapping)
1534+
self.set_active_loras(set([dummy_lora_request]),
1535+
lora_mapping)
15111536

15121537
if self.prompt_adapter_config:
15131538
prompt_adapter_mapping = PromptAdapterMapping(
@@ -1563,6 +1588,9 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None:
15631588
self.graph_runners[virtual_engine][batch_size] = (
15641589
graph_runner)
15651590

1591+
if self.lora_config:
1592+
self._remove_dummy_loras()
1593+
15661594
end_time = time.perf_counter()
15671595
end_free_gpu_memory = torch.cuda.mem_get_info()[0]
15681596
elapsed_time = end_time - start_time

0 commit comments

Comments
 (0)