Skip to content

Commit ca884ef

Browse files
authored
[Misc] Clean up uesless code for LLM initialize (#1373)
This PR aims to clean up the useless code for LLM setup. It helps to make the code more clear. 1. remove useless `self.xxx` property 2. change `set_random_seed` to `seed_everything` 3. remove `set_custom_all_reduce`, it's only used for cuda This is just a code clean. no change for any code logic. Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
1 parent 0060886 commit ca884ef

File tree

2 files changed

+67
-121
lines changed

2 files changed

+67
-121
lines changed

vllm_ascend/worker/model_runner_v1.py

Lines changed: 58 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,7 @@
4949
from vllm.multimodal.utils import group_mm_inputs_by_modality
5050
from vllm.sampling_params import SamplingType
5151
from vllm.sequence import IntermediateTensors
52-
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
53-
LayerBlockType, LazyLoader, cdiv)
52+
from vllm.utils import DeviceMemoryProfiler, LazyLoader, cdiv
5453
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
5554
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
5655
KVCacheSpec)
@@ -137,82 +136,69 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
137136
self.lora_config = vllm_config.lora_config
138137
self.scheduler_config = vllm_config.scheduler_config
139138
self.speculative_config = vllm_config.speculative_config
139+
self.block_size = vllm_config.cache_config.block_size
140+
self.max_num_blocks_per_req = cdiv(self.model_config.max_model_len,
141+
self.block_size)
142+
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
143+
self.max_num_reqs = self.scheduler_config.max_num_seqs
144+
self.dp_size = vllm_config.parallel_config.data_parallel_size
145+
self.dp_rank = vllm_config.parallel_config.data_parallel_rank
146+
self.device = device
147+
self.dtype = self.model_config.dtype
148+
self.sampler = Sampler()
149+
# Multi-modal data support
150+
self.input_registry = INPUT_REGISTRY
151+
self.mm_registry = MULTIMODAL_REGISTRY
152+
self.max_num_encoder_input_tokens, self.encoder_cache_size = compute_encoder_budget(
153+
model_config=self.model_config,
154+
scheduler_config=self.scheduler_config,
155+
mm_registry=self.mm_registry)
156+
157+
# Lazy initialization, these will be set after __init__
158+
self.kv_caches: List[torch.Tensor] = []
159+
self.encoder_cache: Dict[str, Dict[int, torch.Tensor]] = {}
160+
self.attn_mask = None
161+
self.attn_state = None
162+
self.requests: Dict[str, CachedRequestState] = {}
163+
self.intermediate_tensors: Optional[IntermediateTensors] = None
164+
140165
ascend_config = get_ascend_config()
141166
if ascend_config.ascend_scheduler_config.enabled:
142167
self.chunked_prefill_enabled = self.scheduler_config.chunked_prefill_enabled
143168
else:
144169
self.chunked_prefill_enabled = True
145-
self.device = device
146170

147171
self.is_multimodal_model = self.model_config.is_multimodal_model
148-
self.block_size = vllm_config.cache_config.block_size
149-
150-
self.max_num_blocks_per_req = cdiv(self.model_config.max_model_len,
151-
self.block_size)
152-
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
153-
self.max_num_reqs = self.scheduler_config.max_num_seqs
172+
if self.is_multimodal_model:
173+
self.inputs_embeds = torch.zeros(
174+
(self.max_num_tokens, self.model_config.get_hidden_size()),
175+
dtype=self.dtype,
176+
device=self.device)
154177

155178
self.graph_block_tables = np.zeros(
156-
(self.vllm_config.scheduler_config.max_num_seqs,
179+
(self.max_num_reqs,
157180
(self.model_config.max_model_len + self.block_size - 1) //
158181
self.block_size),
159182
dtype=np.int32)
160183

161-
# Model-related.
162-
self.num_attn_layers = self.model_config.get_num_layers_by_block_type(
163-
vllm_config.parallel_config, LayerBlockType.attention)
164-
self.hidden_size = self.model_config.get_hidden_size()
165-
self.dtype = self.model_config.dtype
166-
cache_config = vllm_config.cache_config
167-
if cache_config.cache_dtype == "auto":
168-
self.kv_cache_dtype = self.dtype
169-
else:
170-
self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
171-
cache_config.cache_dtype]
172-
173-
self.head_size = self.model_config.get_head_size()
184+
# Set up Attention
174185
self.attn_backend = get_attn_backend(
175-
self.head_size,
186+
0,
176187
self.dtype,
177-
self.kv_cache_dtype,
188+
None,
178189
self.block_size,
179190
self.model_config.is_attention_free,
180191
use_mla=self.model_config.use_mla,
181192
)
182-
if self.attn_backend is None:
183-
error_msg = (
184-
f"Error with get_att_backend: {self.head_size=}, "
185-
f"{self.dtype=}, {self.kv_cache_dtype=}, {self.block_size=}, "
186-
f"{self.model_config.is_attention_free=}, "
187-
f"{self.model_config.use_mla=}")
188-
logger.error(error_msg)
189-
raise NotImplementedError(
190-
"Non-Attention backend is not supported by V1 NPUModelRunner.")
191-
192193
self.attn_metadata_builder = self.attn_backend.get_builder_cls()(
193194
weakref.proxy(self))
194195

195-
# Multi-modal data support
196-
self.input_registry = INPUT_REGISTRY
197-
self.mm_registry = MULTIMODAL_REGISTRY
198-
self.uses_mrope = self.model_config.uses_mrope
199-
200-
self.max_num_encoder_input_tokens, self.encoder_cache_size = compute_encoder_budget(
201-
model_config=self.model_config,
202-
scheduler_config=self.scheduler_config,
203-
mm_registry=self.mm_registry)
204-
205-
# Lazy initialization
206-
# self.model: nn.Module # Set after load_model
207-
self.kv_caches: List[torch.Tensor] = []
208-
# req_id -> (input_id -> encoder_output)
209-
self.encoder_cache: Dict[str, Dict[int, torch.Tensor]] = {}
210-
211196
# Set up speculative decoding.
212197
self.use_aux_hidden_state_outputs = False
213198
self.use_spec_decode = False
214199
self.spec_attn_mask = None
215200
self.use_eagle = False
201+
self.drafter = None
216202
if self.speculative_config:
217203
self.use_spec_decode = True
218204
self.spec_attn_mask = torch.triu(torch.ones(2048,
@@ -235,10 +221,6 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
235221
f"{self.speculative_config.method}")
236222
self.rejection_sampler = AscendRejectionSampler()
237223

238-
# Request states.
239-
self.requests: Dict[str, CachedRequestState] = {}
240-
# Persistent batch.
241-
242224
self.input_ids = torch.zeros(self.max_num_tokens,
243225
dtype=torch.int32,
244226
device=self.device)
@@ -251,9 +233,8 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
251233
self.seq_lens = torch.zeros(self.max_num_reqs,
252234
dtype=torch.int32,
253235
device=self.device)
254-
# None in the first PP rank. The rest are set after load_model.
255-
self.intermediate_tensors: Optional[IntermediateTensors] = None
256236

237+
self.uses_mrope = self.model_config.uses_mrope
257238
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
258239
if self.uses_mrope:
259240
# NOTE: `mrope_positions` is implemented with one additional dummy
@@ -276,12 +257,6 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
276257
pin_memory=True)
277258
self.mrope_positions_np = self.mrope_positions_cpu.numpy()
278259

279-
if self.is_multimodal_model:
280-
self.inputs_embeds = torch.zeros(
281-
(self.max_num_tokens, self.hidden_size),
282-
dtype=self.dtype,
283-
device=self.device)
284-
285260
# OPTIMIZATION: Cache the tensors rather than creating them every step.
286261
self.arange_np: npt.NDArray[np.int32] = np.arange(max(
287262
self.max_num_reqs + 1, self.model_config.max_model_len,
@@ -305,24 +280,17 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
305280
device="cpu",
306281
pin_memory=True)
307282
self.slot_mapping_np = self.slot_mapping_cpu.numpy()
308-
309283
self.query_start_loc_cpu = torch.zeros(self.max_num_reqs + 1,
310284
dtype=torch.int32,
311285
device="cpu",
312286
pin_memory=True)
313287
self.query_start_loc_np = self.query_start_loc_cpu.numpy()
314-
315288
self.seq_lens_cpu = torch.zeros(self.max_num_reqs,
316289
dtype=torch.int32,
317290
device="cpu",
318291
pin_memory=True)
319292
self.seq_lens_np = self.seq_lens_cpu.numpy()
320293

321-
self.input_positions_cpu = torch.arange(0,
322-
self.max_num_tokens,
323-
device="cpu")
324-
self.attn_mask = None
325-
self.attn_state = None
326294
self.use_aclgraph = (self.vllm_config.compilation_config.level
327295
== CompilationLevel.PIECEWISE
328296
and not self.model_config.enforce_eager)
@@ -339,38 +307,27 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
339307
# Therefore, an environment variable is added here to dynamically set
340308
# the size of the pre-constructed mask matrix based on requirements.
341309
mask_len = os.getenv("PAGED_ATTENTION_MASK_LEN", 10000)
342-
self.attn_mask_len = min(self.model_config.max_model_len,
343-
int(mask_len))
310+
attn_mask_len = min(self.model_config.max_model_len, int(mask_len))
344311
self.attn_mask_builder = AttentionMaskBuilder.initialize_from_len(
345-
self.attn_mask_len, self.dtype)
346-
347-
self.sampler = Sampler()
312+
attn_mask_len, self.dtype)
348313

349314
self.torchair_compiled_model = None # type: ignore
350315
self.torchair_compiled_models = {} # type: ignore
351-
ascend_config = get_ascend_config()
352-
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled and self.vllm_config.model_config.use_mla
316+
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
353317
self.use_cached_npu_graph = ascend_config.torchair_graph_config.use_cached_graph
354318
self.torchair_graph_batch_sizes = ascend_config.torchair_graph_config.graph_batch_sizes
355-
356319
if ascend_config.torchair_graph_config.graph_batch_sizes_init:
357320
self.init_torchair_graph_batch_sizes()
358-
359321
if len(self.torchair_graph_batch_sizes) == 0:
360322
# TODO(zzzzwwjj): check torchair_graph_batch_sizes init code
361-
self.torchair_graph_batch_sizes = [
362-
self.scheduler_config.max_num_seqs
363-
]
323+
self.torchair_graph_batch_sizes = [self.max_num_reqs]
364324

365325
torch._dynamo.cache_size.config.cache_size_limit += len(
366326
self.torchair_graph_batch_sizes)
367327
torch._dynamo.config.capture_dynamic_output_shape_ops = True
368328
torch._logging.set_logs(
369329
recompiles=envs_ascend.VLLM_ASCEND_TRACE_RECOMPILES)
370330

371-
self.dp_size = vllm_config.parallel_config.data_parallel_size
372-
self.dp_rank = vllm_config.parallel_config.data_parallel_rank
373-
374331
def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
375332
"""Update the cached states and the persistent batch with the scheduler
376333
output.
@@ -1702,8 +1659,7 @@ def _dummy_run(
17021659
# for dummy run with LoRA so that the num_reqs collectively
17031660
# has num_tokens in total.
17041661
assert num_tokens <= self.scheduler_config.max_num_batched_tokens
1705-
max_num_reqs = self.scheduler_config.max_num_seqs
1706-
num_reqs = max_num_reqs if num_tokens >= max_num_reqs else num_tokens
1662+
num_reqs = self.max_num_reqs if num_tokens >= self.max_num_reqs else num_tokens
17071663
min_tokens_per_req = num_tokens // num_reqs
17081664
num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs
17091665
num_scheduled_tokens_list[-1] += num_tokens % num_reqs
@@ -1805,14 +1761,13 @@ def profile_run(self) -> None:
18051761

18061762
# For profile, have maximum num_reqs and that collectively have
18071763
# maximum num_tokens.
1808-
num_reqs = self.scheduler_config.max_num_seqs
1809-
num_tokens = self.max_num_tokens
1810-
min_tokens_per_req = num_tokens // num_reqs
1764+
min_tokens_per_req = self.max_num_tokens // self.max_num_reqs
18111765

1812-
num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs
1813-
num_scheduled_tokens_list[-1] += num_tokens % num_reqs
1814-
assert sum(num_scheduled_tokens_list) == num_tokens
1815-
assert len(num_scheduled_tokens_list) == num_reqs
1766+
num_scheduled_tokens_list = [min_tokens_per_req] * self.max_num_reqs
1767+
num_scheduled_tokens_list[
1768+
-1] += self.max_num_tokens % self.max_num_reqs
1769+
assert sum(num_scheduled_tokens_list) == self.max_num_tokens
1770+
assert len(num_scheduled_tokens_list) == self.max_num_reqs
18161771

18171772
num_scheduled_tokens = np.array(num_scheduled_tokens_list,
18181773
dtype=np.int32)
@@ -1840,15 +1795,14 @@ def load_model(self) -> None:
18401795

18411796
with DeviceMemoryProfiler() as m: # noqa: SIM117
18421797
self.model = get_model(vllm_config=self.vllm_config)
1843-
if hasattr(self, "drafter"):
1798+
if self.drafter:
18441799
logger.info("Loading drafter model...")
18451800
if self.use_aux_hidden_state_outputs:
18461801
self.drafter.load_model(self.model)
1802+
self.model.set_aux_hidden_state_layers(
1803+
self.model.get_eagle3_aux_hidden_state_layers())
18471804
else:
18481805
self.drafter.load_model()
1849-
if self.use_aux_hidden_state_outputs:
1850-
self.model.set_aux_hidden_state_layers(
1851-
self.model.get_eagle3_aux_hidden_state_layers())
18521806
if self.lora_config:
18531807
self.model = self.load_lora_model(self.model,
18541808
self.model_config,
@@ -1934,7 +1888,7 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
19341888
device=self.device,
19351889
pin_memory=True,
19361890
vocab_size=self.model_config.get_vocab_size(),
1937-
block_sizes=[self.cache_config.block_size],
1891+
block_sizes=[self.block_size],
19381892
)
19391893

19401894
kv_cache_sizes = {}
@@ -2014,7 +1968,6 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
20141968
"""
20151969

20161970
forward_ctx = self.vllm_config.compilation_config.static_forward_context
2017-
block_size = self.vllm_config.cache_config.block_size
20181971
use_mla = self.vllm_config.model_config.use_mla
20191972
kv_cache_spec: dict[str, KVCacheSpec] = {}
20201973
for layer_name, attn_module in forward_ctx.items():
@@ -2026,7 +1979,7 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
20261979
assert isinstance(attn_module, Attention)
20271980
if attn_module.attn_type == AttentionType.DECODER:
20281981
kv_cache_spec[layer_name] = FullAttentionSpec(
2029-
block_size=block_size,
1982+
block_size=self.block_size,
20301983
num_kv_heads=attn_module.num_kv_heads,
20311984
head_size=attn_module.head_size,
20321985
dtype=attn_module.dtype,
@@ -2115,6 +2068,7 @@ def _generate_draft_token_ids(
21152068
start_idx = self.input_batch.num_tokens_no_spec[i]
21162069
end_idx = start_idx + num_sampled_ids
21172070
self.input_batch.token_ids_cpu[i, start_idx:end_idx] = sampled_ids
2071+
assert self.drafter is not None
21182072
drafter_output = self.drafter.propose(
21192073
self.input_batch.token_ids_cpu[i, :end_idx])
21202074
if drafter_output is None or len(drafter_output) == 0:
@@ -2171,6 +2125,7 @@ def _generate_mtp_token_ids(
21712125
dtype=torch.int32,
21722126
device=self.device,
21732127
)
2128+
assert self.drafter is not None
21742129
cu_num_tokens, token_indices = self.drafter.prepare_inputs(
21752130
attn_metadata.query_start_loc,
21762131
num_rejected_tokens,
@@ -2179,7 +2134,7 @@ def _generate_mtp_token_ids(
21792134
target_positions = positions[token_indices]
21802135
target_hidden_states = hidden_states[token_indices]
21812136
target_slot_mapping = attn_metadata.slot_mapping[token_indices]
2182-
2137+
assert self.drafter is not None
21832138
draft_token_ids = self.drafter.propose(
21842139
target_token_ids=target_token_ids,
21852140
target_positions=target_positions,
@@ -2200,7 +2155,7 @@ def init_torchair_graph_batch_sizes(self):
22002155
# NOTE: When use all2all | mc2, We need to slice the `num_tokens` dimension into `tp_size` blocks
22012156
start_graph_batch_size = max(start_graph_batch_size, tp_size)
22022157

2203-
while (start_graph_batch_size <= self.scheduler_config.max_num_seqs):
2158+
while (start_graph_batch_size <= self.max_num_reqs):
22042159
self.torchair_graph_batch_sizes.append(start_graph_batch_size)
22052160
start_graph_batch_size *= 2
22062161

0 commit comments

Comments
 (0)