@@ -247,26 +247,25 @@ def pin_lora(self, lora_id: int) -> bool:
247
247
248
248
def execute_dummy_batch (self ) -> None :
249
249
runner = self .model_runner
250
+ assert runner .dp_size > 1 , "Dummy batch execution should only be " \
251
+ "performed with data parallelism enabled, but got " \
252
+ f"dp_size={ runner .dp_size } ."
250
253
251
254
# If torchair graph is enabled, notify the other DP ranks that this is a
252
255
# dummy run by using '-1' as a flag for num_tokens. This will be
253
256
# replaced with the final determined graph size before the forward pass.
254
- num_tokens = (- 1 if runner .torchair_graph_enabled and not with_prefill
255
- else 1 )
256
- num_tokens_across_dp = None
257
- with_prefill = False
258
-
259
- if runner .dp_size > 1 :
260
- num_tokens_across_dp , with_prefill = \
261
- runner ._get_forward_metadata_across_dp (num_tokens , with_prefill )
262
- num_tokens = int (num_tokens_across_dp .max ().item ())
257
+ num_tokens_across_dp , with_prefill = \
258
+ runner ._get_forward_metadata_across_dp (- 1 , False )
263
259
264
260
if runner .torchair_graph_enabled and not with_prefill :
265
- num_tokens = runner .select_torchair_padded_batch_size (num_tokens )
266
- if num_tokens_across_dp is not None :
267
- num_tokens_across_dp .masked_fill_ (num_tokens_across_dp == - 1 ,
268
- num_tokens )
261
+ max_num_tokens = int (num_tokens_across_dp .max ().item ())
262
+ num_tokens = runner .select_torchair_padded_batch_size (
263
+ max_num_tokens )
264
+ else :
265
+ num_tokens = 1
269
266
267
+ num_tokens_across_dp .masked_fill_ (num_tokens_across_dp == - 1 ,
268
+ num_tokens )
270
269
runner ._dummy_run (num_tokens ,
271
270
is_compile = False ,
272
271
num_tokens_across_dp = num_tokens_across_dp ,
0 commit comments