Skip to content

Commit cd2f14a

Browse files
authored
[MTP][V1] Adapt mtp with graph mode in v1. (#1023)
Adapts deepseek mtp with torch air graph mode in v1. --------- Signed-off-by: whx-sjtu <2952154980@qq.com>
1 parent 5ac4872 commit cd2f14a

File tree

4 files changed

+87
-24
lines changed

4 files changed

+87
-24
lines changed

vllm_ascend/attention/attention_v1.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ class AscendAttentionState(Enum):
100100
PrefillCacheHit = 1
101101
DecodeOnly = 2
102102
ChunkedPrefill = 3
103+
SpecDecoding = 4
103104

104105

105106
@dataclass

vllm_ascend/attention/mla_v1.py

Lines changed: 63 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
AttentionMetadata,
99
MLAAttentionImpl)
1010
from vllm.attention.backends.utils import PAD_SLOT_ID
11+
from vllm.config import get_current_vllm_config
1112
from vllm.model_executor.layers.linear import (LinearBase,
1213
UnquantizedLinearMethod)
1314

@@ -86,6 +87,7 @@ class AscendMLADecodeMetadata:
8687
seq_lens: torch.Tensor
8788
max_seq_lens: int
8889
seq_lens_list: list[int]
90+
attn_mask: Optional[torch.Tensor] = None
8991

9092

9193
@dataclass
@@ -169,6 +171,8 @@ def __init__(self,
169171
self.runner = runner
170172
scheduler_config = runner.scheduler_config
171173
self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled
174+
ascend_config = get_ascend_config()
175+
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
172176

173177
def reorder_batch(self, input_batch: "InputBatch",
174178
scheduler_output: "SchedulerOutput") -> bool:
@@ -185,16 +189,24 @@ def reorder_batch(self, input_batch: "InputBatch",
185189

186190
for i, req_id in enumerate(input_batch.req_ids):
187191
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
188-
# for now treat 1 scheduled token as "decode" even if its not,
189-
# we should update this to something like < 8 in the future but
190-
# currently the TritonMLA._forward_decode only supports
191-
# num_tokens = 1
192-
if num_tokens == 1:
193-
decodes.append(i)
194-
num_decode_tokens += num_tokens
192+
num_spec_tokens = len(
193+
scheduler_output.scheduled_spec_decode_tokens.get(req_id, []))
194+
# For torch air graph mode we treat spec decoding as decode.
195+
if self.torchair_graph_enabled:
196+
if num_tokens - num_spec_tokens == 1:
197+
decodes.append(i)
198+
num_decode_tokens += num_tokens
199+
else:
200+
prefills.append(i)
201+
num_prefill_tokens += num_tokens
202+
# For eager mode we treat spec decoding as chunked prefill.
195203
else:
196-
prefills.append(i)
197-
num_prefill_tokens += num_tokens
204+
if num_tokens == 1:
205+
decodes.append(i)
206+
num_decode_tokens += num_tokens
207+
else:
208+
prefills.append(i)
209+
num_prefill_tokens += num_tokens
198210

199211
# We hope that this is fairly minimal since decodes
200212
# should be around for a number of iterations so hopefully they are
@@ -284,7 +296,8 @@ def build_dummy(self, num_reqs: int,
284296
block_table=block_table,
285297
seq_lens=seq_lens,
286298
seq_lens_list=seq_lens.tolist(),
287-
max_seq_lens=1)
299+
max_seq_lens=1,
300+
attn_mask=self.runner.spec_attn_mask)
288301
return self.metadata_cls( # type: ignore
289302
num_input_tokens=num_actual_tokens,
290303
num_actual_tokens=num_actual_tokens,
@@ -332,7 +345,7 @@ def build(
332345
seq_lens = seq_lens_cpu
333346
max_query_len = query_lens.max().item()
334347
max_seq_lens = seq_lens.max().item()
335-
query_start_loc = None
348+
query_start_loc = common_attn_metadata.query_start_loc
336349

337350
prefill_metadata = None
338351
if self._num_prefills > 0:
@@ -397,7 +410,8 @@ def build(
397410
block_table=block_table,
398411
seq_lens=seq_lens,
399412
seq_lens_list=seq_lens.tolist(),
400-
max_seq_lens=max_seq_lens)
413+
max_seq_lens=max_seq_lens,
414+
attn_mask=self.runner.spec_attn_mask)
401415

402416
return self.metadata_cls( # type: ignore
403417
num_actual_tokens=num_actual_tokens,
@@ -461,6 +475,11 @@ def __init__(
461475

462476
ascend_config = get_ascend_config()
463477
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
478+
# Adapt torch air graph mode with spec decoding.
479+
speculative_config = get_current_vllm_config().speculative_config
480+
if speculative_config is not None:
481+
self.spec_token_num = speculative_config.num_speculative_tokens
482+
assert self.spec_token_num > 0
464483

465484
def _v_up_proj_and_o_proj(self, x):
466485
# Convert from (B, N, L) to (N, B, L)
@@ -550,7 +569,10 @@ def _forward_prefill(
550569
num_tokens = query.size(0)
551570
attn_output = None
552571
# Here is only 2 possibility of input, ChunkedPrefill or PrefillNoCache
553-
if attn_metadata.attn_state == AscendAttentionState.ChunkedPrefill:
572+
if attn_metadata.attn_state in [
573+
AscendAttentionState.ChunkedPrefill,
574+
AscendAttentionState.SpecDecoding
575+
]:
554576
attn_output = torch.empty(num_tokens,
555577
self.num_heads * self.v_head_dim,
556578
dtype=query.dtype,
@@ -597,7 +619,7 @@ def _forward_prefill(
597619
attn_output = attn_output.view(-1, self.num_heads, self.v_head_dim)
598620
else:
599621
raise RuntimeError(
600-
"Unexpected path reached, AscendMLAImpl should only have PrefillNoCache and ChunkedPrefill scenario in forward prefill, please file a bug to vllm-ascend !"
622+
"Unexpected path reached, AscendMLAImpl should only have PrefillNoCache, ChunkedPrefill and SpecDecoding scenario in forward prefill, please file a bug to vllm-ascend !"
601623
)
602624
attn_output = attn_output.reshape(
603625
[num_tokens, self.num_heads * self.v_head_dim])
@@ -670,9 +692,28 @@ def _forward_decode(
670692
dtype=q.dtype,
671693
device=q.device)
672694
if self.running_in_graph:
673-
# TorchAir's shape is [bs, num_heads_per_rank, seq_len, dim]
674-
q_nope = q_nope.view(num_tokens, self.num_heads, 1, -1)
675-
q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1)
695+
# TorchAir's shape is [bs, num_heads_per_rank, q_seq_len, dim]
696+
if attn_metadata.attn_state == AscendAttentionState.SpecDecoding:
697+
assert num_tokens % self.spec_token_num == 0
698+
q_nope = (q_nope.view(
699+
num_tokens // (self.spec_token_num + 1),
700+
self.spec_token_num + 1,
701+
self.num_heads,
702+
-1,
703+
).transpose(1, 2).contiguous())
704+
q_pe = (q_pe.view(
705+
num_tokens // (self.spec_token_num + 1),
706+
self.spec_token_num + 1,
707+
self.num_heads,
708+
-1,
709+
).transpose(1, 2).contiguous())
710+
sparse_mode = 3
711+
spec_attn_mask = attn_metadata.decode.attn_mask # type:ignore
712+
else:
713+
q_nope = q_nope.view(num_tokens, self.num_heads, 1, -1)
714+
q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1)
715+
sparse_mode = 0
716+
spec_attn_mask = None
676717
# shape of knope/k_pe for npu graph mode should be:
677718
# [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim]
678719
block_size = kv_c_and_k_pe_cache[0].shape[1]
@@ -690,7 +731,8 @@ def _forward_decode(
690731
num_heads=self.num_heads,
691732
num_key_value_heads=self.num_kv_heads,
692733
input_layout="BNSD",
693-
atten_mask=attn_metadata.attn_mask,
734+
atten_mask=spec_attn_mask,
735+
sparse_mode=sparse_mode,
694736
scale=self.scale,
695737
antiquant_mode=0,
696738
antiquant_scale=None,
@@ -732,7 +774,9 @@ def forward(
732774
if attn_metadata is None:
733775
# Profiling run.
734776
return output
735-
self.running_in_graph = self.torchair_graph_enabled and attn_metadata.attn_state == AscendAttentionState.DecodeOnly
777+
self.running_in_graph = self.torchair_graph_enabled and attn_metadata.attn_state in [
778+
AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding
779+
]
736780
num_actual_toks = attn_metadata.num_actual_tokens
737781
if k_pe is None and not self.running_in_graph:
738782
kv_c, k_pe = self.kv_a_proj_with_mqa(

vllm_ascend/worker/model_runner_v1.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -203,8 +203,13 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
203203

204204
# Set up speculative decoding.
205205
self.use_spec_decode = False
206+
self.spec_attn_mask = None
206207
if self.speculative_config:
207208
self.use_spec_decode = True
209+
self.spec_attn_mask = torch.triu(torch.ones(2048,
210+
2048,
211+
dtype=torch.bool),
212+
diagonal=1).to("npu")
208213
if get_pp_group().is_last_rank:
209214
if self.speculative_config.method == "ngram":
210215
self.drafter = NgramProposer(self.vllm_config)
@@ -779,10 +784,13 @@ def _process_reqs(
779784
# Get the number of scheduled tokens for each request.
780785
# TODO: The Python loop can be slow. Optimize.
781786
num_scheduled_tokens = np.empty(num_reqs, dtype=np.int32)
787+
num_valid_tokens = np.empty(num_reqs, dtype=np.int32)
782788
max_num_scheduled_tokens = 0
783789
for i, req_id in enumerate(self.input_batch.req_ids):
784790
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
785791
num_scheduled_tokens[i] = num_tokens
792+
num_valid_tokens[i] = num_tokens - \
793+
len(scheduler_output.scheduled_spec_decode_tokens.get(req_id, []))
786794
max_num_scheduled_tokens = max(max_num_scheduled_tokens,
787795
num_tokens)
788796

@@ -838,11 +846,16 @@ def _process_reqs(
838846
out=self.slot_mapping_np[:total_num_scheduled_tokens])
839847

840848
ascend_config = get_ascend_config()
849+
use_spec_decode = len(
850+
scheduler_output.scheduled_spec_decode_tokens) > 0
841851
if np.array_equal(self.seq_lens_np[:num_reqs], num_scheduled_tokens):
842852
attn_state = AscendAttentionState.PrefillNoCache
843853
# We assume it is the decode stage, where prefill occurs but only one token is not hit in cache.
844854
elif np.all(num_scheduled_tokens == 1):
845855
attn_state = AscendAttentionState.DecodeOnly
856+
# Speculative decoding.
857+
elif np.all(num_valid_tokens == 1):
858+
attn_state = AscendAttentionState.SpecDecoding
846859
# splitfuse
847860
elif not ascend_config.ascend_scheduler_config.enabled or self.chunked_prefill_enabled:
848861
attn_state = AscendAttentionState.ChunkedPrefill
@@ -873,7 +886,9 @@ def _process_reqs(
873886
seq_lens = self.seq_lens[:num_reqs]
874887
common_attn_metadata = CommonAttentionMetadata(
875888
query_start_loc=query_start_loc, seq_lens=seq_lens)
876-
with_prefill = attn_state != AscendAttentionState.DecodeOnly
889+
with_prefill = attn_state not in [
890+
AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding
891+
]
877892

878893
if self.dp_size > 1:
879894
max_num_tokens, with_prefill = self._get_forward_metadata_across_dp(
@@ -883,14 +898,14 @@ def _process_reqs(
883898
# Add graph_pad_size here
884899
if envs_ascend.VLLM_ENABLE_MC2 or (self.torchair_graph_enabled
885900
and not with_prefill):
886-
batch_size = len(seq_lens)
887901
if self.dp_size > 1:
888902
padded_batch_size = self.select_torchair_padded_batch_size(
889903
max_num_tokens)
890904
else:
891905
padded_batch_size = self.select_torchair_padded_batch_size(
892-
batch_size)
893-
graph_pad_size = padded_batch_size - batch_size
906+
total_num_scheduled_tokens)
907+
graph_pad_size = padded_batch_size - total_num_scheduled_tokens
908+
894909
extra_builder_kwargs['graph_pad_size'] = graph_pad_size
895910

896911
if self.vllm_config.model_config.use_mla:

vllm_ascend/worker/mtp_proposer_v1.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
set_current_vllm_config)
55
from vllm.forward_context import set_forward_context
66
from vllm.model_executor.model_loader import get_model_loader
7-
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
7+
from vllm.model_executor.model_loader.utils import (
8+
process_weights_after_loading, set_default_torch_dtype)
89
from vllm.v1.sample.metadata import SamplingMetadata
910

1011
from vllm_ascend.attention.mla_v1 import CommonAttentionMetadata
@@ -199,6 +200,8 @@ def load_model(self) -> None:
199200
loader.get_all_weights(
200201
self.vllm_config.speculative_config.draft_model_config,
201202
self.model))
203+
process_weights_after_loading(self.model, draft_model_config,
204+
target_device)
202205

203206

204207
# TODO Using torch instead of triton may result in poor performance

0 commit comments

Comments
 (0)