Skip to content

Commit 65ac4e1

Browse files
zou3519Chen-zexi
authored andcommitted
[Bugfix] Fix torch.compile x LoRA for PyTorch 2.8 (#20823)
Signed-off-by: rzou <zou3519@gmail.com>
1 parent 08b7dcf commit 65ac4e1

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

vllm/lora/layers.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -240,17 +240,19 @@ def set_lora(
240240
def forward(self, x: torch.Tensor) -> torch.Tensor:
241241
added_tokens_mask = torch.where(x > self.base_layer.org_vocab_size - 1,
242242
1, 0)
243-
embeddings_indices = torch.narrow(
244-
self.punica_wrapper._embeddings_indices, 1, 0, x.size(0))
245243

246-
indices = embeddings_indices[1]
244+
# NB: Don't use torch.narrow here. torch.narrow triggers some
245+
# Dynamic Shape specialization in torch.compile
246+
num_tokens = x.shape[0]
247+
indices_1 = self.punica_wrapper._embeddings_indices[1][:num_tokens]
248+
indices_0 = self.punica_wrapper._embeddings_indices[0][:num_tokens]
249+
247250
full_lora_a_embeddings = F.embedding(
248-
x + indices,
251+
x + indices_1,
249252
self.lora_a_stacked_2d,
250253
)
251-
indices = embeddings_indices[0]
252254
full_output = self.base_layer.forward(x +
253-
(indices * added_tokens_mask))
255+
(indices_0 * added_tokens_mask))
254256

255257
full_output_org = full_output
256258
if full_output.ndim == 3:

0 commit comments

Comments
 (0)