From 859e2caf2fec37ffb153ba42828840dc06d51157 Mon Sep 17 00:00:00 2001 From: skylee-01 <497627264@qq.com> Date: Wed, 9 Jul 2025 09:41:27 +0800 Subject: [PATCH 1/6] [Model] vllm v1 support mlp_speculator Signed-off-by: skylee-01 <497627264@qq.com> --- vllm/engine/arg_utils.py | 3 + vllm/model_executor/models/mlp_speculator.py | 12 +++- vllm/v1/spec_decode/mlp_speculator.py | 63 ++++++++++++++++++++ vllm/v1/worker/gpu_model_runner.py | 36 +++++++++++ 4 files changed, 111 insertions(+), 3 deletions(-) create mode 100644 vllm/v1/spec_decode/mlp_speculator.py diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 0c4fae1dde5..9ae401ad3f9 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1460,6 +1460,7 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool: is_ngram_enabled = False is_eagle_enabled = False is_medusa_enabled = False + is_mlp_speculator_enabled = False if self.speculative_config is not None: # This is supported but experimental (handled below). speculative_method = self.speculative_config.get("method") @@ -1470,6 +1471,8 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool: is_medusa_enabled = True elif speculative_method in ("eagle", "eagle3", "deepseek_mtp"): is_eagle_enabled = True + elif speculative_method == "mlp_speculator": + is_mlp_speculator_enabled = True else: speculative_model = self.speculative_config.get("model") if speculative_model in ("ngram", "[ngram]"): diff --git a/vllm/model_executor/models/mlp_speculator.py b/vllm/model_executor/models/mlp_speculator.py index c6a97388dc1..732d43ecb0c 100644 --- a/vllm/model_executor/models/mlp_speculator.py +++ b/vllm/model_executor/models/mlp_speculator.py @@ -70,7 +70,12 @@ class MLPSpeculator(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: super().__init__() - config = vllm_config.model_config.hf_config + if hasattr(vllm_config, 'speculative_config'): + config = vllm_config.speculative_config.draft_model_config.hf_config + self.sampling_metadata_is_required = False + else: + config = vllm_config.model_config.hf_config + self.sampling_metadata_is_required = True self.n_predict = config.n_predict self.vocab_size = config.vocab_size self.emb_dim = config.emb_dim @@ -182,8 +187,9 @@ def generate_proposals( # TODO: not yet supporting top_k_tokens_per_head states = states.flatten(0, 1) - logits = self.logits_processor(self.head[head_index], states, - sampling_metadata) + logits = self.logits_processor( + self.head[head_index], states, sampling_metadata + if self.sampling_metadata_is_required else None) output = self.sampler(logits, sampling_metadata) last_tokens = output.sampled_token_ids diff --git a/vllm/v1/spec_decode/mlp_speculator.py b/vllm/v1/spec_decode/mlp_speculator.py new file mode 100644 index 00000000000..2bc03f62f1e --- /dev/null +++ b/vllm/v1/spec_decode/mlp_speculator.py @@ -0,0 +1,63 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch +import torch.nn as nn + +from vllm.config import VllmConfig +from vllm.forward_context import set_forward_context +from vllm.logger import init_logger +from vllm.model_executor.model_loader import get_model +from vllm.v1.sample.metadata import SamplingMetadata + +# Initialize logger +logger = init_logger(__name__) + + +class MLPSpeculatorProposer: + """ + MLPSpeculator proposer class for generating token sequences + """ + + def __init__( + self, + vllm_config: VllmConfig, + device: torch.device, + ): + # Save config parameters + self.vllm_config = vllm_config + self.device = device + self.max_num_seqs = vllm_config.scheduler_config.max_num_seqs + self.hidden_size = vllm_config.speculative_config.\ + draft_model_config.get_hidden_size( + ) + self.num_speculative_tokens = vllm_config.speculative_config.num_speculative_tokens + self.dtype = vllm_config.model_config.dtype + + def propose( + self, + input_ids: torch.Tensor, + previous_hidden_states: torch.Tensor, + num_predict_tokens: int, + sampling_metadata: SamplingMetadata, + ) -> torch.Tensor: + # Generate blocks and compute logits + draft_tokens = self.model.generate_proposals(input_ids, previous_hidden_states, num_predict_tokens,sampling_metadata) + draft_tokens = list(map(lambda x: x[0], zip(*[i.sampled_token_ids.tolist() for i in draft_tokens]))) + return draft_tokens + + def load_model(self, target_model: nn.Module) -> None: + self.model = get_model(vllm_config=self.vllm_config, + model_config=self.vllm_config. + speculative_config.draft_model_config) + + @torch.inference_mode() + def dummy_run(self, num_tokens: int) -> None: + input_ids = torch.zeros((self.max_num_seqs, 1), device=self.device) + hidden_states = torch.zeros((self.max_num_seqs, self.hidden_size), + dtype=self.dtype, + device=self.device) + num_predict_tokens = self.num_speculative_tokens + with set_forward_context(None, self.vllm_config, + num_tokens=num_tokens): + self.model.generate_proposals(input_ids, hidden_states, num_predict_tokens, None) \ No newline at end of file diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 8658d7d916f..459e18e2c13 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -62,6 +62,7 @@ from vllm.v1.spec_decode.eagle import EagleProposer from vllm.v1.spec_decode.medusa import MedusaProposer from vllm.v1.spec_decode.metadata import SpecDecodeMetadata +from vllm.v1.spec_decode.mlp_speculator import MLPSpeculatorProposer from vllm.v1.spec_decode.ngram_proposer import NgramProposer from vllm.v1.utils import bind_kv_cache from vllm.v1.worker.block_table import BlockTable @@ -187,6 +188,9 @@ def __init__( self.drafter = MedusaProposer( vllm_config=self.vllm_config, device=self.device) # type: ignore + elif self.speculative_config.method == "mlp_speculator": + self.drafter = MLPSpeculatorProposer(self.vllm_config, + self.device) else: raise ValueError("Unknown speculative decoding method: " f"{self.speculative_config.method}") @@ -1601,6 +1605,38 @@ def propose_draft_token_ids( target_hidden_states=hidden_states, sampling_metadata=sampling_metadata, ) + elif self.speculative_config.method == "mlp_speculator": + assert isinstance(self.drafter, MLPSpeculatorProposer) + + is_sample_match = sample_hidden_states.shape[0] == len( + sampled_token_ids) + # Get last token from each sequence + draft_input_ids = torch.tensor( + sampled_token_ids[0] if is_sample_match else + [tokens[-1] for tokens in sampled_token_ids], + device=sample_hidden_states.device) + + if is_sample_match: + # Calculate indices for hidden states + indices = [] + offset = 0 + for num_draft, tokens in zip( + spec_decode_metadata.num_draft_tokens, + sampled_token_ids): + indices.append(offset + len(tokens) - 1) + offset += num_draft + 1 + indices = torch.tensor(indices, device=self.device) + hidden_states = sample_hidden_states[indices] + else: + hidden_states = sample_hidden_states + + spec_token_ids = self.drafter.propose( + input_ids=draft_input_ids, + previous_hidden_states=hidden_states, + num_predict_tokens=self.vllm_config.speculative_config. + num_speculative_tokens, + sampling_metadata=sampling_metadata, + ) elif self.speculative_config.use_eagle(): assert isinstance(self.drafter, EagleProposer) # TODO(woosuk): Refactor the loop. From 2a57e20237e5e51003adc4d90ae5aa47bce21b45 Mon Sep 17 00:00:00 2001 From: skylee-01 <497627264@qq.com> Date: Wed, 9 Jul 2025 10:45:34 +0800 Subject: [PATCH 2/6] Optimize code format Signed-off-by: skylee-01 <497627264@qq.com> --- vllm/v1/worker/gpu_model_runner.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 459e18e2c13..4d860744572 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -176,6 +176,7 @@ def __init__( # NOTE(Jiayi): currently we put the entire draft model on # the last PP rank. This is not ideal if there are many # layers in the draft model. + self.drafter: Union[NgramProposer, EagleProposer, MedusaProposer, MLPSpeculatorProposer, None] = None if self.speculative_config and get_pp_group().is_last_rank: if self.speculative_config.method == "ngram": self.drafter = NgramProposer(self.vllm_config) From 4711051f0e422e1f8da6757544479ed274a390ae Mon Sep 17 00:00:00 2001 From: skylee-01 <497627264@qq.com> Date: Wed, 9 Jul 2025 12:43:14 +0800 Subject: [PATCH 3/6] Optimize code format Signed-off-by: skylee-01 <497627264@qq.com> --- vllm/v1/worker/gpu_model_runner.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 4d860744572..459e18e2c13 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -176,7 +176,6 @@ def __init__( # NOTE(Jiayi): currently we put the entire draft model on # the last PP rank. This is not ideal if there are many # layers in the draft model. - self.drafter: Union[NgramProposer, EagleProposer, MedusaProposer, MLPSpeculatorProposer, None] = None if self.speculative_config and get_pp_group().is_last_rank: if self.speculative_config.method == "ngram": self.drafter = NgramProposer(self.vllm_config) From 26f2b7f938c28c5719aec35e0113973113bbbd36 Mon Sep 17 00:00:00 2001 From: skylee-01 <497627264@qq.com> Date: Wed, 9 Jul 2025 16:03:56 +0800 Subject: [PATCH 4/6] fix bug Signed-off-by: skylee-01 <497627264@qq.com> --- vllm/engine/arg_utils.py | 2 +- vllm/v1/spec_decode/mlp_speculator.py | 6 +++--- vllm/v1/worker/gpu_model_runner.py | 6 +----- 3 files changed, 5 insertions(+), 9 deletions(-) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 9ae401ad3f9..8e170727dd1 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1477,7 +1477,7 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool: speculative_model = self.speculative_config.get("model") if speculative_model in ("ngram", "[ngram]"): is_ngram_enabled = True - if not (is_ngram_enabled or is_eagle_enabled or is_medusa_enabled): + if not (is_ngram_enabled or is_eagle_enabled or is_medusa_enabled or is_mlp_speculator_enabled): # Other speculative decoding methods are not supported yet. _raise_or_fallback(feature_name="Speculative Decoding", recommend_to_remove=False) diff --git a/vllm/v1/spec_decode/mlp_speculator.py b/vllm/v1/spec_decode/mlp_speculator.py index 2bc03f62f1e..e6e7d78fd08 100644 --- a/vllm/v1/spec_decode/mlp_speculator.py +++ b/vllm/v1/spec_decode/mlp_speculator.py @@ -31,7 +31,8 @@ def __init__( self.hidden_size = vllm_config.speculative_config.\ draft_model_config.get_hidden_size( ) - self.num_speculative_tokens = vllm_config.speculative_config.num_speculative_tokens + self.num_speculative_tokens = vllm_config.speculative_config.\ + num_speculative_tokens self.dtype = vllm_config.model_config.dtype def propose( @@ -43,8 +44,7 @@ def propose( ) -> torch.Tensor: # Generate blocks and compute logits draft_tokens = self.model.generate_proposals(input_ids, previous_hidden_states, num_predict_tokens,sampling_metadata) - draft_tokens = list(map(lambda x: x[0], zip(*[i.sampled_token_ids.tolist() for i in draft_tokens]))) - return draft_tokens + return list(map(lambda x: x[0], zip(*[i.sampled_token_ids.tolist() for i in draft_tokens]))) def load_model(self, target_model: nn.Module) -> None: self.model = get_model(vllm_config=self.vllm_config, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 459e18e2c13..f19b5e537da 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1607,16 +1607,13 @@ def propose_draft_token_ids( ) elif self.speculative_config.method == "mlp_speculator": assert isinstance(self.drafter, MLPSpeculatorProposer) - is_sample_match = sample_hidden_states.shape[0] == len( sampled_token_ids) # Get last token from each sequence draft_input_ids = torch.tensor( - sampled_token_ids[0] if is_sample_match else [tokens[-1] for tokens in sampled_token_ids], device=sample_hidden_states.device) - - if is_sample_match: + if not is_sample_match: # Calculate indices for hidden states indices = [] offset = 0 @@ -1629,7 +1626,6 @@ def propose_draft_token_ids( hidden_states = sample_hidden_states[indices] else: hidden_states = sample_hidden_states - spec_token_ids = self.drafter.propose( input_ids=draft_input_ids, previous_hidden_states=hidden_states, From e7ad830c03227fc6e45bc670cae1d2842b37fb7e Mon Sep 17 00:00:00 2001 From: skylee-01 <497627264@qq.com> Date: Wed, 9 Jul 2025 16:10:21 +0800 Subject: [PATCH 5/6] Optimize code format Signed-off-by: skylee-01 <497627264@qq.com> --- vllm/v1/spec_decode/mlp_speculator.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/vllm/v1/spec_decode/mlp_speculator.py b/vllm/v1/spec_decode/mlp_speculator.py index e6e7d78fd08..f1ac1607574 100644 --- a/vllm/v1/spec_decode/mlp_speculator.py +++ b/vllm/v1/spec_decode/mlp_speculator.py @@ -28,11 +28,10 @@ def __init__( self.vllm_config = vllm_config self.device = device self.max_num_seqs = vllm_config.scheduler_config.max_num_seqs - self.hidden_size = vllm_config.speculative_config.\ - draft_model_config.get_hidden_size( - ) - self.num_speculative_tokens = vllm_config.speculative_config.\ - num_speculative_tokens + self.hidden_size = (vllm_config.speculative_config. + draft_model_config.get_hidden_size()) + self.num_speculative_tokens = (vllm_config.speculative_config. + num_speculative_tokens) self.dtype = vllm_config.model_config.dtype def propose( @@ -41,7 +40,7 @@ def propose( previous_hidden_states: torch.Tensor, num_predict_tokens: int, sampling_metadata: SamplingMetadata, - ) -> torch.Tensor: + ) -> list[list[int]]: # Generate blocks and compute logits draft_tokens = self.model.generate_proposals(input_ids, previous_hidden_states, num_predict_tokens,sampling_metadata) return list(map(lambda x: x[0], zip(*[i.sampled_token_ids.tolist() for i in draft_tokens]))) From de28b865b7c1c59744757d3d01d313a9b9c07ae5 Mon Sep 17 00:00:00 2001 From: skylee-01 <497627264@qq.com> Date: Wed, 9 Jul 2025 16:18:13 +0800 Subject: [PATCH 6/6] Optimize code format Signed-off-by: skylee-01 <497627264@qq.com> --- vllm/model_executor/models/mlp_speculator.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/models/mlp_speculator.py b/vllm/model_executor/models/mlp_speculator.py index 732d43ecb0c..8ef2ade57b6 100644 --- a/vllm/model_executor/models/mlp_speculator.py +++ b/vllm/model_executor/models/mlp_speculator.py @@ -187,9 +187,11 @@ def generate_proposals( # TODO: not yet supporting top_k_tokens_per_head states = states.flatten(0, 1) - logits = self.logits_processor( - self.head[head_index], states, sampling_metadata - if self.sampling_metadata_is_required else None) + if self.logits_processor: + logits = self.logits_processor(self.head[head_index], states, + sampling_metadata) + else: + logits = self.head[head_index](states) output = self.sampler(logits, sampling_metadata) last_tokens = output.sampled_token_ids