From 99be815c8ee2061e469fb36e333b4a48f3f67aa2 Mon Sep 17 00:00:00 2001 From: boying <897013703@qq.com> Date: Tue, 13 May 2025 16:37:42 +0000 Subject: [PATCH] feat: support compile torchair graph while warming up Signed-off-by: boying <897013703@qq.com> --- vllm_ascend/attention/mla_v1.py | 43 ++++++- vllm_ascend/models/deepseek_v2.py | 53 +++++---- vllm_ascend/worker/model_runner_v1.py | 155 +++++++++++++++++++++----- 3 files changed, 196 insertions(+), 55 deletions(-) diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index ff6742d27..a677fd3e7 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -224,7 +224,44 @@ def _get_graph_runner_block_tables( max_blocks] = block_tables[:num_seqs, : max_blocks] - return graph_block_tables + return graph_block_tables[:num_seqs, :max_blocks] + + def dummy_build(self, num_reqs: int, + num_actual_tokens: int) -> AscendMLAMetadata: + device = self.runner.device + _, max_blocks = self.runner.graph_block_tables.shape + block_table = torch.zeros((num_reqs, max_blocks), + dtype=torch.int32, + device=device) + block_table = self._get_graph_runner_block_tables( + num_reqs, block_table) + seq_lens = torch.ones(num_reqs, dtype=torch.int32, device=device) + input_positions = torch.zeros(num_reqs, + dtype=torch.int32, + device=device).long() + slot_mapping = torch.full((num_reqs, ), + PAD_SLOT_ID, + dtype=torch.int32, + device=device) + decode_metadata = AscendMLADecodeMetadata( + input_positions=input_positions, + block_table=block_table, + seq_lens=seq_lens, + seq_lens_list=seq_lens.tolist(), + max_seq_lens=1) + return self.metadata_cls( # type: ignore + num_input_tokens=num_actual_tokens, + num_actual_tokens=num_actual_tokens, + slot_mapping=slot_mapping, + head_dim=self.runner.model_config.get_head_size(), + num_decodes=1, + num_decode_tokens=1, + num_prefills=0, + attn_mask=self.runner.attn_mask, + attn_state=AscendAttentionState.DecodeOnly, + prefill=None, + decode=decode_metadata, + ) def build(self, num_reqs: int, @@ -300,7 +337,7 @@ def build(self, block_table = torch.cat([block_table, block_table_padding], dim=0) block_table = self._get_graph_runner_block_tables( - num_seqs, block_table) + num_seqs + graph_pad_size, block_table) padding_0 = torch.zeros(graph_pad_size, dtype=input_positions.dtype, device=input_positions.device) @@ -795,4 +832,4 @@ def forward( output[:num_decode_tokens] = self._forward_decode( decode_ql_nope, decode_q_pe, decode_k_nope, decode_k_pe, kv_cache, attn_metadata) - return output_padded \ No newline at end of file + return output_padded diff --git a/vllm_ascend/models/deepseek_v2.py b/vllm_ascend/models/deepseek_v2.py index 5bf1126f6..2ae99ece3 100644 --- a/vllm_ascend/models/deepseek_v2.py +++ b/vllm_ascend/models/deepseek_v2.py @@ -36,7 +36,7 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import (CacheConfig, ModelConfig, VllmConfig, get_current_vllm_config) -from vllm.distributed import (get_dp_group, get_pp_group, +from vllm.distributed import (get_pp_group, get_tensor_model_parallel_world_size, get_tp_group, tensor_model_parallel_all_reduce) from vllm.forward_context import get_forward_context @@ -205,17 +205,16 @@ def __init__( ) CustomDeepseekV2MoE.top_k = config.num_experts_per_tok - vllm_config = get_current_vllm_config() - self.dp_size = get_dp_group().world_size - batch_size = vllm_config.scheduler_config.max_num_seqs - - params_dtype = torch.get_default_dtype() - self.final_hidden_states = torch.zeros( - [batch_size, config.hidden_size], dtype=params_dtype, device="npu") + self.params_dtype = torch.get_default_dtype() + self.tp_rank_in_group = get_tp_group().rank_in_group self.tp_group = get_tp_group().device_group - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - attn_metadata = get_forward_context().attn_metadata + def forward( + self, + hidden_states: torch.Tensor, + attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor: + if attn_metadata is None: + attn_metadata = get_forward_context().attn_metadata if attn_metadata is None: # for profile run is_prefill = True @@ -224,16 +223,17 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: num_tokens, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) + if self.n_shared_experts is not None: + shared_output = self.shared_experts(hidden_states) + if (self.tp_size > 1 and VLLM_ENABLE_MC2 and not is_prefill): - chunks = torch.chunk(hidden_states, - get_tp_group().world_size, - dim=0) - hidden_states = chunks[get_tp_group().rank_in_group] + chunks = torch.chunk(hidden_states, self.tp_size, dim=0) + hidden_states = chunks[self.tp_rank_in_group] # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) - final_hidden_states = self.experts( + hidden_states = self.experts( hidden_states=hidden_states, router_logits=router_logits, is_prefill=is_prefill, @@ -241,17 +241,18 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: if self.tp_size > 1: if VLLM_ENABLE_MC2 and not is_prefill: - dist.all_gather_into_tensor(self.final_hidden_states, - final_hidden_states, self.tp_group) - final_hidden_states = self.final_hidden_states + final_hidden_states = torch.zeros([num_tokens, hidden_dim], + dtype=self.params_dtype, + device="npu") + dist.all_gather_into_tensor(final_hidden_states, hidden_states, + self.tp_group) + hidden_states = final_hidden_states else: - final_hidden_states = tensor_model_parallel_all_reduce( - final_hidden_states) + hidden_states = tensor_model_parallel_all_reduce(hidden_states) if self.n_shared_experts is not None: - shared_output = self.shared_experts(hidden_states) - final_hidden_states = final_hidden_states + shared_output + hidden_states = hidden_states + shared_output - return final_hidden_states.view(num_tokens, hidden_dim) + return hidden_states.view(num_tokens, hidden_dim) class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention): @@ -524,7 +525,11 @@ def forward( # Fully Connected hidden_states, residual = self.post_attention_layernorm( hidden_states, residual) - hidden_states = self.mlp(hidden_states) + + if isinstance(self.mlp, CustomDeepseekV2MoE): + hidden_states = self.mlp(hidden_states, attn_metadata) + else: + hidden_states = self.mlp(hidden_states) if isinstance( self.mlp, diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index e46b7e572..1d0db04ab 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -28,10 +28,12 @@ import numpy as np import numpy.typing as npt import torch +import torch._dynamo.cache_size import torch.nn as nn from vllm.attention import AttentionType, get_attn_backend from vllm.attention.layer import Attention from vllm.config import CompilationLevel, VllmConfig +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed.parallel_state import get_pp_group from vllm.forward_context import set_forward_context from vllm.inputs import INPUT_REGISTRY @@ -62,7 +64,9 @@ else: xgr = LazyLoader("xgr", globals(), "xgrammar") -import vllm.envs as envs +import vllm.envs as envs_vllm + +import vllm_ascend.envs as envs_ascend @dataclass @@ -316,6 +320,7 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): self.sampler = Sampler() self.enable_torchair_graph_mode = False self.use_cached_npu_graph = False + self.torchair_graph_batch_sizes = [] additional_config = vllm_config.additional_config if additional_config: self.enable_torchair_graph_mode = additional_config.get( @@ -323,6 +328,31 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): False) and self.vllm_config.model_config.use_mla self.use_cached_npu_graph = additional_config.get( "use_cached_npu_graph", False) + if additional_config.get("trace_recompiles", False): + torch._logging.set_logs(recompiles=True) + self.torchair_graph_batch_sizes = additional_config.get( + "torchair_graph_batch_sizes", []) + if not isinstance(self.torchair_graph_batch_sizes, list): + logger.warning("torchair_graph_batch_sizes must be list[int]") + self.torchair_graph_batch_sizes = [] + if len(self.torchair_graph_batch_sizes + ) == 0 and additional_config.get( + "init_torchair_graph_batch_sizes", False): + self.init_torchair_graph_batch_sizes() + + if len(self.torchair_graph_batch_sizes) == 0: + #If MC2 is enabled, torchair_graph_batch_size should pad to tp_size + if envs_ascend.VLLM_ENABLE_MC2: + self.torchair_graph_batch_sizes = [ + self.scheduler_config.max_num_seqs + ] + else: + self.torchair_graph_batch_sizes = [ + 1, self.scheduler_config.max_num_seqs + ] + + torch._dynamo.cache_size.config.cache_size_limit += len( + self.torchair_graph_batch_sizes) def _update_states(self, scheduler_output: "SchedulerOutput") -> None: """Update the cached states and the persistent batch with the scheduler @@ -590,7 +620,10 @@ def _process_reqs( # Add graph_pad_size here if self.enable_torchair_graph_mode: - graph_pad_size = self.scheduler_config.max_num_seqs - len(seq_lens) + batchsize = len(seq_lens) + padded_batch_size = self.select_torchair_padded_batchsize( + batchsize) + graph_pad_size = padded_batch_size - batchsize extra_builder_kwargs['graph_pad_size'] = graph_pad_size attn_metadata = self.attn_metadata_builder.build( # type: ignore @@ -614,11 +647,8 @@ def _process_reqs( input_ids = self.input_ids[:num_input_tokens] if self.enable_torchair_graph_mode and attn_metadata.attn_state == AscendAttentionState.DecodeOnly: - padding = torch.zeros(graph_pad_size, - dtype=input_ids.dtype, - device=input_ids.device) - input_ids = torch.cat([input_ids, padding]) - positions = torch.cat([positions, padding]) + input_ids = self.input_ids[:padded_batch_size] + positions = self.positions[:padded_batch_size] # Run forward pass with set_forward_context(attn_metadata, @@ -856,7 +886,11 @@ def _profile_multimodal(self) -> None: self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs)) @torch.inference_mode() - def _dummy_run(self, num_tokens: int) -> torch.Tensor: + def _dummy_run( + self, + num_tokens: int, + attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill + ) -> torch.Tensor: model = self.model if self.is_multimodal_model: input_ids = None @@ -885,10 +919,32 @@ def _dummy_run(self, num_tokens: int) -> torch.Tensor: }) with set_forward_context(None, self.vllm_config): - hidden_states = model(input_ids=input_ids, - positions=positions, - intermediate_tensors=intermediate_tensors, - inputs_embeds=inputs_embeds) + if self.enable_torchair_graph_mode and attn_state == AscendAttentionState.DecodeOnly: + attn_metadata = self.attn_metadata_builder.dummy_build( + num_reqs=num_tokens, num_actual_tokens=1) + torch._dynamo.mark_static(input_ids) + torch._dynamo.mark_static(positions) + torch._dynamo.mark_static(attn_metadata.decode.block_table) + torch._dynamo.mark_static(attn_metadata.decode.input_positions) + torch._dynamo.mark_static(attn_metadata.slot_mapping) + for kv in self.kv_caches: + assert isinstance(kv, tuple), "kv_cache must be a tuple" + torch._dynamo.mark_static(kv[0]) + torch._dynamo.mark_static(kv[1]) + hidden_states = self.compile_model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=None, + kv_caches=self.kv_caches, + attn_metadata=attn_metadata, + ) + else: + hidden_states = model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds) return hidden_states def profile_run(self) -> None: @@ -957,13 +1013,13 @@ def load_model(self) -> None: self.compile_model = torch.compile( self.model, dynamic=True, - fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE, + fullgraph=envs_vllm.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE, backend=npu_backend) else: self.compile_model = torchair.inference.cache_compile( self.model.forward, dynamic=True, - fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE, + fullgraph=envs_vllm.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE, config=config, ge_cache=False) @@ -1070,25 +1126,45 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: return kv_cache_spec def capture_model(self) -> None: - if not self.use_aclgraph: - logger.warning( - "Skipping NPU graph capture. Please add " - "-O %s to use NPU graphs.", CompilationLevel.PIECEWISE) - return - start_time = time.perf_counter() start_free_npu_memory = torch.npu.mem_get_info()[0] - - # Trigger ACL graph capture for specific shapes. - # Capture the large shapes first so that the smaller shapes - # can reuse the memory pool allocated for the large shapes. - with graph_capture(device=self.device): - for num_tokens in reversed(self.aclgraph_batch_sizes): + # TODO(NeverRaR): Calling graph_capture(device=self.device) in + # torchair graph capture can cause some issues, so now we just + # temporarily split the codepath for the two different graph patterns. + if self.enable_torchair_graph_mode: + torchair_graph_batch_sizes = self.torchair_graph_batch_sizes + graph_num = len(torchair_graph_batch_sizes) + logger.info( + "Capturing torchair graph, this usually takes %.1f~%.1f mins.", + 0.5 * graph_num, 1.5 * graph_num) + attn_state = AscendAttentionState.DecodeOnly + # Trigger torchair graph capture for specific shapes. + # Capture the large shapes first so that the smaller shapes + # can reuse the memory pool allocated for the large shapes. + for idx, num_tokens in enumerate( + reversed(torchair_graph_batch_sizes)): for _ in range(self.vllm_config.compilation_config. cudagraph_num_of_warmups): + self._dummy_run(num_tokens, attn_state) + self._dummy_run(num_tokens, attn_state) + logger.info("Batchsize %d is compiled successfully: %d/%d.", + num_tokens, idx + 1, graph_num) + elif self.use_aclgraph: + # Trigger ACL graph capture for specific shapes. + # Capture the large shapes first so that the smaller shapes + # can reuse the memory pool allocated for the large shapes. + with graph_capture(device=self.device): + for num_tokens in reversed(self.aclgraph_batch_sizes): + for _ in range(self.vllm_config.compilation_config. + cudagraph_num_of_warmups): + self._dummy_run(num_tokens) self._dummy_run(num_tokens) - self._dummy_run(num_tokens) - + else: + logger.warning( + "Skipping NPU graph capture. Please add -O %s to use ACL graphs. " + "Or add --additional_config={'enable_graph_mode': True} to use torchair graphs", + CompilationLevel.PIECEWISE) + return end_time = time.perf_counter() end_free_npu_memory = torch.npu.mem_get_info()[0] elapsed_time = end_time - start_time @@ -1096,3 +1172,26 @@ def capture_model(self) -> None: # This usually takes 5~20 seconds. logger.info("Graph capturing finished in %.0f secs, took %.2f GiB", elapsed_time, npu_graph_size / (1 << 30)) + + def init_torchair_graph_batch_sizes(self): + tp_size = get_tensor_model_parallel_world_size() + batch_size_step = 8 + largest_batch_size = 1 + + if envs_ascend.VLLM_ENABLE_MC2: + batch_size_step = max(batch_size_step, tp_size) + largest_batch_size = batch_size_step + while (largest_batch_size < 8): + self.torchair_graph_batch_sizes.append(largest_batch_size) + largest_batch_size *= 2 + + while (largest_batch_size <= self.scheduler_config.max_num_seqs): + self.torchair_graph_batch_sizes.append(largest_batch_size) + largest_batch_size += batch_size_step + + def select_torchair_padded_batchsize(self, batchsize: int): + selected_batchsize = self.max_num_reqs + for padded_batchsize in self.torchair_graph_batch_sizes: + if batchsize <= padded_batchsize < selected_batchsize: + selected_batchsize = padded_batchsize + return selected_batchsize