Skip to content

Commit e9fd658

Browse files
authored
[Feature] Expert Parallelism Load Balancer (EPLB) (#18343)
Signed-off-by: Bowen Wang <abmfy@icloud.com>
1 parent 07b8fae commit e9fd658

File tree

24 files changed

+2446
-54
lines changed

24 files changed

+2446
-54
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,23 @@ steps:
168168
- VLLM_ALLOW_INSECURE_SERIALIZATION=1 RAY_DEDUP_LOGS=0 python3 rlhf_colocate.py
169169
- popd
170170

171+
- label: EPLB Algorithm Test
172+
working_dir: "/vllm-workspace/tests"
173+
source_file_dependencies:
174+
- vllm/distributed/eplb
175+
- tests/distributed/test_eplb_algo.py
176+
commands:
177+
- pytest -v -s distributed/test_eplb_algo.py
178+
179+
- label: EPLB Execution Test # 5min
180+
working_dir: "/vllm-workspace/tests"
181+
num_gpus: 4
182+
source_file_dependencies:
183+
- vllm/distributed/eplb
184+
- tests/distributed/test_eplb_execute.py
185+
commands:
186+
- pytest -v -s distributed/test_eplb_execute.py
187+
171188
- label: Metrics, Tracing Test # 10min
172189
mirror_hardwares: [amdexperimental, amdproduction]
173190
num_gpus: 2

tests/distributed/test_eplb_algo.py

Lines changed: 292 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,292 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
import pytest
5+
import torch
6+
7+
from vllm.distributed.eplb.rebalance_algo import rebalance_experts
8+
9+
10+
def test_basic_rebalance():
11+
"""Test basic rebalancing functionality"""
12+
# Example from https://github.com/deepseek-ai/eplb
13+
weight = torch.tensor([
14+
[90, 132, 40, 61, 104, 165, 39, 4, 73, 56, 183, 86],
15+
[20, 107, 104, 64, 19, 197, 187, 157, 172, 86, 16, 27],
16+
])
17+
18+
num_layers = weight.shape[0]
19+
num_replicas = 16
20+
num_groups = 4
21+
num_nodes = 2
22+
num_gpus = 8
23+
24+
phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas,
25+
num_groups, num_nodes,
26+
num_gpus)
27+
28+
# Verify output shapes
29+
assert phy2log.shape == (
30+
2,
31+
16,
32+
), f"Expected `phy2log` shape (2, 16), got {phy2log.shape}"
33+
assert (log2phy.shape[0] == 2
34+
), f"Expected `log2phy` first dimension 2, got {log2phy.shape[0]}"
35+
assert (
36+
log2phy.shape[1] == 12
37+
), f"Expected `log2phy` second dimension 12, got {log2phy.shape[1]}"
38+
assert logcnt.shape == (
39+
2,
40+
12,
41+
), f"Expected `logcnt` shape (2, 12), got {logcnt.shape}"
42+
43+
# Verify physical to logical expert mapping range is correct
44+
assert torch.all(phy2log >= 0) and torch.all(
45+
phy2log < 12), "Physical to logical mapping should be in range [0, 12)"
46+
47+
# Verify expert count reasonableness
48+
assert torch.all(
49+
logcnt >= 1), "Each logical expert should have at least 1 replica"
50+
assert (
51+
torch.sum(logcnt, dim=1).sum() == num_replicas *
52+
num_layers), f"Total replicas should be {num_replicas * num_layers}"
53+
54+
# Verify expected output
55+
expected_phy2log = torch.tensor([
56+
[5, 6, 5, 7, 8, 4, 3, 4, 10, 9, 10, 2, 0, 1, 11, 1],
57+
[7, 10, 6, 8, 6, 11, 8, 9, 2, 4, 5, 1, 5, 0, 3, 1],
58+
])
59+
assert torch.all(phy2log == expected_phy2log)
60+
61+
expected_logcnt = torch.tensor([[1, 2, 1, 1, 2, 2, 1, 1, 1, 1, 2, 1],
62+
[1, 2, 1, 1, 1, 2, 2, 1, 2, 1, 1, 1]])
63+
assert torch.all(logcnt == expected_logcnt)
64+
65+
66+
def test_single_gpu_case():
67+
"""Test single GPU case"""
68+
weight = torch.tensor([[10, 20, 30, 40]])
69+
num_replicas = 4
70+
num_groups = 1
71+
num_nodes = 1
72+
num_gpus = 1
73+
74+
phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas,
75+
num_groups, num_nodes,
76+
num_gpus)
77+
78+
# Verify shapes
79+
assert phy2log.shape == (1, 4)
80+
assert log2phy.shape[0] == 1
81+
assert log2phy.shape[1] == 4
82+
assert logcnt.shape == (1, 4)
83+
84+
# Verify all logical experts are mapped
85+
assert set(phy2log[0].tolist()) == {0, 1, 2, 3}
86+
87+
88+
def test_equal_weights():
89+
"""Test case with equal weights"""
90+
weight = torch.tensor([[50, 50, 50, 50, 50, 50, 50, 50]])
91+
num_replicas = 8
92+
num_groups = 2
93+
num_nodes = 2
94+
num_gpus = 4
95+
96+
phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas,
97+
num_groups, num_nodes,
98+
num_gpus)
99+
100+
# Verify shapes
101+
assert phy2log.shape == (1, 8)
102+
assert logcnt.shape == (1, 8)
103+
104+
# With equal weights, each expert should have exactly one replica
105+
assert torch.all(
106+
logcnt == 1
107+
), "With equal weights and no replication, " \
108+
"each expert should have exactly 1 replica"
109+
110+
111+
def test_extreme_weight_imbalance():
112+
"""Test extreme weight imbalance case"""
113+
weight = torch.tensor([[1000, 1, 1, 1, 1, 1, 1, 1]])
114+
num_replicas = 12
115+
num_groups = 2
116+
num_nodes = 2
117+
num_gpus = 4
118+
119+
phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas,
120+
num_groups, num_nodes,
121+
num_gpus)
122+
123+
# Verify shapes
124+
assert phy2log.shape == (1, 12)
125+
assert logcnt.shape == (1, 8)
126+
127+
# Expert with highest weight (index 0) should have more replicas
128+
assert (
129+
logcnt[0, 0]
130+
> logcnt[0, 1]), "Expert with highest weight should have more replicas"
131+
132+
133+
def test_multiple_layers():
134+
"""Test multiple layers case"""
135+
weight = torch.tensor([
136+
[10, 20, 30, 40, 50, 60], # First layer
137+
[60, 50, 40, 30, 20, 10], # Second layer (opposite weight pattern)
138+
[25, 25, 25, 25, 25, 25], # Third layer (equal weights)
139+
])
140+
num_replicas = 8
141+
num_groups = 2
142+
num_nodes = 2
143+
num_gpus = 4
144+
145+
phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas,
146+
num_groups, num_nodes,
147+
num_gpus)
148+
149+
# Verify shapes
150+
assert phy2log.shape == (3, 8)
151+
assert logcnt.shape == (3, 6)
152+
153+
# Verify expert allocation is reasonable for each layer
154+
for layer in range(3):
155+
assert torch.all(phy2log[layer] >= 0) and torch.all(
156+
phy2log[layer] < 6
157+
), f"Layer {layer} physical to logical mapping" \
158+
"should be in range [0, 6)"
159+
assert (torch.sum(logcnt[layer]) == num_replicas
160+
), f"Layer {layer} total replicas should be {num_replicas}"
161+
162+
163+
def test_parameter_validation():
164+
"""Test parameter validation"""
165+
weight = torch.tensor([[10, 20, 30, 40]])
166+
167+
# Test non-divisible case - this should handle normally without throwing
168+
# errors because the function will fall back to global load balancing
169+
# strategy
170+
phy2log, log2phy, logcnt = rebalance_experts(weight, 8, 3, 2, 4)
171+
assert phy2log.shape == (1, 8)
172+
assert logcnt.shape == (1, 4)
173+
174+
# Test cases that will actually cause errors:
175+
# num_physical_experts not divisible by num_gpus
176+
with pytest.raises(AssertionError):
177+
rebalance_experts(weight, 7, 2, 2, 4) # 7 not divisible by 4
178+
179+
180+
def test_small_scale_hierarchical():
181+
"""Test small-scale hierarchical load balancing"""
182+
weight = torch.tensor([
183+
[100, 50, 200, 75, 150, 25, 300, 80], # 8 experts
184+
])
185+
num_replicas = 12
186+
num_groups = 4 # 4 groups, 2 experts each
187+
num_nodes = 2 # 2 nodes
188+
num_gpus = 4 # 4 GPUs
189+
190+
phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas,
191+
num_groups, num_nodes,
192+
num_gpus)
193+
194+
# Verify basic constraints
195+
assert phy2log.shape == (1, 12)
196+
assert logcnt.shape == (1, 8)
197+
assert torch.sum(logcnt) == num_replicas
198+
assert torch.all(logcnt >= 1)
199+
200+
# Expert with highest weight should have more replicas
201+
max_weight_expert = torch.argmax(weight[0])
202+
assert (logcnt[0, max_weight_expert]
203+
>= 2), "Highest weight expert should have multiple replicas"
204+
205+
206+
def test_global_load_balance_fallback():
207+
"""Test global load balancing fallback case"""
208+
# When num_groups % num_nodes != 0, should fall back to global load
209+
# balancing
210+
weight = torch.tensor([[10, 20, 30, 40, 50, 60]])
211+
num_replicas = 8
212+
num_groups = 3 # Cannot be divided evenly by num_nodes=2
213+
num_nodes = 2
214+
num_gpus = 4
215+
216+
phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas,
217+
num_groups, num_nodes,
218+
num_gpus)
219+
220+
# Should work normally, just using global load balancing strategy
221+
assert phy2log.shape == (1, 8)
222+
assert logcnt.shape == (1, 6)
223+
assert torch.sum(logcnt) == num_replicas
224+
225+
226+
@pytest.mark.parametrize("device", ["cpu", "cuda"])
227+
def test_device_compatibility(device):
228+
"""Test device compatibility"""
229+
if device == "cuda" and not torch.cuda.is_available():
230+
pytest.skip("CUDA not available")
231+
232+
weight = torch.tensor([[10, 20, 30, 40]], device=device)
233+
num_replicas = 6
234+
num_groups = 2
235+
num_nodes = 1
236+
num_gpus = 2
237+
238+
phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas,
239+
num_groups, num_nodes,
240+
num_gpus)
241+
242+
# Function will convert to CPU internally, but should handle different
243+
# device inputs normally
244+
assert phy2log.shape == (1, 6)
245+
assert logcnt.shape == (1, 4)
246+
247+
248+
def test_additional_cases():
249+
"""Test more edge cases and different parameter combinations"""
250+
251+
# Test case 1: Large-scale distributed setup
252+
weight1 = torch.tensor(
253+
[[50, 100, 75, 120, 90, 60, 80, 110, 40, 70, 95, 85, 65, 55, 45, 35]])
254+
phy2log1, log2phy1, logcnt1 = rebalance_experts(weight1, 24, 8, 4, 8)
255+
256+
assert phy2log1.shape == (1, 24)
257+
assert logcnt1.shape == (1, 16)
258+
assert torch.sum(logcnt1) == 24
259+
260+
# Test case 2: Different weight distributions
261+
weight2 = torch.tensor([
262+
[200, 150, 100, 50, 25, 12], # Decreasing weights
263+
[12, 25, 50, 100, 150, 200], # Increasing weights
264+
])
265+
phy2log2, log2phy2, logcnt2 = rebalance_experts(weight2, 10, 3, 1, 2)
266+
267+
assert phy2log2.shape == (2, 10)
268+
assert logcnt2.shape == (2, 6)
269+
270+
# Verify high-weight experts have more replicas
271+
for layer in range(2):
272+
max_weight_idx = torch.argmax(weight2[layer])
273+
assert logcnt2[layer, max_weight_idx] >= 2
274+
275+
276+
if __name__ == "__main__":
277+
weight = torch.tensor([
278+
[90, 132, 40, 61, 104, 165, 39, 4, 73, 56, 183, 86],
279+
[20, 107, 104, 64, 19, 197, 187, 157, 172, 86, 16, 27],
280+
])
281+
282+
num_replicas = 16
283+
num_groups = 4
284+
num_nodes = 2
285+
num_gpus = 8
286+
287+
phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas,
288+
num_groups, num_nodes,
289+
num_gpus)
290+
print(phy2log)
291+
292+
test_basic_rebalance()

0 commit comments

Comments
 (0)