Skip to content

Commit 57664f0

Browse files
authored
[cherry-pick] static EPLB fix bug, add unit test to v0.9.1-dev (#1667)
### What this PR does / why we need it? [cherry-pick master-> 0.9.1-dev](#1186) 1.add static EPLB unit test 2.fix bug: Tensor cannot be directly judged by if statements ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? Run the unit test. Signed-off-by: songshanhu07 <1763685535@qq.com>
1 parent 31208b4 commit 57664f0

File tree

2 files changed

+149
-2
lines changed

2 files changed

+149
-2
lines changed
Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
# fused moe ops test will hit the infer_schema error, we need add the patch
2+
# here to make the test pass.
3+
import vllm_ascend.patch.worker.patch_common.patch_utils # type: ignore[import] # isort: skip # noqa
4+
5+
import json
6+
import unittest
7+
from typing import List, TypedDict
8+
from unittest import mock
9+
10+
import torch
11+
12+
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
13+
14+
15+
class Device(TypedDict):
16+
device_id: int
17+
device_expert: List[int]
18+
19+
20+
class Layer(TypedDict):
21+
layer_id: int
22+
device_count: int
23+
device_list: List[Device]
24+
25+
26+
class MockData(TypedDict):
27+
moe_layer_count: int
28+
layer_list: List[Layer]
29+
30+
31+
MOCK_DATA: MockData = {
32+
"moe_layer_count":
33+
1,
34+
"layer_list": [{
35+
"layer_id":
36+
0,
37+
"device_count":
38+
2,
39+
"device_list": [{
40+
"device_id": 0,
41+
"device_expert": [7, 2, 0, 3, 5]
42+
}, {
43+
"device_id": 1,
44+
"device_expert": [6, 1, 4, 7, 2]
45+
}]
46+
}]
47+
}
48+
49+
50+
class TestExpertLoadBalancer(unittest.TestCase):
51+
52+
def setUp(self):
53+
json_file = "expert_map.json"
54+
with open(json_file, 'w') as f:
55+
json.dump(MOCK_DATA, f)
56+
57+
self.expert_load_balancer = ExpertLoadBalancer(json_file,
58+
global_expert_num=8)
59+
60+
def test_init(self):
61+
62+
self.assertIsInstance(self.expert_load_balancer.expert_map_tensor,
63+
torch.Tensor)
64+
self.assertEqual(self.expert_load_balancer.layers_num,
65+
MOCK_DATA["moe_layer_count"])
66+
self.assertEqual(self.expert_load_balancer.ranks_num,
67+
MOCK_DATA["layer_list"][0]["device_count"])
68+
69+
def test_generate_index_dicts(self):
70+
tensor_2d = torch.tensor([[7, 2, 0, 3, 5], [6, 1, 4, 7, 2]])
71+
result = self.expert_load_balancer.generate_index_dicts(tensor_2d)
72+
expected_result = [{
73+
7: 0,
74+
2: 1,
75+
0: 2,
76+
3: 3,
77+
5: 4
78+
}, {
79+
6: 5,
80+
1: 6,
81+
4: 7,
82+
7: 8,
83+
2: 9
84+
}]
85+
self.assertEqual(result, expected_result)
86+
87+
def test_generate_expert_placement_map(self):
88+
expert_placement_map = self.expert_load_balancer.generate_expert_placement_map(
89+
)
90+
self.assertEqual(expert_placement_map.shape,
91+
(self.expert_load_balancer.layers_num,
92+
self.expert_load_balancer.ranks_num, 8))
93+
self.assertTrue(torch.all(expert_placement_map >= -1))
94+
95+
def test_generate_log2phy_expert_map(self):
96+
layer_id = 0
97+
log2phy_map = self.expert_load_balancer.generate_log2phy_expert_map(
98+
layer_id)
99+
self.assertEqual(log2phy_map.shape,
100+
(self.expert_load_balancer.ranks_num, 8))
101+
self.assertTrue(torch.all(log2phy_map >= -1))
102+
103+
@mock.patch("torch_npu.npu._lazy_init")
104+
@mock.patch("torch.npu.current_device", return_value="cpu")
105+
def test_get_rank_placement_map(self, mock_current_device, mock_lazy_init):
106+
layer_id = 0
107+
rank_id = 0
108+
rank_local_expert_num, rank_expert_map = self.expert_load_balancer.get_rank_placement_map(
109+
layer_id, rank_id)
110+
self.assertEqual(rank_local_expert_num, 5)
111+
expected_tensor = torch.tensor([2, -1, 1, 3, -1, 4, -1, 0],
112+
dtype=torch.int32).to(
113+
rank_expert_map.device)
114+
self.assertTrue(rank_expert_map.equal(expected_tensor))
115+
116+
rank_id = 1
117+
rank_local_expert_num, rank_expert_map = self.expert_load_balancer.get_rank_placement_map(
118+
layer_id, rank_id)
119+
expected_tensor = torch.tensor([-1, 1, 4, -1, 2, -1, 0, 3],
120+
dtype=torch.int32).to(
121+
rank_expert_map.device)
122+
self.assertTrue(rank_expert_map.equal(expected_tensor))
123+
124+
def test_get_rank_log2phy_map(self):
125+
layer_id = 0
126+
rank_id = 0
127+
log2phy_map = self.expert_load_balancer.get_rank_log2phy_map(
128+
layer_id, rank_id)
129+
expected_tensor = torch.tensor([2, 6, 1, 3, 7, 4, 5, 0],
130+
dtype=torch.int32).to(
131+
log2phy_map.device)
132+
self.assertTrue(log2phy_map.equal(expected_tensor))
133+
134+
rank_id = 1
135+
log2phy_map = self.expert_load_balancer.get_rank_log2phy_map(
136+
layer_id, rank_id)
137+
expected_tensor = torch.tensor([2, 6, 9, 3, 7, 4, 5, 8],
138+
dtype=torch.int32).to(
139+
log2phy_map.device)
140+
self.assertTrue(log2phy_map.equal(expected_tensor))
141+
142+
def test_get_global_redundant_expert_num(self):
143+
redundant_expert_num = self.expert_load_balancer.get_global_redundant_expert_num(
144+
)
145+
expected_redundant_expert_num = len(MOCK_DATA["layer_list"][0]["device_list"][0]["device_expert"]) * \
146+
MOCK_DATA["layer_list"][0]["device_count"] - 8
147+
self.assertEqual(redundant_expert_num, expected_redundant_expert_num)

vllm_ascend/quantization/w8a8_dynamic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@ def fused_experts_with_mc2(
217217
dynamic_scale_for_share: Optional[Any] = None,
218218
mc2_mask: Optional[torch.Tensor] = None,
219219
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
220-
if log2phy:
220+
if log2phy is not None:
221221
topk_ids = log2phy[topk_ids]
222222
quant_mode = 2
223223
ep_group = get_ep_group()
@@ -352,7 +352,7 @@ def fused_experts_with_all2all(hidden_states: torch.Tensor,
352352
global_redundant_expert_num: int = 0,
353353
w1_scale_bias: torch.Tensor = None,
354354
w2_scale_bias: torch.Tensor = None):
355-
if log2phy:
355+
if log2phy is not None:
356356
topk_ids = log2phy[topk_ids]
357357
original_shape = hidden_states.shape
358358
if len(original_shape) == 3:

0 commit comments

Comments
 (0)