Skip to content

Commit c7f6584

Browse files
authored
[V1] clean up V1 code (#505)
Clean up V1 code: 1. remove useless code. 2. format code to be clear. Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
1 parent f6af1d2 commit c7f6584

File tree

2 files changed

+114
-168
lines changed

2 files changed

+114
-168
lines changed

vllm_ascend/worker/model_runner_v1.py

Lines changed: 42 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
import numpy as np
2525
import numpy.typing as npt
2626
import torch
27-
import torch.distributed
2827
import torch.nn as nn
2928
from vllm.attention import AttentionType
3029
from vllm.attention.layer import Attention
@@ -36,11 +35,9 @@
3635
from vllm.model_executor.layers.fused_moe import FusedMoE
3736
from vllm.model_executor.model_loader import get_model
3837
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
39-
from vllm.platforms import current_platform
4038
from vllm.sampling_params import SamplingType
4139
from vllm.sequence import IntermediateTensors
42-
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
43-
LayerBlockType, cdiv, is_pin_memory_available)
40+
from vllm.utils import DeviceMemoryProfiler, LayerBlockType, cdiv
4441
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
4542
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
4643
KVCacheSpec)
@@ -50,6 +47,7 @@
5047

5148
from vllm_ascend.attention.attention_v1 import (AscendAttentionBackend,
5249
AscendMetadata)
50+
from vllm_ascend.platform import NPUPlatform
5351

5452
if TYPE_CHECKING:
5553
from vllm.v1.core.sched.output import SchedulerOutput
@@ -60,81 +58,49 @@
6058
class NPUModelRunner:
6159

6260
def __init__(self, vllm_config: VllmConfig, device: torch.device):
63-
6461
self.vllm_config = vllm_config
6562
self.model_config = vllm_config.model_config
66-
self.cache_config = vllm_config.cache_config
6763
self.lora_config = vllm_config.lora_config
68-
self.load_config = vllm_config.load_config
69-
self.parallel_config = vllm_config.parallel_config
7064
self.scheduler_config = vllm_config.scheduler_config
71-
self.speculative_config = vllm_config.speculative_config
72-
self.prompt_adapter_config = vllm_config.prompt_adapter_config
73-
self.observability_config = vllm_config.observability_config
74-
75-
model_config = self.model_config
76-
cache_config = self.cache_config
77-
scheduler_config = self.scheduler_config
78-
parallel_config = self.parallel_config
79-
8065
self.device = device
81-
self.pin_memory = is_pin_memory_available()
82-
self.dtype = self.model_config.dtype
83-
84-
if cache_config.cache_dtype == "auto":
85-
self.kv_cache_dtype = self.dtype
86-
else:
87-
self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
88-
cache_config.cache_dtype]
89-
90-
self.is_multimodal_model = model_config.is_multimodal_model
91-
self.sliding_window = model_config.get_sliding_window()
92-
self.block_size = cache_config.block_size
93-
self.max_model_len = model_config.max_model_len
94-
self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size)
95-
self.max_num_tokens = scheduler_config.max_num_batched_tokens
96-
self.max_num_reqs = scheduler_config.max_num_seqs
66+
self.is_multimodal_model = self.model_config.is_multimodal_model
67+
self.block_size = vllm_config.cache_config.block_size
68+
self.max_num_blocks_per_req = cdiv(self.model_config.max_model_len,
69+
self.block_size)
70+
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
71+
self.max_num_reqs = self.scheduler_config.max_num_seqs
9772

9873
# Model-related.
99-
self.num_attn_layers = model_config.get_num_layers_by_block_type(
100-
parallel_config, LayerBlockType.attention)
101-
self.num_query_heads = model_config.get_num_attention_heads(
102-
parallel_config)
103-
self.num_kv_heads = model_config.get_num_kv_heads(parallel_config)
104-
self.head_size = model_config.get_head_size()
105-
self.hidden_size = model_config.get_hidden_size()
74+
self.num_attn_layers = self.model_config.get_num_layers_by_block_type(
75+
vllm_config.parallel_config, LayerBlockType.attention)
76+
self.hidden_size = self.model_config.get_hidden_size()
10677

10778
# Multi-modal data support
10879
self.input_registry = INPUT_REGISTRY
10980
self.mm_registry = MULTIMODAL_REGISTRY
110-
self.uses_mrope = model_config.uses_mrope
81+
self.uses_mrope = self.model_config.uses_mrope
11182

112-
encoder_compute_budget, encoder_cache_size = compute_encoder_budget(
113-
model_config=model_config,
114-
scheduler_config=scheduler_config,
83+
self.max_num_encoder_input_tokens, self.encoder_cache_size = compute_encoder_budget(
84+
model_config=self.model_config,
85+
scheduler_config=self.scheduler_config,
11586
mm_registry=self.mm_registry)
116-
self.max_num_encoder_input_tokens = encoder_compute_budget
117-
self.encoder_cache_size = encoder_cache_size
11887

11988
# Lazy initialization
12089
# self.model: nn.Module # Set after load_model
12190
self.kv_caches: List[torch.Tensor] = []
12291
# req_id -> (input_id -> encoder_output)
12392
self.encoder_cache: Dict[str, Dict[int, torch.Tensor]] = {}
12493

125-
# Set up speculative decoding.
126-
self.use_spec_decode = False
127-
12894
# Request states.
12995
self.requests: Dict[str, CachedRequestState] = {}
13096
# Persistent batch.
13197
self.input_batch = InputBatch(
13298
max_num_reqs=self.max_num_reqs,
133-
max_model_len=self.max_model_len,
99+
max_model_len=self.model_config.max_model_len,
134100
max_num_blocks_per_req=self.max_num_blocks_per_req,
135101
device=self.device,
136-
pin_memory=self.pin_memory,
137-
vocab_size=model_config.get_vocab_size(),
102+
pin_memory=True,
103+
vocab_size=self.model_config.get_vocab_size(),
138104
)
139105

140106
self.input_ids = torch.zeros(self.max_num_tokens,
@@ -165,46 +131,41 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
165131
(3, self.max_num_tokens + 1),
166132
dtype=torch.int64,
167133
device="cpu",
168-
pin_memory=self.pin_memory)
134+
pin_memory=True)
169135

170136
self.inputs_embeds = torch.zeros(
171137
(self.max_num_tokens, self.hidden_size),
172-
dtype=self.dtype,
138+
dtype=self.model_config.dtype,
173139
device=self.device)
174140

175141
# OPTIMIZATION: Cache the tensors rather than creating them every step.
176142
self.arange_np: npt.NDArray[np.int32] = np.arange(max(
177-
self.max_num_reqs + 1, self.max_model_len, self.max_num_tokens),
143+
self.max_num_reqs + 1, self.model_config.max_model_len,
144+
self.max_num_tokens),
178145
dtype=np.int32)
179146
# NOTE(woosuk): These tensors are "stateless", i.e., they are literally
180147
# a faster version of creating a new tensor every time. Thus, we should
181148
# not make any assumptions about the values in these tensors.
182149
self.input_ids_cpu = torch.zeros(self.max_num_tokens,
183150
dtype=torch.int32,
184151
device="cpu",
185-
pin_memory=self.pin_memory)
186-
self.input_ids_np = self.input_ids_cpu.numpy()
152+
pin_memory=True)
187153
self.positions_cpu = torch.zeros(self.max_num_tokens,
188154
dtype=torch.int64,
189155
device="cpu",
190-
pin_memory=self.pin_memory)
156+
pin_memory=True)
191157
self.positions_np = self.positions_cpu.numpy()
192158

193159
self.slot_mapping_cpu = torch.zeros(self.max_num_tokens,
194160
dtype=torch.int32,
195161
device="cpu",
196-
pin_memory=self.pin_memory)
162+
pin_memory=True)
197163
self.slot_mapping_np = self.slot_mapping_cpu.numpy()
198164

199-
self.query_start_loc_cpu = torch.zeros(self.max_num_reqs + 1,
200-
dtype=torch.int32,
201-
device="cpu",
202-
pin_memory=self.pin_memory)
203-
self.query_start_loc_np = self.query_start_loc_cpu.numpy()
204165
self.seq_lens_cpu = torch.zeros(self.max_num_reqs,
205166
dtype=torch.int32,
206167
device="cpu",
207-
pin_memory=self.pin_memory)
168+
pin_memory=True)
208169
self.seq_lens_np = self.seq_lens_cpu.numpy()
209170

210171
self.input_positions_cpu = torch.arange(0,
@@ -220,7 +181,8 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
220181
# Therefore, an environment variable is added here to dynamically set
221182
# the size of the pre-constructed mask matrix based on requirements.
222183
mask_len = os.getenv("PAGED_ATTENTION_MASK_LEN", 10000)
223-
self.attn_mask_len = min(self.max_model_len, int(mask_len))
184+
self.attn_mask_len = min(self.model_config.max_model_len,
185+
int(mask_len))
224186
self.attn_mask_npu = torch.full(
225187
(self.attn_mask_len, self.attn_mask_len),
226188
NPU_PAGED_ATTENTION_MASK_VALUE,
@@ -384,8 +346,8 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
384346
def get_model(self) -> nn.Module:
385347
return self.model
386348

387-
def make_attention_mask(self, seq_lens, query_lens,
388-
position) -> torch.Tensor:
349+
def _make_attention_mask(self, seq_lens, query_lens,
350+
position) -> torch.Tensor:
389351
max_seq_len = max(seq_lens, default=0)
390352
if max_seq_len <= self.attn_mask_len:
391353
return torch.index_select(self.attn_mask_npu,
@@ -475,9 +437,9 @@ def _process_reqs(
475437
slot_mapping = self.slot_mapping_cpu[:total_num_scheduled_tokens].to(
476438
self.device, non_blocking=True)
477439

478-
attn_mask = self.make_attention_mask(seq_lens=seq_lens,
479-
query_lens=num_scheduled_tokens,
480-
position=positions)
440+
attn_mask = self._make_attention_mask(seq_lens=seq_lens,
441+
query_lens=num_scheduled_tokens,
442+
position=positions)
481443

482444
attn_metadata = AscendMetadata(
483445
seq_lens=query_lens,
@@ -653,22 +615,19 @@ def _profile_multimodal(self) -> None:
653615
self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs))
654616

655617
@torch.inference_mode()
656-
def _dummy_run(
657-
self,
658-
num_tokens: int,
659-
) -> torch.Tensor:
618+
def _dummy_run(self) -> torch.Tensor:
660619
model = self.model
661620
if self.is_multimodal_model:
662621
input_ids = None
663-
inputs_embeds = self.inputs_embeds[:num_tokens]
622+
inputs_embeds = self.inputs_embeds[:self.max_num_tokens]
664623
else:
665-
input_ids = self.input_ids[:num_tokens]
624+
input_ids = self.input_ids[:self.max_num_tokens]
666625
inputs_embeds = None
667626

668627
if self.uses_mrope:
669-
positions = self.mrope_positions[:, :num_tokens]
628+
positions = self.mrope_positions[:, :self.max_num_tokens]
670629
else:
671-
positions = self.input_positions_cpu[:num_tokens]
630+
positions = self.input_positions_cpu[:self.max_num_tokens]
672631

673632
if get_pp_group().is_first_rank:
674633
intermediate_tensors = None
@@ -680,7 +639,7 @@ def _dummy_run(
680639
dtype=self.model_config.dtype,
681640
device=self.device))
682641
intermediate_tensors = IntermediateTensors({
683-
k: v[:num_tokens]
642+
k: v[:self.max_num_tokens]
684643
for k, v in self.intermediate_tensors.items()
685644
})
686645

@@ -719,15 +678,15 @@ def profile_run(self) -> None:
719678
]
720679

721680
# Trigger compilation for general shape.
722-
hidden_states = self._dummy_run(self.max_num_tokens)
681+
hidden_states = self._dummy_run()
723682

724683
if get_pp_group().is_last_rank:
725684
hidden_states = hidden_states[logit_indices]
726685
logits = self.model.compute_logits(hidden_states, None)
727686
else:
728687
logits = None
729688

730-
current_platform.synchronize()
689+
NPUPlatform.synchronize()
731690
del hidden_states, logits, dummy_kv_caches
732691
self.encoder_cache.clear()
733692
gc.collect()
@@ -739,10 +698,8 @@ def load_model(self) -> None:
739698
self.model = get_model(vllm_config=self.vllm_config)
740699
if self.lora_config:
741700
raise ValueError("LoRA model is not supported on NPU now.")
742-
743-
self.model_memory_usage = m.consumed_memory
744701
logger.info("Loading model weights took %.4f GB",
745-
self.model_memory_usage / float(2**30))
702+
m.consumed_memory / float(2**30))
746703

747704
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
748705
"""

0 commit comments

Comments
 (0)