Skip to content

Commit 770e5dc

Browse files
authored
[full_graph] Fix query_start_loc padding (#19321)
Signed-off-by: Yinghai Lu <yinghai@thinkingmachines.ai>
1 parent c57c941 commit 770e5dc

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

vllm/v1/worker/gpu_model_runner.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -655,7 +655,10 @@ def _prepare_inputs(
655655

656656
# Fill unused with -1. Needed for reshape_and_cache
657657
self.seq_lens[num_reqs:].fill_(0)
658-
self.query_start_loc[num_reqs + 1:].fill_(-1)
658+
# Note: pad query_start_loc to be non-decreasing, as kernels
659+
# like FlashAttention requires that
660+
self.query_start_loc[num_reqs + 1:].fill_(
661+
self.query_start_loc_cpu[num_reqs].item())
659662

660663
query_start_loc = self.query_start_loc[:num_reqs + 1]
661664
seq_lens = self.seq_lens[:num_reqs]

0 commit comments

Comments
 (0)