Skip to content

speculative decoding and mtp optimization #1435

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 30 commits into
base: deepseek_r1
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
e792c2a
Expand DeepSeek MTP code to support k > n_predict (#13626)
benchislett Feb 27, 2025
4cf2965
add torch profiler for the LLM engine (#1114)
yangulei Apr 24, 2025
3a3d25e
Draft version for MTP perf optimization
YuJiankang Jun 4, 2025
9daa0e9
mydebug
inkcherry Jun 4, 2025
c10b26c
my debug2
inkcherry Jun 5, 2025
b90b323
fix hidden1
inkcherry Jun 5, 2025
c7a42a9
wa for crash
inkcherry Jun 6, 2025
1446058
update metadata more previous
inkcherry Jun 6, 2025
83fd06c
update
inkcherry Jun 6, 2025
a3f5818
update
inkcherry Jun 6, 2025
dc015fd
accuracy pass
inkcherry Jun 6, 2025
122b1c1
wa for text accuracy
inkcherry Jun 9, 2025
4491171
delete test files
inkcherry Jun 9, 2025
a3986eb
remove some debug code
inkcherry Jun 9, 2025
2d7251f
refine the code logic and remove some print
YuJiankang Jun 10, 2025
2cbd5a9
add the batch size for the fake token
YuJiankang Jun 10, 2025
11e697b
update to prepare for bs > 1
YuJiankang Jun 10, 2025
dd803ab
eos support
inkcherry Jun 12, 2025
e455f70
eos stop support
inkcherry Jun 12, 2025
aabc738
token_max_len stop support
inkcherry Jun 12, 2025
ce084e0
re-implement the code logic and enable the support for bs > 1
YuJiankang Jun 12, 2025
194eaf7
Merge branch 'spdecode_and_mpt_optimization' into HEAD
inkcherry Jun 12, 2025
f92a446
fix for batch_size=1
inkcherry Jun 13, 2025
79694c0
my debug
inkcherry Jun 13, 2025
645f211
optimization perf
inkcherry Jun 13, 2025
20caa19
update1
inkcherry Jun 13, 2025
daac15d
merge
inkcherry Jun 16, 2025
e37cecd
clean up
inkcherry Jun 16, 2025
d43a676
clean up
inkcherry Jun 16, 2025
d97c7dd
# This is a combination of 2 commits.
YuJiankang Jun 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1147,6 +1147,14 @@ def _process_model_outputs(self,
self.seq_id_to_seq_group,
use_cache=self.use_cached_outputs)
if request_output:
#wa for MTP opt
#todo , sync the final text.
token_ids=request_output.outputs[0].token_ids
if token_ids[-2:] == [10, 12]:
token_ids = token_ids[:-2]
elif token_ids[-1] == 10:
token_ids = token_ids[:-1]
request_output.outputs[0].text=self.tokenizer.tokenizer.decode(token_ids, skip_special_tokens=True)
ctx.request_outputs.append(request_output)

# When we process a single request, we skip it for the next time,
Expand Down Expand Up @@ -1500,7 +1508,7 @@ def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]:
# queued control plane messages, such as add/remove lora adapters.
logger.debug("Stopping remote worker execution loop.")
self.model_executor.stop_remote_worker_execution_loop()

return ctx.request_outputs

def _has_remaining_steps(
Expand Down
2 changes: 1 addition & 1 deletion vllm/engine/output_processor/stop_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def maybe_stop_sequence(

# Check if the sequence has generated the EOS token.
if ((not sampling_params.ignore_eos)
and seq.get_last_token_id() == seq.eos_token_id):
and ( seq.get_last_token_id() == seq.eos_token_id or seq.get_last_n_token_id(3)==seq.eos_token_id)) :
# Remove the last EOS token unless explicitly specified
# This prevents unintended exposure of the EOS token
if new_char_count and (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@
from compressed_tensors.quantization import (QuantizationArgs,
QuantizationStrategy,
QuantizationType)
from compressed_tensors.utils import combine_shards

try:
from compressed_tensors.utils import combine_shards
except:
pass
from vllm import _custom_ops as ops
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear)
Expand Down
4 changes: 3 additions & 1 deletion vllm/model_executor/models/deepseek_mtp.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
import habana_frameworks.torch as htorch

from .deepseek_v2 import (DeepseekV2DecoderLayer,
get_spec_layer_idx_from_weight_name)
Expand Down Expand Up @@ -83,9 +84,10 @@ def forward(
inputs_embeds = self.enorm(inputs_embeds)
previous_hidden_states = self.hnorm(previous_hidden_states)


hidden_states = self.eh_proj(
torch.cat([inputs_embeds, previous_hidden_states], dim=-1))

hidden_states, residual = self.mtp_block(positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
Expand Down
12 changes: 7 additions & 5 deletions vllm/model_executor/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import torch
from torch import nn
from transformers import PretrainedConfig
import habana_frameworks.torch as htorch

from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
Expand Down Expand Up @@ -264,6 +265,7 @@ def __init__(
quant_config=quant_config,
prefix=f"{prefix}.kv_b_proj")
# O projection.

self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim,
self.hidden_size,
bias=False,
Expand Down Expand Up @@ -300,6 +302,7 @@ def forward(
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:

if is_hpu:
# need reshape from tensor(x0, y0) to tensor(x1) for hpu
_batch_size = positions.shape[0]
Expand Down Expand Up @@ -402,19 +405,16 @@ def __init__(
self.qk_rope_head_dim = qk_rope_head_dim
self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
self.v_head_dim = v_head_dim

self.q_lora_rank = q_lora_rank
self.kv_lora_rank = kv_lora_rank

self.num_heads = num_heads
tp_size = get_tensor_model_parallel_world_size()
assert num_heads % tp_size == 0
self.num_local_heads = num_heads // tp_size

self.scaling = self.qk_head_dim**-0.5
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings

if self.q_lora_rank is not None:
self.q_a_proj = ReplicatedLinear(self.hidden_size,
self.q_lora_rank,
Expand Down Expand Up @@ -503,6 +503,7 @@ def forward(
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:

if self.q_lora_rank is not None:
ckq = self.q_a_proj(hidden_states)[0]
hidden_states_or_q_c = self.q_a_layernorm(ckq)
Expand Down Expand Up @@ -592,6 +593,7 @@ def forward(
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)

hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
Expand Down Expand Up @@ -630,7 +632,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
prefix=f"{prefix}.embed_tokens")
else:
self.embed_tokens = PPMissingLayer()

self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: DeepseekV2DecoderLayer(
Expand Down Expand Up @@ -667,15 +668,16 @@ def forward(
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)

residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]

for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
kvcaches = None if kv_caches is None else kv_caches[i - self.start_layer]

hidden_states, residual = layer(positions, hidden_states,
kvcaches,
attn_metadata, residual)
Expand Down
23 changes: 20 additions & 3 deletions vllm/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,8 +316,12 @@ def get_num_computed_tokens(self) -> int:
def update_num_computed_tokens(self, num_new_computed_tokens: int):
"""Update number of tokens computed so far."""
self._num_computed_tokens += num_new_computed_tokens
assert self._num_computed_tokens <= self.get_len(), (
self._num_computed_tokens, self.get_len())
# assert self._num_computed_tokens <= self.get_len(), (
# self._num_computed_tokens, self.get_len())

# if self._num_computed_tokens > self.get_len():
# self._num_computed_tokens-=1
# c=0
# If all tokens are computed, it means it is in decoding phase.
if self.get_num_uncomputed_tokens() == 0:
self._stage = SequenceStage.DECODE
Expand Down Expand Up @@ -351,6 +355,13 @@ def get_last_token_id(self) -> int:
return self._prompt_token_ids[-1]
return self._output_token_ids[-1]

def get_last_n_token_id(self,n) -> int:
if self.get_output_len()<n:
return None
if not self._output_token_ids:
return self._prompt_token_ids[(-1)*n]
return self._output_token_ids[(-1)*n]

def get_prompt_token_ids(self) -> Tuple[int, ...]:
return self.prompt_token_ids

Expand Down Expand Up @@ -572,6 +583,8 @@ def get_prompt_token_ids(self) -> Tuple[int, ...]:
def get_last_token_id(self) -> int:
return self.data.get_last_token_id()

def get_last_n_token_id(self,n):
return self.data.get_last_n_token_id(n)
def get_output_token_ids(self) -> Tuple[int, ...]:
return self.data.get_output_token_ids()

Expand Down Expand Up @@ -1220,7 +1233,7 @@ class HiddenStates(msgspec.Struct, array_like=True,

def __post_init__(self):
if self.seq_group_metadata_list is not None:
assert len(self.seq_group_metadata_list) == len(self.hidden_states)
# assert len(self.seq_group_metadata_list) == len(self.hidden_states)
self._seq_ids = get_all_seq_ids(self.seq_group_metadata_list)

@property
Expand Down Expand Up @@ -1322,6 +1335,10 @@ class ExecuteModelRequest(
# Dummy batch
is_dummy_batch: bool = False


expand: Optional[Callable[[], Tuple[Any, Any]]] = None
hack_indices_of_seq_with_bonus_tokens: Optional[List[int]] = None
expand_req:Optional["ExecuteModelRequest"] =None
@property
def is_first_multi_step(self) -> bool:
# TODO(will) make this be able to handle batches with variable number of
Expand Down
3 changes: 2 additions & 1 deletion vllm/spec_decode/batch_expansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def score_proposals(
self,
execute_model_req: ExecuteModelRequest,
proposals: SpeculativeProposals,
accepted_token_id: Optional[torch.Tensor] = None,
) -> SpeculativeScores:
"""Score the proposed tokens via the scorer model.

Expand Down Expand Up @@ -80,7 +81,7 @@ def score_proposals(

target_sampler_output = self._scorer_worker.execute_model(
execute_model_req=execute_model_req.clone(
seq_group_metadata_list=target_seq_group_metadata_list))
seq_group_metadata_list=target_seq_group_metadata_list), accepted_token_id=accepted_token_id)
assert len(target_sampler_output) == 1, "expected single-step output"
target_sampler_output = target_sampler_output[0]

Expand Down
3 changes: 3 additions & 0 deletions vllm/spec_decode/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def get_spec_proposals(
# If set, this contains all sequence IDs that were assigned
# bonus tokens in their last forward pass.
seq_ids_with_bonus_token_in_last_step: Set[int],
accepted_token_id: Optional[torch.Tensor] = None,
) -> SpeculativeProposals:
raise NotImplementedError

Expand All @@ -94,5 +95,7 @@ def score_proposals(
self,
execute_model_req: ExecuteModelRequest,
proposals: SpeculativeProposals,
accepted_token_id: Optional[torch.Tensor] = None,

) -> SpeculativeScores:
raise NotImplementedError
71 changes: 58 additions & 13 deletions vllm/spec_decode/multi_step_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import copy
import weakref
from typing import Dict, List, Set, Tuple
from typing import Dict, List, Set, Tuple, Optional

import torch

Expand Down Expand Up @@ -61,6 +61,7 @@ def sampler_output(
execute_model_req: ExecuteModelRequest,
sample_len: int,
seq_ids_with_bonus_token_in_last_step: Set[int],
accepted_token_id: Optional[torch.Tensor] = None,
) -> Tuple[List[SamplerOutput], bool]:
"""Run the model forward pass sample_len times. Returns the list of
sampler output, one per model forward pass, along with indicator of
Expand All @@ -69,14 +70,51 @@ def sampler_output(

For multi step worker, this indicator shall be True.
"""
rank=torch.distributed.get_rank()

self._raise_if_unsupported(execute_model_req)
# Expand the batch for sequences with a bonus token.
# Perform a forward pass on the expanded batch and filter the
# response to retain only the original sequences' responses.
expanded_request, indices_of_seq_with_bonus_tokens =\
self._expand_execute_model_request(
execute_model_req, seq_ids_with_bonus_token_in_last_step)


if accepted_token_id is not None:
def bind_expand_fn_to_request(execute_model_req, accepted_token_id, seq_ids_with_bonus_token_in_last_step, expand_fn):
def expand():
if accepted_token_id is not None:
accepted_token_id_=accepted_token_id.cpu()
for seq_index, sg in enumerate(execute_model_req.seq_group_metadata_list):
seq_data_iter = sg.seq_data.values()
last_token_id = accepted_token_id_[seq_index][-1]
token1 = accepted_token_id_[seq_index][0]
if last_token_id == -1:
for seq_id in sg.seq_data:
seq_ids_with_bonus_token_in_last_step.discard(seq_id)
token1 = accepted_token_id_[seq_index][0]
for seq in seq_data_iter:
seq.output_token_ids = seq.output_token_ids[:-2] + (token1,)
#seq._new_appended_tokens = seq._new_appended_tokens[:-3] + [token1]
seq._num_computed_tokens -= 1
else:
token2 = accepted_token_id_[seq_index][1]
for seq in seq_data_iter:
seq.output_token_ids = seq.output_token_ids[:-2] + (token1, token2)
#seq._new_appended_tokens = seq._new_appended_tokens[:-3] + [token1, token2]
return expand_fn(execute_model_req, seq_ids_with_bonus_token_in_last_step)

execute_model_req.expand = expand
bind_expand_fn_to_request(
execute_model_req,
accepted_token_id,
seq_ids_with_bonus_token_in_last_step,
self._expand_execute_model_request,
)
expanded_request=execute_model_req

else:
expanded_request, indices_of_seq_with_bonus_tokens =\
self._expand_execute_model_request(
execute_model_req, seq_ids_with_bonus_token_in_last_step)




# Run model sample_len times.
model_outputs: List[SamplerOutput] = []
if current_platform.is_cuda_alike() and isinstance(
Expand All @@ -99,21 +137,27 @@ def sampler_output(
self.worker.model_runner.return_hidden_states = True
for _ in range(sample_len):
model_output: List[SamplerOutput] = self.worker.execute_model(
execute_model_req=expanded_request)
execute_model_req=expanded_request, accepted_token_id=accepted_token_id)
assert (len(model_output) == 1
), "composing multistep workers not supported"
model_output = model_output[0]
self._maybe_update_previous_hidden_states(
model_output, expanded_request)

if execute_model_req.hack_indices_of_seq_with_bonus_tokens is not None:
indices_of_seq_with_bonus_tokens=execute_model_req.hack_indices_of_seq_with_bonus_tokens
expanded_request=execute_model_req.expand_req
execute_model_req.hack_indices_of_seq_with_bonus_tokens=None
execute_model_req.expand_req=None
self._append_new_tokens(
model_output, expanded_request.seq_group_metadata_list,
indices_of_seq_with_bonus_tokens)
model_outputs.append(model_output)

# move indices to device to avoid stream sync
indices_of_seq_with_bonus_tokens = torch.tensor(
indices_of_seq_with_bonus_tokens, device=self.device)
# if model_outputs[0].sampled_token_ids[0][0].item()==2501:

filtered_model_outputs = self._filter_model_output(
model_outputs, indices_of_seq_with_bonus_tokens)
return filtered_model_outputs, True
Expand All @@ -136,7 +180,7 @@ def _expand_execute_model_request(
execute_model_req: ExecuteModelRequest,
seq_with_bonus_token_in_last_step: set,
) -> Tuple[ExecuteModelRequest, List[int]]:
"""
"""`
Expands the execute model request based on sequences with bonus
tokens.

Expand Down Expand Up @@ -238,12 +282,13 @@ def get_spec_proposals(
self,
execute_model_req: ExecuteModelRequest,
seq_ids_with_bonus_token_in_last_step: set,
accepted_token_id: Optional[torch.Tensor] = None,
) -> SpeculativeProposals:
"""Produce speculations given an input batch of sequences. The number of
speculative tokens per sequence is determined by max_proposal_len.
"""
return self._proposer.get_spec_proposals(
execute_model_req, seq_ids_with_bonus_token_in_last_step)
execute_model_req, seq_ids_with_bonus_token_in_last_step, accepted_token_id)

@staticmethod
def _append_new_tokens(
Expand Down
Loading
Loading