29
29
import numpy .typing as npt
30
30
import torch
31
31
import torch .nn as nn
32
+ from torch .distributed import ReduceOp
32
33
from vllm .attention import AttentionType , get_attn_backend
33
34
from vllm .attention .layer import Attention
34
35
from vllm .config import CompilationLevel , VllmConfig
59
60
60
61
from vllm_ascend .attention .attention import AttentionMaskBuilder
61
62
from vllm_ascend .attention .attention_v1 import AscendAttentionState
63
+ from vllm_ascend .patch .platform .patch_common .patch_distributed import \
64
+ get_dp_group
62
65
from vllm_ascend .platform import NPUPlatform
63
66
from vllm_ascend .sample .rejection_sampler import AscendRejectionSampler
64
67
from vllm_ascend .utils import vllm_version_is
@@ -328,6 +331,8 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
328
331
False ) and self .vllm_config .model_config .use_mla
329
332
self .use_cached_npu_graph = additional_config .get (
330
333
"use_cached_npu_graph" , False )
334
+ self .has_prefilled = False
335
+ self .dp_group = get_dp_group ()
331
336
332
337
def _update_states (self , scheduler_output : "SchedulerOutput" ) -> None :
333
338
"""Update the cached states and the persistent batch with the scheduler
@@ -635,6 +640,22 @@ def _process_reqs(
635
640
device = input_ids .device )
636
641
input_ids = torch .cat ([input_ids , padding ])
637
642
positions = torch .cat ([positions , padding ])
643
+ if self .has_prefilled and not attn_metadata .attn_state == AscendAttentionState .DecodeOnly :
644
+ self .has_prefilled = False
645
+ if not self .has_prefilled and self .enable_torchair_graph_mode :
646
+ self .has_prefilled = self .has_prefilled_all_rank (
647
+ attn_metadata .attn_state == AscendAttentionState .DecodeOnly )
648
+
649
+ if self .dp_group :
650
+ while not self .has_prefilled and self .enable_torchair_graph_mode and attn_metadata .attn_state == AscendAttentionState .DecodeOnly :
651
+ self ._dummy_run (1 )
652
+ tensor = torch .tensor ([1 ], dtype = torch .int32 , device = "cpu" )
653
+ torch .distributed .all_reduce (tensor ,
654
+ op = ReduceOp .MAX ,
655
+ group = self .dp_group )
656
+ self .has_prefilled = self .has_prefilled_all_rank (
657
+ attn_metadata .attn_state ==
658
+ AscendAttentionState .DecodeOnly )
638
659
639
660
# Run forward pass
640
661
with set_forward_context (attn_metadata ,
@@ -644,7 +665,7 @@ def _process_reqs(
644
665
if self .enable_torchair_graph_mode :
645
666
model_kwargs ["kv_caches" ] = self .kv_caches
646
667
model_kwargs ["attn_metadata" ] = attn_metadata
647
- if self .enable_torchair_graph_mode and attn_metadata .attn_state == AscendAttentionState .DecodeOnly :
668
+ if self .enable_torchair_graph_mode and attn_metadata .attn_state == AscendAttentionState .DecodeOnly and self . has_prefilled :
648
669
torch ._dynamo .mark_static (input_ids )
649
670
torch ._dynamo .mark_static (positions )
650
671
torch ._dynamo .mark_static (attn_metadata .decode .block_table )
@@ -772,6 +793,15 @@ def _calc_spec_decode_metadata(
772
793
)
773
794
return metadata
774
795
796
+ def has_prefilled_all_rank (self , has_prefilled : bool ) -> bool :
797
+ tensor = torch .tensor ([has_prefilled ], dtype = torch .int32 , device = "cpu" )
798
+ if self .dp_group :
799
+ torch .distributed .all_reduce (tensor ,
800
+ op = ReduceOp .MIN ,
801
+ group = self .dp_group )
802
+ aggregated_has_prefilled = bool (tensor .item ())
803
+ return aggregated_has_prefilled
804
+
775
805
def apply_grammar_bitmask (
776
806
self ,
777
807
scheduler_output : "SchedulerOutput" ,
@@ -1039,7 +1069,11 @@ def _profile_multimodal(self) -> None:
1039
1069
self .encoder_cache ["tmp" ] = dict (enumerate (dummy_encoder_outputs ))
1040
1070
1041
1071
@torch .inference_mode ()
1042
- def _dummy_run (self , num_tokens : int ) -> torch .Tensor :
1072
+ def _dummy_run (
1073
+ self ,
1074
+ num_tokens : int ,
1075
+ attn_state : AscendAttentionState = AscendAttentionState .ChunkedPrefill
1076
+ ) -> torch .Tensor :
1043
1077
# Set num_scheduled_tokens based on num_tokens and max_num_seqs
1044
1078
# for dummy run with LoRA so that the num_reqs collectively
1045
1079
# has num_tokens in total.
@@ -1083,11 +1117,34 @@ def _dummy_run(self, num_tokens: int) -> torch.Tensor:
1083
1117
})
1084
1118
1085
1119
with set_forward_context (None , self .vllm_config ):
1086
- hidden_states = model (
1087
- input_ids = input_ids ,
1088
- positions = positions ,
1089
- intermediate_tensors = intermediate_tensors ,
1090
- inputs_embeds = inputs_embeds )
1120
+ if self .enable_torchair_graph_mode and attn_state == AscendAttentionState .DecodeOnly :
1121
+ attn_metadata = self .attn_metadata_builder .dummy_build (
1122
+ num_reqs = num_tokens , num_actual_tokens = 1 )
1123
+ torch ._dynamo .mark_static (input_ids )
1124
+ torch ._dynamo .mark_static (positions )
1125
+ torch ._dynamo .mark_static (attn_metadata .decode .block_table )
1126
+ torch ._dynamo .mark_static (
1127
+ attn_metadata .decode .input_positions )
1128
+ torch ._dynamo .mark_static (attn_metadata .slot_mapping )
1129
+ for kv in self .kv_caches :
1130
+ assert isinstance (kv ,
1131
+ tuple ), "kv_cache must be a tuple"
1132
+ torch ._dynamo .mark_static (kv [0 ])
1133
+ torch ._dynamo .mark_static (kv [1 ])
1134
+ hidden_states = self .compile_model (
1135
+ input_ids = input_ids ,
1136
+ positions = positions ,
1137
+ intermediate_tensors = intermediate_tensors ,
1138
+ inputs_embeds = None ,
1139
+ kv_caches = self .kv_caches ,
1140
+ attn_metadata = attn_metadata ,
1141
+ )
1142
+ else :
1143
+ hidden_states = model (
1144
+ input_ids = input_ids ,
1145
+ positions = positions ,
1146
+ intermediate_tensors = intermediate_tensors ,
1147
+ inputs_embeds = inputs_embeds )
1091
1148
return hidden_states
1092
1149
1093
1150
def profile_run (self ) -> None :
0 commit comments