29
29
import numpy .typing as npt
30
30
import torch
31
31
import torch .nn as nn
32
+ from torch .distributed import ReduceOp
33
+
32
34
from vllm .attention import AttentionType , get_attn_backend
33
35
from vllm .attention .layer import Attention
34
36
from vllm .config import CompilationLevel , VllmConfig
53
55
54
56
from vllm_ascend .attention .attention import AttentionMaskBuilder
55
57
from vllm_ascend .attention .attention_v1 import AscendAttentionState
58
+ from vllm_ascend .patch .platform .patch_common .patch_distributed import get_dp_group
56
59
from vllm_ascend .platform import NPUPlatform
57
60
from vllm_ascend .utils import vllm_version_is
58
61
@@ -296,6 +299,8 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
296
299
False ) and self .vllm_config .model_config .use_mla
297
300
self .use_cached_npu_graph = additional_config .get (
298
301
"use_cached_npu_graph" , False )
302
+ self .has_prefilled = False
303
+ self .dp_group = get_dp_group ()
299
304
300
305
def _update_states (self , scheduler_output : "SchedulerOutput" ) -> None :
301
306
"""Update the cached states and the persistent batch with the scheduler
@@ -595,6 +600,22 @@ def _process_reqs(
595
600
device = input_ids .device )
596
601
input_ids = torch .cat ([input_ids , padding ])
597
602
positions = torch .cat ([positions , padding ])
603
+ if self .has_prefilled and not attn_metadata .attn_state == AscendAttentionState .DecodeOnly :
604
+ self .has_prefilled = False
605
+ if not self .has_prefilled and self .enable_torchair_graph_mode :
606
+ self .has_prefilled = self .has_prefilled_all_rank (
607
+ attn_metadata .attn_state == AscendAttentionState .DecodeOnly )
608
+
609
+ if self .dp_group :
610
+ while not self .has_prefilled and self .enable_torchair_graph_mode and attn_metadata .attn_state == AscendAttentionState .DecodeOnly :
611
+ self ._dummy_run (1 )
612
+ tensor = torch .tensor ([1 ], dtype = torch .int32 , device = "cpu" )
613
+ torch .distributed .all_reduce (tensor ,
614
+ op = ReduceOp .MAX ,
615
+ group = self .dp_group )
616
+ self .has_prefilled = self .has_prefilled_all_rank (
617
+ attn_metadata .attn_state ==
618
+ AscendAttentionState .DecodeOnly )
598
619
599
620
# Run forward pass
600
621
with set_forward_context (attn_metadata ,
@@ -604,7 +625,7 @@ def _process_reqs(
604
625
if self .enable_torchair_graph_mode :
605
626
model_kwargs ["kv_caches" ] = self .kv_caches
606
627
model_kwargs ["attn_metadata" ] = attn_metadata
607
- if self .enable_torchair_graph_mode and attn_metadata .attn_state == AscendAttentionState .DecodeOnly :
628
+ if self .enable_torchair_graph_mode and attn_metadata .attn_state == AscendAttentionState .DecodeOnly and self . has_prefilled :
608
629
torch ._dynamo .mark_static (input_ids )
609
630
torch ._dynamo .mark_static (positions )
610
631
torch ._dynamo .mark_static (attn_metadata .decode .block_table )
@@ -633,6 +654,15 @@ def _process_reqs(
633
654
634
655
return hidden_states [sample_indices ]
635
656
657
+ def has_prefilled_all_rank (self , has_prefilled : bool ) -> bool :
658
+ tensor = torch .tensor ([has_prefilled ], dtype = torch .int32 , device = "cpu" )
659
+ if self .dp_group :
660
+ torch .distributed .all_reduce (tensor ,
661
+ op = ReduceOp .MIN ,
662
+ group = self .dp_group )
663
+ aggregated_has_prefilled = bool (tensor .item ())
664
+ return aggregated_has_prefilled
665
+
636
666
def apply_grammar_bitmask (
637
667
self ,
638
668
scheduler_output : "SchedulerOutput" ,
@@ -832,7 +862,11 @@ def _profile_multimodal(self) -> None:
832
862
self .encoder_cache ["tmp" ] = dict (enumerate (dummy_encoder_outputs ))
833
863
834
864
@torch .inference_mode ()
835
- def _dummy_run (self , num_tokens : int ) -> torch .Tensor :
865
+ def _dummy_run (
866
+ self ,
867
+ num_tokens : int ,
868
+ attn_state : AscendAttentionState = AscendAttentionState .ChunkedPrefill
869
+ ) -> torch .Tensor :
836
870
model = self .model
837
871
if self .is_multimodal_model :
838
872
input_ids = None
@@ -861,10 +895,32 @@ def _dummy_run(self, num_tokens: int) -> torch.Tensor:
861
895
})
862
896
863
897
with set_forward_context (None , self .vllm_config ):
864
- hidden_states = model (input_ids = input_ids ,
865
- positions = positions ,
866
- intermediate_tensors = intermediate_tensors ,
867
- inputs_embeds = inputs_embeds )
898
+ if self .enable_torchair_graph_mode and attn_state == AscendAttentionState .DecodeOnly :
899
+ attn_metadata = self .attn_metadata_builder .dummy_build (
900
+ num_reqs = num_tokens , num_actual_tokens = 1 )
901
+ torch ._dynamo .mark_static (input_ids )
902
+ torch ._dynamo .mark_static (positions )
903
+ torch ._dynamo .mark_static (attn_metadata .decode .block_table )
904
+ torch ._dynamo .mark_static (attn_metadata .decode .input_positions )
905
+ torch ._dynamo .mark_static (attn_metadata .slot_mapping )
906
+ for kv in self .kv_caches :
907
+ assert isinstance (kv , tuple ), "kv_cache must be a tuple"
908
+ torch ._dynamo .mark_static (kv [0 ])
909
+ torch ._dynamo .mark_static (kv [1 ])
910
+ hidden_states = self .compile_model (
911
+ input_ids = input_ids ,
912
+ positions = positions ,
913
+ intermediate_tensors = intermediate_tensors ,
914
+ inputs_embeds = None ,
915
+ kv_caches = self .kv_caches ,
916
+ attn_metadata = attn_metadata ,
917
+ )
918
+ else :
919
+ hidden_states = model (
920
+ input_ids = input_ids ,
921
+ positions = positions ,
922
+ intermediate_tensors = intermediate_tensors ,
923
+ inputs_embeds = inputs_embeds )
868
924
return hidden_states
869
925
870
926
def profile_run (self ) -> None :
0 commit comments