Skip to content

Commit adf436b

Browse files
authored
[0.9.1][bugfix] fix torchair_graph_batch_sizes bug (#1570)
### What this PR does / why we need it? fix graph mode error when enable_expert_parallel and graph_batch_size can't be divisible by tp_size ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? Signed-off-by: zzzzwwjj <1183291235@qq.com>
1 parent 129a472 commit adf436b

File tree

2 files changed

+43
-18
lines changed

2 files changed

+43
-18
lines changed

vllm_ascend/attention/mla_v1.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ def _get_graph_runner_block_tables(
290290
self, num_seqs: int, block_tables: torch.Tensor) -> torch.Tensor:
291291

292292
max_batch_size, max_blocks = self.runner.graph_block_tables.shape
293-
assert max_batch_size >= num_seqs
293+
assert max_batch_size >= num_seqs, f"max_batch_size: {max_batch_size} should be bigger than cur_num_seqs: {num_seqs}"
294294

295295
if isinstance(self.runner.graph_block_tables, np.ndarray):
296296
graph_block_tables = torch.zeros((max_batch_size, max_blocks),

vllm_ascend/worker/model_runner_v1.py

Lines changed: 42 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
139139
self.model_config = vllm_config.model_config
140140
self.cache_config = vllm_config.cache_config
141141
self.lora_config = vllm_config.lora_config
142+
self.parallel_config = vllm_config.parallel_config
142143
self.scheduler_config = vllm_config.scheduler_config
143144
self.speculative_config = vllm_config.speculative_config
144145
ascend_config = get_ascend_config()
@@ -156,12 +157,6 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
156157
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
157158
self.max_num_reqs = self.scheduler_config.max_num_seqs
158159

159-
self.graph_block_tables = np.zeros(
160-
(self.vllm_config.scheduler_config.max_num_seqs,
161-
(self.model_config.max_model_len + self.block_size - 1) //
162-
self.block_size),
163-
dtype=np.int32)
164-
165160
# Model-related.
166161
self.num_attn_layers = self.model_config.get_num_layers_by_block_type(
167162
vllm_config.parallel_config, LayerBlockType.attention)
@@ -355,11 +350,13 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
355350
if ascend_config.torchair_graph_config.graph_batch_sizes_init:
356351
self.init_torchair_graph_batch_sizes()
357352

358-
if len(self.torchair_graph_batch_sizes) == 0:
359-
# TODO(zzzzwwjj): check torchair_graph_batch_sizes init code
360-
self.torchair_graph_batch_sizes = [
361-
self.scheduler_config.max_num_seqs
362-
]
353+
self.check_torchair_graph_batch_sizes()
354+
355+
self.graph_block_tables = np.zeros(
356+
(self.torchair_graph_batch_sizes[-1],
357+
(self.model_config.max_model_len + self.block_size - 1) //
358+
self.block_size),
359+
dtype=np.int32)
363360

364361
torch._dynamo.cache_size.config.cache_size_limit += len(
365362
self.torchair_graph_batch_sizes)
@@ -1707,9 +1704,9 @@ def load_model(self) -> None:
17071704
m.consumed_memory / float(2**30))
17081705

17091706
def _get_torchair_lazy_compiled_model(self, batch_size: int):
1710-
if batch_size < 0 or batch_size > self.max_num_reqs:
1707+
if batch_size < 0 or batch_size > self.torchair_graph_batch_sizes[-1]:
17111708
raise ValueError(
1712-
f"Bad graph batch size:{batch_size}! max_num_reqs:{self.max_num_reqs}"
1709+
f"Bad graph batch size:{batch_size}! max_graph_batch_sizes:{self.torchair_graph_batch_sizes[-1]}"
17131710
)
17141711

17151712
compiled_model = self.torchair_compiled_models.get(
@@ -2075,8 +2072,36 @@ def init_torchair_graph_batch_sizes(self):
20752072
start_graph_batch_size *= 2
20762073

20772074
def select_torchair_padded_batch_size(self, batch_size: int):
2078-
selected_batch_size = self.max_num_reqs
20792075
for padded_batch_size in self.torchair_graph_batch_sizes:
2080-
if batch_size <= padded_batch_size < selected_batch_size:
2081-
selected_batch_size = padded_batch_size
2082-
return selected_batch_size
2076+
if batch_size <= padded_batch_size:
2077+
return padded_batch_size
2078+
raise ValueError(
2079+
f"cur batch_size is invalid, torchair_graph_batch_sizes is "
2080+
f"{self.torchair_graph_batch_sizes}, but cur batch_size is {batch_size}."
2081+
)
2082+
2083+
def check_torchair_graph_batch_sizes(self):
2084+
if len(self.torchair_graph_batch_sizes) == 0:
2085+
self.torchair_graph_batch_sizes = [1, self.max_num_reqs]
2086+
else:
2087+
self.torchair_graph_batch_sizes = sorted(
2088+
self.torchair_graph_batch_sizes)
2089+
while self.torchair_graph_batch_sizes[-1] > self.max_num_reqs:
2090+
self.torchair_graph_batch_sizes.pop()
2091+
if len(self.torchair_graph_batch_sizes) == 0:
2092+
logger.warning(
2093+
"torch_graph_batch_sizes is invalid, reset it to [1, max_num_seqs]"
2094+
)
2095+
self.torchair_graph_batch_sizes = [1, self.max_num_reqs]
2096+
if self.torchair_graph_batch_sizes[-1] < self.max_num_reqs:
2097+
self.torchair_graph_batch_sizes.append(self.max_num_reqs)
2098+
2099+
# NOTE: when enable_expert_parallel, we need to check if `graph_batch_size` is divisible by `tp_size`
2100+
tp_size = self.parallel_config.tensor_parallel_size
2101+
if self.parallel_config.enable_expert_parallel:
2102+
new_graph_batch_sizes = []
2103+
for graph_batch_size in self.torchair_graph_batch_sizes:
2104+
cur_graph_batch_size = graph_batch_size + tp_size - graph_batch_size % tp_size
2105+
if cur_graph_batch_size not in new_graph_batch_sizes:
2106+
new_graph_batch_sizes.append(cur_graph_batch_size)
2107+
self.torchair_graph_batch_sizes = new_graph_batch_sizes

0 commit comments

Comments
 (0)