Skip to content

Commit 351f923

Browse files
XWFAlonemengwei805JC-ut0
committed
[1/N][UT][v1 MTP] add basic v1 mtp features
Co-authored-by: XWFAlone <xuewenfei2@huawei.com> Co-authored-by: mengwei805 <mengwei25@huawei.com> Co-authored-by: JC-ut0 <xuyexiong@huawei.com> Signed-off-by: XWFAlone <xuewenfei2@huawei.com>
1 parent e2a0c19 commit 351f923

File tree

8 files changed

+484
-21
lines changed

8 files changed

+484
-21
lines changed

.github/workflows/vllm_ascend_test_long_term.yaml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,10 @@ jobs:
9292
9393
- name: Run vllm-project/vllm-ascend long term test
9494
run: |
95+
if [[ "${{ matrix.vllm_version }}" == "main" ]]; then
96+
VLLM_USE_MODELSCOPE=True pytest -sv tests/long_term/spec_decode/e2e/test_v1_mtp_correctness.py
97+
fi
9598
# spec decode test
9699
VLLM_USE_MODELSCOPE=true pytest -sv tests/long_term/spec_decode/e2e/test_v1_spec_decode.py
97100
VLLM_USE_MODELSCOPE=True pytest -sv tests/long_term/spec_decode/e2e/test_mtp_correctness.py # it needs a clean process
98-
pytest -sv tests/long_term/spec_decode --ignore=tests/long_term/spec_decode/e2e/test_mtp_correctness.py --ignore=tests/long_term/spec_decode/e2e/test_v1_spec_decode.py
101+
pytest -sv tests/long_term/spec_decode --ignore=tests/long_term/spec_decode/e2e/test_mtp_correctness.py --ignore=tests/long_term/spec_decode/e2e/test_v1_spec_decode.py --ignore=tests/long_term/spec_decode/e2e/test_v1_mtp_correctness.py

tests/long_term/spec_decode/e2e/conftest.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,9 @@
2626
from vllm import SamplingParams
2727
from vllm.sequence import PromptLogprobs, SampleLogprobs
2828

29-
from ....model_utils import (TokensTextLogprobs,
30-
TokensTextLogprobsPromptLogprobs,
31-
check_logprobs_close, check_outputs_equal)
29+
from tests.model_utils import (TokensTextLogprobs,
30+
TokensTextLogprobsPromptLogprobs,
31+
check_logprobs_close, check_outputs_equal)
3232

3333
PROMPTS = [
3434
"Hello, my name is",
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
from __future__ import annotations
2+
3+
import random
4+
from typing import Any
5+
6+
import pytest
7+
from vllm import LLM, SamplingParams
8+
9+
10+
@pytest.fixture
11+
def test_prompts():
12+
prompt_types = ["repeat", "sentence"]
13+
num_prompts = 10
14+
prompts = []
15+
16+
random.seed(0)
17+
random_prompt_type_choices = random.choices(prompt_types, k=num_prompts)
18+
19+
# Generate a mixed batch of prompts, some of which can be easily
20+
# predicted by n-gram matching and some which likely cannot.
21+
for kind in random_prompt_type_choices:
22+
word_choices = ["test", "temp", "hello", "where"]
23+
word = random.choice(word_choices)
24+
if kind == "repeat":
25+
prompt = f"""
26+
please repeat the word '{word}' 10 times.
27+
give no other output than the word at least ten times in a row,
28+
in lowercase with spaces between each word and without quotes.
29+
"""
30+
elif kind == "sentence":
31+
prompt = f"""
32+
please give a ten-word sentence that
33+
uses the word {word} at least once.
34+
give no other output than that simple sentence without quotes.
35+
"""
36+
else:
37+
raise ValueError(f"Unknown prompt type: {kind}")
38+
prompts.append([{"role": "user", "content": prompt}])
39+
40+
return prompts
41+
42+
43+
@pytest.fixture
44+
def sampling_config():
45+
return SamplingParams(temperature=0, max_tokens=256, ignore_eos=False)
46+
47+
48+
@pytest.fixture
49+
def model_name():
50+
return "wemaster/deepseek_mtp_main_random_bf16"
51+
52+
53+
def test_mtp_correctness(
54+
monkeypatch: pytest.MonkeyPatch,
55+
test_prompts: list[list[dict[str, Any]]],
56+
sampling_config: SamplingParams,
57+
model_name: str,
58+
):
59+
'''
60+
Compare the outputs of a original LLM and a speculative LLM
61+
should be the same when using mtp speculative decoding.
62+
'''
63+
with monkeypatch.context() as m:
64+
m.setenv("VLLM_USE_V1", "1")
65+
66+
ref_llm = LLM(model=model_name, max_model_len=256)
67+
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
68+
del ref_llm
69+
70+
spec_llm = LLM(model=model_name,
71+
trust_remote_code=True,
72+
speculative_config={
73+
"method": "deepseek_mtp",
74+
"num_speculative_tokens": 1,
75+
},
76+
max_model_len=256,
77+
enforce_eager=False)
78+
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
79+
matches = 0
80+
misses = 0
81+
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
82+
if ref_output.outputs[0].text == spec_output.outputs[0].text:
83+
matches += 1
84+
else:
85+
misses += 1
86+
print(f"ref_output: {ref_output.outputs[0].text}")
87+
print(f"spec_output: {spec_output.outputs[0].text}")
88+
89+
# Heuristic: expect at least 66% of the prompts to match exactly
90+
# Upon failure, inspect the outputs to check for inaccuracy.
91+
assert matches > int(0.66 * len(ref_outputs))
92+
del spec_llm

vllm_ascend/attention/mla_v1.py

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,26 @@
1717
from vllm_ascend.attention.attention_v1 import AscendAttentionState
1818
from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla
1919
from vllm_ascend.utils import vllm_version_is
20-
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
2120

2221
if TYPE_CHECKING:
2322
from vllm.v1.core.sched.output import SchedulerOutput
2423
from vllm.v1.worker.gpu_input_batch import InputBatch
2524

2625

26+
@dataclass
27+
class CommonAttentionMetadata:
28+
"""
29+
Attention metadata attributes that can be shared by layers in different KV
30+
cache groups and thus having different block table.
31+
"""
32+
33+
query_start_loc: torch.Tensor
34+
"""(batch_size + 1,), the start location of each request in query Tensor"""
35+
seq_lens: torch.Tensor
36+
"""(batch_size,), the length of each request including both computed tokens
37+
and newly scheduled tokens"""
38+
39+
2740
class AscendMLABackend(AttentionBackend):
2841

2942
accept_output_buffer: bool = True
@@ -58,6 +71,7 @@ class AscendMLAPrefillMetadata:
5871
seq_lens: list[int]
5972
context_lens: torch.Tensor
6073
input_positions: torch.Tensor
74+
query_start_loc: torch.Tensor
6175
block_table: torch.Tensor
6276
max_query_len: int
6377
max_seq_lens: int
@@ -91,6 +105,9 @@ class AscendMLAMetadata:
91105

92106
num_actual_tokens: int # Number of tokens excluding padding.
93107
slot_mapping: torch.Tensor
108+
query_start_loc: torch.Tensor
109+
seq_lens: torch.Tensor
110+
block_tables: torch.Tensor
94111

95112
# New for MLA (compared to FlashAttention)
96113
# For handling prefill decode split
@@ -131,7 +148,7 @@ class AscendMLAMetadataBuilder:
131148

132149
# _attn_mask_builder = None
133150
def __init__(self,
134-
runner: "NPUModelRunner",
151+
runner,
135152
metadata_cls: Optional[AscendMLAMetadata] = None):
136153
self.metadata_cls: Optional[AscendMLAMetadata] = metadata_cls \
137154
if metadata_cls is not None else AscendMLAMetadata # type: ignore
@@ -231,6 +248,7 @@ def build(self,
231248
num_reqs: int,
232249
num_actual_tokens: int,
233250
max_query_len: int,
251+
common_attn_metadata: CommonAttentionMetadata,
234252
common_prefix_len: Optional[int] = None,
235253
graph_pad_size: int = -1) -> AscendMLAMetadata:
236254
assert self._num_decodes + self._num_prefills == num_reqs
@@ -243,10 +261,8 @@ def build(self,
243261
block_table = (self.runner.input_batch.block_table.
244262
get_device_tensor()[:num_reqs])
245263
else:
246-
block_table = self.runner.input_batch.block_table[
247-
0].get_device_tensor()
248-
block_table[:num_reqs, :self.runner.max_num_blocks_per_req] = (
249-
block_table[:num_reqs])
264+
block_table = (self.runner.input_batch.block_table[0].
265+
get_device_tensor()[:num_reqs])
250266
slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to(
251267
device, non_blocking=True)
252268
input_positions = self.runner.positions_cpu[:num_actual_tokens].to(
@@ -258,13 +274,17 @@ def build(self,
258274
seq_lens = seq_lens_cpu
259275
max_query_len = query_lens.max().item()
260276
max_seq_lens = seq_lens.max().item()
277+
query_start_loc = None
261278

262279
prefill_metadata = None
263280
if self._num_prefills > 0:
264281
reqs_start = self._num_decodes # prefill_start
265282
tokens_start = self._num_decode_tokens
266283
max_query_len = query_lens[tokens_start:].max().item()
267284
max_seq_lens = seq_lens[tokens_start:].max().item()
285+
query_start_loc = common_attn_metadata.query_start_loc
286+
prefill_query_start_loc = query_start_loc[
287+
reqs_start:] - query_start_loc[reqs_start]
268288

269289
prefill_metadata = AscendMLAPrefillMetadata(
270290
attn_mask=self.runner.attn_mask,
@@ -275,6 +295,7 @@ def build(self,
275295
block_table=block_table[reqs_start:, ...],
276296
max_query_len=max_query_len,
277297
max_seq_lens=max_seq_lens,
298+
query_start_loc=prefill_query_start_loc,
278299
)
279300

280301
decode_metadata = None
@@ -331,6 +352,9 @@ def build(self,
331352
attn_state=self.runner.attn_state,
332353
prefill=prefill_metadata,
333354
decode=decode_metadata,
355+
query_start_loc=query_start_loc,
356+
block_tables=block_table,
357+
seq_lens=seq_lens,
334358
)
335359

336360

vllm_ascend/ops/attention.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,17 @@ def vanilla_chunked_prefill_mla(
222222
device="npu",
223223
dtype=value.dtype,
224224
)
225+
num_query = torch.sum(q_mask).item()
226+
num_add_query = num_query - query.size(0)
227+
# mtp will come in
228+
if num_add_query != 0:
229+
add_query_size = query.size()
230+
add_query_size = list(add_query_size)
231+
add_query_size[0] = num_add_query
232+
pad_tensor = torch.zeros(add_query_size,
233+
dtype=query.dtype,
234+
device=query.device)
235+
query = torch.cat([query, pad_tensor], dim=0)
225236
pad_q[q_mask] = query
226237
pad_k[kv_c_mask] = key[kv_c_mask]
227238
pad_v[kv_c_mask] = value[kv_c_mask]
@@ -247,8 +258,8 @@ def vanilla_chunked_prefill_mla(
247258

248259
attn_output = (attn_output[q_mask].view([-1, num_heads,
249260
v_head_dim]).to(output.dtype))
250-
output = output.view_as(attn_output)
251-
output.copy_(attn_output)
261+
output = output.view([-1, num_heads, v_head_dim])
262+
output.copy_(attn_output[:query.size(0) - num_add_query])
252263
return attn_output
253264

254265

vllm_ascend/patch/platform/patch_main/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,4 @@
1313
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
16-
#
16+
#

0 commit comments

Comments
 (0)