Skip to content

Commit cc32a06

Browse files
committed
optimization to skip computing extra metadata if all requests on decode
Signed-off-by: Yong Hoon Shin <yhshin@meta.com>
1 parent fd764de commit cc32a06

File tree

2 files changed

+13
-1
lines changed

2 files changed

+13
-1
lines changed

tests/v1/e2e/test_kv_sharing_skip_prefill.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,7 @@ def test_kv_sharing_skip_prefill(
278278
test_prompts: list[list[dict[str, Any]]],
279279
):
280280
ModelRegistry.register_model("Qwen2ForCausalLM", TestQwen2ForCausalLM)
281-
sampling_params = SamplingParams(temperature=0.0, max_tokens=42)
281+
sampling_params = SamplingParams(temperature=0.0, max_tokens=100)
282282
prompts = [prompt[0]['content'] for prompt in test_prompts]
283283
compilation_config = CompilationConfig(
284284
level=CompilationLevel.PIECEWISE

vllm/v1/worker/gpu_model_runner.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -585,6 +585,18 @@ def _calc_decode_indices(self, logits_indices: torch.Tensor):
585585
"""
586586
if not self.cache_config.kv_sharing_skip_prefill:
587587
return None
588+
589+
num_decode_reqs = 0
590+
for req_index in range(self.input_batch.num_reqs):
591+
if self.input_batch.num_computed_tokens_cpu[
592+
req_index] >= self.input_batch.num_prompt_tokens[
593+
req_index]:
594+
num_decode_reqs += 1
595+
596+
if self.input_batch.num_reqs == num_decode_reqs:
597+
# All requests are on decode, skip calculate decode only indices
598+
return None
599+
588600
num_decodes = logits_indices.shape[0]
589601
# TODO(sarckk): With chunked prefills, logits_indices contains
590602
# indices for partial requests though we do not sample any token

0 commit comments

Comments
 (0)