Skip to content

Commit 6c663df

Browse files
varun-sundar-rabindranathVarun Sundar Rabindranath
andauthored
[misc] LoRA - Skip LoRA kernels when not required (#15152)
Signed-off-by: Varun Sundar Rabindranath <varun@neuralmagic.com> Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
1 parent 33437bc commit 6c663df

File tree

4 files changed

+113
-33
lines changed

4 files changed

+113
-33
lines changed

vllm/lora/ops/triton_ops/lora_expand.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ def _lora_expand(
136136
num_tokens_per_lora: torch.Tensor, # shape [max-loras + 1]
137137
lora_token_start_loc: torch.Tensor, # shape [max-loras + 2]
138138
lora_ids: torch.Tensor, # shape [max-loras + 1]
139+
no_lora_flag_cpu: torch.Tensor, # shape [1]
139140
offset_start: int = 0,
140141
add_inputs: bool = False,
141142
) -> None:
@@ -157,11 +158,19 @@ def _lora_expand(
157158
identifies the the region in token_indices_sorted_by_lora_ids that
158159
LoRA lora_ids[i] should process.
159160
lora_ids (torch.Tensor): LoRA ids to process.
161+
no_lora_flag_cpu (torch.Tensor): A CPU tensor of size 1, that indicates
162+
if there are any requests that require LoRA.
160163
offset_start (int, optional): Offset start for output_tensor.
161164
Defaults to 0.
162165
add_inputs (bool, optional): Whether to add the input tensor to the
163166
output tensor. Defaults to False.
164167
"""
168+
169+
assert no_lora_flag_cpu.numel() == 1
170+
if no_lora_flag_cpu.item():
171+
# None of the inputs require LoRA.
172+
return
173+
165174
assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32]
166175
for weight in lora_b_weights:
167176
assert weight.dtype in [torch.float16, torch.bfloat16]
@@ -170,6 +179,8 @@ def _lora_expand(
170179
assert output_tensor.is_contiguous()
171180

172181
# metadata sanity check.
182+
M = inputs.size(1)
183+
assert token_lora_mapping.size(0) == M
173184
assert token_lora_mapping.size(0) == token_indices_sorted_by_lora_ids.size(
174185
0)
175186
assert lora_ids.size(0) == num_tokens_per_lora.size(0)
@@ -181,7 +192,6 @@ def _lora_expand(
181192
inputs.device)
182193

183194
K = lora_b_weights[0].shape[-1] # K= rank
184-
M = inputs.size(1)
185195
ADD_INPUTS = add_inputs
186196
MAX_LORAS = lora_ids.size(0)
187197
CAST_TYPE = False
@@ -263,6 +273,7 @@ def _lora_expand_fake(
263273
num_tokens_per_lora: torch.Tensor,
264274
lora_token_start_loc: torch.Tensor,
265275
lora_ids: torch.Tensor,
276+
no_lora_flag_cpu: torch.Tensor,
266277
offset_start: int = 0,
267278
add_inputs: bool = False,
268279
) -> None:

vllm/lora/ops/triton_ops/lora_kernel_metadata.py

Lines changed: 36 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,17 @@ class LoRAKernelMeta:
1717
num_tokens_per_lora: torch.Tensor
1818
lora_token_start_loc: torch.Tensor
1919

20+
# The V1 architecture uses the traced torch.compile graphs to execute
21+
# a forward pass. Things to note about this process,
22+
# 1. The tracing infers all python scalar datatype objects into a constant
23+
# value.
24+
# 2. The tracing cannot handle dynamic control flow. (dynamic control flow
25+
# is an experimental feature in pytorch)
26+
# 3. The internals of torch.ops functions are not traced.
27+
# We disguise the "no_lora" flag as a cpu tensor and leverage point number 3
28+
# to early exit from inside the lora_expand / lora_shrink torch operation.
29+
no_lora_flag_cpu: torch.Tensor
30+
2031
@staticmethod
2132
def make(max_loras: int, max_num_tokens: int,
2233
device: Union[torch.device, str]) -> "LoRAKernelMeta":
@@ -47,17 +58,24 @@ def make(max_loras: int, max_num_tokens: int,
4758
lora_token_start_loc = torch.zeros(max_loras + 2,
4859
dtype=torch.int32,
4960
device=device)
61+
62+
no_lora_flag_cpu = torch.tensor([False],
63+
dtype=torch.bool,
64+
device='cpu')
65+
5066
return LoRAKernelMeta(
5167
token_lora_mapping=token_lora_mapping,
5268
token_indices_sorted_by_lora_ids=token_indices_sorted_by_lora_ids,
5369
active_lora_ids=active_lora_ids,
5470
num_tokens_per_lora=num_tokens_per_lora,
55-
lora_token_start_loc=lora_token_start_loc)
71+
lora_token_start_loc=lora_token_start_loc,
72+
no_lora_flag_cpu=no_lora_flag_cpu)
5673

5774
def _reset(self):
5875
self.active_lora_ids.fill_(-1)
5976
self.num_tokens_per_lora.fill_(0)
6077
self.lora_token_start_loc.fill_(0)
78+
self.no_lora_flag_cpu.fill_(False)
6179

6280
def prepare_tensors(self, token_lora_mapping: torch.Tensor) -> None:
6381
"""
@@ -70,6 +88,14 @@ def prepare_tensors(self, token_lora_mapping: torch.Tensor) -> None:
7088

7189
self._reset()
7290

91+
# Check and record no-lora case.
92+
no_lora = torch.all(token_lora_mapping == -1)
93+
self.no_lora_flag_cpu[0] = no_lora
94+
95+
if no_lora:
96+
# Early exit. LoRA kernels will not be run.
97+
return
98+
7399
num_tokens = token_lora_mapping.size(0)
74100

75101
# copy token lora mapping
@@ -100,7 +126,7 @@ def prepare_tensors(self, token_lora_mapping: torch.Tensor) -> None:
100126
def meta_args(
101127
self, token_nums: int
102128
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
103-
torch.Tensor]:
129+
torch.Tensor, torch.Tensor]:
104130
"""
105131
This function returns the kernel metadata required for the current
106132
forward pass execution of the kernel. The function returns all the
@@ -111,7 +137,11 @@ def meta_args(
111137
token_nums (int): Number of input tokens in the current forward
112138
pass.
113139
"""
114-
return (self.token_lora_mapping[:token_nums],
115-
self.token_indices_sorted_by_lora_ids[:token_nums],
116-
self.num_tokens_per_lora, self.lora_token_start_loc,
117-
self.active_lora_ids)
140+
return (
141+
self.token_lora_mapping[:token_nums],
142+
self.token_indices_sorted_by_lora_ids[:token_nums],
143+
self.num_tokens_per_lora,
144+
self.lora_token_start_loc,
145+
self.active_lora_ids,
146+
self.no_lora_flag_cpu,
147+
)

vllm/lora/ops/triton_ops/lora_shrink.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ def _lora_shrink(
106106
num_tokens_per_lora: torch.Tensor, # shape [max-loras + 1]
107107
lora_token_start_loc: torch.Tensor, # shape [max-loras + 2]
108108
lora_ids: torch.Tensor, # shape [max-loras + 1]
109+
no_lora_flag_cpu: torch.Tensor, # shape [1]
109110
scaling: float,
110111
) -> None:
111112
"""
@@ -126,8 +127,16 @@ def _lora_shrink(
126127
identifies the region in token_indices_sorted_by_lora_ids that
127128
LoRA lora_ids[i] should process.
128129
lora_ids (torch.Tensor): LoRA ids to process.
130+
no_lora_flag_cpu (torch.Tensor): A CPU tensor of size 1, that indicates
131+
if there are any requests that require LoRA.
129132
scaling (float): Scaling factor.
130133
"""
134+
135+
assert no_lora_flag_cpu.numel() == 1
136+
if no_lora_flag_cpu.item():
137+
# None of the inputs require LoRA.
138+
return
139+
131140
assert inputs.dtype == lora_a_weights[0].dtype
132141
assert inputs.dtype in [torch.float16, torch.bfloat16]
133142
for weight in lora_a_weights:
@@ -138,6 +147,8 @@ def _lora_shrink(
138147
assert output_tensor.is_contiguous()
139148

140149
# metadata sanity check
150+
M = inputs.size(0)
151+
assert token_lora_mapping.size(0) == M
141152
assert token_lora_mapping.size(0) == token_indices_sorted_by_lora_ids.size(
142153
0)
143154
assert lora_ids.size(0) == num_tokens_per_lora.size(0)
@@ -146,7 +157,6 @@ def _lora_shrink(
146157
(lora_ptr_tensor, lora_strides_d0, lora_strides_d1,
147158
lora_strides_d2) = _get_lora_a_ptr(lora_a_weights, inputs.device)
148159
N, K = lora_a_weights[0].shape[-2:] # K=hidden_size,N=rank
149-
M = inputs.size(0)
150160
NUM_SLICES = len(lora_a_weights)
151161
MAX_LORAS = lora_ids.size(0)
152162

@@ -218,6 +228,7 @@ def _lora_shrink_fake(
218228
num_tokens_per_lora: torch.Tensor,
219229
lora_token_start_loc: torch.Tensor,
220230
lora_ids: torch.Tensor,
231+
no_lora_flag_cpu: torch.Tensor,
221232
scaling: float,
222233
) -> None:
223234
return

vllm/worker/model_runner.py

Lines changed: 53 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1242,6 +1242,29 @@ def profile_run(self) -> None:
12421242
max_num_seqs = self.scheduler_config.max_num_seqs
12431243
self._dummy_run(max_num_batched_tokens, max_num_seqs)
12441244

1245+
def _add_dummy_loras(self, num_loras: int) -> list[LoRARequest]:
1246+
assert num_loras > 0
1247+
assert self.lora_manager is not None
1248+
1249+
dummy_lora_requests: list[LoRARequest] = []
1250+
with self.lora_manager.dummy_lora_cache():
1251+
for idx in range(num_loras):
1252+
lora_id = idx + 1
1253+
dummy_lora_request = LoRARequest(
1254+
lora_name=f"warmup_{lora_id}",
1255+
lora_int_id=lora_id,
1256+
lora_path="/not/a/real/path",
1257+
)
1258+
self.lora_manager.add_dummy_lora(dummy_lora_request,
1259+
rank=LORA_WARMUP_RANK)
1260+
dummy_lora_requests.append(dummy_lora_request)
1261+
return dummy_lora_requests
1262+
1263+
def _remove_dummy_loras(self):
1264+
# Remove dummy loras.
1265+
assert self.lora_manager is not None
1266+
self.remove_all_loras()
1267+
12451268
def _dummy_run(self,
12461269
max_num_batched_tokens: int,
12471270
max_num_seqs: int = 1) -> None:
@@ -1251,28 +1274,20 @@ def _dummy_run(self,
12511274
SamplingParams(top_p=0.99, top_k=self.vocab_size - 1)
12521275

12531276
# This represents the maximum number of different requests
1254-
# that will have unique loras, an therefore the max amount of memory
1255-
# consumption create dummy lora request copies from the lora request
1256-
# passed in, which contains a lora from the lora warmup path.
1277+
# that will have unique loras, and therefore the max amount of
1278+
# memory consumption. Create dummy lora request copies from the
1279+
# lora request passed in, which contains a lora from the lora
1280+
# warmup path.
12571281
dummy_lora_requests: List[LoRARequest] = []
12581282
dummy_lora_requests_per_seq: List[LoRARequest] = []
12591283
if self.lora_config:
1260-
assert self.lora_manager is not None
1261-
with self.lora_manager.dummy_lora_cache():
1262-
for idx in range(self.lora_config.max_loras):
1263-
lora_id = idx + 1
1264-
dummy_lora_request = LoRARequest(
1265-
lora_name=f"warmup_{lora_id}",
1266-
lora_int_id=lora_id,
1267-
lora_path="/not/a/real/path",
1268-
)
1269-
self.lora_manager.add_dummy_lora(dummy_lora_request,
1270-
rank=LORA_WARMUP_RANK)
1271-
dummy_lora_requests.append(dummy_lora_request)
1272-
dummy_lora_requests_per_seq = [
1273-
dummy_lora_requests[idx % len(dummy_lora_requests)]
1274-
for idx in range(max_num_seqs)
1275-
]
1284+
dummy_lora_requests = self._add_dummy_loras(
1285+
self.lora_config.max_loras)
1286+
assert len(dummy_lora_requests) == self.lora_config.max_loras
1287+
dummy_lora_requests_per_seq = [
1288+
dummy_lora_requests[idx % len(dummy_lora_requests)]
1289+
for idx in range(max_num_seqs)
1290+
]
12761291

12771292
# Profile memory usage with max_num_sequences sequences and the
12781293
# total number of tokens equal to max_num_batched_tokens.
@@ -1354,9 +1369,8 @@ def _dummy_run(self,
13541369
self.execute_model(model_input, kv_caches, intermediate_tensors)
13551370
torch.cuda.synchronize()
13561371
if self.lora_config:
1357-
# Remove dummy loras.
1358-
assert self.lora_manager is not None
1359-
self.remove_all_loras()
1372+
self._remove_dummy_loras()
1373+
13601374
return
13611375

13621376
def remove_all_loras(self):
@@ -1479,6 +1493,16 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None:
14791493
dtype=self.model_config.dtype,
14801494
device=self.device)
14811495

1496+
dummy_lora_id: Optional[int] = None
1497+
dummy_lora_request: LoRARequest = []
1498+
if self.lora_config:
1499+
# The goal is to capture the LoRA kernels in cuda graphs.
1500+
# for this purpose, as single dummy lora is sufficient.
1501+
dummy_lora_requests = self._add_dummy_loras(num_loras=1)
1502+
assert len(dummy_lora_requests) == 1
1503+
dummy_lora_request = dummy_lora_requests[0]
1504+
dummy_lora_id = dummy_lora_request.lora_int_id
1505+
14821506
with self.attn_state.graph_capture(max_batch_size), graph_capture(
14831507
self.device) as graph_capture_context:
14841508
# NOTE: Capturing the largest batch size first may help reduce the
@@ -1503,10 +1527,11 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None:
15031527
attn_metadata.enable_kv_scales_calculation = False
15041528
if self.lora_config:
15051529
lora_mapping = LoRAMapping(
1506-
**dict(index_mapping=[0] * batch_size,
1507-
prompt_mapping=[0] * batch_size,
1530+
**dict(index_mapping=[dummy_lora_id] * batch_size,
1531+
prompt_mapping=[dummy_lora_id] * batch_size,
15081532
is_prefill=False))
1509-
self.set_active_loras(set(), lora_mapping)
1533+
self.set_active_loras(set([dummy_lora_request]),
1534+
lora_mapping)
15101535

15111536
if self.prompt_adapter_config:
15121537
prompt_adapter_mapping = PromptAdapterMapping(
@@ -1562,6 +1587,9 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None:
15621587
self.graph_runners[virtual_engine][batch_size] = (
15631588
graph_runner)
15641589

1590+
if self.lora_config:
1591+
self._remove_dummy_loras()
1592+
15651593
end_time = time.perf_counter()
15661594
end_free_gpu_memory = torch.cuda.mem_get_info()[0]
15671595
elapsed_time = end_time - start_time

0 commit comments

Comments
 (0)