Skip to content

Commit d556d49

Browse files
author
weijinqian_v1
committed
handle conflict
Signed-off-by: weijinqian_v1 <weijinqian@huawei.com>
2 parents 56ded48 + 57664f0 commit d556d49

File tree

8 files changed

+289
-44
lines changed

8 files changed

+289
-44
lines changed

tests/long_term/spec_decode_v1/test_v1_mtp_correctness.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,8 @@ def test_mtp_torchair_correctness(
114114
enforce_eager=False,
115115
additional_config={
116116
"torchair_graph_config": {
117-
"enabled": True
117+
"enabled": True,
118+
"graph_batch_size": [256]
118119
},
119120
"ascend_scheduler_config": {
120121
"enabled": True
@@ -132,7 +133,8 @@ def test_mtp_torchair_correctness(
132133
},
133134
additional_config={
134135
"torchair_graph_config": {
135-
"enabled": True
136+
"enabled": True,
137+
"graph_batch_size": [256]
136138
},
137139
"ascend_scheduler_config": {
138140
"enabled": True
Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
# fused moe ops test will hit the infer_schema error, we need add the patch
2+
# here to make the test pass.
3+
import vllm_ascend.patch.worker.patch_common.patch_utils # type: ignore[import] # isort: skip # noqa
4+
5+
import json
6+
import unittest
7+
from typing import List, TypedDict
8+
from unittest import mock
9+
10+
import torch
11+
12+
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
13+
14+
15+
class Device(TypedDict):
16+
device_id: int
17+
device_expert: List[int]
18+
19+
20+
class Layer(TypedDict):
21+
layer_id: int
22+
device_count: int
23+
device_list: List[Device]
24+
25+
26+
class MockData(TypedDict):
27+
moe_layer_count: int
28+
layer_list: List[Layer]
29+
30+
31+
MOCK_DATA: MockData = {
32+
"moe_layer_count":
33+
1,
34+
"layer_list": [{
35+
"layer_id":
36+
0,
37+
"device_count":
38+
2,
39+
"device_list": [{
40+
"device_id": 0,
41+
"device_expert": [7, 2, 0, 3, 5]
42+
}, {
43+
"device_id": 1,
44+
"device_expert": [6, 1, 4, 7, 2]
45+
}]
46+
}]
47+
}
48+
49+
50+
class TestExpertLoadBalancer(unittest.TestCase):
51+
52+
def setUp(self):
53+
json_file = "expert_map.json"
54+
with open(json_file, 'w') as f:
55+
json.dump(MOCK_DATA, f)
56+
57+
self.expert_load_balancer = ExpertLoadBalancer(json_file,
58+
global_expert_num=8)
59+
60+
def test_init(self):
61+
62+
self.assertIsInstance(self.expert_load_balancer.expert_map_tensor,
63+
torch.Tensor)
64+
self.assertEqual(self.expert_load_balancer.layers_num,
65+
MOCK_DATA["moe_layer_count"])
66+
self.assertEqual(self.expert_load_balancer.ranks_num,
67+
MOCK_DATA["layer_list"][0]["device_count"])
68+
69+
def test_generate_index_dicts(self):
70+
tensor_2d = torch.tensor([[7, 2, 0, 3, 5], [6, 1, 4, 7, 2]])
71+
result = self.expert_load_balancer.generate_index_dicts(tensor_2d)
72+
expected_result = [{
73+
7: 0,
74+
2: 1,
75+
0: 2,
76+
3: 3,
77+
5: 4
78+
}, {
79+
6: 5,
80+
1: 6,
81+
4: 7,
82+
7: 8,
83+
2: 9
84+
}]
85+
self.assertEqual(result, expected_result)
86+
87+
def test_generate_expert_placement_map(self):
88+
expert_placement_map = self.expert_load_balancer.generate_expert_placement_map(
89+
)
90+
self.assertEqual(expert_placement_map.shape,
91+
(self.expert_load_balancer.layers_num,
92+
self.expert_load_balancer.ranks_num, 8))
93+
self.assertTrue(torch.all(expert_placement_map >= -1))
94+
95+
def test_generate_log2phy_expert_map(self):
96+
layer_id = 0
97+
log2phy_map = self.expert_load_balancer.generate_log2phy_expert_map(
98+
layer_id)
99+
self.assertEqual(log2phy_map.shape,
100+
(self.expert_load_balancer.ranks_num, 8))
101+
self.assertTrue(torch.all(log2phy_map >= -1))
102+
103+
@mock.patch("torch_npu.npu._lazy_init")
104+
@mock.patch("torch.npu.current_device", return_value="cpu")
105+
def test_get_rank_placement_map(self, mock_current_device, mock_lazy_init):
106+
layer_id = 0
107+
rank_id = 0
108+
rank_local_expert_num, rank_expert_map = self.expert_load_balancer.get_rank_placement_map(
109+
layer_id, rank_id)
110+
self.assertEqual(rank_local_expert_num, 5)
111+
expected_tensor = torch.tensor([2, -1, 1, 3, -1, 4, -1, 0],
112+
dtype=torch.int32).to(
113+
rank_expert_map.device)
114+
self.assertTrue(rank_expert_map.equal(expected_tensor))
115+
116+
rank_id = 1
117+
rank_local_expert_num, rank_expert_map = self.expert_load_balancer.get_rank_placement_map(
118+
layer_id, rank_id)
119+
expected_tensor = torch.tensor([-1, 1, 4, -1, 2, -1, 0, 3],
120+
dtype=torch.int32).to(
121+
rank_expert_map.device)
122+
self.assertTrue(rank_expert_map.equal(expected_tensor))
123+
124+
def test_get_rank_log2phy_map(self):
125+
layer_id = 0
126+
rank_id = 0
127+
log2phy_map = self.expert_load_balancer.get_rank_log2phy_map(
128+
layer_id, rank_id)
129+
expected_tensor = torch.tensor([2, 6, 1, 3, 7, 4, 5, 0],
130+
dtype=torch.int32).to(
131+
log2phy_map.device)
132+
self.assertTrue(log2phy_map.equal(expected_tensor))
133+
134+
rank_id = 1
135+
log2phy_map = self.expert_load_balancer.get_rank_log2phy_map(
136+
layer_id, rank_id)
137+
expected_tensor = torch.tensor([2, 6, 9, 3, 7, 4, 5, 8],
138+
dtype=torch.int32).to(
139+
log2phy_map.device)
140+
self.assertTrue(log2phy_map.equal(expected_tensor))
141+
142+
def test_get_global_redundant_expert_num(self):
143+
redundant_expert_num = self.expert_load_balancer.get_global_redundant_expert_num(
144+
)
145+
expected_redundant_expert_num = len(MOCK_DATA["layer_list"][0]["device_list"][0]["device_expert"]) * \
146+
MOCK_DATA["layer_list"][0]["device_count"] - 8
147+
self.assertEqual(redundant_expert_num, expected_redundant_expert_num)

vllm_ascend/attention/mla_v1.py

Lines changed: 74 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from vllm.config import get_current_vllm_config
1212
from vllm.model_executor.layers.linear import (LinearBase,
1313
UnquantizedLinearMethod)
14+
from vllm.platforms import current_platform
1415
from vllm.utils import cdiv, round_down
1516

1617
from vllm_ascend import envs
@@ -81,6 +82,8 @@ class ChunkedContextMetadata:
8182
max_query_len: int
8283
max_seq_lens: int
8384
chunked_context: Optional[ChunkedContextMetadata] = None
85+
sin: torch.Tensor = None
86+
cos: torch.Tensor = None
8487

8588

8689
@dataclass
@@ -94,6 +97,9 @@ class AscendMLADecodeMetadata:
9497
seq_lens_list: list[int]
9598
actual_seq_q_lens: Optional[list[int]] = None
9699
attn_mask: Optional[torch.Tensor] = None
100+
sin: torch.Tensor = None
101+
cos: torch.Tensor = None
102+
mc2_mask: Optional[torch.Tensor] = None
97103

98104

99105
@dataclass
@@ -205,6 +211,16 @@ def __init__(self,
205211
)
206212
ascend_config = get_ascend_config()
207213
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
214+
self.rope_dim = self.runner.model_config.hf_text_config.qk_rope_head_dim
215+
self.cos_cache = None
216+
self.sin_cache = None
217+
218+
def generate_activate_mask(self, actual_seqs_num, batch_size):
219+
mc2_mask = torch.zeros(batch_size,
220+
dtype=torch.bool,
221+
device=current_platform.device_type)
222+
mc2_mask[:actual_seqs_num].fill_(True)
223+
return mc2_mask
208224

209225
def reorder_batch(self, input_batch: "InputBatch",
210226
scheduler_output: "SchedulerOutput") -> bool:
@@ -317,7 +333,7 @@ def build_torchair_graph_dummy(
317333
num_reqs, block_table)
318334
num_tokens = num_reqs * self.runner.decode_token_per_req
319335
seq_lens = torch.zeros(num_reqs, dtype=torch.int32, device=device)
320-
seq_lens_list = seq_lens.tolist()
336+
seq_lens_list = [0] * num_reqs
321337
input_positions = torch.zeros(num_tokens,
322338
dtype=torch.int32,
323339
device=device).long()
@@ -336,6 +352,19 @@ def build_torchair_graph_dummy(
336352
else:
337353
attn_state = AscendAttentionState.DecodeOnly
338354
num_decode_tokens = 1
355+
sin = torch.ones(num_reqs,
356+
1,
357+
1,
358+
self.rope_dim,
359+
dtype=self.runner.dtype,
360+
device=device)
361+
cos = torch.ones(num_reqs,
362+
1,
363+
1,
364+
self.rope_dim,
365+
dtype=self.runner.dtype,
366+
device=device)
367+
mc2_mask = self.generate_activate_mask(num_actual_tokens, num_reqs)
339368
decode_metadata = AscendMLADecodeMetadata(
340369
input_positions=input_positions,
341370
block_table=block_table,
@@ -344,7 +373,9 @@ def build_torchair_graph_dummy(
344373
max_seq_lens=1,
345374
attn_mask=self.runner.spec_attn_mask,
346375
actual_seq_q_lens=self.runner.actual_seq_q_lens[:num_reqs],
347-
)
376+
sin=sin,
377+
cos=cos,
378+
mc2_mask=mc2_mask)
348379
return self.metadata_cls( # type: ignore
349380
num_input_tokens=num_actual_tokens,
350381
num_actual_tokens=num_actual_tokens,
@@ -396,6 +427,16 @@ def build(
396427
max_query_len = query_lens.max().item()
397428
max_seq_lens = seq_lens.max().item()
398429
query_start_loc = common_attn_metadata.query_start_loc
430+
if self.cos_cache is None:
431+
self.cos_cache = self.runner.get_model(
432+
).model.layers[0].self_attn.rotary_emb.cos_cached
433+
self.sin_cache = self.runner.get_model(
434+
).model.layers[0].self_attn.rotary_emb.sin_cached
435+
if self.cos_cache.dtype != self.runner.dtype: # type: ignore
436+
self.cos_cache = self.cos_cache.to( # type: ignore
437+
self.runner.dtype) # type: ignore
438+
self.sin_cache = self.sin_cache.to( # type: ignore
439+
self.runner.dtype) # type: ignore
399440

400441
prefill_metadata = None
401442
chunked_context_metadata = None
@@ -442,24 +483,32 @@ def build(
442483
chunk_seq_lens=chunk_seq_lens,
443484
workspace=self.chunked_prefill_workspace,
444485
)
445-
486+
prefill_input_positions = input_positions[tokens_start:]
487+
cos = self.cos_cache[
488+
prefill_input_positions].unsqueeze( # type: ignore
489+
1).unsqueeze(2)
490+
sin = self.sin_cache[
491+
prefill_input_positions].unsqueeze( # type: ignore
492+
1).unsqueeze(2)
446493
prefill_metadata = AscendMLAPrefillMetadata(
447494
attn_mask=self.runner.attn_mask,
448495
query_lens=query_lens[tokens_start:],
449496
seq_lens=seq_lens,
450497
context_lens=seq_lens[tokens_start:],
451-
input_positions=input_positions[tokens_start:],
498+
input_positions=prefill_input_positions,
452499
block_table=block_table[reqs_start:, ...],
453500
max_query_len=max_query_len,
454501
max_seq_lens=max_seq_lens,
455502
query_start_loc=prefill_query_start_loc,
456503
chunked_context=chunked_context_metadata,
504+
sin=sin,
505+
cos=cos,
457506
)
458507

459508
decode_metadata = None
460509
use_torchair_graph = num_token_pad_size != -1
461510
if self._num_decodes > 0:
462-
actual_seq_q_lens = None
511+
actual_seq_q_lens = query_start_loc[1:].tolist()
463512
max_seq_lens = seq_lens[:self._num_decodes].max().item()
464513
seq_lens = seq_lens[:self._num_decode_tokens]
465514
input_positions = input_positions[:self._num_decode_tokens]
@@ -498,8 +547,17 @@ def build(
498547
actual_seq_q_lens = query_start_loc[1:].tolist(
499548
) + self.runner.actual_seq_q_lens[num_reqs:num_reqs +
500549
num_reqs_pad_size]
550+
cos = self.cos_cache[
551+
input_positions].unsqueeze( # type: ignore
552+
1).unsqueeze(2)
553+
sin = self.sin_cache[
554+
input_positions].unsqueeze( # type: ignore
555+
1).unsqueeze(2)
501556
else:
502557
seq_lens_list = seq_lens.tolist()
558+
cos, sin = None, None
559+
mc2_mask = self.generate_activate_mask(
560+
num_actual_tokens, num_reqs + num_reqs_pad_size)
503561

504562
decode_metadata = AscendMLADecodeMetadata(
505563
input_positions=input_positions,
@@ -509,7 +567,9 @@ def build(
509567
max_seq_lens=max_seq_lens,
510568
attn_mask=self.runner.spec_attn_mask,
511569
actual_seq_q_lens=actual_seq_q_lens,
512-
)
570+
sin=sin,
571+
cos=cos,
572+
mc2_mask=mc2_mask)
513573

514574
return self.metadata_cls( # type: ignore
515575
num_actual_tokens=num_actual_tokens,
@@ -968,11 +1028,13 @@ def _forward_decode(
9681028
self.qk_rope_head_dim)
9691029
input_layout = "BNSD"
9701030

971-
# TorchAir's shape is [bs, num_heads_per_rank, q_seq_len, dim]
9721031
if attn_metadata.attn_state == AscendAttentionState.SpecDecoding:
9731032
assert num_tokens % self.spec_token_num == 0
1033+
if self.enable_kv_nz:
1034+
input_layout = "TND_NTD"
1035+
else:
1036+
input_layout = "TND"
9741037
# [bs * q_seq_len, num_heads_per_rank, dim]
975-
input_layout = "TND"
9761038
q_nope = q_nope.view(num_tokens, self.num_heads, -1)
9771039
q_pe = q_pe.view(num_tokens, self.num_heads, -1)
9781040
sparse_mode = 3
@@ -1101,15 +1163,8 @@ def forward(
11011163
decode_k_nope = None
11021164
assert attn_metadata.decode is not None
11031165
if self.running_in_graph:
1104-
seq_len = self.rotary_emb.max_position_embeddings * self.rotary_emb.scaling_factor
1105-
cos = self.rotary_emb.cos_cached[:seq_len].to(
1106-
dtype=decode_hs_or_q_c.dtype)
1107-
sin = self.rotary_emb.sin_cached[:seq_len].to(
1108-
dtype=decode_hs_or_q_c.dtype)
1109-
cos = cos[attn_metadata.decode.input_positions]
1110-
sin = sin[attn_metadata.decode.input_positions]
1111-
cos = cos[:, None, None, :]
1112-
sin = sin[:, None, None, :]
1166+
cos = attn_metadata.decode.cos
1167+
sin = attn_metadata.decode.sin
11131168
# Without explicitly controlling the order, IndexByTensor operations
11141169
# would be placed after `matmul W_KV_T` hindering the overlapping of
11151170
# KvRmsNormRopeCache and SingleRope.
@@ -1144,15 +1199,8 @@ def forward(
11441199
prefill_q_nope = prefill_q[..., :self.qk_nope_head_dim]
11451200
if self.torchair_graph_enabled:
11461201
num_tokens = prefill_hs_or_q_c.shape[0]
1147-
seq_len = self.rotary_emb.max_position_embeddings * self.rotary_emb.scaling_factor
1148-
cos = self.rotary_emb.cos_cached[:seq_len].to(
1149-
dtype=prefill_q_pe.dtype)
1150-
sin = self.rotary_emb.sin_cached[:seq_len].to(
1151-
dtype=prefill_q_pe.dtype)
1152-
cos = cos[attn_metadata.prefill.input_positions]
1153-
sin = sin[attn_metadata.prefill.input_positions]
1154-
cos = cos[:, None, None, :]
1155-
sin = sin[:, None, None, :]
1202+
cos = attn_metadata.prefill.cos
1203+
sin = attn_metadata.prefill.sin
11561204

11571205
prefill_q_pe = self.rope_single(prefill_q_pe, cos, sin)
11581206
prefill_k_pe, prefill_k_nope = self.exec_kv_prefill(

0 commit comments

Comments
 (0)