-
-
Notifications
You must be signed in to change notification settings - Fork 8.7k
[Model] vllm v1 support mlp_speculator #20655
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
859e2ca
2a57e20
4711051
26f2b7f
e7ad830
de28b86
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -70,7 +70,12 @@ class MLPSpeculator(nn.Module): | |
|
||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: | ||
super().__init__() | ||
config = vllm_config.model_config.hf_config | ||
if hasattr(vllm_config, 'speculative_config'): | ||
config = vllm_config.speculative_config.draft_model_config.hf_config | ||
self.sampling_metadata_is_required = False | ||
else: | ||
config = vllm_config.model_config.hf_config | ||
self.sampling_metadata_is_required = True | ||
Comment on lines
+73
to
+78
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The logic to differentiate between v0 and v1 paths is incomplete. For the v1 path ( To fix this, the initialization of |
||
self.n_predict = config.n_predict | ||
self.vocab_size = config.vocab_size | ||
self.emb_dim = config.emb_dim | ||
|
@@ -182,8 +187,11 @@ def generate_proposals( | |
# TODO: not yet supporting top_k_tokens_per_head | ||
states = states.flatten(0, 1) | ||
|
||
logits = self.logits_processor(self.head[head_index], states, | ||
sampling_metadata) | ||
if self.logits_processor: | ||
logits = self.logits_processor(self.head[head_index], states, | ||
sampling_metadata) | ||
else: | ||
logits = self.head[head_index](states) | ||
|
||
output = self.sampler(logits, sampling_metadata) | ||
last_tokens = output.sampled_token_ids | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
||
from vllm.config import VllmConfig | ||
from vllm.forward_context import set_forward_context | ||
from vllm.logger import init_logger | ||
from vllm.model_executor.model_loader import get_model | ||
from vllm.v1.sample.metadata import SamplingMetadata | ||
|
||
# Initialize logger | ||
logger = init_logger(__name__) | ||
|
||
|
||
class MLPSpeculatorProposer: | ||
""" | ||
MLPSpeculator proposer class for generating token sequences | ||
""" | ||
|
||
def __init__( | ||
self, | ||
vllm_config: VllmConfig, | ||
device: torch.device, | ||
): | ||
# Save config parameters | ||
self.vllm_config = vllm_config | ||
self.device = device | ||
self.max_num_seqs = vllm_config.scheduler_config.max_num_seqs | ||
self.hidden_size = (vllm_config.speculative_config. | ||
draft_model_config.get_hidden_size()) | ||
self.num_speculative_tokens = (vllm_config.speculative_config. | ||
num_speculative_tokens) | ||
self.dtype = vllm_config.model_config.dtype | ||
|
||
def propose( | ||
self, | ||
input_ids: torch.Tensor, | ||
previous_hidden_states: torch.Tensor, | ||
num_predict_tokens: int, | ||
sampling_metadata: SamplingMetadata, | ||
) -> list[list[int]]: | ||
# Generate blocks and compute logits | ||
draft_tokens = self.model.generate_proposals(input_ids, previous_hidden_states, num_predict_tokens,sampling_metadata) | ||
return list(map(lambda x: x[0], zip(*[i.sampled_token_ids.tolist() for i in draft_tokens]))) | ||
|
||
def load_model(self, target_model: nn.Module) -> None: | ||
self.model = get_model(vllm_config=self.vllm_config, | ||
model_config=self.vllm_config. | ||
speculative_config.draft_model_config) | ||
|
||
@torch.inference_mode() | ||
def dummy_run(self, num_tokens: int) -> None: | ||
input_ids = torch.zeros((self.max_num_seqs, 1), device=self.device) | ||
hidden_states = torch.zeros((self.max_num_seqs, self.hidden_size), | ||
dtype=self.dtype, | ||
device=self.device) | ||
num_predict_tokens = self.num_speculative_tokens | ||
with set_forward_context(None, self.vllm_config, | ||
num_tokens=num_tokens): | ||
self.model.generate_proposals(input_ids, hidden_states, num_predict_tokens, None) |
Uh oh!
There was an error while loading. Please reload this page.