Skip to content

Commit 9c886d0

Browse files
authored
[EPLB] support deepseek eplb strategy (#1196)
### What this PR does / why we need it? This PR implements the DeepSeek Expert Parallel Load Balancing (EPLB) strategy to optimize expert distribution in vllm-ascend. The implementation: - Adapts the expert-map format to work with vllm-ascend's architecture - Provides DeepSeek-provided mechanism to balance expert workload across devices ### Does this PR introduce _any_ user-facing change? This PR adds a new script that allows users to: - Generate expert map configurations based on workload analysis - Optimize expert distribution for their specific use case ### How was this patch tested? To use this feature: 1. First collect expert heat information during model execution 2. Run the provided script to generate the expert map configuration 3. Apply the generated configuration to your vllm-ascend deployment User example: ```bash # expert_load_view.pt: dumped expert heat info file python3 examples/eplb/eplb_strategy.py --exp_name 'deepseek_demo' \ --input_path expert_load_view.pt --output_path examples/eplb/results/demo \ --num_nodes 4 ``` --------- Signed-off-by: ZhengWG <zwg0606@gmail.com>
1 parent 4e29c5a commit 9c886d0

File tree

2 files changed

+388
-0
lines changed

2 files changed

+388
-0
lines changed

examples/eplb/eplb_deepseek.py

Lines changed: 205 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
"""
3+
Expert parallelism load balancer (EPLB) for vLLM.
4+
The rearrangement algorithm is adapted from
5+
[DeepSeek EPLB](https://github.com/deepseek-ai/eplb).
6+
"""
7+
from typing import Tuple
8+
9+
import torch
10+
11+
12+
def balanced_packing(weight: torch.Tensor,
13+
num_packs: int) -> Tuple[torch.Tensor, torch.Tensor]:
14+
"""
15+
Pack n weighted objects to m packs, such that each bin contains exactly n/m objects and the weights of all packs
16+
are as balanced as possible.
17+
18+
Parameters:
19+
weight: [X, n], the weight of each item
20+
num_packs: number of packs
21+
22+
Returns:
23+
pack_index: [X, n], the pack index of each item
24+
rank_in_pack: [X, n], the rank of the item in the pack
25+
"""
26+
num_layers, num_groups = weight.shape
27+
assert num_groups % num_packs == 0
28+
groups_per_pack = num_groups // num_packs
29+
30+
if groups_per_pack == 1:
31+
pack_index = torch.arange(weight.size(-1),
32+
dtype=torch.int64,
33+
device=weight.device).expand(weight.shape)
34+
rank_in_pack = torch.zeros_like(weight, dtype=torch.int64)
35+
return pack_index, rank_in_pack
36+
37+
indices = weight.float().sort(-1, descending=True).indices.cpu()
38+
pack_index = torch.full_like(weight,
39+
fill_value=-1,
40+
dtype=torch.int64,
41+
device='cpu')
42+
rank_in_pack = torch.full_like(pack_index, fill_value=-1)
43+
for i in range(num_layers):
44+
pack_weights = [0] * num_packs
45+
pack_items = [0] * num_packs
46+
for group in indices[i]:
47+
pack = min(
48+
(i
49+
for i in range(num_packs) if pack_items[i] < groups_per_pack),
50+
key=pack_weights.__getitem__)
51+
assert pack_items[pack] < groups_per_pack
52+
pack_index[i, group] = pack
53+
rank_in_pack[i, group] = pack_items[pack]
54+
pack_weights[pack] += weight[i, group]
55+
pack_items[pack] += 1
56+
return pack_index, rank_in_pack
57+
58+
59+
def replicate_experts(
60+
weight: torch.Tensor,
61+
num_phy: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
62+
"""
63+
Replicate `num_log` experts to `num_phy` replicas, such that the maximum load of all replicas is minimized.
64+
65+
Parameters:
66+
weight: [X, num_log]
67+
num_phy: total number of experts after replication
68+
69+
Returns:
70+
phy2log: [X, num_phy], logical expert id of each physical expert
71+
rank: [X, num_phy], the replica rank
72+
logcnt: [X, num_log], number of replicas for each logical expert
73+
"""
74+
n, num_log = weight.shape
75+
num_redundant = num_phy - num_log
76+
assert num_redundant >= 0
77+
device = weight.device
78+
phy2log = torch.arange(num_phy, dtype=torch.int64,
79+
device=device).repeat(n, 1)
80+
rank = torch.zeros(n, num_phy, dtype=torch.int64, device=device)
81+
logcnt = torch.ones(n, num_log, dtype=torch.int64, device=device)
82+
arangen = torch.arange(n, dtype=torch.int64, device=device)
83+
for i in range(num_log, num_phy):
84+
redundant_indices = (weight / logcnt).max(dim=-1).indices
85+
phy2log[:, i] = redundant_indices
86+
rank[:, i] = logcnt[arangen, redundant_indices]
87+
logcnt[arangen, redundant_indices] += 1
88+
return phy2log, rank, logcnt
89+
90+
91+
def rebalance_experts_hierarchical(weight: torch.Tensor,
92+
num_physical_experts: int, num_groups: int,
93+
num_nodes: int, num_gpus: int):
94+
"""
95+
Parameters:
96+
weight: [num_moe_layers, num_logical_experts]
97+
num_physical_experts: number of physical experts after replication
98+
num_groups: number of expert groups
99+
num_nodes: number of server nodes, where the intra-node network (e.g, NVLink) is faster
100+
num_gpus: number of GPUs, must be a multiple of `num_nodes`
101+
102+
Returns:
103+
physical_to_logical_map: [num_moe_layers, num_physical_experts]
104+
logical_to_physical_map: [num_moe_layers, num_logical_experts, X]
105+
logical_count: [num_moe_layers, num_logical_experts]
106+
"""
107+
num_layers, num_logical_experts = weight.shape
108+
assert num_logical_experts % num_groups == 0
109+
group_size = num_logical_experts // num_groups
110+
assert num_groups % num_nodes == 0
111+
groups_per_node = num_groups // num_nodes
112+
assert num_gpus % num_nodes == 0
113+
assert num_physical_experts % num_gpus == 0
114+
phy_experts_per_gpu = num_physical_experts // num_gpus
115+
116+
def inverse(perm: torch.Tensor) -> torch.Tensor:
117+
inv = torch.empty_like(perm)
118+
inv.scatter_(
119+
1, perm,
120+
torch.arange(perm.size(1), dtype=torch.int64,
121+
device=perm.device).expand(perm.shape))
122+
return inv
123+
124+
# Step 1: pack groups to nodes
125+
tokens_per_group = weight.unflatten(-1, (num_groups, group_size)).sum(-1)
126+
group_pack_index, group_rank_in_pack = balanced_packing(
127+
tokens_per_group, num_nodes)
128+
log2mlog = (((group_pack_index * groups_per_node + group_rank_in_pack) *
129+
group_size).unsqueeze(-1) +
130+
torch.arange(group_size,
131+
dtype=torch.int64,
132+
device=group_pack_index.device)).flatten(-2)
133+
mlog2log = inverse(log2mlog)
134+
135+
# Step 2: construct redundant experts within nodes
136+
# [num_layers * num_nodes, num_logical_experts // num_nodes]
137+
tokens_per_mlog = weight.gather(-1, mlog2log).view(
138+
-1, num_logical_experts // num_nodes)
139+
phy2mlog, phyrank, mlogcnt = replicate_experts(
140+
tokens_per_mlog, num_physical_experts // num_nodes)
141+
142+
# Step 3: pack physical_experts to GPUs
143+
# [num_layers * num_nodes, num_physical_experts // num_nodes]
144+
tokens_per_phy = (tokens_per_mlog / mlogcnt).gather(-1, phy2mlog)
145+
pack_index, rank_in_pack = balanced_packing(tokens_per_phy,
146+
num_gpus // num_nodes)
147+
phy2pphy = pack_index * phy_experts_per_gpu + rank_in_pack
148+
pphy2phy = inverse(phy2pphy)
149+
150+
pphy2mlog = phy2mlog.gather(
151+
-1, pphy2phy) # [num_layers * num_nodes, num_log_per_nodes]
152+
pphy2mlog = (pphy2mlog.view(num_layers, num_nodes, -1) + torch.arange(
153+
0,
154+
num_logical_experts,
155+
num_logical_experts // num_nodes,
156+
device=group_pack_index.device).view(1, -1, 1)).flatten(-2)
157+
pphy2log = mlog2log.gather(-1, pphy2mlog)
158+
pphyrank = phyrank.gather(-1, pphy2phy).view(num_layers, -1)
159+
logcnt = mlogcnt.view(num_layers, -1).gather(-1, log2mlog)
160+
return pphy2log, pphyrank, logcnt
161+
162+
163+
def rebalance_experts(
164+
weight: torch.Tensor, num_replicas: int, num_groups: int,
165+
num_nodes: int,
166+
num_gpus: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
167+
"""
168+
Entry point for expert-parallelism load balancer.
169+
170+
Parameters:
171+
weight: [layers, num_logical_experts], the load statistics for all logical experts
172+
num_replicas: number of physical experts, must be a multiple of `num_gpus`
173+
num_groups: number of expert groups
174+
num_nodes: number of server nodes, where the intra-node network (e.g, NVLink) is faster
175+
num_gpus: number of GPUs, must be a multiple of `num_nodes`
176+
177+
Returns:
178+
physical_to_logical_map: [layers, num_replicas], the expert index of each replica
179+
logical_to_physical_map: [layers, num_logical_experts, X], the replica indices for each expert
180+
expert_count: [layers, num_logical_experts], number of physical replicas for each logical expert
181+
"""
182+
num_layers, num_logical_experts = weight.shape
183+
weight = weight.float().cpu()
184+
if num_groups % num_nodes == 0:
185+
# use hierarchical load-balance policy
186+
phy2log, phyrank, logcnt = rebalance_experts_hierarchical(
187+
weight, num_replicas, num_groups, num_nodes, num_gpus)
188+
else:
189+
# use global load-balance policy
190+
phy2log, phyrank, logcnt = rebalance_experts_hierarchical(
191+
weight, num_replicas, 1, 1, num_gpus)
192+
maxlogcnt = logcnt.max().item()
193+
log2phy: torch.Tensor = torch.full(
194+
(num_layers, num_logical_experts, maxlogcnt),
195+
-1,
196+
dtype=torch.int64,
197+
device=logcnt.device)
198+
log2phy.view(num_layers, -1).scatter_(
199+
-1, phy2log * maxlogcnt + phyrank,
200+
torch.arange(num_replicas, dtype=torch.int64,
201+
device=log2phy.device).expand(num_layers, -1))
202+
return phy2log, log2phy, logcnt
203+
204+
205+
__all__ = ['rebalance_experts']

examples/eplb/eplb_strategy.py

Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
# coding=utf-8
2+
# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
3+
import json
4+
import logging
5+
import os
6+
7+
import matplotlib.pyplot as plt # type: ignore
8+
import numpy as np
9+
import torch
10+
11+
logger = logging.getLogger("msit_logger")
12+
13+
14+
def save_matrix_to_json(output_path, file_name, deployment):
15+
num_layers = deployment.shape[0]
16+
num_cards = deployment.shape[1]
17+
18+
data = {"moe_layer_count": num_layers}
19+
layer_list = []
20+
for i in range(num_layers):
21+
layer = {"layer_id": i, "device_count": num_cards}
22+
device_list = []
23+
for j in range(num_cards):
24+
device = {
25+
"device_id": j,
26+
"device_expert": deployment[i, j].tolist()
27+
}
28+
device_list.append(device)
29+
layer["device_list"] = device_list
30+
layer_list.append(layer)
31+
data["layer_list"] = layer_list
32+
33+
file_name = f"{output_path}{file_name}.json"
34+
35+
# Save as JSON file
36+
try:
37+
with open(file_name, 'w') as f:
38+
json.dump(data, f, indent=4)
39+
except Exception as e:
40+
print(f"write {file_name} failed: {e}")
41+
42+
43+
def calculate_average(lst):
44+
"""calculate the average of a list"""
45+
if not lst:
46+
raise ValueError("list is empty")
47+
48+
total = 0.0
49+
count = 0
50+
51+
for element in lst:
52+
# Check if element is numeric
53+
if isinstance(element, (int, float, np.int64, np.float64)):
54+
total += float(element)
55+
count += 1
56+
else:
57+
# Non-numeric elements will be ignored with a warning
58+
print(f"warning: element {element} is not a number, ignored")
59+
60+
if count == 0:
61+
raise ValueError("list does not contain any number")
62+
63+
return total / count
64+
65+
66+
def layer_imblance_polt(y_list, label_names, device_num, output_path,
67+
file_name):
68+
69+
plt.rcParams['font.sans-serif'] = ['Arial']
70+
plt.rcParams['axes.unicode_minus'] = False
71+
x = [i for i in range(58)]
72+
for index, y in enumerate(y_list):
73+
plt.plot(x,
74+
y,
75+
label=rf'{label_names[index]},avg={calculate_average(y)}')
76+
77+
plt.legend()
78+
plt.title(rf'Load Distribution (num_gpus={device_num})')
79+
plt.xlabel('layer')
80+
plt.ylabel('Device Load')
81+
82+
# Show grid lines
83+
plt.grid(True)
84+
85+
plt.savefig(os.path.join(output_path, file_name), dpi=300)
86+
87+
# Clear current plot
88+
plt.close()
89+
90+
91+
def deepseek_deploy(workload, num_redundancy_expert, num_groups, num_nodes,
92+
num_gpus, num_original_expert):
93+
from eplb_deepseek import rebalance_experts
94+
num_replicas = num_original_expert + num_redundancy_expert
95+
hy2log, log2phy, logcnt = rebalance_experts(workload, num_replicas,
96+
num_groups, num_nodes,
97+
num_gpus)
98+
99+
# Convert to global_deployment
100+
workload = workload.cpu().numpy()
101+
global_deployment = []
102+
layer_num = log2phy.shape[0]
103+
num_physical_experts_local = (num_original_expert +
104+
num_redundancy_expert) // num_gpus
105+
for layer_idx in range(layer_num):
106+
layer_deployment = []
107+
for gpu_idx in range(num_gpus):
108+
local_deployment = hy2log[layer_idx][gpu_idx *
109+
num_physical_experts_local:
110+
(gpu_idx + 1) *
111+
num_physical_experts_local]
112+
local_deployment = local_deployment.flatten()
113+
layer_deployment.append(local_deployment.tolist())
114+
global_deployment.append(layer_deployment)
115+
116+
# Remap expert distribution according to log2phy
117+
original_weights = []
118+
max_weights = []
119+
average_weights = []
120+
y_list = []
121+
for layer_idx in range(layer_num):
122+
new_value = workload[layer_idx].reshape(num_gpus, -1)
123+
row_sum = np.sum(new_value, axis=1)
124+
original_weights.append(row_sum.max())
125+
average_weights.append((np.sum(workload[layer_idx]) / num_gpus))
126+
127+
opt_workload = np.zeros((num_original_expert + num_redundancy_expert),
128+
dtype=np.float64)
129+
for expert_idx in range(num_original_expert):
130+
physical_expert_idxs = log2phy[layer_idx][expert_idx]
131+
physical_expert_idxs = physical_expert_idxs.flatten()
132+
physical_expert_idxs = physical_expert_idxs[
133+
physical_expert_idxs != -1]
134+
for physical_expert_idx in physical_expert_idxs:
135+
opt_workload[physical_expert_idx] += workload[layer_idx][
136+
expert_idx] / len(physical_expert_idxs)
137+
opt_workload = opt_workload.reshape(num_gpus, -1)
138+
row_sum = np.sum(opt_workload, axis=1)
139+
max_weights.append(row_sum.max())
140+
141+
y_list = [original_weights, max_weights, average_weights]
142+
return global_deployment, y_list
143+
144+
145+
if __name__ == '__main__':
146+
import argparse
147+
parser = argparse.ArgumentParser()
148+
parser.add_argument("--exp_name", type=str, default="gsm8k_temp0.0")
149+
parser.add_argument("--num_original_expert", type=int, default=256)
150+
parser.add_argument("--input_path", type=str, default="")
151+
parser.add_argument("--output_path", type=str, default="")
152+
parser.add_argument("--num_redundancy_expert", type=int, default=0)
153+
parser.add_argument("--num_devices", type=int, default=32)
154+
parser.add_argument("--num_groups", type=int, default=8)
155+
parser.add_argument("--num_nodes", type=int, default=4)
156+
args = parser.parse_args()
157+
exp_name = args.exp_name
158+
input_path = args.input_path
159+
output_path = args.output_path
160+
os.makedirs(output_path, exist_ok=True)
161+
num_redundancy_expert = args.num_redundancy_expert
162+
num_devices = args.num_devices
163+
num_original_expert = args.num_original_expert
164+
num_groups = args.num_groups
165+
num_nodes = args.num_nodes
166+
167+
# NOTE: assume input workload format: [layer_num, num_experts]
168+
workload = torch.load(input_path, map_location=torch.device('cpu'))
169+
global_deployment, y_list = deepseek_deploy(workload,
170+
num_redundancy_expert,
171+
num_groups, num_nodes,
172+
num_devices,
173+
num_original_expert)
174+
175+
file_name = f"{exp_name}_{num_devices}_{num_redundancy_expert}"
176+
save_matrix_to_json(output_path, file_name, np.array(global_deployment))
177+
label_names = [
178+
'default deployment max load', 'balanced load max load',
179+
'balanced load avg load'
180+
]
181+
new_file_name = f"{exp_name}_{num_devices}_{num_redundancy_expert}.png"
182+
layer_imblance_polt(y_list, label_names, num_devices, output_path,
183+
new_file_name)

0 commit comments

Comments
 (0)