From 53679a8a627a5b29b1691add2c7aae291bb47e78 Mon Sep 17 00:00:00 2001 From: boying <897013703@qq.com> Date: Tue, 13 May 2025 16:37:42 +0000 Subject: [PATCH] feat: support compile torchair graph while warming up Signed-off-by: boying <897013703@qq.com> --- .github/workflows/vllm_ascend_test.yaml | 6 +- tests/singlecard/test_scheduler.py | 66 ++++----- vllm_ascend/attention/mla_v1.py | 41 +++++- vllm_ascend/core/scheduler.py | 164 ++------------------- vllm_ascend/envs.py | 2 + vllm_ascend/models/deepseek_v2.py | 17 ++- vllm_ascend/worker/model_runner_v1.py | 180 ++++++++++++++++++------ 7 files changed, 242 insertions(+), 234 deletions(-) diff --git a/.github/workflows/vllm_ascend_test.yaml b/.github/workflows/vllm_ascend_test.yaml index d3622dce1..725a570e9 100644 --- a/.github/workflows/vllm_ascend_test.yaml +++ b/.github/workflows/vllm_ascend_test.yaml @@ -108,8 +108,7 @@ jobs: run: | if [[ "${{ matrix.os }}" == "linux-arm64-npu-1" ]]; then VLLM_USE_MODELSCOPE=True pytest -sv tests/singlecard/test_offline_inference.py - # AscendScheduler doesn't work, fix it later - # pytest -sv tests/singlecard/tets_schedule.py + pytest -sv tests/singlecard/test_scheduler.py # guided decoding doesn't work, fix it later # pytest -sv tests/singlecard/test_guided_decoding.py.py pytest -sv tests/singlecard/ --ignore=tests/singlecard/test_offline_inference.py --ignore=tests/singlecard/test_scheduler.py --ignore=tests/singlecard/test_guided_decoding.py @@ -124,8 +123,7 @@ jobs: run: | if [[ "${{ matrix.os }}" == "linux-arm64-npu-1" ]]; then VLLM_USE_MODELSCOPE=True pytest -sv tests/singlecard/test_offline_inference.py - # AscendScheduler doesn't work, fix it later - # pytest -sv tests/singlecard/tets_schedule.py + pytest -sv tests/singlecard/test_scheduler.py # guided decoding doesn't work, fix it later # pytest -sv tests/singlecard/test_guided_decoding.py.py pytest -sv tests/singlecard/ --ignore=tests/singlecard/test_offline_inference.py --ignore=tests/singlecard/test_scheduler.py --ignore=tests/singlecard/test_guided_decoding.py diff --git a/tests/singlecard/test_scheduler.py b/tests/singlecard/test_scheduler.py index 6eddd4f2c..e22be2e54 100644 --- a/tests/singlecard/test_scheduler.py +++ b/tests/singlecard/test_scheduler.py @@ -31,6 +31,7 @@ from vllm.v1.structured_output import StructuredOutputManager from vllm_ascend.core.scheduler import AscendScheduler +from vllm_ascend.utils import vllm_version_is EOS_TOKEN_ID = 50256 @@ -83,11 +84,10 @@ def create_scheduler( cache_dtype="auto", **kwargs_cache, ) - vllm_config = VllmConfig( - scheduler_config=scheduler_config, - model_config=model_config, - cache_config=cache_config, - ) + vllm_config = VllmConfig(scheduler_config=scheduler_config, + model_config=model_config, + cache_config=cache_config) + kv_cache_config = KVCacheConfig( num_blocks=10000, # A large number of blocks to hold all requests tensors={}, @@ -98,10 +98,7 @@ def create_scheduler( ) cache_config.num_gpu_blocks = 10000 return AscendScheduler( - scheduler_config, - model_config, - cache_config, - lora_config=None, + vllm_config, kv_cache_config=kv_cache_config, log_stats=True, structured_output_manager=StructuredOutputManager(vllm_config), @@ -126,17 +123,27 @@ def create_requests(num_requests: int, else: mm_position = None mm_inputs = None - request = Request( - request_id=f"{i}", - prompt=None, - prompt_token_ids=[i] * num_tokens, - sampling_params=sampling_params, - multi_modal_inputs=mm_inputs, - multi_modal_placeholders=mm_position, - multi_modal_hashes=None, - eos_token_id=EOS_TOKEN_ID, - arrival_time=0, - ) + if vllm_version_is("0.9.0"): + request = Request( + request_id=f"{i}", + prompt_token_ids=[i] * num_tokens, + sampling_params=sampling_params, + multi_modal_inputs=mm_inputs, + multi_modal_placeholders=mm_position, + multi_modal_hashes=None, + arrival_time=0, + eos_token_id=EOS_TOKEN_ID, + ) + else: + request = Request( + request_id=f"{i}", + prompt_token_ids=[i] * num_tokens, + sampling_params=sampling_params, + multi_modal_inputs=mm_inputs, + multi_modal_placeholders=mm_position, + multi_modal_hashes=None, + eos_token_id=EOS_TOKEN_ID, + ) requests.append(request) return requests @@ -225,12 +232,9 @@ def test_stop_via_update_from_output(): requests[0].request_id: 1, requests[1].request_id: 2 }, + scheduled_spec_decode_tokens={}, total_num_scheduled_tokens=3, scheduled_encoder_inputs={}, - scheduled_spec_decode_tokens={ - requests[0].request_id: [], - requests[1].request_id: [10] - }, num_common_prefix_blocks=0, finished_req_ids=set(), free_encoder_input_ids=[], @@ -275,12 +279,9 @@ def test_stop_via_update_from_output(): requests[0].request_id: 3, requests[1].request_id: 2 }, + scheduled_spec_decode_tokens={}, total_num_scheduled_tokens=5, scheduled_encoder_inputs={}, - scheduled_spec_decode_tokens={ - requests[0].request_id: [10, 42], - requests[1].request_id: [13] - }, num_common_prefix_blocks=0, finished_req_ids=set(), free_encoder_input_ids=[], @@ -323,12 +324,9 @@ def test_stop_via_update_from_output(): requests[0].request_id: 3, requests[1].request_id: 1 }, + scheduled_spec_decode_tokens={}, total_num_scheduled_tokens=4, scheduled_encoder_inputs={}, - scheduled_spec_decode_tokens={ - requests[0].request_id: [10, 11], - requests[1].request_id: [] - }, num_common_prefix_blocks=0, finished_req_ids=set(), free_encoder_input_ids=[], @@ -369,11 +367,9 @@ def test_stop_via_update_from_output(): scheduled_new_reqs=[], scheduled_cached_reqs=[], num_scheduled_tokens={requests[0].request_id: 3}, + scheduled_spec_decode_tokens={}, total_num_scheduled_tokens=3, scheduled_encoder_inputs={}, - scheduled_spec_decode_tokens={ - requests[0].request_id: [EOS_TOKEN_ID, 10] - }, num_common_prefix_blocks=0, finished_req_ids=set(), free_encoder_input_ids=[], diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 9f918c034..7dfb8a27f 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -241,7 +241,44 @@ def _get_graph_runner_block_tables( max_blocks] = block_tables[:num_seqs, : max_blocks] - return graph_block_tables + return graph_block_tables[:num_seqs, :max_blocks] + + def build_dummy(self, num_reqs: int, + num_actual_tokens: int) -> AscendMLAMetadata: + device = self.runner.device + _, max_blocks = self.runner.graph_block_tables.shape + block_table = torch.zeros((num_reqs, max_blocks), + dtype=torch.int32, + device=device) + block_table = self._get_graph_runner_block_tables( + num_reqs, block_table) + seq_lens = torch.ones(num_reqs, dtype=torch.int32, device=device) + input_positions = torch.zeros(num_reqs, + dtype=torch.int32, + device=device).long() + slot_mapping = torch.full((num_reqs, ), + PAD_SLOT_ID, + dtype=torch.int32, + device=device) + decode_metadata = AscendMLADecodeMetadata( + input_positions=input_positions, + block_table=block_table, + seq_lens=seq_lens, + seq_lens_list=seq_lens.tolist(), + max_seq_lens=1) + return self.metadata_cls( # type: ignore + num_input_tokens=num_actual_tokens, + num_actual_tokens=num_actual_tokens, + slot_mapping=slot_mapping, + head_dim=self.runner.model_config.get_head_size(), + num_decodes=1, + num_decode_tokens=1, + num_prefills=0, + attn_mask=self.runner.attn_mask, + attn_state=AscendAttentionState.DecodeOnly, + prefill=None, + decode=decode_metadata, + ) def build(self, num_reqs: int, @@ -324,7 +361,7 @@ def build(self, block_table = torch.cat([block_table, block_table_padding], dim=0) block_table = self._get_graph_runner_block_tables( - num_seqs, block_table) + num_seqs + graph_pad_size, block_table) padding_0 = torch.zeros(graph_pad_size, dtype=input_positions.dtype, device=input_positions.device) diff --git a/vllm_ascend/core/scheduler.py b/vllm_ascend/core/scheduler.py index 122f7b965..0135d953f 100644 --- a/vllm_ascend/core/scheduler.py +++ b/vllm_ascend/core/scheduler.py @@ -15,7 +15,7 @@ # This file is a part of the vllm-ascend project. # from collections import deque -from typing import Iterable, Optional, Union +from typing import Iterable, Union from vllm.config import VllmConfig from vllm.logger import logger @@ -23,12 +23,10 @@ from vllm.utils import cdiv from vllm.v1.core.sched.output import NewRequestData, SchedulerOutput from vllm.v1.core.sched.scheduler import Scheduler -from vllm.v1.core.sched.utils import check_stop -from vllm.v1.engine import EngineCoreOutput, EngineCoreOutputs +from vllm.v1.engine import EngineCoreOutputs from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.request import Request, RequestStatus -from vllm.v1.spec_decode.metrics import SpecDecodingStats from vllm.v1.structured_output import StructuredOutputManager @@ -130,14 +128,15 @@ def skip_cur_request(): assert num_new_tokens > 0 watermark = getattr(self.scheduler_config, "watermark", 0.01) - if not self._check_watermark_for_prefill( - request, num_new_tokens, computed_blocks, watermark): + if not self._check_watermark_for_prefill(request, num_new_tokens, + computed_blocks.blocks, + watermark): # Scheduling would exceed watermark, skip. skip_cur_request() continue new_blocks = self.kv_cache_manager.allocate_slots( - request, num_new_tokens, computed_blocks) + request, num_new_tokens, new_computed_blocks=computed_blocks) if new_blocks is None: # The request cannot be scheduled. break @@ -155,9 +154,8 @@ def skip_cur_request(): if self.lora_config and request.lora_request: scheduled_loras.add(request.lora_request.lora_int_id) - req_to_new_block_ids[request.request_id] = [ - b.block_id for b in computed_blocks + new_blocks - ] + req_to_new_block_ids[request.request_id] = ( + self.kv_cache_manager.get_block_ids(request.request_id)) # Update request info. num_scheduled_tokens[request.request_id] = num_new_tokens token_budget -= num_new_tokens @@ -215,9 +213,8 @@ def skip_cur_request(): # Schedule the request. scheduled_running_reqs.append(request) self.scheduled_req_ids.add(request.request_id) - req_to_new_block_ids[request.request_id] = [ - b.block_id for b in new_blocks - ] + req_to_new_block_ids[request.request_id] = ( + new_blocks.get_block_ids()) num_scheduled_tokens[request.request_id] = num_new_tokens token_budget -= num_new_tokens req_index += 1 @@ -326,7 +323,8 @@ def _check_watermark_for_prefill(self, len(computed_blocks) * self.block_size) num_required_blocks = cdiv(num_new_tokens + num_computed_tokens, self.block_size) - req_blocks = self.kv_cache_manager.req_to_blocks[request.request_id] + req_blocks = self.kv_cache_manager.single_type_manager.req_to_blocks[ + request.request_id] num_new_blocks = (num_required_blocks - len(req_blocks) - len(computed_blocks)) num_evictable_computed_blocks = sum(1 for blk in computed_blocks @@ -365,41 +363,22 @@ def finish_requests( For example, the API server can abort a request when the client disconnects. """ - assert RequestStatus.is_finished(finished_status) - if isinstance(request_ids, str): - request_ids = (request_ids, ) - else: - request_ids = set(request_ids) - for req_id in request_ids: request = self.requests.get(req_id) if request is None: # Invalid request ID. continue - if request.status == RequestStatus.RUNNING: - self.running.remove(request) self.scheduled_req_ids.discard(request.request_id) - else: - self.waiting.remove(request) - request.status = finished_status - self._free_request(request) + super().finish_requests(request_ids, finished_status) def update_from_output( self, scheduler_output: SchedulerOutput, model_runner_output: ModelRunnerOutput, ) -> EngineCoreOutputs: - sampled_token_ids = model_runner_output.sampled_token_ids - spec_token_ids = model_runner_output.spec_token_ids - logprobs = model_runner_output.logprobs - prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict num_scheduled_tokens = scheduler_output.num_scheduled_tokens - new_running: list[Request] = [] - outputs: list[EngineCoreOutput] = [] - spec_decoding_stats: Optional[SpecDecodingStats] = None - # NOTE(woosuk): As len(self.running) can be up to 1K or more, the below # loop can be a performance bottleneck. We should do our best to avoid # expensive operations inside the loop. @@ -408,121 +387,8 @@ def update_from_output( num_tokens_scheduled = num_scheduled_tokens.get(req_id, 0) if num_tokens_scheduled == 0: # The request was not scheduled in this step. - new_running.append(request) continue - - req_index = model_runner_output.req_id_to_index[req_id] - generated_token_ids = sampled_token_ids[req_index] - - scheduled_spec_token_ids = ( - scheduler_output.scheduled_spec_decode_tokens.get(req_id)) - if scheduled_spec_token_ids: - # num_computed_tokens represents the number of tokens - # processed in the current step, considering scheduled - # tokens and rejections. If some tokens are rejected, - # num_computed_tokens is decreased by the number of rejected - # tokens, where is given by: - # len(scheduled_spec_token_ids) + 1 - len(generated_token_ids). - num_tokens_rejected = (len(scheduled_spec_token_ids) + 1 - - len(generated_token_ids)) - request.num_computed_tokens -= num_tokens_rejected - spec_decoding_stats = self.make_spec_decoding_stats( - spec_decoding_stats, - num_draft_tokens=len(scheduled_spec_token_ids), - num_accepted_tokens=len(generated_token_ids) - 1) - - cached_encoder_input_ids = ( - self.encoder_cache_manager.get_cached_input_ids(request)) - # OPTIMIZATION: Avoid list(set) if the set is empty. - if cached_encoder_input_ids: - for input_id in list(cached_encoder_input_ids): - mm_positions = request.mm_positions[input_id] - start_pos = mm_positions.offset - num_tokens = mm_positions.length - if start_pos + num_tokens <= request.num_computed_tokens: - # The encoder output is already processed and stored - # in the decoder's KV cache. - self.encoder_cache_manager.free_encoder_input( - request, input_id) - - stopped = False - new_logprobs = None - new_token_ids = generated_token_ids - - # Append generated tokens and check for stop. Note that if - # a request is still being prefilled, we expect the model runner - # to return empty token ids for the request. - for num_new, output_token_id in enumerate(new_token_ids, 1): - request.append_output_token_ids(output_token_id) - - # Check for stop and update request state. - # This must be called before we make the EngineCoreOutput. - stopped = check_stop(request, self.max_model_len) - if stopped: - self._free_request(request) - del new_token_ids[num_new:] # Trim new tokens if needed. - break - - # Extract sample logprobs if needed. - if request.sampling_params.logprobs is not None and logprobs: - # NOTE: once we support N tokens per step (spec decode), - # the outer lists can be of length > 1. - new_logprobs = logprobs.slice(req_index, req_index + 1) - - if new_token_ids and request.use_structured_output: - # NOTE: structured_output_request - # should not be None if use_structured_output, we have - # check above, so safe to ignore type warning - request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr] - req_id, new_token_ids) - - # Add newly generated spec token ids to the request. - if spec_token_ids is not None: - if request.use_structured_output: - metadata = request.structured_output_request - assert metadata is not None and metadata.grammar is not None - # Needs to happen after new_token_ids are accepted. - request.spec_token_ids = metadata.grammar.validate_tokens( - spec_token_ids[req_index]) - else: - request.spec_token_ids = spec_token_ids[req_index] - - # Get prompt logprobs for this request. - prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id) - if new_token_ids: - # Add EngineCoreOutput for this Request. - outputs.append( - EngineCoreOutput( - request_id=req_id, - new_token_ids=new_token_ids, - finish_reason=request.get_finished_reason(), - new_logprobs=new_logprobs, - new_prompt_logprobs_tensors=prompt_logprobs_tensors, - stop_reason=request.stop_reason, - events=request.take_events())) - else: - # Invariant: EngineCore returns no partial prefill outputs. - assert not prompt_logprobs_tensors - self.scheduled_req_ids.remove(req_id) - if not stopped: - new_running.append(request) - - # Return the cached request data to the queue so they can be reused. - for req_data in scheduler_output.scheduled_cached_reqs: - # NOTE(rob): since we free stopped reqs above, adding stopped reqs - # to _cached_reqs_data will cause a memory leak. - if req_data.req_id not in self.finished_req_ids: - self._cached_reqs_data[req_data.req_id].append(req_data) - - self.running = new_running - engine_core_outputs = EngineCoreOutputs( - outputs=outputs, - scheduler_stats=self.make_stats(spec_decoding_stats), - ) - if self.include_finished_set: - #TODO currently sending duplicates here, improve this - engine_core_outputs.finished_requests = ( - scheduler_output.finished_req_ids | self.finished_req_ids) - return engine_core_outputs + return super().update_from_output(scheduler_output, + model_runner_output) diff --git a/vllm_ascend/envs.py b/vllm_ascend/envs.py index 8e1cc1c16..a4e9450b8 100644 --- a/vllm_ascend/envs.py +++ b/vllm_ascend/envs.py @@ -66,6 +66,8 @@ lambda: os.getenv("C_COMPILER", None), "VLLM_VERSION": lambda: os.getenv("VLLM_VERSION", None), + "VLLM_ASCEND_TRACE_RECOMPILES": + lambda: bool(int(os.getenv("VLLM_ASCEND_TRACE_RECOMPILES", '0'))), } # end-env-vars-definition diff --git a/vllm_ascend/models/deepseek_v2.py b/vllm_ascend/models/deepseek_v2.py index 264a798ee..38a8053ec 100644 --- a/vllm_ascend/models/deepseek_v2.py +++ b/vllm_ascend/models/deepseek_v2.py @@ -36,9 +36,10 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import (CacheConfig, ModelConfig, VllmConfig, get_current_vllm_config) -from vllm.distributed import (get_dp_group, get_pp_group, +from vllm.distributed import (get_pp_group, get_tensor_model_parallel_world_size, get_tp_group, tensor_model_parallel_all_reduce) +from vllm.distributed.parallel_state import get_dp_group from vllm.forward_context import get_forward_context from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm @@ -211,8 +212,12 @@ def __init__( self.tp_group = get_tp_group().device_group self.tp_rank = get_tp_group().rank_in_group - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - attn_metadata = get_forward_context().attn_metadata + def forward( + self, + hidden_states: torch.Tensor, + attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor: + if attn_metadata is None: + attn_metadata = get_forward_context().attn_metadata # when profile runs, force experts to load balanced tokens # to avoid high memory consumption on a single rank. # TODO: need a better flag to indicate whether in profile run or not. @@ -547,7 +552,11 @@ def forward( # Fully Connected hidden_states, residual = self.post_attention_layernorm( hidden_states, residual) - hidden_states = self.mlp(hidden_states) + + if isinstance(self.mlp, CustomDeepseekV2MoE): + hidden_states = self.mlp(hidden_states, attn_metadata) + else: + hidden_states = self.mlp(hidden_states) if isinstance( self.mlp, diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 06988b251..a0bc21259 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -28,10 +28,12 @@ import numpy as np import numpy.typing as npt import torch +import torch._dynamo.cache_size import torch.nn as nn from vllm.attention import AttentionType, get_attn_backend from vllm.attention.layer import Attention from vllm.config import CompilationLevel, VllmConfig +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed.parallel_state import get_pp_group from vllm.forward_context import set_forward_context from vllm.inputs import INPUT_REGISTRY @@ -70,7 +72,9 @@ else: xgr = LazyLoader("xgr", globals(), "xgrammar") -import vllm.envs as envs +import vllm.envs as envs_vllm + +import vllm_ascend.envs as envs_ascend @dataclass @@ -321,6 +325,7 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): self.sampler = Sampler() self.enable_torchair_graph_mode = False self.use_cached_npu_graph = False + self.torchair_graph_batch_sizes = [] additional_config = vllm_config.additional_config if additional_config: self.enable_torchair_graph_mode = additional_config.get( @@ -328,6 +333,32 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): False) and self.vllm_config.model_config.use_mla self.use_cached_npu_graph = additional_config.get( "use_cached_npu_graph", False) + self.torchair_graph_batch_sizes = additional_config.get( + "torchair_graph_batch_sizes", []) + if not isinstance(self.torchair_graph_batch_sizes, list): + logger.warning("torchair_graph_batch_sizes must be list[int]") + self.torchair_graph_batch_sizes = [] + if len(self.torchair_graph_batch_sizes + ) == 0 and additional_config.get( + "torchair_graph_batch_sizes_init", False): + self.init_torchair_graph_batch_sizes() + + if len(self.torchair_graph_batch_sizes) == 0: + #If MC2 is enabled, torchair_graph_batch_size should pad to tp_size + if envs_ascend.VLLM_ENABLE_MC2: + self.torchair_graph_batch_sizes = [ + self.scheduler_config.max_num_seqs + ] + else: + self.torchair_graph_batch_sizes = [ + 1, self.scheduler_config.max_num_seqs + ] + + torch._dynamo.cache_size.config.cache_size_limit += len( + self.torchair_graph_batch_sizes) + torch._dynamo.config.capture_dynamic_output_shape_ops = True + torch._logging.set_logs( + recompiles=envs_ascend.VLLM_ASCEND_TRACE_RECOMPILES) def _update_states(self, scheduler_output: "SchedulerOutput") -> None: """Update the cached states and the persistent batch with the scheduler @@ -618,7 +649,10 @@ def _process_reqs( query_start_loc=query_start_loc, seq_lens=seq_lens) # Add graph_pad_size here if self.enable_torchair_graph_mode: - graph_pad_size = self.scheduler_config.max_num_seqs - len(seq_lens) + batchsize = len(seq_lens) + padded_batch_size = self.select_torchair_padded_batchsize( + batchsize) + graph_pad_size = padded_batch_size - batchsize extra_builder_kwargs['graph_pad_size'] = graph_pad_size if self.vllm_config.model_config.use_mla: @@ -653,11 +687,8 @@ def _process_reqs( input_ids = self.input_ids[:num_input_tokens] if self.enable_torchair_graph_mode and attn_metadata.attn_state == AscendAttentionState.DecodeOnly: - padding = torch.zeros(graph_pad_size, - dtype=input_ids.dtype, - device=input_ids.device) - input_ids = torch.cat([input_ids, padding]) - positions = torch.cat([positions, padding]) + input_ids = self.input_ids[:padded_batch_size] + positions = self.positions[:padded_batch_size] # Run forward pass with set_forward_context(attn_metadata, @@ -668,15 +699,6 @@ def _process_reqs( model_kwargs["kv_caches"] = self.kv_caches model_kwargs["attn_metadata"] = attn_metadata if self.enable_torchair_graph_mode and attn_metadata.attn_state == AscendAttentionState.DecodeOnly: - torch._dynamo.mark_static(input_ids) - torch._dynamo.mark_static(positions) - torch._dynamo.mark_static(attn_metadata.decode.block_table) - torch._dynamo.mark_static(attn_metadata.decode.input_positions) - torch._dynamo.mark_static(attn_metadata.slot_mapping) - for kv in self.kv_caches: - if isinstance(kv, tuple): - torch._dynamo.mark_static(kv[0]) - torch._dynamo.mark_static(kv[1]) hidden_states = self.compile_model( input_ids=input_ids, positions=positions, @@ -1068,7 +1090,12 @@ def _profile_multimodal(self) -> None: self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs)) @torch.inference_mode() - def _dummy_run(self, num_tokens: int) -> torch.Tensor: + def _dummy_run( + self, + num_tokens: int, + is_compile: bool = False, + attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill, + ) -> torch.Tensor: # Set num_scheduled_tokens based on num_tokens and max_num_seqs # for dummy run with LoRA so that the num_reqs collectively # has num_tokens in total. @@ -1112,12 +1139,38 @@ def _dummy_run(self, num_tokens: int) -> torch.Tensor: }) with set_forward_context(None, self.vllm_config): - hidden_states = model( - input_ids=input_ids, - positions=positions, - intermediate_tensors=intermediate_tensors, - inputs_embeds=inputs_embeds) - return hidden_states + if self.enable_torchair_graph_mode and attn_state == AscendAttentionState.DecodeOnly: + attn_metadata = self.attn_metadata_builder.build_dummy( + num_reqs=num_tokens, num_actual_tokens=1) + # Only mark static while compiling + if is_compile: + torch._dynamo.mark_static(input_ids) + torch._dynamo.mark_static(positions) + torch._dynamo.mark_static( + attn_metadata.decode.block_table) + torch._dynamo.mark_static( + attn_metadata.decode.input_positions) + torch._dynamo.mark_static(attn_metadata.slot_mapping) + for kv in self.kv_caches: + assert isinstance( + kv, tuple), "kv_cache must be a tuple" + torch._dynamo.mark_static(kv[0]) + torch._dynamo.mark_static(kv[1]) + hidden_states = self.compile_model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=None, + kv_caches=self.kv_caches, + attn_metadata=attn_metadata, + ) + else: + hidden_states = model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds) + return hidden_states def profile_run(self) -> None: # Profile with multimodal encoder & encoder cache. @@ -1192,13 +1245,13 @@ def load_model(self) -> None: self.compile_model = torch.compile( self.model, dynamic=True, - fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE, + fullgraph=envs_vllm.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE, backend=npu_backend) else: self.compile_model = torchair.inference.cache_compile( self.model.forward, dynamic=True, - fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE, + fullgraph=envs_vllm.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE, config=config, ge_cache=False) @@ -1316,25 +1369,49 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: return kv_cache_spec def capture_model(self) -> None: - if not self.use_aclgraph: - logger.warning( - "Skipping NPU graph capture. Please add " - "-O %s to use NPU graphs.", CompilationLevel.PIECEWISE) - return - start_time = time.perf_counter() start_free_npu_memory = torch.npu.mem_get_info()[0] - - # Trigger ACL graph capture for specific shapes. - # Capture the large shapes first so that the smaller shapes - # can reuse the memory pool allocated for the large shapes. - with graph_capture(device=self.device): - for num_tokens in reversed(self.aclgraph_batch_sizes): + # TODO(NeverRaR): Calling graph_capture(device=self.device) in + # torchair graph capture can cause some issues, so now we just + # temporarily split the codepath for the two different graph patterns. + if self.enable_torchair_graph_mode: + torchair_graph_batch_sizes = self.torchair_graph_batch_sizes + graph_num = len(torchair_graph_batch_sizes) + logger.info( + "Capturing torchair graph, this usually takes %.1f~%.1f mins.", + 0.5 * graph_num, 1.5 * graph_num) + attn_state = AscendAttentionState.DecodeOnly + # Trigger torchair graph capture for specific shapes. + # Capture the large shapes first so that the smaller shapes + # can reuse the memory pool allocated for the large shapes. + for idx, num_tokens in enumerate( + reversed(torchair_graph_batch_sizes)): for _ in range(self.vllm_config.compilation_config. cudagraph_num_of_warmups): + self._dummy_run(num_tokens, + is_compile=True, + attn_state=attn_state) + self._dummy_run(num_tokens, + is_compile=True, + attn_state=attn_state) + logger.info("Batchsize %d is compiled successfully: %d/%d.", + num_tokens, idx + 1, graph_num) + elif self.use_aclgraph: + # Trigger ACL graph capture for specific shapes. + # Capture the large shapes first so that the smaller shapes + # can reuse the memory pool allocated for the large shapes. + with graph_capture(device=self.device): + for num_tokens in reversed(self.aclgraph_batch_sizes): + for _ in range(self.vllm_config.compilation_config. + cudagraph_num_of_warmups): + self._dummy_run(num_tokens) self._dummy_run(num_tokens) - self._dummy_run(num_tokens) - + else: + logger.warning( + "Skipping NPU graph capture. Please add -O %s to use ACL graphs. " + "Or add --additional_config={'enable_graph_mode': True} to use torchair graphs", + CompilationLevel.PIECEWISE) + return end_time = time.perf_counter() end_free_npu_memory = torch.npu.mem_get_info()[0] elapsed_time = end_time - start_time @@ -1443,4 +1520,27 @@ def _generate_mtp_token_ids( sampling_metadata=sampling_metadata, ) spec_token_ids = draft_token_ids.tolist() - return spec_token_ids \ No newline at end of file + return spec_token_ids + + def init_torchair_graph_batch_sizes(self): + tp_size = get_tensor_model_parallel_world_size() + batch_size_step = 8 + largest_batch_size = 1 + + if envs_ascend.VLLM_ENABLE_MC2: + batch_size_step = max(batch_size_step, tp_size) + largest_batch_size = batch_size_step + while (largest_batch_size < 8): + self.torchair_graph_batch_sizes.append(largest_batch_size) + largest_batch_size *= 2 + + while (largest_batch_size <= self.scheduler_config.max_num_seqs): + self.torchair_graph_batch_sizes.append(largest_batch_size) + largest_batch_size += batch_size_step + + def select_torchair_padded_batchsize(self, batchsize: int): + selected_batchsize = self.max_num_reqs + for padded_batchsize in self.torchair_graph_batch_sizes: + if batchsize <= padded_batchsize < selected_batchsize: + selected_batchsize = padded_batchsize + return selected_batchsize