Skip to content

feat: support compile torchair graph while warming up #839

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 40 additions & 3 deletions vllm_ascend/attention/mla_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,44 @@ def _get_graph_runner_block_tables(
max_blocks] = block_tables[:num_seqs, :
max_blocks]

return graph_block_tables
return graph_block_tables[:num_seqs, :max_blocks]

def dummy_build(self, num_reqs: int,
num_actual_tokens: int) -> AscendMLAMetadata:
device = self.runner.device
_, max_blocks = self.runner.graph_block_tables.shape
block_table = torch.zeros((num_reqs, max_blocks),
dtype=torch.int32,
device=device)
block_table = self._get_graph_runner_block_tables(
num_reqs, block_table)
seq_lens = torch.ones(num_reqs, dtype=torch.int32, device=device)
input_positions = torch.zeros(num_reqs,
dtype=torch.int32,
device=device).long()
slot_mapping = torch.full((num_reqs, ),
PAD_SLOT_ID,
dtype=torch.int32,
device=device)
decode_metadata = AscendMLADecodeMetadata(
input_positions=input_positions,
block_table=block_table,
seq_lens=seq_lens,
seq_lens_list=seq_lens.tolist(),
max_seq_lens=1)
return self.metadata_cls( # type: ignore
num_input_tokens=num_actual_tokens,
num_actual_tokens=num_actual_tokens,
slot_mapping=slot_mapping,
head_dim=self.runner.model_config.get_head_size(),
num_decodes=1,
num_decode_tokens=1,
num_prefills=0,
attn_mask=self.runner.attn_mask,
attn_state=AscendAttentionState.DecodeOnly,
prefill=None,
decode=decode_metadata,
)

def build(self,
num_reqs: int,
Expand Down Expand Up @@ -300,7 +337,7 @@ def build(self,
block_table = torch.cat([block_table, block_table_padding],
dim=0)
block_table = self._get_graph_runner_block_tables(
num_seqs, block_table)
num_seqs + graph_pad_size, block_table)
padding_0 = torch.zeros(graph_pad_size,
dtype=input_positions.dtype,
device=input_positions.device)
Expand Down Expand Up @@ -795,4 +832,4 @@ def forward(
output[:num_decode_tokens] = self._forward_decode(
decode_ql_nope, decode_q_pe, decode_k_nope, decode_k_pe,
kv_cache, attn_metadata)
return output_padded
return output_padded
53 changes: 29 additions & 24 deletions vllm_ascend/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from vllm.attention import Attention, AttentionMetadata
from vllm.config import (CacheConfig, ModelConfig, VllmConfig,
get_current_vllm_config)
from vllm.distributed import (get_dp_group, get_pp_group,
from vllm.distributed import (get_pp_group,
get_tensor_model_parallel_world_size,
get_tp_group, tensor_model_parallel_all_reduce)
from vllm.forward_context import get_forward_context
Expand Down Expand Up @@ -205,17 +205,16 @@ def __init__(
)
CustomDeepseekV2MoE.top_k = config.num_experts_per_tok

vllm_config = get_current_vllm_config()
self.dp_size = get_dp_group().world_size
batch_size = vllm_config.scheduler_config.max_num_seqs

params_dtype = torch.get_default_dtype()
self.final_hidden_states = torch.zeros(
[batch_size, config.hidden_size], dtype=params_dtype, device="npu")
self.params_dtype = torch.get_default_dtype()
self.tp_rank_in_group = get_tp_group().rank_in_group
self.tp_group = get_tp_group().device_group

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
attn_metadata = get_forward_context().attn_metadata
def forward(
self,
hidden_states: torch.Tensor,
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
if attn_metadata is None:
attn_metadata = get_forward_context().attn_metadata
if attn_metadata is None:
# for profile run
is_prefill = True
Expand All @@ -224,34 +223,36 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)

if self.n_shared_experts is not None:
shared_output = self.shared_experts(hidden_states)

if (self.tp_size > 1 and VLLM_ENABLE_MC2 and not is_prefill):
chunks = torch.chunk(hidden_states,
get_tp_group().world_size,
dim=0)
hidden_states = chunks[get_tp_group().rank_in_group]
chunks = torch.chunk(hidden_states, self.tp_size, dim=0)
hidden_states = chunks[self.tp_rank_in_group]

# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)

final_hidden_states = self.experts(
hidden_states = self.experts(
hidden_states=hidden_states,
router_logits=router_logits,
is_prefill=is_prefill,
top_k=CustomDeepseekV2MoE.top_k) * self.routed_scaling_factor

if self.tp_size > 1:
if VLLM_ENABLE_MC2 and not is_prefill:
dist.all_gather_into_tensor(self.final_hidden_states,
final_hidden_states, self.tp_group)
final_hidden_states = self.final_hidden_states
final_hidden_states = torch.zeros([num_tokens, hidden_dim],
dtype=self.params_dtype,
device="npu")
dist.all_gather_into_tensor(final_hidden_states, hidden_states,
self.tp_group)
hidden_states = final_hidden_states
else:
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states)
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
if self.n_shared_experts is not None:
shared_output = self.shared_experts(hidden_states)
final_hidden_states = final_hidden_states + shared_output
hidden_states = hidden_states + shared_output

return final_hidden_states.view(num_tokens, hidden_dim)
return hidden_states.view(num_tokens, hidden_dim)


class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention):
Expand Down Expand Up @@ -524,7 +525,11 @@ def forward(
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states = self.mlp(hidden_states)

if isinstance(self.mlp, CustomDeepseekV2MoE):
hidden_states = self.mlp(hidden_states, attn_metadata)
else:
hidden_states = self.mlp(hidden_states)

if isinstance(
self.mlp,
Expand Down
Loading
Loading