@@ -621,17 +621,15 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
621
621
if batch_changed :
622
622
self .input_batch .refresh_sampling_metadata ()
623
623
624
- def _get_forward_metadata_across_dp (
625
- self , total_num_scheduled_tokens : int ,
626
- with_prefill : bool ) -> tuple [int , bool ]:
627
- forward_metadata = torch .tensor (
628
- [total_num_scheduled_tokens , with_prefill ],
629
- device = "cpu" ,
630
- dtype = torch .int32 )
631
- dist .all_reduce (forward_metadata ,
632
- op = ReduceOp .MAX ,
633
- group = get_dp_group ().cpu_group )
634
- return int (forward_metadata [0 ]), bool (forward_metadata [1 ] > 0 )
624
+ def _get_forward_metadata_across_dp (self , num_tokens : int ,
625
+ with_prefill : bool ) -> tuple [int , bool ]:
626
+ local_forward_metadata = torch .tensor ([num_tokens , with_prefill ],
627
+ device = "npu" , dtype = torch .int32 )
628
+ global_forward_metadata = get_dp_group ().all_gather (
629
+ local_forward_metadata )
630
+ num_tokens_across_dp = global_forward_metadata [:, 0 ].cpu ()
631
+ with_prefill = bool (global_forward_metadata [:, 1 ].any ())
632
+ return num_tokens_across_dp , with_prefill
635
633
636
634
def get_eagle_atten_dict (
637
635
self ,
@@ -1100,9 +1098,12 @@ def _process_reqs(
1100
1098
AscendAttentionState .DecodeOnly , AscendAttentionState .SpecDecoding
1101
1099
]
1102
1100
1101
+ num_tokens_across_dp = None
1103
1102
if self .dp_size > 1 :
1104
- max_num_tokens , with_prefill = self ._get_forward_metadata_across_dp (
1105
- total_num_scheduled_tokens , with_prefill )
1103
+ num_tokens_across_dp , with_prefill = \
1104
+ self ._get_forward_metadata_across_dp (num_input_tokens ,
1105
+ with_prefill )
1106
+ max_num_tokens = int (num_tokens_across_dp .max ().item ())
1106
1107
extra_builder_kwargs ['max_num_tokens_across_dp' ] = max_num_tokens
1107
1108
extra_builder_kwargs ['with_prefill_across_dp' ] = with_prefill
1108
1109
@@ -1111,6 +1112,8 @@ def _process_reqs(
1111
1112
if self .dp_size > 1 :
1112
1113
padded_batch_size = self .select_torchair_padded_batch_size (
1113
1114
max_num_tokens )
1115
+ num_tokens_across_dp .masked_fill_ (num_tokens_across_dp == - 1 ,
1116
+ padded_batch_size )
1114
1117
else :
1115
1118
padded_batch_size = self .select_torchair_padded_batch_size (
1116
1119
total_num_scheduled_tokens )
@@ -1189,7 +1192,8 @@ def _process_reqs(
1189
1192
# Run forward pass
1190
1193
with set_forward_context (attn_metadata ,
1191
1194
self .vllm_config ,
1192
- num_tokens = num_input_tokens ):
1195
+ num_tokens = num_input_tokens ,
1196
+ num_tokens_across_dp = num_tokens_across_dp ):
1193
1197
with ProfileExecuteDuration ().capture_async ("forward" ):
1194
1198
model_kwargs = {}
1195
1199
if self .torchair_graph_enabled :
@@ -1806,6 +1810,7 @@ def _dummy_run(
1806
1810
is_compile : bool = False ,
1807
1811
with_prefill : bool = True ,
1808
1812
skip_attn : bool = True ,
1813
+ num_tokens_across_dp : Optional [int ] = None ,
1809
1814
) -> torch .Tensor :
1810
1815
# Set num_scheduled_tokens based on num_tokens and max_num_seqs
1811
1816
# for dummy run with LoRA so that the num_reqs collectively
@@ -1860,7 +1865,8 @@ def _dummy_run(
1860
1865
1861
1866
with set_forward_context (None ,
1862
1867
self .vllm_config ,
1863
- num_tokens = num_tokens ):
1868
+ num_tokens = num_tokens ,
1869
+ num_tokens_across_dp = num_tokens_across_dp ):
1864
1870
if self .torchair_graph_enabled and not with_prefill :
1865
1871
attn_metadata = self .attn_metadata_builder .build_dummy (
1866
1872
num_reqs = num_tokens , num_actual_tokens = 1 )
0 commit comments