Skip to content

Commit 26ab61d

Browse files
committed
optimization to skip computing extra metadata if all requests on decode
Signed-off-by: Yong Hoon Shin <yhshin@meta.com>
1 parent aa670a0 commit 26ab61d

File tree

2 files changed

+13
-1
lines changed

2 files changed

+13
-1
lines changed

tests/v1/e2e/test_kv_sharing_skip_prefill.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,7 @@ def test_kv_sharing_skip_prefill(
278278
test_prompts: list[list[dict[str, Any]]],
279279
):
280280
ModelRegistry.register_model("Qwen2ForCausalLM", TestQwen2ForCausalLM)
281-
sampling_params = SamplingParams(temperature=0.0, max_tokens=42)
281+
sampling_params = SamplingParams(temperature=0.0, max_tokens=100)
282282
prompts = [prompt[0]['content'] for prompt in test_prompts]
283283
compilation_config = CompilationConfig(
284284
level=CompilationLevel.PIECEWISE

vllm/v1/worker/gpu_model_runner.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -602,6 +602,18 @@ def _calc_decode_indices(self, logits_indices: torch.Tensor):
602602
"""
603603
if not self.cache_config.kv_sharing_skip_prefill:
604604
return None
605+
606+
num_decode_reqs = 0
607+
for req_index in range(self.input_batch.num_reqs):
608+
if self.input_batch.num_computed_tokens_cpu[
609+
req_index] >= self.input_batch.num_prompt_tokens[
610+
req_index]:
611+
num_decode_reqs += 1
612+
613+
if self.input_batch.num_reqs == num_decode_reqs:
614+
# All requests are on decode, skip calculate decode only indices
615+
return None
616+
605617
num_decodes = logits_indices.shape[0]
606618
# TODO(sarckk): With chunked prefills, logits_indices contains
607619
# indices for partial requests though we do not sample any token

0 commit comments

Comments
 (0)