Skip to content

Commit 1cffb52

Browse files
committed
fix: correct handling the num_tokens for dummy run
Signed-off-by: Jade Zheng <zheng.shoujian@outlook.com>
1 parent 7f1da12 commit 1cffb52

File tree

1 file changed

+12
-13
lines changed

1 file changed

+12
-13
lines changed

vllm_ascend/worker/worker_v1.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -277,26 +277,25 @@ def pin_lora(self, lora_id: int) -> bool:
277277

278278
def execute_dummy_batch(self) -> None:
279279
runner = self.model_runner
280+
assert runner.dp_size > 1, "Dummy batch execution should only be " \
281+
"performed with data parallelism enabled, but got " \
282+
f"dp_size={runner.dp_size}."
280283

281284
# If torchair graph is enabled, notify the other DP ranks that this is a
282285
# dummy run by using '-1' as a flag for num_tokens. This will be
283286
# replaced with the final determined graph size before the forward pass.
284-
num_tokens = (-1 if runner.torchair_graph_enabled and not with_prefill
285-
else 1)
286-
num_tokens_across_dp = None
287-
with_prefill = False
288-
289-
if runner.dp_size > 1:
290-
num_tokens_across_dp, with_prefill = \
291-
runner._get_forward_metadata_across_dp(num_tokens, with_prefill)
292-
num_tokens = int(num_tokens_across_dp.max().item())
287+
num_tokens_across_dp, with_prefill = \
288+
runner._get_forward_metadata_across_dp(-1, False)
293289

294290
if runner.torchair_graph_enabled and not with_prefill:
295-
num_tokens = runner.select_torchair_padded_batch_size(num_tokens)
296-
if num_tokens_across_dp is not None:
297-
num_tokens_across_dp.masked_fill_(num_tokens_across_dp == -1,
298-
num_tokens)
291+
max_num_tokens = int(num_tokens_across_dp.max().item())
292+
num_tokens = runner.select_torchair_padded_batch_size(
293+
max_num_tokens)
294+
else:
295+
num_tokens = 1
299296

297+
num_tokens_across_dp.masked_fill_(num_tokens_across_dp == -1,
298+
num_tokens)
300299
runner._dummy_run(num_tokens,
301300
is_compile=False,
302301
num_tokens_across_dp=num_tokens_across_dp,

0 commit comments

Comments
 (0)