Skip to content

[Model] vllm v1 support mlp_speculator #20655

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1460,6 +1460,7 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
is_ngram_enabled = False
is_eagle_enabled = False
is_medusa_enabled = False
is_mlp_speculator_enabled = False
if self.speculative_config is not None:
# This is supported but experimental (handled below).
speculative_method = self.speculative_config.get("method")
Expand All @@ -1470,11 +1471,13 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
is_medusa_enabled = True
elif speculative_method in ("eagle", "eagle3", "deepseek_mtp"):
is_eagle_enabled = True
elif speculative_method == "mlp_speculator":
is_mlp_speculator_enabled = True
else:
speculative_model = self.speculative_config.get("model")
if speculative_model in ("ngram", "[ngram]"):
is_ngram_enabled = True
if not (is_ngram_enabled or is_eagle_enabled or is_medusa_enabled):
if not (is_ngram_enabled or is_eagle_enabled or is_medusa_enabled or is_mlp_speculator_enabled):
# Other speculative decoding methods are not supported yet.
_raise_or_fallback(feature_name="Speculative Decoding",
recommend_to_remove=False)
Expand Down
14 changes: 11 additions & 3 deletions vllm/model_executor/models/mlp_speculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,12 @@ class MLPSpeculator(nn.Module):

def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__()
config = vllm_config.model_config.hf_config
if hasattr(vllm_config, 'speculative_config'):
config = vllm_config.speculative_config.draft_model_config.hf_config
self.sampling_metadata_is_required = False
else:
config = vllm_config.model_config.hf_config
self.sampling_metadata_is_required = True
Comment on lines +73 to +78
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The logic to differentiate between v0 and v1 paths is incomplete. For the v1 path (hasattr(vllm_config, 'speculative_config')), self.sampler and self.logits_processor are still initialized with v0 components at the end of __init__ (lines 147-149), which will cause a crash due to type mismatches with SamplingMetadata.

To fix this, the initialization of sampler and logits_processor should be moved inside this if/else block. For the v1 path, self.sampler should be an instance of vllm.v1.sample.sampler.Sampler and self.logits_processor should be None.

self.n_predict = config.n_predict
self.vocab_size = config.vocab_size
self.emb_dim = config.emb_dim
Expand Down Expand Up @@ -182,8 +187,11 @@ def generate_proposals(
# TODO: not yet supporting top_k_tokens_per_head
states = states.flatten(0, 1)

logits = self.logits_processor(self.head[head_index], states,
sampling_metadata)
if self.logits_processor:
logits = self.logits_processor(self.head[head_index], states,
sampling_metadata)
else:
logits = self.head[head_index](states)

output = self.sampler(logits, sampling_metadata)
last_tokens = output.sampled_token_ids
Expand Down
62 changes: 62 additions & 0 deletions vllm/v1/spec_decode/mlp_speculator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import torch
import torch.nn as nn

from vllm.config import VllmConfig
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model
from vllm.v1.sample.metadata import SamplingMetadata

# Initialize logger
logger = init_logger(__name__)


class MLPSpeculatorProposer:
"""
MLPSpeculator proposer class for generating token sequences
"""

def __init__(
self,
vllm_config: VllmConfig,
device: torch.device,
):
# Save config parameters
self.vllm_config = vllm_config
self.device = device
self.max_num_seqs = vllm_config.scheduler_config.max_num_seqs
self.hidden_size = (vllm_config.speculative_config.
draft_model_config.get_hidden_size())
self.num_speculative_tokens = (vllm_config.speculative_config.
num_speculative_tokens)
self.dtype = vllm_config.model_config.dtype

def propose(
self,
input_ids: torch.Tensor,
previous_hidden_states: torch.Tensor,
num_predict_tokens: int,
sampling_metadata: SamplingMetadata,
) -> list[list[int]]:
# Generate blocks and compute logits
draft_tokens = self.model.generate_proposals(input_ids, previous_hidden_states, num_predict_tokens,sampling_metadata)
return list(map(lambda x: x[0], zip(*[i.sampled_token_ids.tolist() for i in draft_tokens])))

def load_model(self, target_model: nn.Module) -> None:
self.model = get_model(vllm_config=self.vllm_config,
model_config=self.vllm_config.
speculative_config.draft_model_config)

@torch.inference_mode()
def dummy_run(self, num_tokens: int) -> None:
input_ids = torch.zeros((self.max_num_seqs, 1), device=self.device)
hidden_states = torch.zeros((self.max_num_seqs, self.hidden_size),
dtype=self.dtype,
device=self.device)
num_predict_tokens = self.num_speculative_tokens
with set_forward_context(None, self.vllm_config,
num_tokens=num_tokens):
self.model.generate_proposals(input_ids, hidden_states, num_predict_tokens, None)
32 changes: 32 additions & 0 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
from vllm.v1.spec_decode.eagle import EagleProposer
from vllm.v1.spec_decode.medusa import MedusaProposer
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.spec_decode.mlp_speculator import MLPSpeculatorProposer
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
from vllm.v1.utils import bind_kv_cache
from vllm.v1.worker.block_table import BlockTable
Expand Down Expand Up @@ -187,6 +188,9 @@
self.drafter = MedusaProposer(
vllm_config=self.vllm_config,
device=self.device) # type: ignore
elif self.speculative_config.method == "mlp_speculator":
self.drafter = MLPSpeculatorProposer(self.vllm_config,

Check failure on line 192 in vllm/v1/worker/gpu_model_runner.py

View workflow job for this annotation

GitHub Actions / pre-commit

Incompatible types in assignment (expression has type "MLPSpeculatorProposer", variable has type "NgramProposer") [assignment]

Check failure on line 192 in vllm/v1/worker/gpu_model_runner.py

View workflow job for this annotation

GitHub Actions / pre-commit

Incompatible types in assignment (expression has type "MLPSpeculatorProposer", variable has type "NgramProposer") [assignment]

Check failure on line 192 in vllm/v1/worker/gpu_model_runner.py

View workflow job for this annotation

GitHub Actions / pre-commit

Incompatible types in assignment (expression has type "MLPSpeculatorProposer", variable has type "NgramProposer") [assignment]

Check failure on line 192 in vllm/v1/worker/gpu_model_runner.py

View workflow job for this annotation

GitHub Actions / pre-commit

Incompatible types in assignment (expression has type "MLPSpeculatorProposer", variable has type "NgramProposer") [assignment]

Check failure on line 192 in vllm/v1/worker/gpu_model_runner.py

View workflow job for this annotation

GitHub Actions / pre-commit

Incompatible types in assignment (expression has type "MLPSpeculatorProposer", variable has type "NgramProposer") [assignment]

Check failure on line 192 in vllm/v1/worker/gpu_model_runner.py

View workflow job for this annotation

GitHub Actions / pre-commit

Incompatible types in assignment (expression has type "MLPSpeculatorProposer", variable has type "NgramProposer") [assignment]

Check failure on line 192 in vllm/v1/worker/gpu_model_runner.py

View workflow job for this annotation

GitHub Actions / pre-commit

Incompatible types in assignment (expression has type "MLPSpeculatorProposer", variable has type "NgramProposer") [assignment]

Check failure on line 192 in vllm/v1/worker/gpu_model_runner.py

View workflow job for this annotation

GitHub Actions / pre-commit

Incompatible types in assignment (expression has type "MLPSpeculatorProposer", variable has type "NgramProposer") [assignment]

Check failure on line 192 in vllm/v1/worker/gpu_model_runner.py

View workflow job for this annotation

GitHub Actions / pre-commit

Incompatible types in assignment (expression has type "MLPSpeculatorProposer", variable has type "NgramProposer") [assignment]

Check failure on line 192 in vllm/v1/worker/gpu_model_runner.py

View workflow job for this annotation

GitHub Actions / pre-commit

Incompatible types in assignment (expression has type "MLPSpeculatorProposer", variable has type "NgramProposer") [assignment]
self.device)
else:
raise ValueError("Unknown speculative decoding method: "
f"{self.speculative_config.method}")
Expand Down Expand Up @@ -1601,6 +1605,34 @@
target_hidden_states=hidden_states,
sampling_metadata=sampling_metadata,
)
elif self.speculative_config.method == "mlp_speculator":
assert isinstance(self.drafter, MLPSpeculatorProposer)
is_sample_match = sample_hidden_states.shape[0] == len(
sampled_token_ids)
# Get last token from each sequence
draft_input_ids = torch.tensor(
[tokens[-1] for tokens in sampled_token_ids],
device=sample_hidden_states.device)
if not is_sample_match:
# Calculate indices for hidden states
indices = []
offset = 0
for num_draft, tokens in zip(
spec_decode_metadata.num_draft_tokens,
sampled_token_ids):
indices.append(offset + len(tokens) - 1)
offset += num_draft + 1
indices = torch.tensor(indices, device=self.device)
hidden_states = sample_hidden_states[indices]
else:
hidden_states = sample_hidden_states
spec_token_ids = self.drafter.propose(
input_ids=draft_input_ids,
previous_hidden_states=hidden_states,
num_predict_tokens=self.vllm_config.speculative_config.
num_speculative_tokens,
sampling_metadata=sampling_metadata,
)
elif self.speculative_config.use_eagle():
assert isinstance(self.drafter, EagleProposer)
# TODO(woosuk): Refactor the loop.
Expand Down