24
24
import numpy as np
25
25
import numpy .typing as npt
26
26
import torch
27
- import torch .distributed
28
27
import torch .nn as nn
29
28
from vllm .attention import AttentionType
30
29
from vllm .attention .layer import Attention
36
35
from vllm .model_executor .layers .fused_moe import FusedMoE
37
36
from vllm .model_executor .model_loader import get_model
38
37
from vllm .multimodal import MULTIMODAL_REGISTRY , MultiModalKwargs
39
- from vllm .platforms import current_platform
40
38
from vllm .sampling_params import SamplingType
41
39
from vllm .sequence import IntermediateTensors
42
- from vllm .utils import (STR_DTYPE_TO_TORCH_DTYPE , DeviceMemoryProfiler ,
43
- LayerBlockType , cdiv , is_pin_memory_available )
40
+ from vllm .utils import DeviceMemoryProfiler , LayerBlockType , cdiv
44
41
from vllm .v1 .core .encoder_cache_manager import compute_encoder_budget
45
42
from vllm .v1 .kv_cache_interface import (FullAttentionSpec , KVCacheConfig ,
46
43
KVCacheSpec )
50
47
51
48
from vllm_ascend .attention .attention_v1 import (AscendAttentionBackend ,
52
49
AscendMetadata )
50
+ from vllm_ascend .platform import NPUPlatform
53
51
54
52
if TYPE_CHECKING :
55
53
from vllm .v1 .core .sched .output import SchedulerOutput
60
58
class NPUModelRunner :
61
59
62
60
def __init__ (self , vllm_config : VllmConfig , device : torch .device ):
63
-
64
61
self .vllm_config = vllm_config
65
62
self .model_config = vllm_config .model_config
66
- self .cache_config = vllm_config .cache_config
67
63
self .lora_config = vllm_config .lora_config
68
- self .load_config = vllm_config .load_config
69
- self .parallel_config = vllm_config .parallel_config
70
64
self .scheduler_config = vllm_config .scheduler_config
71
- self .speculative_config = vllm_config .speculative_config
72
- self .prompt_adapter_config = vllm_config .prompt_adapter_config
73
- self .observability_config = vllm_config .observability_config
74
-
75
- model_config = self .model_config
76
- cache_config = self .cache_config
77
- scheduler_config = self .scheduler_config
78
- parallel_config = self .parallel_config
79
-
80
65
self .device = device
81
- self .pin_memory = is_pin_memory_available ()
82
- self .dtype = self .model_config .dtype
83
-
84
- if cache_config .cache_dtype == "auto" :
85
- self .kv_cache_dtype = self .dtype
86
- else :
87
- self .kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE [
88
- cache_config .cache_dtype ]
89
-
90
- self .is_multimodal_model = model_config .is_multimodal_model
91
- self .sliding_window = model_config .get_sliding_window ()
92
- self .block_size = cache_config .block_size
93
- self .max_model_len = model_config .max_model_len
94
- self .max_num_blocks_per_req = cdiv (self .max_model_len , self .block_size )
95
- self .max_num_tokens = scheduler_config .max_num_batched_tokens
96
- self .max_num_reqs = scheduler_config .max_num_seqs
66
+ self .is_multimodal_model = self .model_config .is_multimodal_model
67
+ self .block_size = vllm_config .cache_config .block_size
68
+ self .max_num_blocks_per_req = cdiv (self .model_config .max_model_len ,
69
+ self .block_size )
70
+ self .max_num_tokens = self .scheduler_config .max_num_batched_tokens
71
+ self .max_num_reqs = self .scheduler_config .max_num_seqs
97
72
98
73
# Model-related.
99
- self .num_attn_layers = model_config .get_num_layers_by_block_type (
100
- parallel_config , LayerBlockType .attention )
101
- self .num_query_heads = model_config .get_num_attention_heads (
102
- parallel_config )
103
- self .num_kv_heads = model_config .get_num_kv_heads (parallel_config )
104
- self .head_size = model_config .get_head_size ()
105
- self .hidden_size = model_config .get_hidden_size ()
74
+ self .num_attn_layers = self .model_config .get_num_layers_by_block_type (
75
+ vllm_config .parallel_config , LayerBlockType .attention )
76
+ self .hidden_size = self .model_config .get_hidden_size ()
106
77
107
78
# Multi-modal data support
108
79
self .input_registry = INPUT_REGISTRY
109
80
self .mm_registry = MULTIMODAL_REGISTRY
110
- self .uses_mrope = model_config .uses_mrope
81
+ self .uses_mrope = self . model_config .uses_mrope
111
82
112
- encoder_compute_budget , encoder_cache_size = compute_encoder_budget (
113
- model_config = model_config ,
114
- scheduler_config = scheduler_config ,
83
+ self . max_num_encoder_input_tokens , self . encoder_cache_size = compute_encoder_budget (
84
+ model_config = self . model_config ,
85
+ scheduler_config = self . scheduler_config ,
115
86
mm_registry = self .mm_registry )
116
- self .max_num_encoder_input_tokens = encoder_compute_budget
117
- self .encoder_cache_size = encoder_cache_size
118
87
119
88
# Lazy initialization
120
89
# self.model: nn.Module # Set after load_model
121
90
self .kv_caches : List [torch .Tensor ] = []
122
91
# req_id -> (input_id -> encoder_output)
123
92
self .encoder_cache : Dict [str , Dict [int , torch .Tensor ]] = {}
124
93
125
- # Set up speculative decoding.
126
- self .use_spec_decode = False
127
-
128
94
# Request states.
129
95
self .requests : Dict [str , CachedRequestState ] = {}
130
96
# Persistent batch.
131
97
self .input_batch = InputBatch (
132
98
max_num_reqs = self .max_num_reqs ,
133
- max_model_len = self .max_model_len ,
99
+ max_model_len = self .model_config . max_model_len ,
134
100
max_num_blocks_per_req = self .max_num_blocks_per_req ,
135
101
device = self .device ,
136
- pin_memory = self . pin_memory ,
137
- vocab_size = model_config .get_vocab_size (),
102
+ pin_memory = True ,
103
+ vocab_size = self . model_config .get_vocab_size (),
138
104
)
139
105
140
106
self .input_ids = torch .zeros (self .max_num_tokens ,
@@ -165,46 +131,41 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
165
131
(3 , self .max_num_tokens + 1 ),
166
132
dtype = torch .int64 ,
167
133
device = "cpu" ,
168
- pin_memory = self . pin_memory )
134
+ pin_memory = True )
169
135
170
136
self .inputs_embeds = torch .zeros (
171
137
(self .max_num_tokens , self .hidden_size ),
172
- dtype = self .dtype ,
138
+ dtype = self .model_config . dtype ,
173
139
device = self .device )
174
140
175
141
# OPTIMIZATION: Cache the tensors rather than creating them every step.
176
142
self .arange_np : npt .NDArray [np .int32 ] = np .arange (max (
177
- self .max_num_reqs + 1 , self .max_model_len , self .max_num_tokens ),
143
+ self .max_num_reqs + 1 , self .model_config .max_model_len ,
144
+ self .max_num_tokens ),
178
145
dtype = np .int32 )
179
146
# NOTE(woosuk): These tensors are "stateless", i.e., they are literally
180
147
# a faster version of creating a new tensor every time. Thus, we should
181
148
# not make any assumptions about the values in these tensors.
182
149
self .input_ids_cpu = torch .zeros (self .max_num_tokens ,
183
150
dtype = torch .int32 ,
184
151
device = "cpu" ,
185
- pin_memory = self .pin_memory )
186
- self .input_ids_np = self .input_ids_cpu .numpy ()
152
+ pin_memory = True )
187
153
self .positions_cpu = torch .zeros (self .max_num_tokens ,
188
154
dtype = torch .int64 ,
189
155
device = "cpu" ,
190
- pin_memory = self . pin_memory )
156
+ pin_memory = True )
191
157
self .positions_np = self .positions_cpu .numpy ()
192
158
193
159
self .slot_mapping_cpu = torch .zeros (self .max_num_tokens ,
194
160
dtype = torch .int32 ,
195
161
device = "cpu" ,
196
- pin_memory = self . pin_memory )
162
+ pin_memory = True )
197
163
self .slot_mapping_np = self .slot_mapping_cpu .numpy ()
198
164
199
- self .query_start_loc_cpu = torch .zeros (self .max_num_reqs + 1 ,
200
- dtype = torch .int32 ,
201
- device = "cpu" ,
202
- pin_memory = self .pin_memory )
203
- self .query_start_loc_np = self .query_start_loc_cpu .numpy ()
204
165
self .seq_lens_cpu = torch .zeros (self .max_num_reqs ,
205
166
dtype = torch .int32 ,
206
167
device = "cpu" ,
207
- pin_memory = self . pin_memory )
168
+ pin_memory = True )
208
169
self .seq_lens_np = self .seq_lens_cpu .numpy ()
209
170
210
171
self .input_positions_cpu = torch .arange (0 ,
@@ -220,7 +181,8 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
220
181
# Therefore, an environment variable is added here to dynamically set
221
182
# the size of the pre-constructed mask matrix based on requirements.
222
183
mask_len = os .getenv ("PAGED_ATTENTION_MASK_LEN" , 10000 )
223
- self .attn_mask_len = min (self .max_model_len , int (mask_len ))
184
+ self .attn_mask_len = min (self .model_config .max_model_len ,
185
+ int (mask_len ))
224
186
self .attn_mask_npu = torch .full (
225
187
(self .attn_mask_len , self .attn_mask_len ),
226
188
NPU_PAGED_ATTENTION_MASK_VALUE ,
@@ -384,8 +346,8 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
384
346
def get_model (self ) -> nn .Module :
385
347
return self .model
386
348
387
- def make_attention_mask (self , seq_lens , query_lens ,
388
- position ) -> torch .Tensor :
349
+ def _make_attention_mask (self , seq_lens , query_lens ,
350
+ position ) -> torch .Tensor :
389
351
max_seq_len = max (seq_lens , default = 0 )
390
352
if max_seq_len <= self .attn_mask_len :
391
353
return torch .index_select (self .attn_mask_npu ,
@@ -475,9 +437,9 @@ def _process_reqs(
475
437
slot_mapping = self .slot_mapping_cpu [:total_num_scheduled_tokens ].to (
476
438
self .device , non_blocking = True )
477
439
478
- attn_mask = self .make_attention_mask (seq_lens = seq_lens ,
479
- query_lens = num_scheduled_tokens ,
480
- position = positions )
440
+ attn_mask = self ._make_attention_mask (seq_lens = seq_lens ,
441
+ query_lens = num_scheduled_tokens ,
442
+ position = positions )
481
443
482
444
attn_metadata = AscendMetadata (
483
445
seq_lens = query_lens ,
@@ -653,22 +615,19 @@ def _profile_multimodal(self) -> None:
653
615
self .encoder_cache ["tmp" ] = dict (enumerate (dummy_encoder_outputs ))
654
616
655
617
@torch .inference_mode ()
656
- def _dummy_run (
657
- self ,
658
- num_tokens : int ,
659
- ) -> torch .Tensor :
618
+ def _dummy_run (self ) -> torch .Tensor :
660
619
model = self .model
661
620
if self .is_multimodal_model :
662
621
input_ids = None
663
- inputs_embeds = self .inputs_embeds [:num_tokens ]
622
+ inputs_embeds = self .inputs_embeds [:self . max_num_tokens ]
664
623
else :
665
- input_ids = self .input_ids [:num_tokens ]
624
+ input_ids = self .input_ids [:self . max_num_tokens ]
666
625
inputs_embeds = None
667
626
668
627
if self .uses_mrope :
669
- positions = self .mrope_positions [:, :num_tokens ]
628
+ positions = self .mrope_positions [:, :self . max_num_tokens ]
670
629
else :
671
- positions = self .input_positions_cpu [:num_tokens ]
630
+ positions = self .input_positions_cpu [:self . max_num_tokens ]
672
631
673
632
if get_pp_group ().is_first_rank :
674
633
intermediate_tensors = None
@@ -680,7 +639,7 @@ def _dummy_run(
680
639
dtype = self .model_config .dtype ,
681
640
device = self .device ))
682
641
intermediate_tensors = IntermediateTensors ({
683
- k : v [:num_tokens ]
642
+ k : v [:self . max_num_tokens ]
684
643
for k , v in self .intermediate_tensors .items ()
685
644
})
686
645
@@ -719,15 +678,15 @@ def profile_run(self) -> None:
719
678
]
720
679
721
680
# Trigger compilation for general shape.
722
- hidden_states = self ._dummy_run (self . max_num_tokens )
681
+ hidden_states = self ._dummy_run ()
723
682
724
683
if get_pp_group ().is_last_rank :
725
684
hidden_states = hidden_states [logit_indices ]
726
685
logits = self .model .compute_logits (hidden_states , None )
727
686
else :
728
687
logits = None
729
688
730
- current_platform .synchronize ()
689
+ NPUPlatform .synchronize ()
731
690
del hidden_states , logits , dummy_kv_caches
732
691
self .encoder_cache .clear ()
733
692
gc .collect ()
@@ -739,10 +698,8 @@ def load_model(self) -> None:
739
698
self .model = get_model (vllm_config = self .vllm_config )
740
699
if self .lora_config :
741
700
raise ValueError ("LoRA model is not supported on NPU now." )
742
-
743
- self .model_memory_usage = m .consumed_memory
744
701
logger .info ("Loading model weights took %.4f GB" ,
745
- self . model_memory_usage / float (2 ** 30 ))
702
+ m . consumed_memory / float (2 ** 30 ))
746
703
747
704
def initialize_kv_cache (self , kv_cache_config : KVCacheConfig ) -> None :
748
705
"""
0 commit comments