29
29
import numpy .typing as npt
30
30
import torch
31
31
import torch .nn as nn
32
+ from torch .distributed import ReduceOp
32
33
from vllm .attention import AttentionType , get_attn_backend
33
34
from vllm .attention .layer import Attention
34
35
from vllm .config import CompilationLevel , VllmConfig
59
60
60
61
from vllm_ascend .attention .attention import AttentionMaskBuilder
61
62
from vllm_ascend .attention .attention_v1 import AscendAttentionState
63
+ from vllm_ascend .patch .platform .patch_common .patch_distributed import \
64
+ get_dp_group
62
65
from vllm_ascend .platform import NPUPlatform
63
66
from vllm_ascend .sample .rejection_sampler import AscendRejectionSampler
64
67
from vllm_ascend .utils import vllm_version_is
@@ -355,6 +358,8 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
355
358
False ) and self .vllm_config .model_config .use_mla
356
359
self .use_cached_npu_graph = additional_config .get (
357
360
"use_cached_npu_graph" , False )
361
+ self .has_prefilled = False
362
+ self .dp_group = get_dp_group ()
358
363
359
364
def _update_states (self , scheduler_output : "SchedulerOutput" ) -> None :
360
365
"""Update the cached states and the persistent batch with the scheduler
@@ -659,6 +664,22 @@ def _process_reqs(
659
664
device = input_ids .device )
660
665
input_ids = torch .cat ([input_ids , padding ])
661
666
positions = torch .cat ([positions , padding ])
667
+ if self .has_prefilled and not attn_metadata .attn_state == AscendAttentionState .DecodeOnly :
668
+ self .has_prefilled = False
669
+ if not self .has_prefilled and self .enable_torchair_graph_mode :
670
+ self .has_prefilled = self .has_prefilled_all_rank (
671
+ attn_metadata .attn_state == AscendAttentionState .DecodeOnly )
672
+
673
+ if self .dp_group :
674
+ while not self .has_prefilled and self .enable_torchair_graph_mode and attn_metadata .attn_state == AscendAttentionState .DecodeOnly :
675
+ self ._dummy_run (1 )
676
+ tensor = torch .tensor ([1 ], dtype = torch .int32 , device = "cpu" )
677
+ torch .distributed .all_reduce (tensor ,
678
+ op = ReduceOp .MAX ,
679
+ group = self .dp_group )
680
+ self .has_prefilled = self .has_prefilled_all_rank (
681
+ attn_metadata .attn_state ==
682
+ AscendAttentionState .DecodeOnly )
662
683
663
684
# Run forward pass
664
685
with set_forward_context (attn_metadata ,
@@ -668,7 +689,7 @@ def _process_reqs(
668
689
if self .enable_torchair_graph_mode :
669
690
model_kwargs ["kv_caches" ] = self .kv_caches
670
691
model_kwargs ["attn_metadata" ] = attn_metadata
671
- if self .enable_torchair_graph_mode and attn_metadata .attn_state == AscendAttentionState .DecodeOnly :
692
+ if self .enable_torchair_graph_mode and attn_metadata .attn_state == AscendAttentionState .DecodeOnly and self . has_prefilled :
672
693
torch ._dynamo .mark_static (input_ids )
673
694
torch ._dynamo .mark_static (positions )
674
695
torch ._dynamo .mark_static (attn_metadata .decode .block_table )
@@ -796,6 +817,15 @@ def _calc_spec_decode_metadata(
796
817
)
797
818
return metadata
798
819
820
+ def has_prefilled_all_rank (self , has_prefilled : bool ) -> bool :
821
+ tensor = torch .tensor ([has_prefilled ], dtype = torch .int32 , device = "cpu" )
822
+ if self .dp_group :
823
+ torch .distributed .all_reduce (tensor ,
824
+ op = ReduceOp .MIN ,
825
+ group = self .dp_group )
826
+ aggregated_has_prefilled = bool (tensor .item ())
827
+ return aggregated_has_prefilled
828
+
799
829
def apply_grammar_bitmask (
800
830
self ,
801
831
scheduler_output : "SchedulerOutput" ,
@@ -1063,7 +1093,11 @@ def _profile_multimodal(self) -> None:
1063
1093
self .encoder_cache ["tmp" ] = dict (enumerate (dummy_encoder_outputs ))
1064
1094
1065
1095
@torch .inference_mode ()
1066
- def _dummy_run (self , num_tokens : int ) -> torch .Tensor :
1096
+ def _dummy_run (
1097
+ self ,
1098
+ num_tokens : int ,
1099
+ attn_state : AscendAttentionState = AscendAttentionState .ChunkedPrefill
1100
+ ) -> torch .Tensor :
1067
1101
# Set num_scheduled_tokens based on num_tokens and max_num_seqs
1068
1102
# for dummy run with LoRA so that the num_reqs collectively
1069
1103
# has num_tokens in total.
@@ -1107,11 +1141,32 @@ def _dummy_run(self, num_tokens: int) -> torch.Tensor:
1107
1141
})
1108
1142
1109
1143
with set_forward_context (None , self .vllm_config ):
1110
- hidden_states = model (
1111
- input_ids = input_ids ,
1112
- positions = positions ,
1113
- intermediate_tensors = intermediate_tensors ,
1114
- inputs_embeds = inputs_embeds )
1144
+ if self .enable_torchair_graph_mode and attn_state == AscendAttentionState .DecodeOnly :
1145
+ attn_metadata = self .attn_metadata_builder .dummy_build (
1146
+ num_reqs = num_tokens , num_actual_tokens = 1 )
1147
+ torch ._dynamo .mark_static (input_ids )
1148
+ torch ._dynamo .mark_static (positions )
1149
+ torch ._dynamo .mark_static (attn_metadata .decode .block_table )
1150
+ torch ._dynamo .mark_static (attn_metadata .decode .input_positions )
1151
+ torch ._dynamo .mark_static (attn_metadata .slot_mapping )
1152
+ for kv in self .kv_caches :
1153
+ assert isinstance (kv , tuple ), "kv_cache must be a tuple"
1154
+ torch ._dynamo .mark_static (kv [0 ])
1155
+ torch ._dynamo .mark_static (kv [1 ])
1156
+ hidden_states = self .compile_model (
1157
+ input_ids = input_ids ,
1158
+ positions = positions ,
1159
+ intermediate_tensors = intermediate_tensors ,
1160
+ inputs_embeds = None ,
1161
+ kv_caches = self .kv_caches ,
1162
+ attn_metadata = attn_metadata ,
1163
+ )
1164
+ else :
1165
+ hidden_states = model (
1166
+ input_ids = input_ids ,
1167
+ positions = positions ,
1168
+ intermediate_tensors = intermediate_tensors ,
1169
+ inputs_embeds = inputs_embeds )
1115
1170
return hidden_states
1116
1171
1117
1172
def profile_run (self ) -> None :
0 commit comments