Skip to content

Commit e1d282d

Browse files
authored
[OPEN] [MTP V1] MTP adapt torchair graph mode (#1294)
### What this PR does / why we need it? 1. Add MTP dummy_run, and adapt main model dummy_run when mtp is enabled 2. Adapt main model torchair graph mode, when mtp is enabled 3. mtp model torchair graph mode will be supported in the future ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? This patch is tested by `vllm-ascend/tests/long_term/spec_decode_v1/test_v1_mtp_correctness.py` ### Usage online example ```shell export VLLM_USE_V1=1 export VLLM_ENABLE_MC2=1 export VLLM_VERSION=0.9.1 export ASCEND_LAUNCH_BLOCKING=0 python -m vllm.entrypoints.openai.api_server \ --model="/model_weight_path" \ --trust-remote-code \ --max-model-len 40000 \ --tensor-parallel-size 4 \ --data_parallel_size 4 \ --enable_expert_parallel \ --served-model-name deepseekr1 \ --quantization ascend \ --host 0.0.0.0 \ --port 1234 \ --speculative-config '{"num_speculative_tokens": 1, "method":"deepseek_mtp"}' \ --additional-config '{"ascend_scheduler_config":{"enabled":true,"enable_chunked_prefill":false},"torchair_graph_config":{"enabled":true}}' \ --gpu_memory_utilization 0.95 ``` offline example ``` lm = LLM( model="/model_weight_path", tensor_parallel_size=16, max_num_seqs=128, gpu_memory_utilization=0.95, distributed_executor_backend="mp", enable_expert_parallel=True, speculative_config={ "method": "deepseek_mtp", "num_speculative_tokens": 1, }, trust_remote_code=True, enforce_eager=False, additional_config = { 'torchair_graph_config': { 'enabled': True, 'enable_multistream_shared_expert': False }, "ascend_scheduler_config": { "enabled": True }, } ) ``` Signed-off-by: xuyexiong <xuyexiong@huawei.com>
1 parent 4f007e8 commit e1d282d

File tree

5 files changed

+313
-67
lines changed

5 files changed

+313
-67
lines changed

tests/long_term/spec_decode_v1/test_v1_mtp_correctness.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,3 +94,62 @@ def test_mtp_correctness(
9494
# Upon failure, inspect the outputs to check for inaccuracy.
9595
assert matches > int(0.66 * len(ref_outputs))
9696
del spec_llm
97+
98+
99+
def test_mtp_torchair_correctness(
100+
monkeypatch: pytest.MonkeyPatch,
101+
test_prompts: list[list[dict[str, Any]]],
102+
sampling_config: SamplingParams,
103+
model_name: str,
104+
):
105+
'''
106+
Compare the outputs of a original LLM and a speculative LLM
107+
should be the same when using mtp speculative decoding.
108+
'''
109+
with monkeypatch.context() as m:
110+
m.setenv("VLLM_USE_V1", "1")
111+
112+
ref_llm = LLM(model=model_name,
113+
max_model_len=256,
114+
enforce_eager=False,
115+
additional_config={
116+
"torchair_graph_config": {
117+
"enabled": True
118+
},
119+
"ascend_scheduler_config": {
120+
"enabled": True
121+
},
122+
})
123+
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
124+
del ref_llm
125+
126+
spec_llm = LLM(model=model_name,
127+
trust_remote_code=True,
128+
enforce_eager=False,
129+
speculative_config={
130+
"method": "deepseek_mtp",
131+
"num_speculative_tokens": 1,
132+
},
133+
additional_config={
134+
"torchair_graph_config": {
135+
"enabled": True
136+
},
137+
"ascend_scheduler_config": {
138+
"enabled": True
139+
},
140+
})
141+
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
142+
matches = 0
143+
misses = 0
144+
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
145+
if ref_output.outputs[0].text == spec_output.outputs[0].text:
146+
matches += 1
147+
else:
148+
misses += 1
149+
print(f"ref_output: {ref_output.outputs[0].text}")
150+
print(f"spec_output: {spec_output.outputs[0].text}")
151+
152+
# Heuristic: expect at least 66% of the prompts to match exactly
153+
# Upon failure, inspect the outputs to check for inaccuracy.
154+
assert matches > int(0.66 * len(ref_outputs))
155+
del spec_llm

vllm_ascend/attention/mla_v1.py

Lines changed: 76 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ class AscendMLADecodeMetadata:
104104
seq_lens: torch.Tensor
105105
max_seq_lens: int
106106
seq_lens_list: list[int]
107+
actual_seq_q_lens: Optional[list[int]] = None
107108
attn_mask: Optional[torch.Tensor] = None
108109

109110

@@ -138,6 +139,7 @@ class AscendMLAMetadata:
138139
num_input_tokens: int = 0 # Number of tokens including padding.
139140

140141
enable_dbo_across_dp: bool = False
142+
is_mtp_model: bool = False
141143

142144
query_lens: Optional[list[int]] = None
143145
# The dimension of the attention heads
@@ -313,48 +315,64 @@ def _get_graph_runner_block_tables(
313315
return graph_block_tables[:num_seqs, :max_blocks]
314316

315317
def build_torchair_graph_dummy(
316-
self, num_reqs: int, num_actual_tokens: int) -> AscendMLAMetadata:
318+
self,
319+
num_reqs: int,
320+
num_actual_tokens: int,
321+
is_mtp_model: bool = False,
322+
) -> AscendMLAMetadata:
317323
device = self.runner.device
318324
_, max_blocks = self.runner.graph_block_tables.shape
319325
block_table = torch.zeros((num_reqs, max_blocks),
320326
dtype=torch.int32,
321327
device=device)
322328
block_table = self._get_graph_runner_block_tables(
323329
num_reqs, block_table)
324-
seq_lens = torch.ones(num_reqs, dtype=torch.int32, device=device)
325-
input_positions = torch.zeros(num_reqs,
330+
num_tokens = num_reqs * self.runner.decode_token_per_req
331+
seq_lens = torch.zeros(num_reqs, dtype=torch.int32, device=device)
332+
seq_lens_list = seq_lens.tolist()
333+
input_positions = torch.zeros(num_tokens,
326334
dtype=torch.int32,
327335
device=device).long()
328-
slot_mapping = torch.full((num_reqs, ),
336+
slot_mapping = torch.full((num_tokens, ),
329337
PAD_SLOT_ID,
330338
dtype=torch.int32,
331339
device=device)
332340
query_start_loc = torch.full((num_reqs, ),
333341
-1,
334342
dtype=torch.int32,
335343
device=device)
344+
if self.runner.speculative_config is not None and\
345+
self.runner.speculative_config.method == 'deepseek_mtp' and not is_mtp_model:
346+
attn_state = AscendAttentionState.SpecDecoding
347+
num_decode_tokens = 2
348+
else:
349+
attn_state = AscendAttentionState.DecodeOnly
350+
num_decode_tokens = 1
336351
decode_metadata = AscendMLADecodeMetadata(
337352
input_positions=input_positions,
338353
block_table=block_table,
339354
seq_lens=seq_lens,
340-
seq_lens_list=seq_lens.tolist(),
355+
seq_lens_list=seq_lens_list,
341356
max_seq_lens=1,
342-
attn_mask=self.runner.spec_attn_mask)
357+
attn_mask=self.runner.spec_attn_mask,
358+
actual_seq_q_lens=self.runner.actual_seq_q_lens[:num_reqs],
359+
)
343360
return self.metadata_cls( # type: ignore
344361
num_input_tokens=num_actual_tokens,
345362
num_actual_tokens=num_actual_tokens,
346363
slot_mapping=slot_mapping,
347364
head_dim=self.runner.model_config.get_head_size(),
348365
num_decodes=1,
349-
num_decode_tokens=1,
366+
num_decode_tokens=num_decode_tokens,
350367
num_prefills=0,
351368
attn_mask=self.runner.attn_mask,
352-
attn_state=AscendAttentionState.DecodeOnly,
369+
attn_state=attn_state,
353370
prefill=None,
354371
decode=decode_metadata,
355372
query_start_loc=query_start_loc,
356373
seq_lens=seq_lens,
357374
block_tables=block_table,
375+
is_mtp_model=is_mtp_model,
358376
)
359377

360378
def build(
@@ -364,8 +382,10 @@ def build(
364382
max_query_len: int,
365383
common_attn_metadata: CommonAttentionMetadata,
366384
common_prefix_len: Optional[int] = None,
367-
graph_pad_size: int = -1,
385+
num_token_pad_size: int = -1,
386+
num_reqs_pad_size: int = 0,
368387
enable_dbo_across_dp: bool = False,
388+
is_mtp_model: bool = False,
369389
) -> AscendMLAMetadata:
370390
assert self._num_decodes + self._num_prefills == num_reqs
371391

@@ -449,8 +469,9 @@ def build(
449469
)
450470

451471
decode_metadata = None
452-
use_torchair_graph = graph_pad_size != -1
472+
use_torchair_graph = num_token_pad_size != -1
453473
if self._num_decodes > 0:
474+
actual_seq_q_lens = None
454475
max_seq_lens = seq_lens[:self._num_decodes].max().item()
455476
seq_lens = seq_lens[:self._num_decode_tokens]
456477
input_positions = input_positions[:self._num_decode_tokens]
@@ -459,41 +480,48 @@ def build(
459480
AscendAttentionState.DecodeOnly,
460481
AscendAttentionState.SpecDecoding
461482
]:
462-
num_seqs = len(seq_lens)
463-
if graph_pad_size != 0:
464-
pad_value = 1
465-
padded_seq_lens = seq_lens.tolist() + [pad_value
466-
] * graph_pad_size
483+
if num_token_pad_size != 0:
484+
pad_value = 0
485+
padded_seq_lens = seq_lens.tolist(
486+
) + [pad_value] * num_reqs_pad_size
467487
else:
468488
padded_seq_lens = seq_lens.tolist()
469489

470490
seq_lens = torch.from_numpy(
471491
np.array(padded_seq_lens).astype(np.int32))
472-
padding = torch.full((graph_pad_size, ),
492+
seq_lens_list = padded_seq_lens
493+
padding = torch.full((num_token_pad_size, ),
473494
PAD_SLOT_ID,
474495
dtype=slot_mapping.dtype,
475496
device=slot_mapping.device)
476497
slot_mapping = torch.cat([slot_mapping, padding])
477498
block_table_padding = torch.zeros(
478-
(graph_pad_size, ) + block_table.shape[1:],
499+
(num_reqs_pad_size, ) + block_table.shape[1:],
479500
dtype=block_table.dtype,
480501
device=block_table.device)
481502
block_table = torch.cat([block_table, block_table_padding],
482503
dim=0)
483504
block_table = self._get_graph_runner_block_tables(
484-
num_seqs + graph_pad_size, block_table)
485-
padding_0 = torch.zeros(graph_pad_size,
505+
num_reqs + num_reqs_pad_size, block_table)
506+
padding_0 = torch.zeros(num_token_pad_size,
486507
dtype=input_positions.dtype,
487508
device=input_positions.device)
488509
input_positions = torch.cat([input_positions, padding_0])
510+
actual_seq_q_lens = query_start_loc[1:].tolist(
511+
) + self.runner.actual_seq_q_lens[num_reqs:num_reqs +
512+
num_reqs_pad_size]
513+
else:
514+
seq_lens_list = seq_lens.tolist()
489515

490516
decode_metadata = AscendMLADecodeMetadata(
491517
input_positions=input_positions,
492518
block_table=block_table,
493519
seq_lens=seq_lens,
494-
seq_lens_list=seq_lens.tolist(),
520+
seq_lens_list=seq_lens_list,
495521
max_seq_lens=max_seq_lens,
496-
attn_mask=self.runner.spec_attn_mask)
522+
attn_mask=self.runner.spec_attn_mask,
523+
actual_seq_q_lens=actual_seq_q_lens,
524+
)
497525

498526
return self.metadata_cls( # type: ignore
499527
num_actual_tokens=num_actual_tokens,
@@ -510,7 +538,9 @@ def build(
510538
query_start_loc=query_start_loc,
511539
block_tables=block_table,
512540
seq_lens=seq_lens,
513-
enable_dbo_across_dp=enable_dbo_across_dp)
541+
enable_dbo_across_dp=enable_dbo_across_dp,
542+
is_mtp_model=is_mtp_model,
543+
)
514544

515545

516546
class AscendMLAImpl(MLAAttentionImpl):
@@ -933,31 +963,10 @@ def _forward_decode(
933963
assert decode_meta is not None
934964
num_tokens = q_nope.size(0)
935965
if self.running_in_graph:
936-
# TorchAir's shape is [bs, num_heads_per_rank, q_seq_len, dim]
937-
if attn_metadata.attn_state == AscendAttentionState.SpecDecoding:
938-
assert num_tokens % self.spec_token_num == 0
939-
q_nope = q_nope.view(num_tokens // (self.spec_token_num + 1),
940-
self.spec_token_num + 1, self.num_heads,
941-
-1)
942-
q_pe = q_pe.view(num_tokens // (self.spec_token_num + 1),
943-
self.spec_token_num + 1, self.num_heads, -1)
944-
if not self.enable_kv_nz:
945-
q_nope = q_nope.transpose(1, 2).contiguous()
946-
q_pe = q_pe.transpose(1, 2).contiguous()
947-
sparse_mode = 3
948-
spec_attn_mask = attn_metadata.decode.attn_mask # type:ignore
949-
else:
950-
if self.enable_kv_nz:
951-
q_nope = q_nope.view(num_tokens, 1, self.num_heads, -1)
952-
q_pe = q_pe.view(num_tokens, 1, self.num_heads, -1)
953-
else:
954-
q_nope = q_nope.view(num_tokens, self.num_heads, 1, -1)
955-
q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1)
956-
sparse_mode = 0
957-
spec_attn_mask = None
958966
# shape of knope/k_pe for npu graph mode should be:
959967
# [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim]
960968
block_size = kv_c_and_k_pe_cache[0].shape[1]
969+
actual_seq_lengths = None
961970
if self.enable_kv_nz:
962971
k_nope = k_nope.view(-1, self.num_kv_heads,
963972
self.kv_lora_rank // 16, block_size, 16)
@@ -971,6 +980,26 @@ def _forward_decode(
971980
self.qk_rope_head_dim)
972981
input_layout = "BNSD"
973982

983+
# TorchAir's shape is [bs, num_heads_per_rank, q_seq_len, dim]
984+
if attn_metadata.attn_state == AscendAttentionState.SpecDecoding:
985+
assert num_tokens % self.spec_token_num == 0
986+
# [bs * q_seq_len, num_heads_per_rank, dim]
987+
input_layout = "TND"
988+
q_nope = q_nope.view(num_tokens, self.num_heads, -1)
989+
q_pe = q_pe.view(num_tokens, self.num_heads, -1)
990+
sparse_mode = 3
991+
spec_attn_mask = attn_metadata.decode.attn_mask # type:ignore
992+
actual_seq_lengths = decode_meta.actual_seq_q_lens
993+
else:
994+
if self.enable_kv_nz:
995+
q_nope = q_nope.view(num_tokens, 1, self.num_heads, -1)
996+
q_pe = q_pe.view(num_tokens, 1, self.num_heads, -1)
997+
else:
998+
q_nope = q_nope.view(num_tokens, self.num_heads, 1, -1)
999+
q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1)
1000+
sparse_mode = 0
1001+
spec_attn_mask = None
1002+
9741003
attn_output, _ = torch_npu.npu_fused_infer_attention_score(
9751004
q_nope,
9761005
k_nope,
@@ -988,7 +1017,7 @@ def _forward_decode(
9881017
block_table=decode_meta.block_table,
9891018
block_size=block_size,
9901019
actual_seq_lengths_kv=decode_meta.seq_lens_list,
991-
)
1020+
actual_seq_lengths=actual_seq_lengths)
9921021
else:
9931022
# The MLA_PA path will be used as default path in the future, `_npu_paged_attention_mla` will
9941023
# be removed after the torch_npu contains `torch_npu.atb.npu_multi_head_latent_attention` become
@@ -1042,6 +1071,8 @@ def forward(
10421071
if attn_metadata is None:
10431072
# Profiling run.
10441073
return output
1074+
# mtp model is not support for graph mode yet
1075+
self.torchair_graph_enabled = self.torchair_graph_enabled and not attn_metadata.is_mtp_model
10451076
self.running_in_graph = self.torchair_graph_enabled and attn_metadata.attn_state in [
10461077
AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding
10471078
]

vllm_ascend/models/deepseek_v2.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -482,7 +482,8 @@ def forward(
482482
hidden_states_or_q_c = self.q_a_layernorm(ckq)
483483
else:
484484
hidden_states_or_q_c = hidden_states
485-
if self.torchair_graph_enabled:
485+
is_mtp_model = attn_metadata is not None and attn_metadata.is_mtp_model
486+
if self.torchair_graph_enabled and not is_mtp_model:
486487
forward_kwargs = {}
487488
if envs.VLLM_USE_V1:
488489
output_shape = hidden_states.shape

0 commit comments

Comments
 (0)