Skip to content

Commit 5caa186

Browse files
committed
feat: support compile torchair graph while warming up
Signed-off-by: boying <897013703@qq.com>
1 parent a0c3e9b commit 5caa186

File tree

4 files changed

+200
-184
lines changed

4 files changed

+200
-184
lines changed

vllm_ascend/attention/mla_v1.py

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ class AscendMLAMetadata:
101101
# For logging.
102102
num_input_tokens: int = 0 # Number of tokens including padding.
103103

104+
is_dummy: bool = False
104105
# The dimension of the attention heads
105106
head_dim: Optional[int] = None
106107
attn_mask: torch.Tensor = None
@@ -225,7 +226,44 @@ def _get_graph_runner_block_tables(
225226
max_blocks] = block_tables[:num_seqs, :
226227
max_blocks]
227228

228-
return graph_block_tables
229+
return graph_block_tables[:num_seqs, :max_blocks]
230+
231+
def build_dummy(self, num_reqs: int,
232+
num_actual_tokens: int) -> AscendMLAMetadata:
233+
device = self.runner.device
234+
_, max_blocks = self.runner.graph_block_tables.shape
235+
block_table = torch.zeros((num_reqs, max_blocks),
236+
dtype=torch.int32,
237+
device=device)
238+
block_table = self._get_graph_runner_block_tables(
239+
num_reqs, block_table)
240+
seq_lens = torch.ones(num_reqs, dtype=torch.int32, device=device)
241+
input_positions = torch.zeros(num_reqs,
242+
dtype=torch.int32,
243+
device=device).long()
244+
slot_mapping = torch.full((num_reqs, ),
245+
PAD_SLOT_ID,
246+
dtype=torch.int32,
247+
device=device)
248+
decode_metadata = AscendMLADecodeMetadata(
249+
input_positions=input_positions,
250+
block_table=block_table,
251+
seq_lens=seq_lens,
252+
seq_lens_list=seq_lens.tolist(),
253+
max_seq_lens=1)
254+
return self.metadata_cls( # type: ignore
255+
num_input_tokens=num_actual_tokens,
256+
num_actual_tokens=num_actual_tokens,
257+
slot_mapping=slot_mapping,
258+
head_dim=self.runner.model_config.get_head_size(),
259+
num_decodes=1,
260+
num_decode_tokens=1,
261+
num_prefills=0,
262+
attn_mask=self.runner.attn_mask,
263+
attn_state=AscendAttentionState.DecodeOnly,
264+
prefill=None,
265+
decode=decode_metadata,
266+
is_dummy=False)
229267

230268
def build(self,
231269
num_reqs: int,
@@ -307,7 +345,7 @@ def build(self,
307345
block_table = torch.cat([block_table, block_table_padding],
308346
dim=0)
309347
block_table = self._get_graph_runner_block_tables(
310-
num_seqs, block_table)
348+
num_seqs + graph_pad_size, block_table)
311349
padding_0 = torch.zeros(graph_pad_size,
312350
dtype=input_positions.dtype,
313351
device=input_positions.device)
@@ -663,6 +701,9 @@ def forward(
663701
if attn_metadata is None:
664702
# Profiling run.
665703
return output
704+
if attn_metadata.is_dummy:
705+
# Skip dummy run
706+
return output
666707
self.running_in_graph = self.enable_graph_mode and attn_metadata.attn_state == AscendAttentionState.DecodeOnly
667708
num_actual_toks = attn_metadata.num_actual_tokens
668709
if k_pe is None and not self.running_in_graph:

vllm_ascend/core/scheduler.py

Lines changed: 5 additions & 139 deletions
Original file line numberDiff line numberDiff line change
@@ -15,20 +15,18 @@
1515
# This file is a part of the vllm-ascend project.
1616
#
1717
from collections import deque
18-
from typing import Iterable, Optional, Union
18+
from typing import Iterable, Union
1919

2020
from vllm.config import VllmConfig
2121
from vllm.logger import logger
2222
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
2323
from vllm.utils import cdiv
2424
from vllm.v1.core.sched.output import NewRequestData, SchedulerOutput
2525
from vllm.v1.core.sched.scheduler import Scheduler
26-
from vllm.v1.core.sched.utils import check_stop
27-
from vllm.v1.engine import EngineCoreOutput, EngineCoreOutputs
26+
from vllm.v1.engine import EngineCoreOutputs
2827
from vllm.v1.kv_cache_interface import KVCacheConfig
2928
from vllm.v1.outputs import ModelRunnerOutput
3029
from vllm.v1.request import Request, RequestStatus
31-
from vllm.v1.spec_decode.metrics import SpecDecodingStats
3230
from vllm.v1.structured_output import StructuredOutputManager
3331

3432

@@ -365,41 +363,22 @@ def finish_requests(
365363
For example, the API server can abort a request when the client
366364
disconnects.
367365
"""
368-
assert RequestStatus.is_finished(finished_status)
369-
if isinstance(request_ids, str):
370-
request_ids = (request_ids, )
371-
else:
372-
request_ids = set(request_ids)
373-
374366
for req_id in request_ids:
375367
request = self.requests.get(req_id)
376368
if request is None:
377369
# Invalid request ID.
378370
continue
379-
380371
if request.status == RequestStatus.RUNNING:
381-
self.running.remove(request)
382372
self.scheduled_req_ids.discard(request.request_id)
383-
else:
384-
self.waiting.remove(request)
385-
request.status = finished_status
386-
self._free_request(request)
373+
super().finish_requests(request_ids, finished_status)
387374

388375
def update_from_output(
389376
self,
390377
scheduler_output: SchedulerOutput,
391378
model_runner_output: ModelRunnerOutput,
392379
) -> EngineCoreOutputs:
393-
sampled_token_ids = model_runner_output.sampled_token_ids
394-
spec_token_ids = model_runner_output.spec_token_ids
395-
logprobs = model_runner_output.logprobs
396-
prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict
397380
num_scheduled_tokens = scheduler_output.num_scheduled_tokens
398381

399-
new_running: list[Request] = []
400-
outputs: list[EngineCoreOutput] = []
401-
spec_decoding_stats: Optional[SpecDecodingStats] = None
402-
403382
# NOTE(woosuk): As len(self.running) can be up to 1K or more, the below
404383
# loop can be a performance bottleneck. We should do our best to avoid
405384
# expensive operations inside the loop.
@@ -408,121 +387,8 @@ def update_from_output(
408387
num_tokens_scheduled = num_scheduled_tokens.get(req_id, 0)
409388
if num_tokens_scheduled == 0:
410389
# The request was not scheduled in this step.
411-
new_running.append(request)
412390
continue
413-
414-
req_index = model_runner_output.req_id_to_index[req_id]
415-
generated_token_ids = sampled_token_ids[req_index]
416-
417-
scheduled_spec_token_ids = (
418-
scheduler_output.scheduled_spec_decode_tokens.get(req_id))
419-
if scheduled_spec_token_ids:
420-
# num_computed_tokens represents the number of tokens
421-
# processed in the current step, considering scheduled
422-
# tokens and rejections. If some tokens are rejected,
423-
# num_computed_tokens is decreased by the number of rejected
424-
# tokens, where is given by:
425-
# len(scheduled_spec_token_ids) + 1 - len(generated_token_ids).
426-
num_tokens_rejected = (len(scheduled_spec_token_ids) + 1 -
427-
len(generated_token_ids))
428-
request.num_computed_tokens -= num_tokens_rejected
429-
spec_decoding_stats = self.make_spec_decoding_stats(
430-
spec_decoding_stats,
431-
num_draft_tokens=len(scheduled_spec_token_ids),
432-
num_accepted_tokens=len(generated_token_ids) - 1)
433-
434-
cached_encoder_input_ids = (
435-
self.encoder_cache_manager.get_cached_input_ids(request))
436-
# OPTIMIZATION: Avoid list(set) if the set is empty.
437-
if cached_encoder_input_ids:
438-
for input_id in list(cached_encoder_input_ids):
439-
mm_positions = request.mm_positions[input_id]
440-
start_pos = mm_positions.offset
441-
num_tokens = mm_positions.length
442-
if start_pos + num_tokens <= request.num_computed_tokens:
443-
# The encoder output is already processed and stored
444-
# in the decoder's KV cache.
445-
self.encoder_cache_manager.free_encoder_input(
446-
request, input_id)
447-
448-
stopped = False
449-
new_logprobs = None
450-
new_token_ids = generated_token_ids
451-
452-
# Append generated tokens and check for stop. Note that if
453-
# a request is still being prefilled, we expect the model runner
454-
# to return empty token ids for the request.
455-
for num_new, output_token_id in enumerate(new_token_ids, 1):
456-
request.append_output_token_ids(output_token_id)
457-
458-
# Check for stop and update request state.
459-
# This must be called before we make the EngineCoreOutput.
460-
stopped = check_stop(request, self.max_model_len)
461-
if stopped:
462-
self._free_request(request)
463-
del new_token_ids[num_new:] # Trim new tokens if needed.
464-
break
465-
466-
# Extract sample logprobs if needed.
467-
if request.sampling_params.logprobs is not None and logprobs:
468-
# NOTE: once we support N tokens per step (spec decode),
469-
# the outer lists can be of length > 1.
470-
new_logprobs = logprobs.slice(req_index, req_index + 1)
471-
472-
if new_token_ids and request.use_structured_output:
473-
# NOTE: structured_output_request
474-
# should not be None if use_structured_output, we have
475-
# check above, so safe to ignore type warning
476-
request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr]
477-
req_id, new_token_ids)
478-
479-
# Add newly generated spec token ids to the request.
480-
if spec_token_ids is not None:
481-
if request.use_structured_output:
482-
metadata = request.structured_output_request
483-
assert metadata is not None and metadata.grammar is not None
484-
# Needs to happen after new_token_ids are accepted.
485-
request.spec_token_ids = metadata.grammar.validate_tokens(
486-
spec_token_ids[req_index])
487-
else:
488-
request.spec_token_ids = spec_token_ids[req_index]
489-
490-
# Get prompt logprobs for this request.
491-
prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id)
492-
if new_token_ids:
493-
# Add EngineCoreOutput for this Request.
494-
outputs.append(
495-
EngineCoreOutput(
496-
request_id=req_id,
497-
new_token_ids=new_token_ids,
498-
finish_reason=request.get_finished_reason(),
499-
new_logprobs=new_logprobs,
500-
new_prompt_logprobs_tensors=prompt_logprobs_tensors,
501-
stop_reason=request.stop_reason,
502-
events=request.take_events()))
503-
else:
504-
# Invariant: EngineCore returns no partial prefill outputs.
505-
assert not prompt_logprobs_tensors
506-
507391
self.scheduled_req_ids.remove(req_id)
508-
if not stopped:
509-
new_running.append(request)
510-
511-
# Return the cached request data to the queue so they can be reused.
512-
for req_data in scheduler_output.scheduled_cached_reqs:
513-
# NOTE(rob): since we free stopped reqs above, adding stopped reqs
514-
# to _cached_reqs_data will cause a memory leak.
515-
if req_data.req_id not in self.finished_req_ids:
516-
self._cached_reqs_data[req_data.req_id].append(req_data)
517-
518-
self.running = new_running
519-
engine_core_outputs = EngineCoreOutputs(
520-
outputs=outputs,
521-
scheduler_stats=self.make_stats(spec_decoding_stats),
522-
)
523-
if self.include_finished_set:
524-
#TODO currently sending duplicates here, improve this
525-
engine_core_outputs.finished_requests = (
526-
scheduler_output.finished_req_ids | self.finished_req_ids)
527392

528-
return engine_core_outputs
393+
return super().update_from_output(scheduler_output,
394+
model_runner_output)

vllm_ascend/models/deepseek_v2.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,10 @@
3636
from vllm.attention import Attention, AttentionMetadata
3737
from vllm.config import (CacheConfig, ModelConfig, VllmConfig,
3838
get_current_vllm_config)
39-
from vllm.distributed import (get_dp_group, get_pp_group,
39+
from vllm.distributed import (get_pp_group,
4040
get_tensor_model_parallel_world_size,
4141
get_tp_group, tensor_model_parallel_all_reduce)
42+
from vllm.distributed.parallel_state import get_dp_group
4243
from vllm.forward_context import get_forward_context
4344
from vllm.model_executor.layers.activation import SiluAndMul
4445
from vllm.model_executor.layers.layernorm import RMSNorm
@@ -210,8 +211,12 @@ def __init__(
210211
self.tp_group = get_tp_group().device_group
211212
self.tp_rank = get_tp_group().rank_in_group
212213

213-
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
214-
attn_metadata = get_forward_context().attn_metadata
214+
def forward(
215+
self,
216+
hidden_states: torch.Tensor,
217+
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
218+
if attn_metadata is None:
219+
attn_metadata = get_forward_context().attn_metadata
215220
# when profile runs, force experts to load balanced tokens
216221
# to avoid high memory consumption on a single rank.
217222
# TODO: need a better flag to indicate whether in profile run or not.
@@ -540,7 +545,11 @@ def forward(
540545
# Fully Connected
541546
hidden_states, residual = self.post_attention_layernorm(
542547
hidden_states, residual)
543-
hidden_states = self.mlp(hidden_states)
548+
549+
if isinstance(self.mlp, CustomDeepseekV2MoE):
550+
hidden_states = self.mlp(hidden_states, attn_metadata)
551+
else:
552+
hidden_states = self.mlp(hidden_states)
544553

545554
if isinstance(
546555
self.mlp,

0 commit comments

Comments
 (0)