49
49
from vllm .multimodal .utils import group_mm_inputs_by_modality
50
50
from vllm .sampling_params import SamplingType
51
51
from vllm .sequence import IntermediateTensors
52
- from vllm .utils import (STR_DTYPE_TO_TORCH_DTYPE , DeviceMemoryProfiler ,
53
- LayerBlockType , LazyLoader , cdiv )
52
+ from vllm .utils import DeviceMemoryProfiler , LazyLoader , cdiv
54
53
from vllm .v1 .core .encoder_cache_manager import compute_encoder_budget
55
54
from vllm .v1 .kv_cache_interface import (FullAttentionSpec , KVCacheConfig ,
56
55
KVCacheSpec )
@@ -137,82 +136,69 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
137
136
self .lora_config = vllm_config .lora_config
138
137
self .scheduler_config = vllm_config .scheduler_config
139
138
self .speculative_config = vllm_config .speculative_config
139
+ self .block_size = vllm_config .cache_config .block_size
140
+ self .max_num_blocks_per_req = cdiv (self .model_config .max_model_len ,
141
+ self .block_size )
142
+ self .max_num_tokens = self .scheduler_config .max_num_batched_tokens
143
+ self .max_num_reqs = self .scheduler_config .max_num_seqs
144
+ self .dp_size = vllm_config .parallel_config .data_parallel_size
145
+ self .dp_rank = vllm_config .parallel_config .data_parallel_rank
146
+ self .device = device
147
+ self .dtype = self .model_config .dtype
148
+ self .sampler = Sampler ()
149
+ # Multi-modal data support
150
+ self .input_registry = INPUT_REGISTRY
151
+ self .mm_registry = MULTIMODAL_REGISTRY
152
+ self .max_num_encoder_input_tokens , self .encoder_cache_size = compute_encoder_budget (
153
+ model_config = self .model_config ,
154
+ scheduler_config = self .scheduler_config ,
155
+ mm_registry = self .mm_registry )
156
+
157
+ # Lazy initialization, these will be set after __init__
158
+ self .kv_caches : List [torch .Tensor ] = []
159
+ self .encoder_cache : Dict [str , Dict [int , torch .Tensor ]] = {}
160
+ self .attn_mask = None
161
+ self .attn_state = None
162
+ self .requests : Dict [str , CachedRequestState ] = {}
163
+ self .intermediate_tensors : Optional [IntermediateTensors ] = None
164
+
140
165
ascend_config = get_ascend_config ()
141
166
if ascend_config .ascend_scheduler_config .enabled :
142
167
self .chunked_prefill_enabled = self .scheduler_config .chunked_prefill_enabled
143
168
else :
144
169
self .chunked_prefill_enabled = True
145
- self .device = device
146
170
147
171
self .is_multimodal_model = self .model_config .is_multimodal_model
148
- self .block_size = vllm_config .cache_config .block_size
149
-
150
- self .max_num_blocks_per_req = cdiv (self .model_config .max_model_len ,
151
- self .block_size )
152
- self .max_num_tokens = self .scheduler_config .max_num_batched_tokens
153
- self .max_num_reqs = self .scheduler_config .max_num_seqs
172
+ if self .is_multimodal_model :
173
+ self .inputs_embeds = torch .zeros (
174
+ (self .max_num_tokens , self .model_config .get_hidden_size ()),
175
+ dtype = self .dtype ,
176
+ device = self .device )
154
177
155
178
self .graph_block_tables = np .zeros (
156
- (self .vllm_config . scheduler_config . max_num_seqs ,
179
+ (self .max_num_reqs ,
157
180
(self .model_config .max_model_len + self .block_size - 1 ) //
158
181
self .block_size ),
159
182
dtype = np .int32 )
160
183
161
- # Model-related.
162
- self .num_attn_layers = self .model_config .get_num_layers_by_block_type (
163
- vllm_config .parallel_config , LayerBlockType .attention )
164
- self .hidden_size = self .model_config .get_hidden_size ()
165
- self .dtype = self .model_config .dtype
166
- cache_config = vllm_config .cache_config
167
- if cache_config .cache_dtype == "auto" :
168
- self .kv_cache_dtype = self .dtype
169
- else :
170
- self .kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE [
171
- cache_config .cache_dtype ]
172
-
173
- self .head_size = self .model_config .get_head_size ()
184
+ # Set up Attention
174
185
self .attn_backend = get_attn_backend (
175
- self . head_size ,
186
+ 0 ,
176
187
self .dtype ,
177
- self . kv_cache_dtype ,
188
+ None ,
178
189
self .block_size ,
179
190
self .model_config .is_attention_free ,
180
191
use_mla = self .model_config .use_mla ,
181
192
)
182
- if self .attn_backend is None :
183
- error_msg = (
184
- f"Error with get_att_backend: { self .head_size = } , "
185
- f"{ self .dtype = } , { self .kv_cache_dtype = } , { self .block_size = } , "
186
- f"{ self .model_config .is_attention_free = } , "
187
- f"{ self .model_config .use_mla = } " )
188
- logger .error (error_msg )
189
- raise NotImplementedError (
190
- "Non-Attention backend is not supported by V1 NPUModelRunner." )
191
-
192
193
self .attn_metadata_builder = self .attn_backend .get_builder_cls ()(
193
194
weakref .proxy (self ))
194
195
195
- # Multi-modal data support
196
- self .input_registry = INPUT_REGISTRY
197
- self .mm_registry = MULTIMODAL_REGISTRY
198
- self .uses_mrope = self .model_config .uses_mrope
199
-
200
- self .max_num_encoder_input_tokens , self .encoder_cache_size = compute_encoder_budget (
201
- model_config = self .model_config ,
202
- scheduler_config = self .scheduler_config ,
203
- mm_registry = self .mm_registry )
204
-
205
- # Lazy initialization
206
- # self.model: nn.Module # Set after load_model
207
- self .kv_caches : List [torch .Tensor ] = []
208
- # req_id -> (input_id -> encoder_output)
209
- self .encoder_cache : Dict [str , Dict [int , torch .Tensor ]] = {}
210
-
211
196
# Set up speculative decoding.
212
197
self .use_aux_hidden_state_outputs = False
213
198
self .use_spec_decode = False
214
199
self .spec_attn_mask = None
215
200
self .use_eagle = False
201
+ self .drafter = None
216
202
if self .speculative_config :
217
203
self .use_spec_decode = True
218
204
self .spec_attn_mask = torch .triu (torch .ones (2048 ,
@@ -235,10 +221,6 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
235
221
f"{ self .speculative_config .method } " )
236
222
self .rejection_sampler = AscendRejectionSampler ()
237
223
238
- # Request states.
239
- self .requests : Dict [str , CachedRequestState ] = {}
240
- # Persistent batch.
241
-
242
224
self .input_ids = torch .zeros (self .max_num_tokens ,
243
225
dtype = torch .int32 ,
244
226
device = self .device )
@@ -251,9 +233,8 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
251
233
self .seq_lens = torch .zeros (self .max_num_reqs ,
252
234
dtype = torch .int32 ,
253
235
device = self .device )
254
- # None in the first PP rank. The rest are set after load_model.
255
- self .intermediate_tensors : Optional [IntermediateTensors ] = None
256
236
237
+ self .uses_mrope = self .model_config .uses_mrope
257
238
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
258
239
if self .uses_mrope :
259
240
# NOTE: `mrope_positions` is implemented with one additional dummy
@@ -276,12 +257,6 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
276
257
pin_memory = True )
277
258
self .mrope_positions_np = self .mrope_positions_cpu .numpy ()
278
259
279
- if self .is_multimodal_model :
280
- self .inputs_embeds = torch .zeros (
281
- (self .max_num_tokens , self .hidden_size ),
282
- dtype = self .dtype ,
283
- device = self .device )
284
-
285
260
# OPTIMIZATION: Cache the tensors rather than creating them every step.
286
261
self .arange_np : npt .NDArray [np .int32 ] = np .arange (max (
287
262
self .max_num_reqs + 1 , self .model_config .max_model_len ,
@@ -305,24 +280,17 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
305
280
device = "cpu" ,
306
281
pin_memory = True )
307
282
self .slot_mapping_np = self .slot_mapping_cpu .numpy ()
308
-
309
283
self .query_start_loc_cpu = torch .zeros (self .max_num_reqs + 1 ,
310
284
dtype = torch .int32 ,
311
285
device = "cpu" ,
312
286
pin_memory = True )
313
287
self .query_start_loc_np = self .query_start_loc_cpu .numpy ()
314
-
315
288
self .seq_lens_cpu = torch .zeros (self .max_num_reqs ,
316
289
dtype = torch .int32 ,
317
290
device = "cpu" ,
318
291
pin_memory = True )
319
292
self .seq_lens_np = self .seq_lens_cpu .numpy ()
320
293
321
- self .input_positions_cpu = torch .arange (0 ,
322
- self .max_num_tokens ,
323
- device = "cpu" )
324
- self .attn_mask = None
325
- self .attn_state = None
326
294
self .use_aclgraph = (self .vllm_config .compilation_config .level
327
295
== CompilationLevel .PIECEWISE
328
296
and not self .model_config .enforce_eager )
@@ -339,38 +307,27 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
339
307
# Therefore, an environment variable is added here to dynamically set
340
308
# the size of the pre-constructed mask matrix based on requirements.
341
309
mask_len = os .getenv ("PAGED_ATTENTION_MASK_LEN" , 10000 )
342
- self .attn_mask_len = min (self .model_config .max_model_len ,
343
- int (mask_len ))
310
+ attn_mask_len = min (self .model_config .max_model_len , int (mask_len ))
344
311
self .attn_mask_builder = AttentionMaskBuilder .initialize_from_len (
345
- self .attn_mask_len , self .dtype )
346
-
347
- self .sampler = Sampler ()
312
+ attn_mask_len , self .dtype )
348
313
349
314
self .torchair_compiled_model = None # type: ignore
350
315
self .torchair_compiled_models = {} # type: ignore
351
- ascend_config = get_ascend_config ()
352
- self .torchair_graph_enabled = ascend_config .torchair_graph_config .enabled and self .vllm_config .model_config .use_mla
316
+ self .torchair_graph_enabled = ascend_config .torchair_graph_config .enabled
353
317
self .use_cached_npu_graph = ascend_config .torchair_graph_config .use_cached_graph
354
318
self .torchair_graph_batch_sizes = ascend_config .torchair_graph_config .graph_batch_sizes
355
-
356
319
if ascend_config .torchair_graph_config .graph_batch_sizes_init :
357
320
self .init_torchair_graph_batch_sizes ()
358
-
359
321
if len (self .torchair_graph_batch_sizes ) == 0 :
360
322
# TODO(zzzzwwjj): check torchair_graph_batch_sizes init code
361
- self .torchair_graph_batch_sizes = [
362
- self .scheduler_config .max_num_seqs
363
- ]
323
+ self .torchair_graph_batch_sizes = [self .max_num_reqs ]
364
324
365
325
torch ._dynamo .cache_size .config .cache_size_limit += len (
366
326
self .torchair_graph_batch_sizes )
367
327
torch ._dynamo .config .capture_dynamic_output_shape_ops = True
368
328
torch ._logging .set_logs (
369
329
recompiles = envs_ascend .VLLM_ASCEND_TRACE_RECOMPILES )
370
330
371
- self .dp_size = vllm_config .parallel_config .data_parallel_size
372
- self .dp_rank = vllm_config .parallel_config .data_parallel_rank
373
-
374
331
def _update_states (self , scheduler_output : "SchedulerOutput" ) -> None :
375
332
"""Update the cached states and the persistent batch with the scheduler
376
333
output.
@@ -1702,8 +1659,7 @@ def _dummy_run(
1702
1659
# for dummy run with LoRA so that the num_reqs collectively
1703
1660
# has num_tokens in total.
1704
1661
assert num_tokens <= self .scheduler_config .max_num_batched_tokens
1705
- max_num_reqs = self .scheduler_config .max_num_seqs
1706
- num_reqs = max_num_reqs if num_tokens >= max_num_reqs else num_tokens
1662
+ num_reqs = self .max_num_reqs if num_tokens >= self .max_num_reqs else num_tokens
1707
1663
min_tokens_per_req = num_tokens // num_reqs
1708
1664
num_scheduled_tokens_list = [min_tokens_per_req ] * num_reqs
1709
1665
num_scheduled_tokens_list [- 1 ] += num_tokens % num_reqs
@@ -1805,14 +1761,13 @@ def profile_run(self) -> None:
1805
1761
1806
1762
# For profile, have maximum num_reqs and that collectively have
1807
1763
# maximum num_tokens.
1808
- num_reqs = self .scheduler_config .max_num_seqs
1809
- num_tokens = self .max_num_tokens
1810
- min_tokens_per_req = num_tokens // num_reqs
1764
+ min_tokens_per_req = self .max_num_tokens // self .max_num_reqs
1811
1765
1812
- num_scheduled_tokens_list = [min_tokens_per_req ] * num_reqs
1813
- num_scheduled_tokens_list [- 1 ] += num_tokens % num_reqs
1814
- assert sum (num_scheduled_tokens_list ) == num_tokens
1815
- assert len (num_scheduled_tokens_list ) == num_reqs
1766
+ num_scheduled_tokens_list = [min_tokens_per_req ] * self .max_num_reqs
1767
+ num_scheduled_tokens_list [
1768
+ - 1 ] += self .max_num_tokens % self .max_num_reqs
1769
+ assert sum (num_scheduled_tokens_list ) == self .max_num_tokens
1770
+ assert len (num_scheduled_tokens_list ) == self .max_num_reqs
1816
1771
1817
1772
num_scheduled_tokens = np .array (num_scheduled_tokens_list ,
1818
1773
dtype = np .int32 )
@@ -1840,15 +1795,14 @@ def load_model(self) -> None:
1840
1795
1841
1796
with DeviceMemoryProfiler () as m : # noqa: SIM117
1842
1797
self .model = get_model (vllm_config = self .vllm_config )
1843
- if hasattr ( self , " drafter" ) :
1798
+ if self . drafter :
1844
1799
logger .info ("Loading drafter model..." )
1845
1800
if self .use_aux_hidden_state_outputs :
1846
1801
self .drafter .load_model (self .model )
1802
+ self .model .set_aux_hidden_state_layers (
1803
+ self .model .get_eagle3_aux_hidden_state_layers ())
1847
1804
else :
1848
1805
self .drafter .load_model ()
1849
- if self .use_aux_hidden_state_outputs :
1850
- self .model .set_aux_hidden_state_layers (
1851
- self .model .get_eagle3_aux_hidden_state_layers ())
1852
1806
if self .lora_config :
1853
1807
self .model = self .load_lora_model (self .model ,
1854
1808
self .model_config ,
@@ -1934,7 +1888,7 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
1934
1888
device = self .device ,
1935
1889
pin_memory = True ,
1936
1890
vocab_size = self .model_config .get_vocab_size (),
1937
- block_sizes = [self .cache_config . block_size ],
1891
+ block_sizes = [self .block_size ],
1938
1892
)
1939
1893
1940
1894
kv_cache_sizes = {}
@@ -2014,7 +1968,6 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
2014
1968
"""
2015
1969
2016
1970
forward_ctx = self .vllm_config .compilation_config .static_forward_context
2017
- block_size = self .vllm_config .cache_config .block_size
2018
1971
use_mla = self .vllm_config .model_config .use_mla
2019
1972
kv_cache_spec : dict [str , KVCacheSpec ] = {}
2020
1973
for layer_name , attn_module in forward_ctx .items ():
@@ -2026,7 +1979,7 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
2026
1979
assert isinstance (attn_module , Attention )
2027
1980
if attn_module .attn_type == AttentionType .DECODER :
2028
1981
kv_cache_spec [layer_name ] = FullAttentionSpec (
2029
- block_size = block_size ,
1982
+ block_size = self . block_size ,
2030
1983
num_kv_heads = attn_module .num_kv_heads ,
2031
1984
head_size = attn_module .head_size ,
2032
1985
dtype = attn_module .dtype ,
@@ -2115,6 +2068,7 @@ def _generate_draft_token_ids(
2115
2068
start_idx = self .input_batch .num_tokens_no_spec [i ]
2116
2069
end_idx = start_idx + num_sampled_ids
2117
2070
self .input_batch .token_ids_cpu [i , start_idx :end_idx ] = sampled_ids
2071
+ assert self .drafter is not None
2118
2072
drafter_output = self .drafter .propose (
2119
2073
self .input_batch .token_ids_cpu [i , :end_idx ])
2120
2074
if drafter_output is None or len (drafter_output ) == 0 :
@@ -2171,6 +2125,7 @@ def _generate_mtp_token_ids(
2171
2125
dtype = torch .int32 ,
2172
2126
device = self .device ,
2173
2127
)
2128
+ assert self .drafter is not None
2174
2129
cu_num_tokens , token_indices = self .drafter .prepare_inputs (
2175
2130
attn_metadata .query_start_loc ,
2176
2131
num_rejected_tokens ,
@@ -2179,7 +2134,7 @@ def _generate_mtp_token_ids(
2179
2134
target_positions = positions [token_indices ]
2180
2135
target_hidden_states = hidden_states [token_indices ]
2181
2136
target_slot_mapping = attn_metadata .slot_mapping [token_indices ]
2182
-
2137
+ assert self . drafter is not None
2183
2138
draft_token_ids = self .drafter .propose (
2184
2139
target_token_ids = target_token_ids ,
2185
2140
target_positions = target_positions ,
@@ -2200,7 +2155,7 @@ def init_torchair_graph_batch_sizes(self):
2200
2155
# NOTE: When use all2all | mc2, We need to slice the `num_tokens` dimension into `tp_size` blocks
2201
2156
start_graph_batch_size = max (start_graph_batch_size , tp_size )
2202
2157
2203
- while (start_graph_batch_size <= self .scheduler_config . max_num_seqs ):
2158
+ while (start_graph_batch_size <= self .max_num_reqs ):
2204
2159
self .torchair_graph_batch_sizes .append (start_graph_batch_size )
2205
2160
start_graph_batch_size *= 2
2206
2161
0 commit comments