Skip to content

Commit a7eca8f

Browse files
committed
skip sample on remote prefill worker
1 parent a3c9862 commit a7eca8f

File tree

1 file changed

+29
-5
lines changed

1 file changed

+29
-5
lines changed

vllm/worker/model_runner.py

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1631,6 +1631,8 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
16311631
ModelInputForGPUWithSamplingMetadata)
16321632
_builder_cls: Type[ModelInputForGPUBuilder] = ModelInputForGPUBuilder
16331633

1634+
_fake_sample_output: Optional[SamplerOutput] = None
1635+
16341636
def make_model_input_from_broadcasted_tensor_dict(
16351637
self,
16361638
tensor_dict: Dict[str, Any],
@@ -1822,11 +1824,16 @@ def execute_model(
18221824
if model_input.async_callback is not None:
18231825
model_input.async_callback()
18241826

1825-
# Sample the next token.
1826-
output: SamplerOutput = self.model.sample(
1827-
logits=logits,
1828-
sampling_metadata=model_input.sampling_metadata,
1829-
)
1827+
# in the producer side of pd disagg scenario, the next tokens are
1828+
# not needed. So we skip it
1829+
if self.need_skip_sampling() and self._fake_sample_output is not None:
1830+
output = self._fake_sample_output
1831+
else:
1832+
# Sample the next token.
1833+
output: SamplerOutput = self.model.sample(
1834+
logits=logits,
1835+
sampling_metadata=model_input.sampling_metadata,
1836+
)
18301837
if (self.observability_config is not None
18311838
and self.observability_config.collect_model_forward_time
18321839
and output is not None):
@@ -1859,6 +1866,13 @@ def execute_model(
18591866

18601867
output.hidden_states = hidden_states
18611868

1869+
# save a fake output
1870+
if (self._fake_sample_output is None
1871+
and output is not None
1872+
and self.need_skip_sampling()):
1873+
1874+
self._fake_sample_output = output
1875+
18621876
return [output]
18631877

18641878
def need_recv_kv(self, model_input, kv_caches) -> bool:
@@ -1889,6 +1903,16 @@ def need_recv_kv(self, model_input, kv_caches) -> bool:
18891903
return self.vllm_config.kv_transfer_config.is_kv_consumer and (
18901904
not is_profile_run) and is_prefill_run
18911905

1906+
def need_skip_sampling(self) -> bool:
1907+
"""
1908+
check whether skip the step of sampling.
1909+
"""
1910+
1911+
if self.vllm_config.kv_transfer_config is None:
1912+
return False
1913+
1914+
return self.vllm_config.kv_transfer_config.get_from_extra_config("skip_sampling", False)
1915+
18921916
def need_send_kv(self, model_input, kv_caches) -> bool:
18931917
"""Check if we need to send kv-cache to the other worker.
18941918
We need to send KV when

0 commit comments

Comments
 (0)