Skip to content

Commit ac2bf41

Browse files
[Model] Remove model sampler (#21059)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
1 parent a931b4c commit ac2bf41

File tree

6 files changed

+0
-45
lines changed

6 files changed

+0
-45
lines changed

vllm/model_executor/models/bailing_moe.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@
4747
from vllm.model_executor.layers.quantization.base_config import (
4848
QuantizationConfig)
4949
from vllm.model_executor.layers.rotary_embedding import get_rope
50-
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
5150
from vllm.model_executor.layers.vocab_parallel_embedding import (
5251
ParallelLMHead, VocabParallelEmbedding)
5352
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
@@ -485,7 +484,6 @@ def __init__(
485484
else:
486485
self.lm_head = PPMissingLayer()
487486

488-
self.sampler = get_sampler()
489487
self.make_empty_intermediate_tensors = (
490488
self.model.make_empty_intermediate_tensors)
491489

@@ -512,14 +510,6 @@ def compute_logits(
512510
sampling_metadata)
513511
return logits
514512

515-
def sample(
516-
self,
517-
logits: torch.Tensor,
518-
sampling_metadata: SamplingMetadata,
519-
) -> Optional[SamplerOutput]:
520-
next_tokens = self.sampler(logits, sampling_metadata)
521-
return next_tokens
522-
523513
def load_weights(self, weights: Iterable[tuple[str,
524514
torch.Tensor]]) -> set[str]:
525515
loader = AutoWeightsLoader(

vllm/model_executor/models/granite_speech.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
3737
RowParallelLinear)
3838
from vllm.model_executor.layers.quantization import QuantizationConfig
39-
from vllm.model_executor.layers.sampler import get_sampler
4039
from vllm.model_executor.models.module_mapping import MultiModelKeys
4140
from vllm.model_executor.sampling_metadata import SamplingMetadata
4241
from vllm.multimodal import MULTIMODAL_REGISTRY
@@ -549,7 +548,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str):
549548
self.config = config
550549
self.quant_config = quant_config
551550
self.cache_config = cache_config
552-
self.sampler = get_sampler()
553551

554552
# The language model is typically a Granite LLM
555553
self.language_model = init_vllm_registered_model(

vllm/model_executor/models/hunyuan_v1_moe.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@
4949
from vllm.model_executor.layers.quantization.base_config import (
5050
QuantizationConfig)
5151
from vllm.model_executor.layers.rotary_embedding import get_rope
52-
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
5352
from vllm.model_executor.layers.vocab_parallel_embedding import (
5453
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
5554
from vllm.model_executor.model_loader.weight_utils import (
@@ -661,7 +660,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
661660
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
662661
config.vocab_size,
663662
logit_scale)
664-
self.sampler = get_sampler()
665663
else:
666664
self.lm_head = PPMissingLayer()
667665

@@ -685,14 +683,6 @@ def compute_logits(
685683
sampling_metadata)
686684
return logits
687685

688-
def sample(
689-
self,
690-
logits: torch.Tensor,
691-
sampling_metadata: SamplingMetadata,
692-
) -> Optional[SamplerOutput]:
693-
next_tokens = self.sampler(logits, sampling_metadata)
694-
return next_tokens
695-
696686
def make_empty_intermediate_tensors(
697687
self, batch_size: int, dtype: torch.dtype,
698688
device: torch.device) -> IntermediateTensors:

vllm/model_executor/models/mimo.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636
from vllm.distributed import get_pp_group
3737
from vllm.logger import init_logger
3838
from vllm.model_executor.layers.logits_processor import LogitsProcessor
39-
from vllm.model_executor.layers.sampler import get_sampler
4039
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
4140
from vllm.model_executor.model_loader.weight_utils import (
4241
default_weight_loader, maybe_remap_kv_scale_name)
@@ -176,7 +175,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
176175
self.lm_head = PPMissingLayer()
177176

178177
self.logits_processor = LogitsProcessor(config.vocab_size)
179-
self.sampler = get_sampler()
180178

181179
self.make_empty_intermediate_tensors = (
182180
self.model.make_empty_intermediate_tensors)

vllm/model_executor/models/mimo_mtp.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
from vllm.model_executor.layers.layernorm import RMSNorm
3131
from vllm.model_executor.layers.logits_processor import LogitsProcessor
3232
from vllm.model_executor.layers.quantization import QuantizationConfig
33-
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
3433
from vllm.model_executor.layers.vocab_parallel_embedding import (
3534
ParallelLMHead, VocabParallelEmbedding)
3635
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
@@ -161,8 +160,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
161160
self.lm_head = ParallelLMHead(self.config.vocab_size,
162161
self.config.hidden_size)
163162

164-
self.sampler = get_sampler()
165-
166163
def forward(
167164
self,
168165
input_ids: torch.Tensor,
@@ -187,14 +184,6 @@ def compute_logits(
187184
return self.model.compute_logits(hidden_states, self.lm_head,
188185
sampling_metadata, spec_step_idx)
189186

190-
def sample(
191-
self,
192-
logits: torch.Tensor,
193-
sampling_metadata: SamplingMetadata,
194-
) -> Optional[SamplerOutput]:
195-
next_tokens = self.sampler(logits, sampling_metadata)
196-
return next_tokens
197-
198187
def load_weights(self, weights: Iterable[tuple[str,
199188
torch.Tensor]]) -> set[str]:
200189
stacked_params_mapping = [

vllm/model_executor/models/phi4flash.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
causal_conv1d_fn, causal_conv1d_update)
2424
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
2525
selective_scan_fn, selective_state_update)
26-
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
2726
from vllm.model_executor.layers.vocab_parallel_embedding import (
2827
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
2928
from vllm.model_executor.models.interfaces import (HasInnerState, IsHybrid,
@@ -641,7 +640,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
641640
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
642641
config.vocab_size,
643642
logits_as_input=False)
644-
self.sampler = get_sampler()
645643

646644
def forward(
647645
self,
@@ -709,14 +707,6 @@ def compute_logits(
709707
prune_hidden_states=prune_hidden_states)
710708
return processed_logits
711709

712-
def sample(
713-
self,
714-
logits: torch.Tensor,
715-
sampling_metadata: SamplingMetadata,
716-
) -> Optional[SamplerOutput]:
717-
next_tokens = self.sampler(logits, sampling_metadata)
718-
return next_tokens
719-
720710
def load_weights(
721711
self,
722712
weights: Iterable[tuple[str, torch.Tensor]],

0 commit comments

Comments
 (0)