Skip to content

Commit 876de71

Browse files
committed
fix: update dummy run batch size handling
Signed-off-by: Jade Zheng <zheng.shoujian@outlook.com>
1 parent 618309a commit 876de71

File tree

1 file changed

+8
-3
lines changed

1 file changed

+8
-3
lines changed

vllm_ascend/worker/model_runner_v1.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1113,14 +1113,19 @@ def _process_reqs(
11131113
if self.dp_size > 1 else num_input_tokens)
11141114
padded_batch_size = self.select_torchair_padded_batch_size(
11151115
max_num_tokens)
1116-
num_tokens_across_dp.masked_fill_(num_tokens_across_dp == -1,
1117-
padded_batch_size)
11181116
graph_pad_size = padded_batch_size - total_num_scheduled_tokens
11191117
extra_builder_kwargs['graph_pad_size'] = graph_pad_size
1118+
# If torchair graph is enabled and in decode mode, the dummy run
1119+
# batch size is set to the selected graph size.
1120+
dummy_num_tokens = padded_batch_size
11201121
else:
11211122
# If torchair graph is not enabled, or if with_prefill is True, the
11221123
# dummy run batch size is set to 1.
1123-
num_tokens_across_dp.masked_fill_(num_tokens_across_dp == -1, 1)
1124+
dummy_num_tokens = 1
1125+
1126+
if self.dp_size > 1:
1127+
num_tokens_across_dp.masked_fill_(num_tokens_across_dp == -1,
1128+
dummy_num_tokens)
11241129

11251130
if self.vllm_config.model_config.use_mla:
11261131
attn_metadata = self.attn_metadata_builder.build( # type: ignore

0 commit comments

Comments
 (0)