@@ -277,26 +277,25 @@ def pin_lora(self, lora_id: int) -> bool:
277
277
278
278
def execute_dummy_batch (self ) -> None :
279
279
runner = self .model_runner
280
+ assert runner .dp_size > 1 , "Dummy batch execution should only be " \
281
+ "performed with data parallelism enabled, but got " \
282
+ f"dp_size={ runner .dp_size } ."
280
283
281
284
# If torchair graph is enabled, notify the other DP ranks that this is a
282
285
# dummy run by using '-1' as a flag for num_tokens. This will be
283
286
# replaced with the final determined graph size before the forward pass.
284
- num_tokens = (- 1 if runner .torchair_graph_enabled and not with_prefill
285
- else 1 )
286
- num_tokens_across_dp = None
287
- with_prefill = False
288
-
289
- if runner .dp_size > 1 :
290
- num_tokens_across_dp , with_prefill = \
291
- runner ._get_forward_metadata_across_dp (num_tokens , with_prefill )
292
- num_tokens = int (num_tokens_across_dp .max ().item ())
287
+ num_tokens_across_dp , with_prefill = \
288
+ runner ._get_forward_metadata_across_dp (- 1 , False )
293
289
294
290
if runner .torchair_graph_enabled and not with_prefill :
295
- num_tokens = runner .select_torchair_padded_batch_size (num_tokens )
296
- if num_tokens_across_dp is not None :
297
- num_tokens_across_dp .masked_fill_ (num_tokens_across_dp == - 1 ,
298
- num_tokens )
291
+ max_num_tokens = int (num_tokens_across_dp .max ().item ())
292
+ num_tokens = runner .select_torchair_padded_batch_size (
293
+ max_num_tokens )
294
+ else :
295
+ num_tokens = 1
299
296
297
+ num_tokens_across_dp .masked_fill_ (num_tokens_across_dp == - 1 ,
298
+ num_tokens )
300
299
runner ._dummy_run (num_tokens ,
301
300
is_compile = False ,
302
301
num_tokens_across_dp = num_tokens_across_dp ,
0 commit comments