Skip to content

Commit 4a614a1

Browse files
committed
[Bugfix] graph batch size round up to tp size, when enable expert parallel
Signed-off-by: liziyu <liziyu16@huawei.com>
1 parent 3ea2410 commit 4a614a1

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

vllm_ascend/worker/model_runner_v1.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -811,8 +811,8 @@ def _process_reqs(
811811
assert total_num_scheduled_tokens > 0
812812
num_reqs = self.input_batch.num_reqs
813813
assert num_reqs > 0
814-
if (self.use_aclgraph and
815-
total_num_scheduled_tokens <= self.aclgraph_batch_sizes[-1]):
814+
if (self.use_aclgraph and total_num_scheduled_tokens
815+
<= self.aclgraph_batch_sizes[-1]):
816816
# Add padding to the batch size.
817817
num_input_tokens = self.vllm_config.pad_for_cudagraph(
818818
total_num_scheduled_tokens)
@@ -2101,7 +2101,9 @@ def check_torchair_graph_batch_sizes(self):
21012101
if self.parallel_config.enable_expert_parallel:
21022102
new_graph_batch_sizes = []
21032103
for graph_batch_size in self.torchair_graph_batch_sizes:
2104-
cur_graph_batch_size = graph_batch_size + tp_size - graph_batch_size % tp_size
2105-
if cur_graph_batch_size not in new_graph_batch_sizes:
2104+
cur_graph_batch_size = (graph_batch_size + tp_size -
2105+
1) // tp_size * tp_size
2106+
if cur_graph_batch_size not in new_graph_batch_sizes and \
2107+
cur_graph_batch_size <= self.scheduler_config.max_num_batched_tokens:
21062108
new_graph_batch_sizes.append(cur_graph_batch_size)
21072109
self.torchair_graph_batch_sizes = new_graph_batch_sizes

0 commit comments

Comments
 (0)