@@ -1631,6 +1631,8 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
1631
1631
ModelInputForGPUWithSamplingMetadata )
1632
1632
_builder_cls : Type [ModelInputForGPUBuilder ] = ModelInputForGPUBuilder
1633
1633
1634
+ _fake_sample_output : Optional [SamplerOutput ] = None
1635
+
1634
1636
def make_model_input_from_broadcasted_tensor_dict (
1635
1637
self ,
1636
1638
tensor_dict : Dict [str , Any ],
@@ -1822,11 +1824,16 @@ def execute_model(
1822
1824
if model_input .async_callback is not None :
1823
1825
model_input .async_callback ()
1824
1826
1825
- # Sample the next token.
1826
- output : SamplerOutput = self .model .sample (
1827
- logits = logits ,
1828
- sampling_metadata = model_input .sampling_metadata ,
1829
- )
1827
+ # in the producer side of pd disagg scenario, the next tokens are
1828
+ # not needed. So we skip it
1829
+ if self .need_skip_sampling () and self ._fake_sample_output is not None :
1830
+ output = self ._fake_sample_output
1831
+ else :
1832
+ # Sample the next token.
1833
+ output : SamplerOutput = self .model .sample (
1834
+ logits = logits ,
1835
+ sampling_metadata = model_input .sampling_metadata ,
1836
+ )
1830
1837
if (self .observability_config is not None
1831
1838
and self .observability_config .collect_model_forward_time
1832
1839
and output is not None ):
@@ -1859,6 +1866,13 @@ def execute_model(
1859
1866
1860
1867
output .hidden_states = hidden_states
1861
1868
1869
+ # save a fake output
1870
+ if (self ._fake_sample_output is None
1871
+ and output is not None
1872
+ and self .need_skip_sampling ()):
1873
+
1874
+ self ._fake_sample_output = output
1875
+
1862
1876
return [output ]
1863
1877
1864
1878
def need_recv_kv (self , model_input , kv_caches ) -> bool :
@@ -1889,6 +1903,16 @@ def need_recv_kv(self, model_input, kv_caches) -> bool:
1889
1903
return self .vllm_config .kv_transfer_config .is_kv_consumer and (
1890
1904
not is_profile_run ) and is_prefill_run
1891
1905
1906
+ def need_skip_sampling (self ) -> bool :
1907
+ """
1908
+ check whether skip the step of sampling.
1909
+ """
1910
+
1911
+ if self .vllm_config .kv_transfer_config is None :
1912
+ return False
1913
+
1914
+ return self .vllm_config .kv_transfer_config .get_from_extra_config ("skip_sampling" , False )
1915
+
1892
1916
def need_send_kv (self , model_input , kv_caches ) -> bool :
1893
1917
"""Check if we need to send kv-cache to the other worker.
1894
1918
We need to send KV when
0 commit comments