Skip to content

Commit a88cf11

Browse files
committed
optimization to skip computing extra metadata if all requests on decode
Signed-off-by: Yong Hoon Shin <yhshin@meta.com>
1 parent dae0441 commit a88cf11

File tree

2 files changed

+13
-1
lines changed

2 files changed

+13
-1
lines changed

tests/v1/e2e/test_kv_sharing_skip_prefill.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,7 @@ def test_kv_sharing_skip_prefill(
278278
test_prompts: list[list[dict[str, Any]]],
279279
):
280280
ModelRegistry.register_model("Qwen2ForCausalLM", TestQwen2ForCausalLM)
281-
sampling_params = SamplingParams(temperature=0.0, max_tokens=42)
281+
sampling_params = SamplingParams(temperature=0.0, max_tokens=100)
282282
prompts = [prompt[0]['content'] for prompt in test_prompts]
283283
compilation_config = CompilationConfig(
284284
level=CompilationLevel.PIECEWISE

vllm/v1/worker/gpu_model_runner.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -582,6 +582,18 @@ def _calc_decode_indices(self, logits_indices: torch.Tensor):
582582
"""
583583
if not self.cache_config.kv_sharing_skip_prefill:
584584
return None
585+
586+
num_decode_reqs = 0
587+
for req_index in range(self.input_batch.num_reqs):
588+
if self.input_batch.num_computed_tokens_cpu[
589+
req_index] >= self.input_batch.num_prompt_tokens[
590+
req_index]:
591+
num_decode_reqs += 1
592+
593+
if self.input_batch.num_reqs == num_decode_reqs:
594+
# All requests are on decode, skip calculate decode only indices
595+
return None
596+
585597
num_decodes = logits_indices.shape[0]
586598
# TODO(sarckk): With chunked prefills, logits_indices contains
587599
# indices for partial requests though we do not sample any token

0 commit comments

Comments
 (0)