diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 0c4fae1dde5..8e170727dd1 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1460,6 +1460,7 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool: is_ngram_enabled = False is_eagle_enabled = False is_medusa_enabled = False + is_mlp_speculator_enabled = False if self.speculative_config is not None: # This is supported but experimental (handled below). speculative_method = self.speculative_config.get("method") @@ -1470,11 +1471,13 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool: is_medusa_enabled = True elif speculative_method in ("eagle", "eagle3", "deepseek_mtp"): is_eagle_enabled = True + elif speculative_method == "mlp_speculator": + is_mlp_speculator_enabled = True else: speculative_model = self.speculative_config.get("model") if speculative_model in ("ngram", "[ngram]"): is_ngram_enabled = True - if not (is_ngram_enabled or is_eagle_enabled or is_medusa_enabled): + if not (is_ngram_enabled or is_eagle_enabled or is_medusa_enabled or is_mlp_speculator_enabled): # Other speculative decoding methods are not supported yet. _raise_or_fallback(feature_name="Speculative Decoding", recommend_to_remove=False) diff --git a/vllm/model_executor/models/mlp_speculator.py b/vllm/model_executor/models/mlp_speculator.py index c6a97388dc1..8ef2ade57b6 100644 --- a/vllm/model_executor/models/mlp_speculator.py +++ b/vllm/model_executor/models/mlp_speculator.py @@ -70,7 +70,12 @@ class MLPSpeculator(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: super().__init__() - config = vllm_config.model_config.hf_config + if hasattr(vllm_config, 'speculative_config'): + config = vllm_config.speculative_config.draft_model_config.hf_config + self.sampling_metadata_is_required = False + else: + config = vllm_config.model_config.hf_config + self.sampling_metadata_is_required = True self.n_predict = config.n_predict self.vocab_size = config.vocab_size self.emb_dim = config.emb_dim @@ -182,8 +187,11 @@ def generate_proposals( # TODO: not yet supporting top_k_tokens_per_head states = states.flatten(0, 1) - logits = self.logits_processor(self.head[head_index], states, - sampling_metadata) + if self.logits_processor: + logits = self.logits_processor(self.head[head_index], states, + sampling_metadata) + else: + logits = self.head[head_index](states) output = self.sampler(logits, sampling_metadata) last_tokens = output.sampled_token_ids diff --git a/vllm/v1/spec_decode/mlp_speculator.py b/vllm/v1/spec_decode/mlp_speculator.py new file mode 100644 index 00000000000..f1ac1607574 --- /dev/null +++ b/vllm/v1/spec_decode/mlp_speculator.py @@ -0,0 +1,62 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch +import torch.nn as nn + +from vllm.config import VllmConfig +from vllm.forward_context import set_forward_context +from vllm.logger import init_logger +from vllm.model_executor.model_loader import get_model +from vllm.v1.sample.metadata import SamplingMetadata + +# Initialize logger +logger = init_logger(__name__) + + +class MLPSpeculatorProposer: + """ + MLPSpeculator proposer class for generating token sequences + """ + + def __init__( + self, + vllm_config: VllmConfig, + device: torch.device, + ): + # Save config parameters + self.vllm_config = vllm_config + self.device = device + self.max_num_seqs = vllm_config.scheduler_config.max_num_seqs + self.hidden_size = (vllm_config.speculative_config. + draft_model_config.get_hidden_size()) + self.num_speculative_tokens = (vllm_config.speculative_config. + num_speculative_tokens) + self.dtype = vllm_config.model_config.dtype + + def propose( + self, + input_ids: torch.Tensor, + previous_hidden_states: torch.Tensor, + num_predict_tokens: int, + sampling_metadata: SamplingMetadata, + ) -> list[list[int]]: + # Generate blocks and compute logits + draft_tokens = self.model.generate_proposals(input_ids, previous_hidden_states, num_predict_tokens,sampling_metadata) + return list(map(lambda x: x[0], zip(*[i.sampled_token_ids.tolist() for i in draft_tokens]))) + + def load_model(self, target_model: nn.Module) -> None: + self.model = get_model(vllm_config=self.vllm_config, + model_config=self.vllm_config. + speculative_config.draft_model_config) + + @torch.inference_mode() + def dummy_run(self, num_tokens: int) -> None: + input_ids = torch.zeros((self.max_num_seqs, 1), device=self.device) + hidden_states = torch.zeros((self.max_num_seqs, self.hidden_size), + dtype=self.dtype, + device=self.device) + num_predict_tokens = self.num_speculative_tokens + with set_forward_context(None, self.vllm_config, + num_tokens=num_tokens): + self.model.generate_proposals(input_ids, hidden_states, num_predict_tokens, None) \ No newline at end of file diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 8658d7d916f..f19b5e537da 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -62,6 +62,7 @@ from vllm.v1.spec_decode.eagle import EagleProposer from vllm.v1.spec_decode.medusa import MedusaProposer from vllm.v1.spec_decode.metadata import SpecDecodeMetadata +from vllm.v1.spec_decode.mlp_speculator import MLPSpeculatorProposer from vllm.v1.spec_decode.ngram_proposer import NgramProposer from vllm.v1.utils import bind_kv_cache from vllm.v1.worker.block_table import BlockTable @@ -187,6 +188,9 @@ def __init__( self.drafter = MedusaProposer( vllm_config=self.vllm_config, device=self.device) # type: ignore + elif self.speculative_config.method == "mlp_speculator": + self.drafter = MLPSpeculatorProposer(self.vllm_config, + self.device) else: raise ValueError("Unknown speculative decoding method: " f"{self.speculative_config.method}") @@ -1601,6 +1605,34 @@ def propose_draft_token_ids( target_hidden_states=hidden_states, sampling_metadata=sampling_metadata, ) + elif self.speculative_config.method == "mlp_speculator": + assert isinstance(self.drafter, MLPSpeculatorProposer) + is_sample_match = sample_hidden_states.shape[0] == len( + sampled_token_ids) + # Get last token from each sequence + draft_input_ids = torch.tensor( + [tokens[-1] for tokens in sampled_token_ids], + device=sample_hidden_states.device) + if not is_sample_match: + # Calculate indices for hidden states + indices = [] + offset = 0 + for num_draft, tokens in zip( + spec_decode_metadata.num_draft_tokens, + sampled_token_ids): + indices.append(offset + len(tokens) - 1) + offset += num_draft + 1 + indices = torch.tensor(indices, device=self.device) + hidden_states = sample_hidden_states[indices] + else: + hidden_states = sample_hidden_states + spec_token_ids = self.drafter.propose( + input_ids=draft_input_ids, + previous_hidden_states=hidden_states, + num_predict_tokens=self.vllm_config.speculative_config. + num_speculative_tokens, + sampling_metadata=sampling_metadata, + ) elif self.speculative_config.use_eagle(): assert isinstance(self.drafter, EagleProposer) # TODO(woosuk): Refactor the loop.