Skip to content

Commit 6b853f1

Browse files
Yuxiao-Xusongshanhu07Xu Yuxiaowangxiyuan
authored
Add static EPLB (#1116)
### What this PR does / why we need it? Add EPLB expert map import capabilities ### Does this PR introduce _any_ user-facing change? When importing the EPLB expert map you need import expert map file by vllm args additional_config ### How was this patch tested? 1.You need to collect expert hotness and generate an expert placement file based on the hotness and the EPLB algorithm, or you can directly use an existing expert placement table. 2.When launching vLLM, enable EC2 and pass the configuration via the command-line argument: --additional-config '{"expert_map_path": "/xxx/xxx/xx.json"} Co-authored-by: songshanhu07 <1763685535@qq.com> --------- Signed-off-by: songshanhu07 <1763685535@qq.com> Signed-off-by: Yuxiao-Xu <664988918@qq.com> Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com> Co-authored-by: songshanhu07 <1763685535@qq.com> Co-authored-by: Xu Yuxiao <xuyuxiao2@huawei.com> Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
1 parent cb341c7 commit 6b853f1

File tree

6 files changed

+179
-31
lines changed

6 files changed

+179
-31
lines changed

docs/source/user_guide/additional_config.md

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,13 @@ LLM(model="Qwen/Qwen3-8B", additional_config={"config_key":"config_value"})
2424

2525
The following table lists the additional configuration options available in vLLM Ascend:
2626

27-
| Name | Type | Default | Description |
28-
| ---- | ---- | ------- | ----------- |
29-
| `torchair_graph_config` | dict | `{}` | The config options for torchair graph mode |
30-
| `ascend_scheduler_config` | dict | `{}` | The config options for ascend scheduler |
31-
| `expert_tensor_parallel_size` | str | `0` | Expert tensor parallel size the model to use. |
32-
| `refresh` | bool | `false` | Whether to refresh global ascend config content. This value is usually used by rlhf case. |
27+
| Name | Type | Default | Description |
28+
|-------------------------------| ---- |------|-----------------------------------------------------------------------------------------------|
29+
| `torchair_graph_config` | dict | `{}` | The config options for torchair graph mode |
30+
| `ascend_scheduler_config` | dict | `{}` | The config options for ascend scheduler |
31+
| `expert_tensor_parallel_size` | str | `0` | Expert tensor parallel size the model to use. |
32+
| `refresh` | bool | `false` | Whether to refresh global ascend config content. This value is usually used by rlhf case. |
33+
| `expert_map_path` | str | None | When using expert load balancing for the MOE model, an expert map path needs to be passed in. |
3334

3435
The details of each config option are as follows:
3536

vllm_ascend/ascend_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ def __init__(self, vllm_config):
3838

3939
self.expert_tensor_parallel_size = int(
4040
additional_config.get("expert_tensor_parallel_size", 0))
41+
self.expert_map_path = additional_config.get("expert_map_path", None)
4142

4243

4344
class TorchairGraphConfig:
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
import json
2+
import random
3+
from typing import Dict, List
4+
5+
import torch
6+
7+
8+
class ExpertLoadBalancer(object):
9+
10+
def __init__(self, expert_map_path, global_expert_num):
11+
self.expert_map_path = expert_map_path
12+
self.global_expert_num = global_expert_num
13+
self.expert_map_tensor, self.layers_num, self.ranks_num = (
14+
self._expert_file_to_tensor())
15+
16+
def _expert_file_to_tensor(self):
17+
with open(self.expert_map_path, "r") as f:
18+
data = json.load(f)
19+
layers_num = data["moe_layer_count"]
20+
gpus_num = data["layer_list"][0]["device_count"]
21+
22+
tensor_data = []
23+
for layer in data["layer_list"]:
24+
device_data = []
25+
for device in layer["device_list"]:
26+
device_data.append(device["device_expert"])
27+
tensor_data.append(device_data)
28+
expert_map_tensor = torch.tensor(tensor_data, dtype=torch.int32)
29+
return expert_map_tensor, layers_num, gpus_num
30+
31+
def generate_index_dicts(self, tensor_2d):
32+
dict_list = []
33+
current_idx = 0
34+
35+
for row in tensor_2d:
36+
value_to_index = {}
37+
for i in range(row.size(0)):
38+
value = row[i].item()
39+
value_to_index[value] = current_idx + i
40+
dict_list.append(value_to_index)
41+
current_idx += row.size(0)
42+
43+
return dict_list
44+
45+
def generate_expert_placement_map(self):
46+
expert_placement_map = torch.full(
47+
(self.layers_num, self.ranks_num, self.global_expert_num),
48+
-1,
49+
dtype=torch.int32,
50+
)
51+
for layer_id in range(self.layers_num):
52+
for gpu_id in range(self.ranks_num):
53+
e_ids = self.expert_map_tensor[layer_id, gpu_id]
54+
expert_placement_map[layer_id, gpu_id,
55+
e_ids] = torch.arange(len(e_ids),
56+
dtype=torch.int32)
57+
return expert_placement_map
58+
59+
def generate_log2phy_expert_map(self, layer_id):
60+
concatenated = torch.flatten(self.expert_map_tensor[layer_id])
61+
rank_expert_to_global = self.generate_index_dicts(
62+
self.expert_map_tensor[layer_id])
63+
result_dict: Dict[int, List[int]] = {}
64+
for idx, value in enumerate(concatenated):
65+
key = value.item()
66+
if key not in result_dict:
67+
result_dict[key] = []
68+
result_dict[key].append(idx)
69+
70+
log2phy_map = torch.full((self.ranks_num, self.global_expert_num),
71+
-1,
72+
dtype=torch.int32)
73+
for rank in range(self.ranks_num):
74+
for key in result_dict:
75+
indices_in_concat = result_dict[key]
76+
if key in rank_expert_to_global[rank]:
77+
log2phy_map[rank][key] = rank_expert_to_global[rank][key]
78+
else:
79+
chosen_index = random.choice(indices_in_concat)
80+
log2phy_map[rank][key] = chosen_index
81+
return log2phy_map
82+
83+
def get_rank_placement_map(self, layer_id, rank_id):
84+
expert_placement_map = self.generate_expert_placement_map()
85+
layer_expert_map = expert_placement_map[layer_id]
86+
rank_expert_map = layer_expert_map[rank_id].to(
87+
torch.npu.current_device())
88+
rank_local_expert_num = torch.sum(torch.ne(rank_expert_map, -1)).item()
89+
return rank_local_expert_num, rank_expert_map
90+
91+
def get_rank_log2phy_map(self, layer_id, rank_id):
92+
layer_log2phy_map = self.generate_log2phy_expert_map(layer_id)
93+
return layer_log2phy_map[rank_id]
94+
95+
def get_global_redundant_expert_num(self):
96+
global_redundant_expert_num = (
97+
len(self.expert_map_tensor[0][0]) * self.ranks_num -
98+
self.global_expert_num)
99+
return global_redundant_expert_num

vllm_ascend/ops/fused_moe.py

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
# This file is a part of the vllm-ascend project.
1616
# Adapted from vllm/tests/kernels/test_moe.py
1717

18+
import os
1819
from typing import Callable, List, Optional
1920

2021
import torch
@@ -34,6 +35,7 @@
3435
import vllm_ascend.envs as envs_ascend
3536
from vllm_ascend.ascend_config import get_ascend_config
3637
from vllm_ascend.distributed.parallel_state import get_ep_group, get_etp_group
38+
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
3739

3840
VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2
3941
USING_LCCL_COM: bool = envs_ascend.USING_LCCL_COM
@@ -956,6 +958,10 @@ def apply(
956958

957959
class AscendFusedMoE(FusedMoE):
958960

961+
# The moe_counter parameter is required during the initialization of EPLB
962+
# to identify the current layer index within the MOE model.
963+
moe_counter = -1
964+
959965
def __init__(
960966
self,
961967
num_experts: int, # Global number of experts
@@ -983,6 +989,9 @@ def __init__(
983989
# fixme and make __init__() of AscendFusedMoE more clear
984990
super(FusedMoE, self).__init__()
985991

992+
AscendFusedMoE.moe_counter += 1
993+
self.moe_instance_id = AscendFusedMoE.moe_counter
994+
986995
if params_dtype is None:
987996
params_dtype = torch.get_default_dtype()
988997

@@ -1016,16 +1025,33 @@ def __init__(
10161025
self.e_score_correction_bias = e_score_correction_bias
10171026
self.expert_map = None
10181027
self.activation = activation
1028+
self.log2phy = None
1029+
self.global_redundant_expert_num = 0
10191030

1020-
# Create a tensor of size num_experts filled with -1
1021-
self.local_num_experts, self.expert_map = determine_expert_map(
1022-
self.ep_size,
1023-
get_ep_group().rank_in_group, self.global_num_experts)
1031+
ascend_config = get_ascend_config()
1032+
expert_map_path = ascend_config.expert_map_path
1033+
if expert_map_path and os.path.exists(expert_map_path):
1034+
# moe expert load balance
1035+
expert_load_balancer = ExpertLoadBalancer(expert_map_path,
1036+
self.global_num_experts)
1037+
self.local_num_experts, self.expert_map = \
1038+
expert_load_balancer.get_rank_placement_map(
1039+
self.moe_instance_id,
1040+
get_ep_group().rank_in_group)
1041+
self.log2phy = expert_load_balancer.get_rank_log2phy_map(
1042+
self.moe_instance_id,
1043+
get_ep_group().rank_in_group)
1044+
self.global_redundant_expert_num = \
1045+
expert_load_balancer.get_global_redundant_expert_num()
1046+
else:
1047+
# Create a tensor of size num_experts filled with -1
1048+
self.local_num_experts, self.expert_map = determine_expert_map(
1049+
self.ep_size,
1050+
get_ep_group().rank_in_group, self.global_num_experts)
10241051

10251052
self.moe_parallel_config.tp_rank = get_etp_group().rank_in_group
10261053
self.moe_parallel_config.ep_rank = get_ep_group().rank_in_group
10271054

1028-
ascend_config = get_ascend_config()
10291055
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
10301056
# NOTE: multistream only effective when `VLLM_ENABLE_MC2` is on
10311057
self.enable_multistream_shared_expert = \
@@ -1122,6 +1148,8 @@ def forward(self,
11221148
e_score_correction_bias=self.e_score_correction_bias,
11231149
is_prefill=is_prefill,
11241150
enable_force_load_balance=enable_force_load_balance,
1151+
log2phy=self.log2phy,
1152+
global_redundant_expert_num=self.global_redundant_expert_num,
11251153
**kwargs)
11261154

11271155
if self.enable_multistream_shared_expert and not is_prefill:

vllm_ascend/quantization/quant_config.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -323,13 +323,16 @@ def apply(
323323
e_score_correction_bias: Optional[torch.Tensor] = None,
324324
is_prefill: bool = True,
325325
enable_force_load_balance: bool = False,
326+
log2phy: torch.Tensor = None,
327+
global_redundant_expert_num=0,
326328
**kwargs,
327329
) -> torch.Tensor:
328330
return self.quant_method.apply(
329331
layer, x, router_logits, top_k, renormalize, use_grouped_topk,
330332
global_num_experts, expert_map, topk_group, num_expert_group,
331333
custom_routing_function, scoring_func, e_score_correction_bias,
332-
is_prefill, enable_force_load_balance, **kwargs)
334+
is_prefill, enable_force_load_balance, log2phy,
335+
global_redundant_expert_num, **kwargs)
333336

334337
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
335338
if hasattr(self.quant_method, "process_weights_after_loading"):

vllm_ascend/quantization/w8a8_dynamic.py

Lines changed: 35 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -147,9 +147,13 @@ def fused_experts_with_mc2(hidden_states: torch.Tensor,
147147
top_k: int,
148148
expert_map: torch.Tensor = None,
149149
moe_all_to_all_group_name: str = "",
150+
log2phy: torch.Tensor = None,
151+
global_redundant_expert_num: int = 0,
150152
**kwargs) -> torch.Tensor:
153+
154+
topk_ids = log2phy[topk_ids]
151155
global_bs = 0
152-
moe_expert_num = len(expert_map)
156+
moe_expert_num = len(expert_map) + global_redundant_expert_num
153157
# hidden_states = hidden_states.bfloat16()
154158
kwargs_mc2 = {
155159
"x": hidden_states,
@@ -271,7 +275,10 @@ def fused_experts_with_all2all(
271275
top_k: int,
272276
expert_map: torch.Tensor = None,
273277
ep_group: GroupCoordinator = None,
278+
log2phy: torch.Tensor = None,
279+
global_redundant_expert_num: int = 0,
274280
):
281+
topk_ids = log2phy[topk_ids]
275282
original_shape = hidden_states.shape
276283
if len(original_shape) == 3:
277284
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
@@ -281,7 +288,7 @@ def fused_experts_with_all2all(
281288
device = hidden_states.device
282289

283290
if expert_map is not None:
284-
global_num_experts = len(expert_map)
291+
global_num_experts = len(expert_map) + global_redundant_expert_num
285292
local_num_experts = global_num_experts // ep_group.world_size
286293
row_idx_len = num_tokens * top_k
287294
row_idx = (torch.arange(0,
@@ -341,13 +348,14 @@ def fused_experts_with_all2all(
341348
group_list_type = 0
342349

343350
# `hidden_states` will be disposed in the `apply_mlp` function
344-
hidden_states = apply_mlp(hidden_states,
345-
w1,
346-
w1_scale,
347-
w2,
348-
w2_scale,
349-
expert_tokens,
350-
group_list_type=group_list_type)
351+
hidden_states = apply_mlp(
352+
hidden_states,
353+
w1,
354+
w1_scale, #17
355+
w2,
356+
w2_scale,
357+
expert_tokens, #16
358+
group_list_type=group_list_type)
351359

352360
if expert_map is not None:
353361
resorted_idx = torch.argsort(sorted_idx)
@@ -639,6 +647,8 @@ def apply(
639647
e_score_correction_bias: Optional[torch.Tensor] = None,
640648
is_prefill: bool = True,
641649
enable_force_load_balance: bool = True,
650+
log2phy: torch.Tensor = None,
651+
global_redundant_expert_num: int = 0,
642652
**kwargs,
643653
) -> torch.Tensor:
644654
assert router_logits.shape[
@@ -693,6 +703,8 @@ def apply(
693703
top_k=top_k,
694704
expert_map=expert_map,
695705
moe_all_to_all_group_name=self.moe_all_to_all_group_name,
706+
log2phy=log2phy,
707+
global_redundant_expert_num=global_redundant_expert_num,
696708
**kwargs)
697709
elif self.torchair_graph_enabled or self.ep_group.world_size == 1:
698710
return fused_experts(hidden_states=x,
@@ -709,16 +721,20 @@ def apply(
709721
# according to tp_size before they are feed into fused_moe module.
710722
# Therefore, all2all is needed no matter how dp/tp is set so as to
711723
# dispatch/combine tokens.
712-
return fused_experts_with_all2all(hidden_states=x,
713-
w1=layer.w13_weight,
714-
w1_scale=layer.w13_weight_scale,
715-
w2=layer.w2_weight,
716-
w2_scale=layer.w2_weight_scale,
717-
topk_weights=topk_weights,
718-
topk_ids=topk_ids,
719-
top_k=top_k,
720-
expert_map=expert_map,
721-
ep_group=self.ep_group)
724+
return fused_experts_with_all2all(
725+
hidden_states=x,
726+
w1=layer.w13_weight,
727+
w1_scale=layer.w13_weight_scale,
728+
w2=layer.w2_weight,
729+
w2_scale=layer.w2_weight_scale,
730+
topk_weights=topk_weights,
731+
topk_ids=topk_ids,
732+
top_k=top_k,
733+
expert_map=expert_map,
734+
ep_group=self.ep_group,
735+
log2phy=log2phy,
736+
global_redundant_expert_num=global_redundant_expert_num,
737+
)
722738

723739
def process_weights_after_loading(self, layer):
724740
if self.transpose_weight:

0 commit comments

Comments
 (0)