Skip to content

Commit 7d3ca5a

Browse files
XWFAlonemengwei805
andcommitted
[1/N][UT][v1 MTP] add basic v1 mtp features
Co-authored-by: XWFAlone <xuewenfei2@huawei.com> Co-authored-by: mengwei805 <mengwei25@huawei.com> Signed-off-by: x30059026 <xuewenfei2@huawei.com>
1 parent a8730e7 commit 7d3ca5a

File tree

8 files changed

+781
-0
lines changed

8 files changed

+781
-0
lines changed
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
from __future__ import annotations
2+
3+
import os
4+
import random
5+
from typing import Any
6+
7+
import pytest
8+
from vllm import LLM, SamplingParams
9+
10+
os.environ['VLLM_USE_MODELSCOPE'] = 'True'
11+
12+
13+
@pytest.fixture
14+
def test_prompts():
15+
prompt_types = ["repeat", "sentence"]
16+
num_prompts = 100
17+
prompts = []
18+
19+
random.seed(0)
20+
random_prompt_type_choices = random.choices(prompt_types, k=num_prompts)
21+
22+
# Generate a mixed batch of prompts, some of which can be easily
23+
# predicted by n-gram matching and some which likely cannot.
24+
for kind in random_prompt_type_choices:
25+
word_choices = ["test", "temp", "hello", "where"]
26+
word = random.choice(word_choices)
27+
if kind == "repeat":
28+
prompt = f"""
29+
please repeat the word '{word}' 10 times.
30+
give no other output than the word at least ten times in a row,
31+
in lowercase with spaces between each word and without quotes.
32+
"""
33+
elif kind == "sentence":
34+
prompt = f"""
35+
please give a ten-word sentence that
36+
uses the word {word} at least once.
37+
give no other output than that simple sentence without quotes.
38+
"""
39+
else:
40+
raise ValueError(f"Unknown prompt type: {kind}")
41+
prompts.append([{"role": "user", "content": prompt}])
42+
43+
return prompts
44+
45+
46+
@pytest.fixture
47+
def sampling_config():
48+
return SamplingParams(temperature=0, max_tokens=10, ignore_eos=False)
49+
50+
51+
@pytest.fixture
52+
def model_name():
53+
return "wemaster/deepseek_mtp_main_random_bf16"
54+
55+
56+
def test_mtp_correctness(
57+
monkeypatch: pytest.MonkeyPatch,
58+
test_prompts: list[list[dict[str, Any]]],
59+
sampling_config: SamplingParams,
60+
model_name: str,
61+
):
62+
'''
63+
Compare the outputs of a original LLM and a speculative LLM
64+
should be the same when using mtp speculative decoding.
65+
'''
66+
with monkeypatch.context() as m:
67+
m.setenv("VLLM_USE_V1", "1")
68+
69+
ref_llm = LLM(model=model_name, max_model_len=256)
70+
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
71+
del ref_llm
72+
73+
spec_llm = LLM(model=model_name,
74+
trust_remote_code=True,
75+
speculative_config={
76+
"method": "mtp",
77+
"num_speculative_tokens": 1,
78+
},
79+
max_model_len=256,
80+
enforce_eager=False)
81+
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
82+
matches = 0
83+
misses = 0
84+
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
85+
print(f"ref rst is {ref_output.outputs[0].text}")
86+
print(f"mtp rst is {spec_output.outputs[0].text}")
87+
if ref_output.outputs[0].text == spec_output.outputs[0].text:
88+
matches += 1
89+
else:
90+
misses += 1
91+
print(f"ref_output: {ref_output.outputs[0].text}")
92+
print(f"spec_output: {spec_output.outputs[0].text}")
93+
94+
# Heuristic: expect at least 66% of the prompts to match exactly
95+
# Upon failure, inspect the outputs to check for inaccuracy.
96+
assert matches > int(0.66 * len(ref_outputs))
97+
del spec_llm

vllm_ascend/attention/mla_v1.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -495,6 +495,9 @@ def _forward_prefill(
495495
) -> torch.Tensor:
496496
assert attn_metadata.prefill is not None
497497

498+
# TODO Don't know why PrefillCacheHit exists after turning on mtp
499+
if attn_metadata.attn_state == AscendAttentionState.PrefillCacheHit:
500+
attn_metadata.attn_state = AscendAttentionState.PrefillNoCache
498501
num_tokens = query.size(0)
499502
attn_output = None
500503
# Here is only 2 possibility of input, ChunkedPrefill or PrefillNoCache

vllm_ascend/patch/__init__.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,26 @@
8585
# Future Plan:
8686
# Remove those patch when vllm merged them
8787
#
88+
# ** File: platform/patch_common/patch_arg_utils.py**
89+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
90+
# 1. `vllm.engine.arg_utils.EngineArgs._is_v1_supported_oracle()`
91+
# Why:
92+
# In order to adapt to the mtp function of v1, a new patch is added.
93+
# How:
94+
# Add verification related to mtp function.
95+
# Future Plan:
96+
# Delete patch to follow the version plan.
97+
#
98+
# ** File: platform/patch_common/patch_config.py**
99+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
100+
# 1. `vllm.config.SpeculativeConfig.__post_init__()`
101+
# Why:
102+
# In order to adapt to the mtp function of v1, a new patch is added.
103+
# How:
104+
# Add verification related to mtp function.
105+
# Future Plan:
106+
# Delete patch to follow the version plan.
107+
#
88108
#
89109
# * Worker Patch:
90110
# ===============
@@ -158,4 +178,16 @@
158178
# - https://github.com/vllm-project/vllm-ascend/pull/395
159179
# Future Plan:
160180
# Revert it when the related pr is merged in vllm and vllm-ascend.
181+
#
182+
# ** File: worker/patch_common/patch_v1_mtp_proposer.py **
183+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
184+
# 1. `vllm.v1.spec_decode`
185+
# Why:
186+
# In order to adapt to the mtp function of v1, Because vllm has not yet merged into the
187+
# implementation of mtp in v1, we will implement mtp_proposer separately.
188+
# How:
189+
# Add verification related to mtp function.
190+
# Future Plan:
191+
# When vllm is merged into mtp, only the special parts will be modified.
192+
# Delete patch to follow the version plan.
161193
#

vllm_ascend/patch/platform/patch_common/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,6 @@
1515
# limitations under the License.
1616
#
1717

18+
import vllm_ascend.patch.platform.patch_common.patch_arg_utils
19+
import vllm_ascend.patch.platform.patch_common.patch_config
1820
import vllm_ascend.patch.platform.patch_common.patch_distributed # noqa
Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
from typing import TYPE_CHECKING, Any, Callable, Literal, TypeVar, Union
2+
3+
import vllm.envs as envs
4+
from transformers import PretrainedConfig
5+
from vllm.config import ModelConfig, SpeculativeConfig
6+
7+
if TYPE_CHECKING:
8+
from _typeshed import DataclassInstance
9+
10+
ConfigType = type[DataclassInstance]
11+
else:
12+
ConfigType = type
13+
14+
ConfigT = TypeVar("ConfigT", bound=ConfigType)
15+
16+
TaskOption = Literal["auto", "generate", "embedding", "embed", "classify",
17+
"score", "reward", "transcription"]
18+
19+
RunnerType = Literal["generate", "pooling", "draft", "transcription"]
20+
21+
HfOverrides = Union[dict[str, Any], Callable[[PretrainedConfig],
22+
PretrainedConfig]]
23+
24+
25+
def __post_init__(self):
26+
27+
# Note: "method" is a new parameter that helps to extend the
28+
# configuration of non-model-based proposers, and the "model" parameter
29+
# will be used to set the draft model, eagle head, or additional weight
30+
# when needed. If users do not specify "method", the speculative method
31+
# will be detected automatically if possible. If the speculative method
32+
# can not be detected, it will be considered as the "draft_model" by
33+
# default.
34+
35+
if self.model is None and self.num_speculative_tokens is not None:
36+
# TODO(Shangming): Refactor mtp configuration logic when supporting
37+
# mtp acceleration for more models besides deepseek_v3
38+
if self.target_model_config and \
39+
(self.target_model_config.hf_text_config.model_type \
40+
== "deepseek_v3" or
41+
self.target_model_config.hf_text_config.model_type \
42+
== "mimo"):
43+
# use the draft model from the same model:
44+
self.model = self.target_model_config.model
45+
elif self.method in ("ngram", "[ngram]"):
46+
self.model = "ngram"
47+
else:
48+
raise ValueError("num_speculative_tokens was provided without "
49+
"speculative model.")
50+
51+
# Automatically configure the method for ngram when "model" is used
52+
# instead of "method"
53+
if self.method is None and (self.model is not None
54+
and self.model in ("ngram", "[ngram]")):
55+
self.method = "ngram"
56+
57+
if self.method in ("ngram", "[ngram]"):
58+
# Unified to "ngram" internally
59+
self.method = "ngram"
60+
# Set default values if not provided
61+
if (self.prompt_lookup_min is None and self.prompt_lookup_max is None):
62+
# TODO(woosuk): Tune these values. They are arbitrarily chosen.
63+
self.prompt_lookup_min = 5
64+
self.prompt_lookup_max = 5
65+
elif self.prompt_lookup_min is None:
66+
assert self.prompt_lookup_max is not None
67+
self.prompt_lookup_min = self.prompt_lookup_max
68+
elif self.prompt_lookup_max is None:
69+
assert self.prompt_lookup_min is not None
70+
self.prompt_lookup_max = self.prompt_lookup_min
71+
72+
# Validate values
73+
if self.prompt_lookup_min < 1:
74+
raise ValueError(
75+
f"prompt_lookup_min={self.prompt_lookup_min} must be > 0")
76+
if self.prompt_lookup_max < 1:
77+
raise ValueError(
78+
f"prompt_lookup_max={self.prompt_lookup_max} must be > 0")
79+
if self.prompt_lookup_min > self.prompt_lookup_max:
80+
raise ValueError(
81+
f"prompt_lookup_min={self.prompt_lookup_min} must "
82+
f"be <= prompt_lookup_max={self.prompt_lookup_max}")
83+
84+
# TODO: current we still need extract vocab_size from target model
85+
# config, in future, we may try refactor it out, and set
86+
# draft related config as None here.
87+
self.draft_model_config = self.target_model_config
88+
self.draft_parallel_config = self.target_parallel_config
89+
else:
90+
self.prompt_lookup_max = 0
91+
self.prompt_lookup_min = 0
92+
93+
if self.model is not None:
94+
self.draft_model_config = ModelConfig(
95+
model=self.model,
96+
task="draft",
97+
tokenizer=self.target_model_config.tokenizer,
98+
tokenizer_mode=self.target_model_config.tokenizer_mode,
99+
trust_remote_code=self.target_model_config.trust_remote_code,
100+
allowed_local_media_path=self.target_model_config.
101+
allowed_local_media_path,
102+
dtype=self.target_model_config.dtype,
103+
seed=self.target_model_config.seed,
104+
revision=self.revision,
105+
code_revision=self.code_revision,
106+
tokenizer_revision=self.target_model_config.tokenizer_revision,
107+
spec_target_max_model_len=self.target_model_config.
108+
max_model_len,
109+
quantization=self.quantization,
110+
enforce_eager=self.target_model_config.enforce_eager,
111+
max_seq_len_to_capture=self.target_model_config.
112+
max_seq_len_to_capture,
113+
max_logprobs=self.target_model_config.max_logprobs,
114+
hf_overrides=SpeculativeConfig.hf_config_override,
115+
)
116+
117+
# Automatically detect the method
118+
if self.method in ('eagle', 'eagle3'):
119+
pass
120+
elif "eagle-" in self.draft_model_config.model.lower() or \
121+
"eagle3-" in self.draft_model_config.model.lower():
122+
self.method = "eagle"
123+
elif self.draft_model_config.hf_config.model_type == "medusa":
124+
self.method = "medusa"
125+
elif (self.draft_model_config.hf_config.model_type ==
126+
"mlp_speculator"):
127+
self.method = "mlp_speculator"
128+
elif self.draft_model_config.hf_config.model_type == "deepseek_mtp":
129+
self.method = 'mtp'
130+
else:
131+
self.method = "draft_model"
132+
133+
# Replace hf_config for EAGLE draft_model
134+
if self.method in ("eagle", "eagle3"):
135+
if self.enable_chunked_prefill and not envs.VLLM_USE_V1:
136+
raise ValueError(
137+
"Chunked prefill and EAGLE are not compatible "
138+
"when using V0.")
139+
140+
from vllm.platforms import current_platform
141+
from vllm.transformers_utils.configs.eagle import EAGLEConfig
142+
if isinstance(self.draft_model_config.hf_config,
143+
EAGLEConfig) or current_platform.is_neuron():
144+
pass
145+
else:
146+
eagle_config = EAGLEConfig(
147+
self.draft_model_config.hf_config, method=self.method)
148+
self.draft_model_config.hf_config = eagle_config
149+
150+
if (self.num_speculative_tokens is not None
151+
and hasattr(self.draft_model_config.hf_config,
152+
"num_lookahead_tokens")):
153+
self.draft_model_config.hf_config.num_lookahead_tokens = \
154+
self.num_speculative_tokens
155+
156+
n_predict = getattr(self.draft_model_config.hf_config, "n_predict",
157+
None)
158+
if n_predict is not None:
159+
if self.num_speculative_tokens is None:
160+
# Default to max value defined in draft model config.
161+
self.num_speculative_tokens = n_predict
162+
elif self.num_speculative_tokens > n_predict and \
163+
self.num_speculative_tokens % n_predict != 0:
164+
# Ensure divisibility for MTP module reuse.
165+
raise ValueError(
166+
f"num_speculative_tokens:{self.num_speculative_tokens}"
167+
f" must be divisible by {n_predict=}")
168+
169+
self.draft_tensor_parallel_size = \
170+
SpeculativeConfig._verify_and_get_draft_tp(
171+
self.target_parallel_config,
172+
self.draft_tensor_parallel_size,
173+
self.draft_model_config.hf_config
174+
)
175+
176+
self.draft_model_config.max_model_len = (
177+
SpeculativeConfig._maybe_override_draft_max_model_len(
178+
self.max_model_len,
179+
self.draft_model_config.max_model_len,
180+
self.target_model_config.max_model_len,
181+
))
182+
183+
self.draft_parallel_config = (
184+
SpeculativeConfig.create_draft_parallel_config(
185+
self.target_parallel_config,
186+
self.draft_tensor_parallel_size))
187+
188+
if self.acceptance_method == "typical_acceptance_sampler":
189+
if self.posterior_threshold is None:
190+
self.posterior_threshold = 0.09
191+
if self.posterior_alpha is None:
192+
self.posterior_alpha = 0.3
193+
194+
self._verify_args()
195+
196+
197+
SpeculativeConfig.__post_init__ = __post_init__

0 commit comments

Comments
 (0)