@@ -614,17 +614,15 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
614
614
if batch_changed :
615
615
self .input_batch .refresh_sampling_metadata ()
616
616
617
- def _get_forward_metadata_across_dp (
618
- self , total_num_scheduled_tokens : int ,
619
- with_prefill : bool ) -> tuple [int , bool ]:
620
- forward_metadata = torch .tensor (
621
- [total_num_scheduled_tokens , with_prefill ],
622
- device = "cpu" ,
623
- dtype = torch .int32 )
624
- dist .all_reduce (forward_metadata ,
625
- op = ReduceOp .MAX ,
626
- group = get_dp_group ().cpu_group )
627
- return int (forward_metadata [0 ]), bool (forward_metadata [1 ] > 0 )
617
+ def _get_forward_metadata_across_dp (self , num_tokens : int ,
618
+ with_prefill : bool ) -> tuple [int , bool ]:
619
+ local_forward_metadata = torch .tensor ([num_tokens , with_prefill ],
620
+ device = "npu" , dtype = torch .int32 )
621
+ global_forward_metadata = get_dp_group ().all_gather (
622
+ local_forward_metadata )
623
+ num_tokens_across_dp = global_forward_metadata [:, 0 ].cpu ()
624
+ with_prefill = bool (global_forward_metadata [:, 1 ].any ())
625
+ return num_tokens_across_dp , with_prefill
628
626
629
627
def get_eagle_atten_dict (
630
628
self ,
@@ -1093,9 +1091,12 @@ def _process_reqs(
1093
1091
AscendAttentionState .DecodeOnly , AscendAttentionState .SpecDecoding
1094
1092
]
1095
1093
1094
+ num_tokens_across_dp = None
1096
1095
if self .dp_size > 1 :
1097
- max_num_tokens , with_prefill = self ._get_forward_metadata_across_dp (
1098
- total_num_scheduled_tokens , with_prefill )
1096
+ num_tokens_across_dp , with_prefill = \
1097
+ self ._get_forward_metadata_across_dp (num_input_tokens ,
1098
+ with_prefill )
1099
+ max_num_tokens = int (num_tokens_across_dp .max ().item ())
1099
1100
extra_builder_kwargs ['max_num_tokens_across_dp' ] = max_num_tokens
1100
1101
extra_builder_kwargs ['with_prefill_across_dp' ] = with_prefill
1101
1102
@@ -1104,6 +1105,8 @@ def _process_reqs(
1104
1105
if self .dp_size > 1 :
1105
1106
padded_batch_size = self .select_torchair_padded_batch_size (
1106
1107
max_num_tokens )
1108
+ num_tokens_across_dp .masked_fill_ (num_tokens_across_dp == - 1 ,
1109
+ padded_batch_size )
1107
1110
else :
1108
1111
padded_batch_size = self .select_torchair_padded_batch_size (
1109
1112
total_num_scheduled_tokens )
@@ -1182,7 +1185,8 @@ def _process_reqs(
1182
1185
# Run forward pass
1183
1186
with set_forward_context (attn_metadata ,
1184
1187
self .vllm_config ,
1185
- num_tokens = num_input_tokens ):
1188
+ num_tokens = num_input_tokens ,
1189
+ num_tokens_across_dp = num_tokens_across_dp ):
1186
1190
with ProfileExecuteDuration ().capture_async ("forward" ):
1187
1191
model_kwargs = {}
1188
1192
if self .torchair_graph_enabled :
@@ -1775,6 +1779,7 @@ def _dummy_run(
1775
1779
is_compile : bool = False ,
1776
1780
with_prefill : bool = True ,
1777
1781
skip_attn : bool = True ,
1782
+ num_tokens_across_dp : Optional [int ] = None ,
1778
1783
) -> torch .Tensor :
1779
1784
# Set num_scheduled_tokens based on num_tokens and max_num_seqs
1780
1785
# for dummy run with LoRA so that the num_reqs collectively
@@ -1829,7 +1834,8 @@ def _dummy_run(
1829
1834
1830
1835
with set_forward_context (None ,
1831
1836
self .vllm_config ,
1832
- num_tokens = num_tokens ):
1837
+ num_tokens = num_tokens ,
1838
+ num_tokens_across_dp = num_tokens_across_dp ):
1833
1839
if self .torchair_graph_enabled and not with_prefill :
1834
1840
attn_metadata = self .attn_metadata_builder .build_dummy (
1835
1841
num_reqs = num_tokens , num_actual_tokens = 1 )
0 commit comments