Skip to content

Commit 99be815

Browse files
committed
feat: support compile torchair graph while warming up
Signed-off-by: boying <897013703@qq.com>
1 parent 59e0250 commit 99be815

File tree

3 files changed

+196
-55
lines changed

3 files changed

+196
-55
lines changed

vllm_ascend/attention/mla_v1.py

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,44 @@ def _get_graph_runner_block_tables(
224224
max_blocks] = block_tables[:num_seqs, :
225225
max_blocks]
226226

227-
return graph_block_tables
227+
return graph_block_tables[:num_seqs, :max_blocks]
228+
229+
def dummy_build(self, num_reqs: int,
230+
num_actual_tokens: int) -> AscendMLAMetadata:
231+
device = self.runner.device
232+
_, max_blocks = self.runner.graph_block_tables.shape
233+
block_table = torch.zeros((num_reqs, max_blocks),
234+
dtype=torch.int32,
235+
device=device)
236+
block_table = self._get_graph_runner_block_tables(
237+
num_reqs, block_table)
238+
seq_lens = torch.ones(num_reqs, dtype=torch.int32, device=device)
239+
input_positions = torch.zeros(num_reqs,
240+
dtype=torch.int32,
241+
device=device).long()
242+
slot_mapping = torch.full((num_reqs, ),
243+
PAD_SLOT_ID,
244+
dtype=torch.int32,
245+
device=device)
246+
decode_metadata = AscendMLADecodeMetadata(
247+
input_positions=input_positions,
248+
block_table=block_table,
249+
seq_lens=seq_lens,
250+
seq_lens_list=seq_lens.tolist(),
251+
max_seq_lens=1)
252+
return self.metadata_cls( # type: ignore
253+
num_input_tokens=num_actual_tokens,
254+
num_actual_tokens=num_actual_tokens,
255+
slot_mapping=slot_mapping,
256+
head_dim=self.runner.model_config.get_head_size(),
257+
num_decodes=1,
258+
num_decode_tokens=1,
259+
num_prefills=0,
260+
attn_mask=self.runner.attn_mask,
261+
attn_state=AscendAttentionState.DecodeOnly,
262+
prefill=None,
263+
decode=decode_metadata,
264+
)
228265

229266
def build(self,
230267
num_reqs: int,
@@ -300,7 +337,7 @@ def build(self,
300337
block_table = torch.cat([block_table, block_table_padding],
301338
dim=0)
302339
block_table = self._get_graph_runner_block_tables(
303-
num_seqs, block_table)
340+
num_seqs + graph_pad_size, block_table)
304341
padding_0 = torch.zeros(graph_pad_size,
305342
dtype=input_positions.dtype,
306343
device=input_positions.device)
@@ -795,4 +832,4 @@ def forward(
795832
output[:num_decode_tokens] = self._forward_decode(
796833
decode_ql_nope, decode_q_pe, decode_k_nope, decode_k_pe,
797834
kv_cache, attn_metadata)
798-
return output_padded
835+
return output_padded

vllm_ascend/models/deepseek_v2.py

Lines changed: 29 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
from vllm.attention import Attention, AttentionMetadata
3737
from vllm.config import (CacheConfig, ModelConfig, VllmConfig,
3838
get_current_vllm_config)
39-
from vllm.distributed import (get_dp_group, get_pp_group,
39+
from vllm.distributed import (get_pp_group,
4040
get_tensor_model_parallel_world_size,
4141
get_tp_group, tensor_model_parallel_all_reduce)
4242
from vllm.forward_context import get_forward_context
@@ -205,17 +205,16 @@ def __init__(
205205
)
206206
CustomDeepseekV2MoE.top_k = config.num_experts_per_tok
207207

208-
vllm_config = get_current_vllm_config()
209-
self.dp_size = get_dp_group().world_size
210-
batch_size = vllm_config.scheduler_config.max_num_seqs
211-
212-
params_dtype = torch.get_default_dtype()
213-
self.final_hidden_states = torch.zeros(
214-
[batch_size, config.hidden_size], dtype=params_dtype, device="npu")
208+
self.params_dtype = torch.get_default_dtype()
209+
self.tp_rank_in_group = get_tp_group().rank_in_group
215210
self.tp_group = get_tp_group().device_group
216211

217-
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
218-
attn_metadata = get_forward_context().attn_metadata
212+
def forward(
213+
self,
214+
hidden_states: torch.Tensor,
215+
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
216+
if attn_metadata is None:
217+
attn_metadata = get_forward_context().attn_metadata
219218
if attn_metadata is None:
220219
# for profile run
221220
is_prefill = True
@@ -224,34 +223,36 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
224223
num_tokens, hidden_dim = hidden_states.shape
225224
hidden_states = hidden_states.view(-1, hidden_dim)
226225

226+
if self.n_shared_experts is not None:
227+
shared_output = self.shared_experts(hidden_states)
228+
227229
if (self.tp_size > 1 and VLLM_ENABLE_MC2 and not is_prefill):
228-
chunks = torch.chunk(hidden_states,
229-
get_tp_group().world_size,
230-
dim=0)
231-
hidden_states = chunks[get_tp_group().rank_in_group]
230+
chunks = torch.chunk(hidden_states, self.tp_size, dim=0)
231+
hidden_states = chunks[self.tp_rank_in_group]
232232

233233
# router_logits: (num_tokens, n_experts)
234234
router_logits, _ = self.gate(hidden_states)
235235

236-
final_hidden_states = self.experts(
236+
hidden_states = self.experts(
237237
hidden_states=hidden_states,
238238
router_logits=router_logits,
239239
is_prefill=is_prefill,
240240
top_k=CustomDeepseekV2MoE.top_k) * self.routed_scaling_factor
241241

242242
if self.tp_size > 1:
243243
if VLLM_ENABLE_MC2 and not is_prefill:
244-
dist.all_gather_into_tensor(self.final_hidden_states,
245-
final_hidden_states, self.tp_group)
246-
final_hidden_states = self.final_hidden_states
244+
final_hidden_states = torch.zeros([num_tokens, hidden_dim],
245+
dtype=self.params_dtype,
246+
device="npu")
247+
dist.all_gather_into_tensor(final_hidden_states, hidden_states,
248+
self.tp_group)
249+
hidden_states = final_hidden_states
247250
else:
248-
final_hidden_states = tensor_model_parallel_all_reduce(
249-
final_hidden_states)
251+
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
250252
if self.n_shared_experts is not None:
251-
shared_output = self.shared_experts(hidden_states)
252-
final_hidden_states = final_hidden_states + shared_output
253+
hidden_states = hidden_states + shared_output
253254

254-
return final_hidden_states.view(num_tokens, hidden_dim)
255+
return hidden_states.view(num_tokens, hidden_dim)
255256

256257

257258
class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention):
@@ -524,7 +525,11 @@ def forward(
524525
# Fully Connected
525526
hidden_states, residual = self.post_attention_layernorm(
526527
hidden_states, residual)
527-
hidden_states = self.mlp(hidden_states)
528+
529+
if isinstance(self.mlp, CustomDeepseekV2MoE):
530+
hidden_states = self.mlp(hidden_states, attn_metadata)
531+
else:
532+
hidden_states = self.mlp(hidden_states)
528533

529534
if isinstance(
530535
self.mlp,

0 commit comments

Comments
 (0)