Skip to content

Commit 7420aca

Browse files
committed
fix: improve handling of max_num_tokens
Signed-off-by: Jade Zheng <zheng.shoujian@outlook.com>
1 parent f071cc2 commit 7420aca

File tree

1 file changed

+10
-9
lines changed

1 file changed

+10
-9
lines changed

vllm_ascend/worker/model_runner_v1.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1110,17 +1110,18 @@ def _process_reqs(
11101110

11111111
# Add graph_pad_size here
11121112
if self.torchair_graph_enabled and not with_prefill:
1113-
if self.dp_size > 1:
1114-
padded_batch_size = self.select_torchair_padded_batch_size(
1115-
max_num_tokens)
1116-
num_tokens_across_dp.masked_fill_(num_tokens_across_dp == -1,
1117-
padded_batch_size)
1118-
else:
1119-
padded_batch_size = self.select_torchair_padded_batch_size(
1120-
total_num_scheduled_tokens)
1113+
max_num_tokens = (max_num_tokens
1114+
if self.dp_size > 1 else num_input_tokens)
1115+
padded_batch_size = self.select_torchair_padded_batch_size(
1116+
max_num_tokens)
1117+
num_tokens_across_dp.masked_fill_(num_tokens_across_dp == -1,
1118+
padded_batch_size)
11211119
graph_pad_size = padded_batch_size - total_num_scheduled_tokens
1122-
11231120
extra_builder_kwargs['graph_pad_size'] = graph_pad_size
1121+
else:
1122+
# If torchair graph is not enabled, or if with_prefill is True, the
1123+
# dummy run batch size is set to 1.
1124+
num_tokens_across_dp.masked_fill_(num_tokens_across_dp == -1, 1)
11241125

11251126
if self.vllm_config.model_config.use_mla:
11261127
attn_metadata = self.attn_metadata_builder.build( # type: ignore

0 commit comments

Comments
 (0)