Skip to content

Commit 14117c8

Browse files
committed
fix: update dummy run batch size handling
Signed-off-by: Jade Zheng <zheng.shoujian@outlook.com>
1 parent adfa9e6 commit 14117c8

File tree

1 file changed

+8
-3
lines changed

1 file changed

+8
-3
lines changed

vllm_ascend/worker/model_runner_v1.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1106,14 +1106,19 @@ def _process_reqs(
11061106
if self.dp_size > 1 else num_input_tokens)
11071107
padded_batch_size = self.select_torchair_padded_batch_size(
11081108
max_num_tokens)
1109-
num_tokens_across_dp.masked_fill_(num_tokens_across_dp == -1,
1110-
padded_batch_size)
11111109
graph_pad_size = padded_batch_size - total_num_scheduled_tokens
11121110
extra_builder_kwargs['graph_pad_size'] = graph_pad_size
1111+
# If torchair graph is enabled and in decode mode, the dummy run
1112+
# batch size is set to the selected graph size.
1113+
dummy_num_tokens = padded_batch_size
11131114
else:
11141115
# If torchair graph is not enabled, or if with_prefill is True, the
11151116
# dummy run batch size is set to 1.
1116-
num_tokens_across_dp.masked_fill_(num_tokens_across_dp == -1, 1)
1117+
dummy_num_tokens = 1
1118+
1119+
if self.dp_size > 1:
1120+
num_tokens_across_dp.masked_fill_(num_tokens_across_dp == -1,
1121+
dummy_num_tokens)
11171122

11181123
if self.vllm_config.model_config.use_mla:
11191124
attn_metadata = self.attn_metadata_builder.build( # type: ignore

0 commit comments

Comments
 (0)