Skip to content

Commit 721b02d

Browse files
XWFAlonemengwei805JC-ut0
committed
[1/N][UT][v1 MTP] add basic v1 mtp features
Co-authored-by: XWFAlone <xuewenfei2@huawei.com> Co-authored-by: mengwei805 <mengwei25@huawei.com> Co-authored-by: JC-ut0 <xuyexiong@huawei.com>
1 parent 17f05b1 commit 721b02d

File tree

9 files changed

+927
-14
lines changed

9 files changed

+927
-14
lines changed

.github/workflows/vllm_ascend_test.yaml

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,14 +48,14 @@ jobs:
4848
max-parallel: 2
4949
matrix:
5050
os: [linux-arm64-npu-1, linux-arm64-npu-4]
51-
vllm_verison: [main, v0.8.5.post1]
51+
vllm_version: [main, v0.8.5.post1]
5252
concurrency:
5353
group: >
5454
${{
5555
matrix.os == 'linux-arm64-npu-4'
5656
&& github.event.pull_request.number
5757
&& format('pr-{0}-limit-npu-4', github.event.pull_request.number)
58-
|| format('job-{0}-{1}-{2}', matrix.os, matrix.vllm_verison, github.event.pull_request.number)
58+
|| format('job-{0}-{1}-{2}', matrix.os, matrix.vllm_version, github.event.pull_request.number)
5959
}}
6060
cancel-in-progress: false
6161
name: vLLM Ascend test
@@ -92,7 +92,7 @@ jobs:
9292
uses: actions/checkout@v4
9393
with:
9494
repository: vllm-project/vllm
95-
ref: ${{ matrix.vllm_verison }}
95+
ref: ${{ matrix.vllm_version }}
9696
path: ./vllm-empty
9797

9898
- name: Install vllm-project/vllm from source
@@ -121,6 +121,9 @@ jobs:
121121
pytest -sv tests/ops
122122
pytest -sv tests/compile
123123
fi
124+
if [[ "${{ matrix.os }}" == "linux-arm64-npu-1" ]] && [[ "${{ matrix.vllm_version }}" == "main" ]]; then
125+
pytest -sv tests/singlecard/spec_decode/e2e/test_v1_mtp_correctness.py
126+
fi
124127
125128
- name: Run vllm-project/vllm-ascend test on V0 engine
126129
env:
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
from __future__ import annotations
2+
3+
import os
4+
import random
5+
from typing import Any
6+
7+
import pytest
8+
from vllm import LLM, SamplingParams
9+
10+
os.environ['VLLM_USE_MODELSCOPE'] = 'True'
11+
12+
13+
@pytest.fixture
14+
def test_prompts():
15+
prompt_types = ["repeat", "sentence"]
16+
num_prompts = 10
17+
prompts = []
18+
19+
random.seed(0)
20+
random_prompt_type_choices = random.choices(prompt_types, k=num_prompts)
21+
22+
# Generate a mixed batch of prompts, some of which can be easily
23+
# predicted by n-gram matching and some which likely cannot.
24+
for kind in random_prompt_type_choices:
25+
word_choices = ["test", "temp", "hello", "where"]
26+
word = random.choice(word_choices)
27+
if kind == "repeat":
28+
prompt = f"""
29+
please repeat the word '{word}' 10 times.
30+
give no other output than the word at least ten times in a row,
31+
in lowercase with spaces between each word and without quotes.
32+
"""
33+
elif kind == "sentence":
34+
prompt = f"""
35+
please give a ten-word sentence that
36+
uses the word {word} at least once.
37+
give no other output than that simple sentence without quotes.
38+
"""
39+
else:
40+
raise ValueError(f"Unknown prompt type: {kind}")
41+
prompts.append([{"role": "user", "content": prompt}])
42+
43+
return prompts
44+
45+
46+
@pytest.fixture
47+
def sampling_config():
48+
return SamplingParams(temperature=0, max_tokens=256, ignore_eos=False)
49+
50+
51+
@pytest.fixture
52+
def model_name():
53+
return "wemaster/deepseek_mtp_main_random_bf16"
54+
55+
56+
def test_mtp_correctness(
57+
monkeypatch: pytest.MonkeyPatch,
58+
test_prompts: list[list[dict[str, Any]]],
59+
sampling_config: SamplingParams,
60+
model_name: str,
61+
):
62+
'''
63+
Compare the outputs of a original LLM and a speculative LLM
64+
should be the same when using mtp speculative decoding.
65+
'''
66+
with monkeypatch.context() as m:
67+
m.setenv("VLLM_USE_V1", "1")
68+
69+
ref_llm = LLM(model=model_name, max_model_len=256)
70+
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
71+
del ref_llm
72+
73+
spec_llm = LLM(model=model_name,
74+
trust_remote_code=True,
75+
speculative_config={
76+
"method": "mtp",
77+
"num_speculative_tokens": 1,
78+
},
79+
max_model_len=256,
80+
enforce_eager=False)
81+
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
82+
matches = 0
83+
misses = 0
84+
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
85+
if ref_output.outputs[0].text == spec_output.outputs[0].text:
86+
matches += 1
87+
else:
88+
misses += 1
89+
print(f"ref_output: {ref_output.outputs[0].text}")
90+
print(f"spec_output: {spec_output.outputs[0].text}")
91+
92+
# Heuristic: expect at least 66% of the prompts to match exactly
93+
# Upon failure, inspect the outputs to check for inaccuracy.
94+
assert matches > int(0.66 * len(ref_outputs))
95+
del spec_llm

vllm_ascend/attention/mla_v1.py

Lines changed: 40 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
LinearBase, RowParallelLinear,
1414
UnquantizedLinearMethod)
1515
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
16+
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
1617

1718
from vllm_ascend.attention.attention_v1 import AscendAttentionState
1819
from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla
@@ -57,6 +58,7 @@ class AscendMLAPrefillMetadata:
5758
seq_lens: list[int]
5859
context_lens: torch.Tensor
5960
input_positions: torch.Tensor
61+
query_start_loc: torch.Tensor
6062
block_table: torch.Tensor
6163
max_query_len: int
6264
max_seq_lens: int
@@ -90,6 +92,9 @@ class AscendMLAMetadata:
9092

9193
num_actual_tokens: int # Number of tokens excluding padding.
9294
slot_mapping: torch.Tensor
95+
query_start_loc: torch.Tensor
96+
seq_lens: torch.Tensor
97+
block_tables: torch.Tensor
9398

9499
# New for MLA (compared to FlashAttention)
95100
# For handling prefill decode split
@@ -231,6 +236,7 @@ def build(self,
231236
num_actual_tokens: int,
232237
max_query_len: int,
233238
common_prefix_len: Optional[int] = None,
239+
common_attn_metadata: CommonAttentionMetadata = None,
234240
graph_pad_size: int = -1) -> AscendMLAMetadata:
235241
assert self._num_decodes + self._num_prefills == num_reqs
236242

@@ -245,6 +251,7 @@ def build(self,
245251
input_positions = self.runner.positions_cpu[:num_actual_tokens].to(
246252
device, non_blocking=True).long()
247253

254+
query_start_loc = common_attn_metadata.query_start_loc
248255
seq_lens_cpu = self.runner.seq_lens_cpu[:num_reqs]
249256
query_lens = seq_lens_cpu - self.runner.input_batch.num_computed_tokens_cpu_tensor[:
250257
num_reqs]
@@ -258,6 +265,8 @@ def build(self,
258265
tokens_start = self._num_decode_tokens
259266
max_query_len = query_lens[tokens_start:].max().item()
260267
max_seq_lens = seq_lens[tokens_start:].max().item()
268+
prefill_query_start_loc = query_start_loc[
269+
reqs_start:] - query_start_loc[reqs_start]
261270

262271
prefill_metadata = AscendMLAPrefillMetadata(
263272
attn_mask=self.runner.attn_mask,
@@ -268,6 +277,7 @@ def build(self,
268277
block_table=block_table[reqs_start:, ...],
269278
max_query_len=max_query_len,
270279
max_seq_lens=max_seq_lens,
280+
query_start_loc=prefill_query_start_loc,
271281
)
272282

273283
decode_metadata = None
@@ -324,6 +334,9 @@ def build(self,
324334
attn_state=self.runner.attn_state,
325335
prefill=prefill_metadata,
326336
decode=decode_metadata,
337+
query_start_loc=query_start_loc,
338+
block_tables=block_table,
339+
seq_lens=seq_lens,
327340
)
328341

329342

@@ -373,6 +386,12 @@ def __init__(
373386
self.qk_rope_head_dim = qk_rope_head_dim
374387
self.qk_head_dim = qk_head_dim
375388
self.v_head_dim = v_head_dim
389+
# TODO: below padding should be removed after kernel is ready
390+
# we found npu_flash_attention can only works on 128 divisible head_dim, we pad it to target size here
391+
# and slice the final result to guarantee its functionality.
392+
self.padding_head_dim = (
393+
(self.qk_nope_head_dim + self.qk_rope_head_dim - 1) // 128 +
394+
1) * 128
376395

377396
# Hack for V1 for now to avoid torch library overhead (since we are
378397
# already inside an attention custom op), pull out the forward
@@ -470,11 +489,9 @@ def get_and_maybe_dequant_weights(layer: LinearBase):
470489
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
471490

472491
# Convert from (L, N, V) to (N, L, V)
473-
self.W_UV = W_UV.transpose(0, 1).contiguous()
492+
self.W_UV = W_UV.transpose(0, 1)
474493
# Convert from (L, N, P) to (N, P, L)
475-
self.W_UK_T = W_UK.permute(1, 2, 0).contiguous()
476-
self.W_UV.data = torch_npu.npu_format_cast(self.W_UV.data, 29)
477-
self.W_UK_T.data = torch_npu.npu_format_cast(self.W_UK_T.data, 29)
494+
self.W_UK_T = W_UK.permute(1, 2, 0)
478495

479496
def _forward_prefill(
480497
self,
@@ -514,7 +531,7 @@ def _forward_prefill(
514531
elif attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
515532
attn_output = torch.empty(num_tokens,
516533
self.num_heads,
517-
self.v_head_dim,
534+
self.padding_head_dim,
518535
dtype=query.dtype,
519536
device=query.device)
520537
k_nope, value = self.kv_b_proj(kv_c_normed)[0].view(
@@ -523,17 +540,31 @@ def _forward_prefill(
523540
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
524541
key = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))),
525542
dim=-1)
543+
pad_query = torch.nn.functional.pad(query, [
544+
0, self.padding_head_dim - self.qk_rope_head_dim -
545+
self.qk_nope_head_dim
546+
],
547+
value=0)
548+
pad_key = torch.nn.functional.pad(key, [
549+
0, self.padding_head_dim - self.qk_rope_head_dim -
550+
self.qk_nope_head_dim
551+
],
552+
value=0)
553+
pad_value = torch.nn.functional.pad(
554+
value, [0, self.padding_head_dim - self.v_head_dim], value=0)
526555
torch_npu._npu_flash_attention(
527-
query=query,
528-
key=key,
529-
value=value,
556+
query=pad_query,
557+
key=pad_key,
558+
value=pad_value,
530559
mask=attn_metadata.attn_mask,
531560
seq_len=attn_metadata.prefill.context_lens,
532561
scale_value=self.scale,
533562
num_heads=self.num_heads,
534563
num_kv_heads=self.num_heads,
535564
out=attn_output)
536-
attn_output = attn_output.view(-1, self.num_heads, self.v_head_dim)
565+
attn_output = attn_output.view(
566+
-1, self.num_heads,
567+
self.padding_head_dim)[:, :, :self.v_head_dim]
537568
else:
538569
raise RuntimeError(
539570
"Unexpected path reached, AscendMLAImpl should only have PrefillNoCache and ChunkedPrefill scenario in forward prefill, please file a bug to vllm-ascend !"

vllm_ascend/ops/attention.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,17 @@ def vanilla_chunked_prefill_mla(
222222
device="npu",
223223
dtype=value.dtype,
224224
)
225+
num_query = torch.sum(q_mask).item()
226+
num_add_query = num_query - query.size(0)
227+
# mtp will come in
228+
if num_add_query != 0:
229+
add_query_size = query.size()
230+
add_query_size = list(add_query_size)
231+
add_query_size[0] = num_add_query
232+
pad_tensor = torch.zeros(add_query_size,
233+
dtype=query.dtype,
234+
device=query.device)
235+
query = torch.cat([query, pad_tensor], dim=0)
225236
pad_q[q_mask] = query
226237
pad_k[kv_c_mask] = key[kv_c_mask]
227238
pad_v[kv_c_mask] = value[kv_c_mask]
@@ -247,8 +258,8 @@ def vanilla_chunked_prefill_mla(
247258

248259
attn_output = (attn_output[q_mask].view([-1, num_heads,
249260
v_head_dim]).to(output.dtype))
250-
output = output.view_as(attn_output)
251-
output.copy_(attn_output)
261+
output = output.view([-1, num_heads, v_head_dim])
262+
output.copy_(attn_output[:query.size(0) - num_add_query])
252263
return attn_output
253264

254265

vllm_ascend/patch/platform/patch_main/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,6 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616
#
17+
18+
import vllm_ascend.patch.platform.patch_main.patch_arg_utils
19+
import vllm_ascend.patch.platform.patch_main.patch_config

0 commit comments

Comments
 (0)