Skip to content

Commit 96e95c7

Browse files
committed
fix token selection indexing
Signed-off-by: Leo Tian <leo.tian@centml.ai>
1 parent 65a7e0d commit 96e95c7

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

vllm/v1/worker/gpu_model_runner.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1438,8 +1438,9 @@ def execute_model(
14381438

14391439
# Fill with -1 first (or PLACEHOLDER_ID)
14401440
# tokens selected for every row (valid or not)
1441-
selected_tokens = valid_sampled_token_ids_gpu[:batch,
1442-
last_valid_indices]
1441+
selected_tokens = torch.gather(
1442+
valid_sampled_token_ids_gpu, 1,
1443+
last_valid_indices.unsqueeze(1)).squeeze(1)
14431444

14441445
next_token_ids_gpu = torch.where(
14451446
last_valid_indices != -1, selected_tokens,

0 commit comments

Comments
 (0)