Skip to content

Commit 27f8f0f

Browse files
committed
[1/N][UT][v1 MTP] add basic v1 mtp features
1 parent 01e3d59 commit 27f8f0f

File tree

7 files changed

+520
-18
lines changed

7 files changed

+520
-18
lines changed
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
from __future__ import annotations
2+
3+
import os
4+
import random
5+
from typing import Any
6+
7+
import pytest
8+
from vllm import LLM, SamplingParams
9+
10+
os.environ['VLLM_USE_MODELSCOPE'] = 'True'
11+
12+
13+
@pytest.fixture
14+
def test_prompts():
15+
prompt_types = ["repeat", "sentence"]
16+
num_prompts = 10
17+
prompts = []
18+
19+
random.seed(0)
20+
random_prompt_type_choices = random.choices(prompt_types, k=num_prompts)
21+
22+
# Generate a mixed batch of prompts, some of which can be easily
23+
# predicted by n-gram matching and some which likely cannot.
24+
for kind in random_prompt_type_choices:
25+
word_choices = ["test", "temp", "hello", "where"]
26+
word = random.choice(word_choices)
27+
if kind == "repeat":
28+
prompt = f"""
29+
please repeat the word '{word}' 10 times.
30+
give no other output than the word at least ten times in a row,
31+
in lowercase with spaces between each word and without quotes.
32+
"""
33+
elif kind == "sentence":
34+
prompt = f"""
35+
please give a ten-word sentence that
36+
uses the word {word} at least once.
37+
give no other output than that simple sentence without quotes.
38+
"""
39+
else:
40+
raise ValueError(f"Unknown prompt type: {kind}")
41+
prompts.append([{"role": "user", "content": prompt}])
42+
43+
return prompts
44+
45+
46+
@pytest.fixture
47+
def sampling_config():
48+
return SamplingParams(temperature=0, max_tokens=256, ignore_eos=False)
49+
50+
51+
@pytest.fixture
52+
def model_name():
53+
return "wemaster/deepseek_mtp_main_random_bf16"
54+
55+
56+
def test_mtp_correctness(
57+
monkeypatch: pytest.MonkeyPatch,
58+
test_prompts: list[list[dict[str, Any]]],
59+
sampling_config: SamplingParams,
60+
model_name: str,
61+
):
62+
'''
63+
Compare the outputs of a original LLM and a speculative LLM
64+
should be the same when using mtp speculative decoding.
65+
'''
66+
with monkeypatch.context() as m:
67+
m.setenv("VLLM_USE_V1", "1")
68+
69+
ref_llm = LLM(model=model_name, max_model_len=256)
70+
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
71+
del ref_llm
72+
73+
spec_llm = LLM(model=model_name,
74+
trust_remote_code=True,
75+
speculative_config={
76+
"method": "mtp",
77+
"num_speculative_tokens": 1,
78+
},
79+
max_model_len=256,
80+
enforce_eager=False)
81+
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
82+
matches = 0
83+
misses = 0
84+
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
85+
if ref_output.outputs[0].text == spec_output.outputs[0].text:
86+
matches += 1
87+
else:
88+
misses += 1
89+
print(f"ref_output: {ref_output.outputs[0].text}")
90+
print(f"spec_output: {spec_output.outputs[0].text}")
91+
92+
# Heuristic: expect at least 66% of the prompts to match exactly
93+
# Upon failure, inspect the outputs to check for inaccuracy.
94+
assert matches > int(0.66 * len(ref_outputs))
95+
del spec_llm

vllm_ascend/attention/mla_v1.py

Lines changed: 46 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,20 @@
1313
LinearBase, RowParallelLinear,
1414
UnquantizedLinearMethod)
1515
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
16+
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
1617

1718
from vllm_ascend.attention.attention_v1 import AscendAttentionState
1819
from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla
19-
from vllm_ascend.utils import vllm_version_is
20+
from vllm_ascend.utils import vllm_major_version_is, vllm_version_is
2021
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
2122

2223
if TYPE_CHECKING:
2324
from vllm.v1.core.sched.output import SchedulerOutput
2425
from vllm.v1.worker.gpu_input_batch import InputBatch
2526

27+
if vllm_major_version_is("0.9.0"):
28+
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
29+
2630

2731
class AscendMLABackend(AttentionBackend):
2832

@@ -58,6 +62,7 @@ class AscendMLAPrefillMetadata:
5862
seq_lens: list[int]
5963
context_lens: torch.Tensor
6064
input_positions: torch.Tensor
65+
query_start_loc: torch.Tensor
6166
block_table: torch.Tensor
6267
max_query_len: int
6368
max_seq_lens: int
@@ -91,6 +96,9 @@ class AscendMLAMetadata:
9196

9297
num_actual_tokens: int # Number of tokens excluding padding.
9398
slot_mapping: torch.Tensor
99+
query_start_loc: torch.Tensor
100+
seq_lens: torch.Tensor
101+
block_tables: torch.Tensor
94102

95103
# New for MLA (compared to FlashAttention)
96104
# For handling prefill decode split
@@ -232,6 +240,7 @@ def build(self,
232240
num_actual_tokens: int,
233241
max_query_len: int,
234242
common_prefix_len: Optional[int] = None,
243+
common_attn_metadata: CommonAttentionMetadata = None,
235244
graph_pad_size: int = -1) -> AscendMLAMetadata:
236245
assert self._num_decodes + self._num_prefills == num_reqs
237246

@@ -243,15 +252,14 @@ def build(self,
243252
block_table = (self.runner.input_batch.block_table.
244253
get_device_tensor()[:num_reqs])
245254
else:
246-
block_table = self.runner.input_batch.block_table[
247-
0].get_device_tensor()
248-
block_table[:num_reqs, :self.runner.max_num_blocks_per_req] = (
249-
block_table[:num_reqs])
255+
block_table = (self.runner.input_batch.block_table[0].
256+
get_device_tensor()[:num_reqs])
250257
slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to(
251258
device, non_blocking=True)
252259
input_positions = self.runner.positions_cpu[:num_actual_tokens].to(
253260
device, non_blocking=True).long()
254261

262+
query_start_loc = common_attn_metadata.query_start_loc
255263
seq_lens_cpu = self.runner.seq_lens_cpu[:num_reqs]
256264
query_lens = seq_lens_cpu - self.runner.input_batch.num_computed_tokens_cpu_tensor[:
257265
num_reqs]
@@ -265,6 +273,8 @@ def build(self,
265273
tokens_start = self._num_decode_tokens
266274
max_query_len = query_lens[tokens_start:].max().item()
267275
max_seq_lens = seq_lens[tokens_start:].max().item()
276+
prefill_query_start_loc = query_start_loc[
277+
reqs_start:] - query_start_loc[reqs_start]
268278

269279
prefill_metadata = AscendMLAPrefillMetadata(
270280
attn_mask=self.runner.attn_mask,
@@ -275,6 +285,7 @@ def build(self,
275285
block_table=block_table[reqs_start:, ...],
276286
max_query_len=max_query_len,
277287
max_seq_lens=max_seq_lens,
288+
query_start_loc=prefill_query_start_loc,
278289
)
279290

280291
decode_metadata = None
@@ -331,6 +342,9 @@ def build(self,
331342
attn_state=self.runner.attn_state,
332343
prefill=prefill_metadata,
333344
decode=decode_metadata,
345+
query_start_loc=query_start_loc,
346+
block_tables=block_table,
347+
seq_lens=seq_lens,
334348
)
335349

336350

@@ -380,6 +394,12 @@ def __init__(
380394
self.qk_rope_head_dim = qk_rope_head_dim
381395
self.qk_head_dim = qk_head_dim
382396
self.v_head_dim = v_head_dim
397+
# TODO: below padding should be removed after kernel is ready
398+
# we found npu_flash_attention can only works on 128 divisible head_dim, we pad it to target size here
399+
# and slice the final result to guarantee its functionality.
400+
self.padding_head_dim = (
401+
(self.qk_nope_head_dim + self.qk_rope_head_dim - 1) // 128 +
402+
1) * 128
383403

384404
# Hack for V1 for now to avoid torch library overhead (since we are
385405
# already inside an attention custom op), pull out the forward
@@ -477,11 +497,9 @@ def get_and_maybe_dequant_weights(layer: LinearBase):
477497
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
478498

479499
# Convert from (L, N, V) to (N, L, V)
480-
self.W_UV = W_UV.transpose(0, 1).contiguous()
500+
self.W_UV = W_UV.transpose(0, 1)
481501
# Convert from (L, N, P) to (N, P, L)
482-
self.W_UK_T = W_UK.permute(1, 2, 0).contiguous()
483-
self.W_UV.data = torch_npu.npu_format_cast(self.W_UV.data, 29)
484-
self.W_UK_T.data = torch_npu.npu_format_cast(self.W_UK_T.data, 29)
502+
self.W_UK_T = W_UK.permute(1, 2, 0)
485503

486504
def _forward_prefill(
487505
self,
@@ -521,7 +539,7 @@ def _forward_prefill(
521539
elif attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
522540
attn_output = torch.empty(num_tokens,
523541
self.num_heads,
524-
self.v_head_dim,
542+
self.padding_head_dim,
525543
dtype=query.dtype,
526544
device=query.device)
527545
k_nope, value = self.kv_b_proj(kv_c_normed)[0].view(
@@ -530,17 +548,31 @@ def _forward_prefill(
530548
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
531549
key = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))),
532550
dim=-1)
551+
pad_query = torch.nn.functional.pad(query, [
552+
0, self.padding_head_dim - self.qk_rope_head_dim -
553+
self.qk_nope_head_dim
554+
],
555+
value=0)
556+
pad_key = torch.nn.functional.pad(key, [
557+
0, self.padding_head_dim - self.qk_rope_head_dim -
558+
self.qk_nope_head_dim
559+
],
560+
value=0)
561+
pad_value = torch.nn.functional.pad(
562+
value, [0, self.padding_head_dim - self.v_head_dim], value=0)
533563
torch_npu._npu_flash_attention(
534-
query=query,
535-
key=key,
536-
value=value,
564+
query=pad_query,
565+
key=pad_key,
566+
value=pad_value,
537567
mask=attn_metadata.attn_mask,
538568
seq_len=attn_metadata.prefill.context_lens,
539569
scale_value=self.scale,
540570
num_heads=self.num_heads,
541571
num_kv_heads=self.num_heads,
542572
out=attn_output)
543-
attn_output = attn_output.view(-1, self.num_heads, self.v_head_dim)
573+
attn_output = attn_output.view(
574+
-1, self.num_heads,
575+
self.padding_head_dim)[:, :, :self.v_head_dim]
544576
else:
545577
raise RuntimeError(
546578
"Unexpected path reached, AscendMLAImpl should only have PrefillNoCache and ChunkedPrefill scenario in forward prefill, please file a bug to vllm-ascend !"

vllm_ascend/ops/attention.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,17 @@ def vanilla_chunked_prefill_mla(
222222
device="npu",
223223
dtype=value.dtype,
224224
)
225+
num_query = torch.sum(q_mask).item()
226+
num_add_query = num_query - query.size(0)
227+
# mtp will come in
228+
if num_add_query != 0:
229+
add_query_size = query.size()
230+
add_query_size = list(add_query_size)
231+
add_query_size[0] = num_add_query
232+
pad_tensor = torch.zeros(add_query_size,
233+
dtype=query.dtype,
234+
device=query.device)
235+
query = torch.cat([query, pad_tensor], dim=0)
225236
pad_q[q_mask] = query
226237
pad_k[kv_c_mask] = key[kv_c_mask]
227238
pad_v[kv_c_mask] = value[kv_c_mask]
@@ -247,8 +258,8 @@ def vanilla_chunked_prefill_mla(
247258

248259
attn_output = (attn_output[q_mask].view([-1, num_heads,
249260
v_head_dim]).to(output.dtype))
250-
output = output.view_as(attn_output)
251-
output.copy_(attn_output)
261+
output = output.view([-1, num_heads, v_head_dim])
262+
output.copy_(attn_output[:query.size(0) - num_add_query])
252263
return attn_output
253264

254265

vllm_ascend/patch/platform/patch_main/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,4 @@
1313
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
16-
#
16+
#

0 commit comments

Comments
 (0)