Skip to content

Commit 8de805a

Browse files
committed
add mc2 mask
Signed-off-by: weiguihua2 <weiguihua2@huawei.com>
1 parent 88c31da commit 8de805a

File tree

4 files changed

+77
-33
lines changed

4 files changed

+77
-33
lines changed

vllm_ascend/attention/mla_v1.py

Lines changed: 64 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,8 @@ class ChunkedContextMetadata:
8282
max_query_len: int
8383
max_seq_lens: int
8484
chunked_context: Optional[ChunkedContextMetadata] = None
85+
sin: torch.Tensor = None
86+
cos: torch.Tensor = None
8587

8688

8789
@dataclass
@@ -95,6 +97,8 @@ class AscendMLADecodeMetadata:
9597
seq_lens_list: list[int]
9698
actual_seq_q_lens: Optional[list[int]] = None
9799
attn_mask: Optional[torch.Tensor] = None
100+
sin: torch.Tensor = None
101+
cos: torch.Tensor = None
98102
mc2_mask: Optional[torch.Tensor] = None
99103

100104

@@ -207,9 +211,14 @@ def __init__(self,
207211
)
208212
ascend_config = get_ascend_config()
209213
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
210-
211-
def generate_active_mask(self, actual_seqs_num, batch_size):
212-
mc2_mask = torch.zeros(batch_size, dtype=torch.bool, device=current_platform.device_type)
214+
self.rope_dim = self.runner.model_config.hf_text_config.qk_rope_head_dim
215+
self.cos_cache = None
216+
self.sin_cache = None
217+
218+
def generate_activate_mask(self, actual_seqs_num, batch_size):
219+
mc2_mask = torch.zeros(batch_size,
220+
dtype=torch.bool,
221+
device=current_platform.device_type)
213222
mc2_mask[:actual_seqs_num].fill_(True)
214223
return mc2_mask
215224

@@ -343,7 +352,19 @@ def build_torchair_graph_dummy(
343352
else:
344353
attn_state = AscendAttentionState.DecodeOnly
345354
num_decode_tokens = 1
346-
mc2_mask = self.generate_active_mask(num_actual_tokens, num_reqs)
355+
sin = torch.ones(num_reqs,
356+
1,
357+
1,
358+
self.rope_dim,
359+
dtype=self.runner.dtype,
360+
device=device)
361+
cos = torch.ones(num_reqs,
362+
1,
363+
1,
364+
self.rope_dim,
365+
dtype=self.runner.dtype,
366+
device=device)
367+
mc2_mask = self.generate_activate_mask(num_actual_tokens, num_reqs)
347368
decode_metadata = AscendMLADecodeMetadata(
348369
input_positions=input_positions,
349370
block_table=block_table,
@@ -352,8 +373,9 @@ def build_torchair_graph_dummy(
352373
max_seq_lens=1,
353374
attn_mask=self.runner.spec_attn_mask,
354375
actual_seq_q_lens=self.runner.actual_seq_q_lens[:num_reqs],
355-
mc2_mask=mc2_mask,
356-
)
376+
sin=sin,
377+
cos=cos,
378+
mc2_mask=mc2_mask)
357379
return self.metadata_cls( # type: ignore
358380
num_input_tokens=num_actual_tokens,
359381
num_actual_tokens=num_actual_tokens,
@@ -405,6 +427,16 @@ def build(
405427
max_query_len = query_lens.max().item()
406428
max_seq_lens = seq_lens.max().item()
407429
query_start_loc = common_attn_metadata.query_start_loc
430+
if self.cos_cache is None:
431+
self.cos_cache = self.runner.get_model(
432+
).model.layers[0].self_attn.rotary_emb.cos_cached
433+
self.sin_cache = self.runner.get_model(
434+
).model.layers[0].self_attn.rotary_emb.sin_cached
435+
if self.cos_cache.dtype != self.runner.dtype: # type: ignore
436+
self.cos_cache = self.cos_cache.to( # type: ignore
437+
self.runner.dtype) # type: ignore
438+
self.sin_cache = self.sin_cache.to( # type: ignore
439+
self.runner.dtype) # type: ignore
408440

409441
prefill_metadata = None
410442
chunked_context_metadata = None
@@ -451,18 +483,26 @@ def build(
451483
chunk_seq_lens=chunk_seq_lens,
452484
workspace=self.chunked_prefill_workspace,
453485
)
454-
486+
prefill_input_positions = input_positions[tokens_start:]
487+
cos = self.cos_cache[
488+
prefill_input_positions].unsqueeze( # type: ignore
489+
1).unsqueeze(2)
490+
sin = self.sin_cache[
491+
prefill_input_positions].unsqueeze( # type: ignore
492+
1).unsqueeze(2)
455493
prefill_metadata = AscendMLAPrefillMetadata(
456494
attn_mask=self.runner.attn_mask,
457495
query_lens=query_lens[tokens_start:],
458496
seq_lens=seq_lens,
459497
context_lens=seq_lens[tokens_start:],
460-
input_positions=input_positions[tokens_start:],
498+
input_positions=prefill_input_positions,
461499
block_table=block_table[reqs_start:, ...],
462500
max_query_len=max_query_len,
463501
max_seq_lens=max_seq_lens,
464502
query_start_loc=prefill_query_start_loc,
465503
chunked_context=chunked_context_metadata,
504+
sin=sin,
505+
cos=cos,
466506
)
467507

468508
decode_metadata = None
@@ -507,9 +547,17 @@ def build(
507547
actual_seq_q_lens = query_start_loc[1:].tolist(
508548
) + self.runner.actual_seq_q_lens[num_reqs:num_reqs +
509549
num_reqs_pad_size]
550+
cos = self.cos_cache[
551+
input_positions].unsqueeze( # type: ignore
552+
1).unsqueeze(2)
553+
sin = self.sin_cache[
554+
input_positions].unsqueeze( # type: ignore
555+
1).unsqueeze(2)
510556
else:
511557
seq_lens_list = seq_lens.tolist()
512-
mc2_mask = self.generate_active_mask(num_actual_tokens, num_reqs)
558+
cos, sin = None, None
559+
mc2_mask = self.generate_activate_mask(
560+
num_actual_tokens, num_reqs + num_reqs_pad_size)
513561

514562
decode_metadata = AscendMLADecodeMetadata(
515563
input_positions=input_positions,
@@ -519,8 +567,9 @@ def build(
519567
max_seq_lens=max_seq_lens,
520568
attn_mask=self.runner.spec_attn_mask,
521569
actual_seq_q_lens=actual_seq_q_lens,
522-
mc2_mask=mc2_mask,
523-
)
570+
sin=sin,
571+
cos=cos,
572+
mc2_mask=mc2_mask)
524573

525574
return self.metadata_cls( # type: ignore
526575
num_actual_tokens=num_actual_tokens,
@@ -1112,15 +1161,8 @@ def forward(
11121161
decode_k_nope = None
11131162
assert attn_metadata.decode is not None
11141163
if self.running_in_graph:
1115-
seq_len = self.rotary_emb.max_position_embeddings * self.rotary_emb.scaling_factor
1116-
cos = self.rotary_emb.cos_cached[:seq_len].to(
1117-
dtype=decode_hs_or_q_c.dtype)
1118-
sin = self.rotary_emb.sin_cached[:seq_len].to(
1119-
dtype=decode_hs_or_q_c.dtype)
1120-
cos = cos[attn_metadata.decode.input_positions]
1121-
sin = sin[attn_metadata.decode.input_positions]
1122-
cos = cos[:, None, None, :]
1123-
sin = sin[:, None, None, :]
1164+
cos = attn_metadata.decode.cos
1165+
sin = attn_metadata.decode.sin
11241166
# Without explicitly controlling the order, IndexByTensor operations
11251167
# would be placed after `matmul W_KV_T` hindering the overlapping of
11261168
# KvRmsNormRopeCache and SingleRope.
@@ -1155,15 +1197,8 @@ def forward(
11551197
prefill_q_nope = prefill_q[..., :self.qk_nope_head_dim]
11561198
if self.torchair_graph_enabled:
11571199
num_tokens = prefill_hs_or_q_c.shape[0]
1158-
seq_len = self.rotary_emb.max_position_embeddings * self.rotary_emb.scaling_factor
1159-
cos = self.rotary_emb.cos_cached[:seq_len].to(
1160-
dtype=prefill_q_pe.dtype)
1161-
sin = self.rotary_emb.sin_cached[:seq_len].to(
1162-
dtype=prefill_q_pe.dtype)
1163-
cos = cos[attn_metadata.prefill.input_positions]
1164-
sin = sin[attn_metadata.prefill.input_positions]
1165-
cos = cos[:, None, None, :]
1166-
sin = sin[:, None, None, :]
1200+
cos = attn_metadata.prefill.cos
1201+
sin = attn_metadata.prefill.sin
11671202

11681203
prefill_q_pe = self.rope_single(prefill_q_pe, cos, sin)
11691204
prefill_k_pe, prefill_k_nope = self.exec_kv_prefill(

vllm_ascend/ops/fused_moe.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def fused_experts_with_mc2(
138138
# NOTE: Currently, when in A3 or in torchair graph, we need to pass in some extra param into dispatch & combine
139139
need_extra_args = (get_ascend_soc_version() == AscendSocVersion.A3
140140
or is_torchair)
141-
141+
142142
# NOTE: Currently, when in A3, we need to pass in some extra param into dispatch & combine
143143
a3_need_extra_args = get_ascend_soc_version() == AscendSocVersion.A3
144144

@@ -1168,7 +1168,7 @@ def forward(self,
11681168
if shared_experts:
11691169
if not self.enable_multistream_moe or fused_moe_state != FusedMoEState.MC2:
11701170
shared_hidden_states = shared_experts(hidden_states)
1171-
1171+
11721172
attn_metadata = get_forward_context().attn_metadata
11731173
mc2_mask = attn_metadata.decode.mc2_mask if attn_metadata is not None and attn_metadata.decode is not None else None
11741174

@@ -1179,6 +1179,9 @@ def forward(self,
11791179
hidden_states, (0, 0, 0, tp_size - num_tokens))
11801180
router_logits = nn.functional.pad(
11811181
router_logits, (0, 0, 0, tp_size - num_tokens))
1182+
if mc2_mask is not None:
1183+
mc2_mask = nn.functional.pad(mc2_mask,
1184+
(0, tp_size - num_tokens))
11821185
chunk_hidden_states = torch.tensor_split(hidden_states,
11831186
tp_size,
11841187
dim=0)
@@ -1188,9 +1191,11 @@ def forward(self,
11881191
tp_rank = get_tensor_model_parallel_rank()
11891192
hidden_states = chunk_hidden_states[tp_rank]
11901193
router_logits = chunk_router_logits[tp_rank]
1194+
11911195
if mc2_mask is not None:
11921196
chunk_mc2_mask = torch.tensor_split(mc2_mask, tp_size, dim=0)
11931197
mc2_mask = chunk_mc2_mask[tp_rank]
1198+
11941199
if self.dp_size > 1 and fused_moe_state == FusedMoEState.AllGather:
11951200
# NOTE: When in torchair graph, it has been padded in model_runner_v1
11961201
if not self.torchair_graph_enabled or is_prefill:

vllm_ascend/quantization/w8a8_dynamic.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ def fused_experts_with_mc2(
233233
# NOTE: Currently, when in A3 or in torchair graph, we need to pass in some extra param into dispatch & combine
234234
need_extra_args = (get_ascend_soc_version() == AscendSocVersion.A3
235235
or is_torchair)
236-
236+
237237
# NOTE: Currently, when in A3, we need to pass in some extra param into dispatch & combine
238238
a3_need_extra_args = get_ascend_soc_version() == AscendSocVersion.A3
239239

@@ -803,6 +803,7 @@ def apply(
803803
topk_weights = topk_weights.to(x.dtype)
804804

805805
if fused_moe_state == FusedMoEState.MC2:
806+
mc2_mask = kwargs.get("mc2_mask", None)
806807
return fused_experts_with_mc2(
807808
hidden_states=x,
808809
w1=layer.w13_weight,
@@ -819,7 +820,8 @@ def apply(
819820
shared_experts=shared_experts,
820821
is_torchair=self.torchair_graph_enabled,
821822
quantized_x_for_share=shared_gate_up,
822-
dynamic_scale_for_share=shared_dequant_scale)
823+
dynamic_scale_for_share=shared_dequant_scale,
824+
mc2_mask=mc2_mask)
823825
elif fused_moe_state == FusedMoEState.AllGather:
824826
return fused_experts(hidden_states=x,
825827
w1=layer.w13_weight,

vllm_ascend/worker/model_runner_v1.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1666,6 +1666,8 @@ def _dummy_run(
16661666
attn_metadata.decode.block_table)
16671667
torch._dynamo.mark_static(
16681668
attn_metadata.decode.input_positions)
1669+
torch._dynamo.mark_static(attn_metadata.decode.sin)
1670+
torch._dynamo.mark_static(attn_metadata.decode.cos)
16691671
torch._dynamo.mark_static(
16701672
attn_metadata.decode.mc2_mask)
16711673
torch._dynamo.mark_static(attn_metadata.slot_mapping)

0 commit comments

Comments
 (0)