Skip to content

[Model] vllm v1 support mlp_speculator #21276

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions tests/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,10 +463,8 @@ def check_available_online(
_SPECULATIVE_DECODING_EXAMPLE_MODELS = {
"MedusaModel": _HfExamplesInfo("JackFram/llama-68m",
speculative_model="abhigoyal/vllm-medusa-llama-68m-random"), # noqa: E501
# Temporarily disabled.
# TODO(woosuk): Re-enable this once the MLP Speculator is supported in V1.
# "MLPSpeculatorPreTrainedModel": _HfExamplesInfo("JackFram/llama-160m",
# speculative_model="ibm-ai-platform/llama-160m-accelerator"), # noqa: E501
"MLPSpeculatorPreTrainedModel": _HfExamplesInfo("JackFram/llama-160m",
speculative_model="ibm-ai-platform/llama-160m-accelerator"), # noqa: E501
"DeepSeekMTPModel": _HfExamplesInfo("luccafong/deepseek_mtp_main_random",
speculative_model="luccafong/deepseek_mtp_draft_random", # noqa: E501
trust_remote_code=True),
Expand Down
14 changes: 5 additions & 9 deletions tests/models/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,15 +72,11 @@ def test_registry_model_property(model_arch, is_mm, init_cuda, is_ce):


@create_new_process_for_each_test()
@pytest.mark.parametrize(
"model_arch,is_pp,init_cuda",
[
# TODO(woosuk): Re-enable this once the MLP Speculator is supported
# in V1.
# ("MLPSpeculatorPreTrainedModel", False, False),
("DeepseekV2ForCausalLM", True, False),
("Qwen2VLForConditionalGeneration", True, True),
])
@pytest.mark.parametrize("model_arch,is_pp,init_cuda", [
("MLPSpeculatorPreTrainedModel", False, False),
("DeepseekV2ForCausalLM", True, False),
("Qwen2VLForConditionalGeneration", True, True),
])
def test_registry_is_pp(model_arch, is_pp, init_cuda):
assert ModelRegistry.is_pp_supported_model(model_arch) is is_pp

Expand Down
15 changes: 12 additions & 3 deletions vllm/model_executor/models/mlp_speculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,12 @@ class MLPSpeculator(nn.Module):

def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__()
config = vllm_config.model_config.hf_config
if hasattr(vllm_config, 'speculative_config'):
config = vllm_config.speculative_config.draft_model_config.hf_config
self.sampling_metadata_is_required = False
else:
config = vllm_config.model_config.hf_config
self.sampling_metadata_is_required = True
self.n_predict = config.n_predict
self.vocab_size = config.vocab_size
self.emb_dim = config.emb_dim
Expand Down Expand Up @@ -182,8 +187,12 @@ def generate_proposals(
# TODO: not yet supporting top_k_tokens_per_head
states = states.flatten(0, 1)

logits = self.logits_processor(self.head[head_index], states,
sampling_metadata)
if self.logits_processor:
logits = self.logits_processor(
self.head[head_index], states, sampling_metadata
if self.sampling_metadata_is_required else None)
else:
logits = self.head[head_index](states)

output = self.sampler(logits, sampling_metadata)
last_tokens = output.sampled_token_ids
Expand Down
4 changes: 1 addition & 3 deletions vllm/model_executor/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,9 +247,7 @@
"DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
"Glm4MoeMTPModel": ("glm4_moe_mtp", "Glm4MoeMTP"),
"MedusaModel": ("medusa", "Medusa"),
# Temporarily disabled.
# # TODO(woosuk): Re-enable this once the MLP Speculator is supported in V1.
# "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
"MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
}

_TRANSFORMERS_MODELS = {
Expand Down
68 changes: 68 additions & 0 deletions vllm/v1/spec_decode/mlp_speculator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import torch
import torch.nn as nn

from vllm.config import VllmConfig
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model
from vllm.v1.sample.metadata import SamplingMetadata

# Initialize logger
logger = init_logger(__name__)


class MLPSpeculatorProposer:
"""
MLPSpeculator proposer class for generating token sequences
"""

def __init__(
self,
vllm_config: VllmConfig,
device: torch.device,
):
# Save config parameters
self.vllm_config = vllm_config
self.device = device
self.max_num_seqs = vllm_config.scheduler_config.max_num_seqs
self.hidden_size = (vllm_config.speculative_config.draft_model_config.
get_hidden_size())
self.num_speculative_tokens = (
vllm_config.speculative_config.num_speculative_tokens)
self.dtype = vllm_config.model_config.dtype

def propose(
self,
input_ids: torch.Tensor,
previous_hidden_states: torch.Tensor,
num_predict_tokens: int,
sampling_metadata: SamplingMetadata,
) -> list[list[int]]:
# Generate blocks and compute logits
draft_tokens = self.model.generate_proposals(input_ids,
previous_hidden_states,
num_predict_tokens,
sampling_metadata)
return list(
map(lambda x: x[0],
zip(*[i.sampled_token_ids.tolist() for i in draft_tokens])))

def load_model(self, target_model: nn.Module) -> None:
self.model = get_model(vllm_config=self.vllm_config,
model_config=self.vllm_config.
speculative_config.draft_model_config)

@torch.inference_mode()
def dummy_run(self, num_tokens: int) -> None:
input_ids = torch.zeros((self.max_num_seqs, 1), device=self.device)
hidden_states = torch.zeros((self.max_num_seqs, self.hidden_size),
dtype=self.dtype,
device=self.device)
num_predict_tokens = self.num_speculative_tokens
with set_forward_context(None, self.vllm_config,
num_tokens=num_tokens):
self.model.generate_proposals(input_ids, hidden_states,
num_predict_tokens, None)
33 changes: 33 additions & 0 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
from vllm.v1.spec_decode.eagle import EagleProposer
from vllm.v1.spec_decode.medusa import MedusaProposer
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.spec_decode.mlp_speculator import MLPSpeculatorProposer
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
Expand Down Expand Up @@ -188,6 +189,10 @@ def __init__(
self.drafter = MedusaProposer(
vllm_config=self.vllm_config,
device=self.device) # type: ignore
elif self.speculative_config.method == "mlp_speculator":
self.drafter = MLPSpeculatorProposer(
self.vllm_config, # type: ignore
self.device)
else:
raise ValueError("Unknown speculative decoding method: "
f"{self.speculative_config.method}")
Expand Down Expand Up @@ -1638,6 +1643,34 @@ def propose_draft_token_ids(
target_hidden_states=hidden_states,
sampling_metadata=sampling_metadata,
)
elif self.speculative_config.method == "mlp_speculator":
assert isinstance(self.drafter, MLPSpeculatorProposer)
is_sample_match = sample_hidden_states.shape[0] == len(
sampled_token_ids)
# Get last token from each sequence
draft_input_ids = torch.tensor(
[tokens[-1] for tokens in sampled_token_ids],
device=sample_hidden_states.device)
if not is_sample_match:
# Calculate indices for hidden states
indices = []
offset = 0
for num_draft, tokens in zip(
spec_decode_metadata.num_draft_tokens,
sampled_token_ids):
indices.append(offset + len(tokens) - 1)
offset += num_draft + 1
indices = torch.tensor(indices, device=self.device)
hidden_states = sample_hidden_states[indices]
else:
hidden_states = sample_hidden_states
spec_token_ids = self.drafter.propose(
input_ids=draft_input_ids,
previous_hidden_states=hidden_states,
num_predict_tokens=self.vllm_config.speculative_config.
num_speculative_tokens,
sampling_metadata=sampling_metadata,
)
elif self.speculative_config.use_eagle():
assert isinstance(self.drafter, EagleProposer)
# TODO(woosuk): Refactor the loop.
Expand Down