Skip to content

Commit adfa9e6

Browse files
committed
fix: improve handling of max_num_tokens
Signed-off-by: Jade Zheng <zheng.shoujian@outlook.com>
1 parent 98f9bc9 commit adfa9e6

File tree

1 file changed

+10
-9
lines changed

1 file changed

+10
-9
lines changed

vllm_ascend/worker/model_runner_v1.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1102,17 +1102,18 @@ def _process_reqs(
11021102

11031103
# Add graph_pad_size here
11041104
if self.torchair_graph_enabled and not with_prefill:
1105-
if self.dp_size > 1:
1106-
padded_batch_size = self.select_torchair_padded_batch_size(
1107-
max_num_tokens)
1108-
num_tokens_across_dp.masked_fill_(num_tokens_across_dp == -1,
1109-
padded_batch_size)
1110-
else:
1111-
padded_batch_size = self.select_torchair_padded_batch_size(
1112-
total_num_scheduled_tokens)
1105+
max_num_tokens = (max_num_tokens
1106+
if self.dp_size > 1 else num_input_tokens)
1107+
padded_batch_size = self.select_torchair_padded_batch_size(
1108+
max_num_tokens)
1109+
num_tokens_across_dp.masked_fill_(num_tokens_across_dp == -1,
1110+
padded_batch_size)
11131111
graph_pad_size = padded_batch_size - total_num_scheduled_tokens
1114-
11151112
extra_builder_kwargs['graph_pad_size'] = graph_pad_size
1113+
else:
1114+
# If torchair graph is not enabled, or if with_prefill is True, the
1115+
# dummy run batch size is set to 1.
1116+
num_tokens_across_dp.masked_fill_(num_tokens_across_dp == -1, 1)
11161117

11171118
if self.vllm_config.model_config.use_mla:
11181119
attn_metadata = self.attn_metadata_builder.build( # type: ignore

0 commit comments

Comments
 (0)