Skip to content

Commit 618309a

Browse files
committed
fix: improve handling of max_num_tokens
Signed-off-by: Jade Zheng <zheng.shoujian@outlook.com>
1 parent fedfde3 commit 618309a

File tree

1 file changed

+10
-9
lines changed

1 file changed

+10
-9
lines changed

vllm_ascend/worker/model_runner_v1.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1109,17 +1109,18 @@ def _process_reqs(
11091109

11101110
# Add graph_pad_size here
11111111
if self.torchair_graph_enabled and not with_prefill:
1112-
if self.dp_size > 1:
1113-
padded_batch_size = self.select_torchair_padded_batch_size(
1114-
max_num_tokens)
1115-
num_tokens_across_dp.masked_fill_(num_tokens_across_dp == -1,
1116-
padded_batch_size)
1117-
else:
1118-
padded_batch_size = self.select_torchair_padded_batch_size(
1119-
total_num_scheduled_tokens)
1112+
max_num_tokens = (max_num_tokens
1113+
if self.dp_size > 1 else num_input_tokens)
1114+
padded_batch_size = self.select_torchair_padded_batch_size(
1115+
max_num_tokens)
1116+
num_tokens_across_dp.masked_fill_(num_tokens_across_dp == -1,
1117+
padded_batch_size)
11201118
graph_pad_size = padded_batch_size - total_num_scheduled_tokens
1121-
11221119
extra_builder_kwargs['graph_pad_size'] = graph_pad_size
1120+
else:
1121+
# If torchair graph is not enabled, or if with_prefill is True, the
1122+
# dummy run batch size is set to 1.
1123+
num_tokens_across_dp.masked_fill_(num_tokens_across_dp == -1, 1)
11231124

11241125
if self.vllm_config.model_config.use_mla:
11251126
attn_metadata = self.attn_metadata_builder.build( # type: ignore

0 commit comments

Comments
 (0)