@@ -139,6 +139,7 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
139
139
self .model_config = vllm_config .model_config
140
140
self .cache_config = vllm_config .cache_config
141
141
self .lora_config = vllm_config .lora_config
142
+ self .parallel_config = vllm_config .parallel_config
142
143
self .scheduler_config = vllm_config .scheduler_config
143
144
self .speculative_config = vllm_config .speculative_config
144
145
ascend_config = get_ascend_config ()
@@ -156,12 +157,6 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
156
157
self .max_num_tokens = self .scheduler_config .max_num_batched_tokens
157
158
self .max_num_reqs = self .scheduler_config .max_num_seqs
158
159
159
- self .graph_block_tables = np .zeros (
160
- (self .vllm_config .scheduler_config .max_num_seqs ,
161
- (self .model_config .max_model_len + self .block_size - 1 ) //
162
- self .block_size ),
163
- dtype = np .int32 )
164
-
165
160
# Model-related.
166
161
self .num_attn_layers = self .model_config .get_num_layers_by_block_type (
167
162
vllm_config .parallel_config , LayerBlockType .attention )
@@ -355,11 +350,13 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
355
350
if ascend_config .torchair_graph_config .graph_batch_sizes_init :
356
351
self .init_torchair_graph_batch_sizes ()
357
352
358
- if len (self .torchair_graph_batch_sizes ) == 0 :
359
- # TODO(zzzzwwjj): check torchair_graph_batch_sizes init code
360
- self .torchair_graph_batch_sizes = [
361
- self .scheduler_config .max_num_seqs
362
- ]
353
+ self .check_torchair_graph_batch_sizes ()
354
+
355
+ self .graph_block_tables = np .zeros (
356
+ (self .torchair_graph_batch_sizes [- 1 ],
357
+ (self .model_config .max_model_len + self .block_size - 1 ) //
358
+ self .block_size ),
359
+ dtype = np .int32 )
363
360
364
361
torch ._dynamo .cache_size .config .cache_size_limit += len (
365
362
self .torchair_graph_batch_sizes )
@@ -1707,9 +1704,9 @@ def load_model(self) -> None:
1707
1704
m .consumed_memory / float (2 ** 30 ))
1708
1705
1709
1706
def _get_torchair_lazy_compiled_model (self , batch_size : int ):
1710
- if batch_size < 0 or batch_size > self .max_num_reqs :
1707
+ if batch_size < 0 or batch_size > self .torchair_graph_batch_sizes [ - 1 ] :
1711
1708
raise ValueError (
1712
- f"Bad graph batch size:{ batch_size } ! max_num_reqs :{ self .max_num_reqs } "
1709
+ f"Bad graph batch size:{ batch_size } ! max_graph_batch_sizes :{ self .torchair_graph_batch_sizes [ - 1 ] } "
1713
1710
)
1714
1711
1715
1712
compiled_model = self .torchair_compiled_models .get (
@@ -2075,8 +2072,36 @@ def init_torchair_graph_batch_sizes(self):
2075
2072
start_graph_batch_size *= 2
2076
2073
2077
2074
def select_torchair_padded_batch_size (self , batch_size : int ):
2078
- selected_batch_size = self .max_num_reqs
2079
2075
for padded_batch_size in self .torchair_graph_batch_sizes :
2080
- if batch_size <= padded_batch_size < selected_batch_size :
2081
- selected_batch_size = padded_batch_size
2082
- return selected_batch_size
2076
+ if batch_size <= padded_batch_size :
2077
+ return padded_batch_size
2078
+ raise ValueError (
2079
+ f"cur batch_size is invalid, torchair_graph_batch_sizes is "
2080
+ f"{ self .torchair_graph_batch_sizes } , but cur batch_size is { batch_size } ."
2081
+ )
2082
+
2083
+ def check_torchair_graph_batch_sizes (self ):
2084
+ if len (self .torchair_graph_batch_sizes ) == 0 :
2085
+ self .torchair_graph_batch_sizes = [1 , self .max_num_reqs ]
2086
+ else :
2087
+ self .torchair_graph_batch_sizes = sorted (
2088
+ self .torchair_graph_batch_sizes )
2089
+ while self .torchair_graph_batch_sizes [- 1 ] > self .max_num_reqs :
2090
+ self .torchair_graph_batch_sizes .pop ()
2091
+ if len (self .torchair_graph_batch_sizes ) == 0 :
2092
+ logger .warning (
2093
+ "torch_graph_batch_sizes is invalid, reset it to [1, max_num_seqs]"
2094
+ )
2095
+ self .torchair_graph_batch_sizes = [1 , self .max_num_reqs ]
2096
+ if self .torchair_graph_batch_sizes [- 1 ] < self .max_num_reqs :
2097
+ self .torchair_graph_batch_sizes .append (self .max_num_reqs )
2098
+
2099
+ # NOTE: when enable_expert_parallel, we need to check if `graph_batch_size` is divisible by `tp_size`
2100
+ tp_size = self .parallel_config .tensor_parallel_size
2101
+ if self .parallel_config .enable_expert_parallel :
2102
+ new_graph_batch_sizes = []
2103
+ for graph_batch_size in self .torchair_graph_batch_sizes :
2104
+ cur_graph_batch_size = graph_batch_size + tp_size - graph_batch_size % tp_size
2105
+ if cur_graph_batch_size not in new_graph_batch_sizes :
2106
+ new_graph_batch_sizes .append (cur_graph_batch_size )
2107
+ self .torchair_graph_batch_sizes = new_graph_batch_sizes
0 commit comments