@@ -628,17 +628,15 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
628
628
if batch_changed :
629
629
self .input_batch .refresh_sampling_metadata ()
630
630
631
- def _get_forward_metadata_across_dp (
632
- self , total_num_scheduled_tokens : int ,
633
- with_prefill : bool ) -> tuple [int , bool ]:
634
- forward_metadata = torch .tensor (
635
- [total_num_scheduled_tokens , with_prefill ],
636
- device = "cpu" ,
637
- dtype = torch .int32 )
638
- dist .all_reduce (forward_metadata ,
639
- op = ReduceOp .MAX ,
640
- group = get_dp_group ().cpu_group )
641
- return int (forward_metadata [0 ]), bool (forward_metadata [1 ] > 0 )
631
+ def _get_forward_metadata_across_dp (self , num_tokens : int ,
632
+ with_prefill : bool ) -> tuple [int , bool ]:
633
+ local_forward_metadata = torch .tensor ([num_tokens , with_prefill ],
634
+ device = "npu" , dtype = torch .int32 )
635
+ global_forward_metadata = get_dp_group ().all_gather (
636
+ local_forward_metadata )
637
+ num_tokens_across_dp = global_forward_metadata [:, 0 ].cpu ()
638
+ with_prefill = bool (global_forward_metadata [:, 1 ].any ())
639
+ return num_tokens_across_dp , with_prefill
642
640
643
641
def get_eagle_atten_dict (
644
642
self ,
@@ -1107,9 +1105,12 @@ def _process_reqs(
1107
1105
AscendAttentionState .DecodeOnly , AscendAttentionState .SpecDecoding
1108
1106
]
1109
1107
1108
+ num_tokens_across_dp = None
1110
1109
if self .dp_size > 1 :
1111
- max_num_tokens , with_prefill = self ._get_forward_metadata_across_dp (
1112
- total_num_scheduled_tokens , with_prefill )
1110
+ num_tokens_across_dp , with_prefill = \
1111
+ self ._get_forward_metadata_across_dp (num_input_tokens ,
1112
+ with_prefill )
1113
+ max_num_tokens = int (num_tokens_across_dp .max ().item ())
1113
1114
extra_builder_kwargs ['max_num_tokens_across_dp' ] = max_num_tokens
1114
1115
extra_builder_kwargs ['with_prefill_across_dp' ] = with_prefill
1115
1116
@@ -1118,6 +1119,8 @@ def _process_reqs(
1118
1119
if self .dp_size > 1 :
1119
1120
padded_batch_size = self .select_torchair_padded_batch_size (
1120
1121
max_num_tokens )
1122
+ num_tokens_across_dp .masked_fill_ (num_tokens_across_dp == - 1 ,
1123
+ padded_batch_size )
1121
1124
else :
1122
1125
padded_batch_size = self .select_torchair_padded_batch_size (
1123
1126
total_num_scheduled_tokens )
@@ -1196,7 +1199,8 @@ def _process_reqs(
1196
1199
# Run forward pass
1197
1200
with set_forward_context (attn_metadata ,
1198
1201
self .vllm_config ,
1199
- num_tokens = num_input_tokens ):
1202
+ num_tokens = num_input_tokens ,
1203
+ num_tokens_across_dp = num_tokens_across_dp ):
1200
1204
with ProfileExecuteDuration ().capture_async ("forward" ):
1201
1205
model_kwargs = {}
1202
1206
if self .torchair_graph_enabled :
@@ -1819,6 +1823,7 @@ def _dummy_run(
1819
1823
is_compile : bool = False ,
1820
1824
with_prefill : bool = True ,
1821
1825
skip_attn : bool = True ,
1826
+ num_tokens_across_dp : Optional [int ] = None ,
1822
1827
) -> torch .Tensor :
1823
1828
# Set num_scheduled_tokens based on num_tokens and max_num_seqs
1824
1829
# for dummy run with LoRA so that the num_reqs collectively
@@ -1873,7 +1878,8 @@ def _dummy_run(
1873
1878
1874
1879
with set_forward_context (None ,
1875
1880
self .vllm_config ,
1876
- num_tokens = num_tokens ):
1881
+ num_tokens = num_tokens ,
1882
+ num_tokens_across_dp = num_tokens_across_dp ):
1877
1883
if self .torchair_graph_enabled and not with_prefill :
1878
1884
attn_metadata = self .attn_metadata_builder .build_dummy (
1879
1885
num_reqs = num_tokens , num_actual_tokens = 1 )
0 commit comments