Skip to content

Commit 3879d9c

Browse files
authored
[CI] Fix sample backward compatibility problem (#648)
vllm-project/vllm@b411418 this vllm commit change the sample usage. This PR adapt the change for main and make sure it works for 0.8.4 as well. Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
1 parent d785e78 commit 3879d9c

File tree

3 files changed

+47
-12
lines changed

3 files changed

+47
-12
lines changed

vllm_ascend/worker/draft_model_runner.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
ModelRunnerWrapperBase)
2929

3030
from vllm_ascend.attention.attention import AscendMetadata
31+
from vllm_ascend.utils import vllm_version_is
3132

3233
# A flag to enable debug prints for the updated input tensors
3334
# before each step.
@@ -286,10 +287,17 @@ def execute_model(
286287
if not self.is_driver_worker:
287288
return []
288289
# Sample the next token.
289-
output = self.model.sample(
290-
logits=logits,
291-
sampling_metadata=model_input.sampling_metadata,
292-
)
290+
if vllm_version_is("0.8.4"):
291+
output = self.model.sample(
292+
logits=logits,
293+
sampling_metadata=model_input.sampling_metadata,
294+
)
295+
else:
296+
assert self.sampler is not None
297+
output = self.sampler(
298+
logits=logits,
299+
sampling_metadata=model_input.sampling_metadata,
300+
)
293301
outputs.append(output)
294302

295303
if model_input.attn_metadata.num_prefills == 0 \

vllm_ascend/worker/model_runner.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -937,6 +937,12 @@ def __init__(
937937
SamplingMetadataCache() \
938938
if self.parallel_config.pipeline_parallel_size == 1 else None
939939

940+
if vllm_version_is("0.8.4"):
941+
self.sampler = None
942+
else:
943+
from vllm.model_executor.layers.sampler import get_sampler
944+
self.sampler = get_sampler()
945+
940946
def get_model(self) -> nn.Module:
941947
return self.model
942948

@@ -1404,10 +1410,17 @@ def execute_model(
14041410
model_input.async_callback()
14051411

14061412
# Sample the next token.
1407-
output: SamplerOutput = self.model.sample(
1408-
logits=logits,
1409-
sampling_metadata=model_input.sampling_metadata,
1410-
)
1413+
if vllm_version_is("0.8.4"):
1414+
output = self.model.sample(
1415+
logits=logits,
1416+
sampling_metadata=model_input.sampling_metadata,
1417+
)
1418+
else:
1419+
assert self.sampler is not None
1420+
output = self.sampler(
1421+
logits=logits,
1422+
sampling_metadata=model_input.sampling_metadata,
1423+
)
14111424
if (self.observability_config is not None
14121425
and self.observability_config.collect_model_forward_time
14131426
and output is not None):

vllm_ascend/worker/model_runner_v1.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
from vllm_ascend.attention.attention import AttentionMaskBuilder
5454
from vllm_ascend.attention.attention_v1 import AscendAttentionState
5555
from vllm_ascend.platform import NPUPlatform
56+
from vllm_ascend.utils import vllm_version_is
5657

5758
if TYPE_CHECKING:
5859
import xgrammar as xgr # type: ignore[import-untyped]
@@ -290,6 +291,12 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
290291
self.attn_mask_builder = AttentionMaskBuilder.initialize_from_len(
291292
self.attn_mask_len, self.dtype)
292293

294+
if vllm_version_is("0.8.4"):
295+
self.sampler = None
296+
else:
297+
from vllm.v1.sample.sampler import Sampler
298+
self.sampler = Sampler()
299+
293300
def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
294301
"""Update the cached states and the persistent batch with the scheduler
295302
output.
@@ -645,10 +652,17 @@ def execute_model(
645652

646653
# Sample the next token and get logprobs if needed.
647654
sampling_metadata = self.input_batch.sampling_metadata
648-
sampler_output = self.model.sample(
649-
logits=logits,
650-
sampling_metadata=sampling_metadata,
651-
)
655+
if vllm_version_is("0.8.4"):
656+
sampler_output = self.model.sample(
657+
logits=logits,
658+
sampling_metadata=sampling_metadata,
659+
)
660+
else:
661+
assert self.sampler is not None
662+
sampler_output = self.sampler(
663+
logits=logits,
664+
sampling_metadata=sampling_metadata,
665+
)
652666

653667
# TODO(woosuk): The following loop can be slow since it iterates over
654668
# the requests one by one. Optimize.

0 commit comments

Comments
 (0)