Skip to content

Commit d81acdf

Browse files
committed
fix: correct handling the num_tokens for dummy run
Signed-off-by: Jade Zheng <zheng.shoujian@outlook.com>
1 parent b07c798 commit d81acdf

File tree

1 file changed

+12
-13
lines changed

1 file changed

+12
-13
lines changed

vllm_ascend/worker/worker_v1.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -247,26 +247,25 @@ def pin_lora(self, lora_id: int) -> bool:
247247

248248
def execute_dummy_batch(self) -> None:
249249
runner = self.model_runner
250+
assert runner.dp_size > 1, "Dummy batch execution should only be " \
251+
"performed with data parallelism enabled, but got " \
252+
f"dp_size={runner.dp_size}."
250253

251254
# If torchair graph is enabled, notify the other DP ranks that this is a
252255
# dummy run by using '-1' as a flag for num_tokens. This will be
253256
# replaced with the final determined graph size before the forward pass.
254-
num_tokens = (-1 if runner.torchair_graph_enabled and not with_prefill
255-
else 1)
256-
num_tokens_across_dp = None
257-
with_prefill = False
258-
259-
if runner.dp_size > 1:
260-
num_tokens_across_dp, with_prefill = \
261-
runner._get_forward_metadata_across_dp(num_tokens, with_prefill)
262-
num_tokens = int(num_tokens_across_dp.max().item())
257+
num_tokens_across_dp, with_prefill = \
258+
runner._get_forward_metadata_across_dp(-1, False)
263259

264260
if runner.torchair_graph_enabled and not with_prefill:
265-
num_tokens = runner.select_torchair_padded_batch_size(num_tokens)
266-
if num_tokens_across_dp is not None:
267-
num_tokens_across_dp.masked_fill_(num_tokens_across_dp == -1,
268-
num_tokens)
261+
max_num_tokens = int(num_tokens_across_dp.max().item())
262+
num_tokens = runner.select_torchair_padded_batch_size(
263+
max_num_tokens)
264+
else:
265+
num_tokens = 1
269266

267+
num_tokens_across_dp.masked_fill_(num_tokens_across_dp == -1,
268+
num_tokens)
270269
runner._dummy_run(num_tokens,
271270
is_compile=False,
272271
num_tokens_across_dp=num_tokens_across_dp,

0 commit comments

Comments
 (0)