Skip to content

Commit 2918e6d

Browse files
qmkakaxiyangcheng (AJ)
authored andcommitted
fix bugsw
1 parent 8b6b6ca commit 2918e6d

File tree

1 file changed

+37
-13
lines changed

1 file changed

+37
-13
lines changed

vllm_ascend/eplb/adaptor/vllm_adaptor.py

Lines changed: 37 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,12 @@
1717

1818
import torch
1919
import torch.distributed as dist
20+
import numpy as np
2021

2122
from vllm_ascend.eplb.adaptor.abstract_adaptor import EplbAdaptor
2223
from vllm.logger import logger
24+
import random
25+
2326

2427
class VllmEplbAdaptor(EplbAdaptor):
2528

@@ -29,6 +32,7 @@ def __init__(self, model, **args):
2932
self.param_dict = dict(self.model.named_parameters())
3033
self.num_dense_layers = self.model.config.first_k_dense_replace
3134
self.num_moe_layers = self.model.config.num_hidden_layers - self.num_dense_layers
35+
self.global_expert_num = 256
3236

3337
# TODO: init self.expert_weight_names depending on different model types, only deepseek v3 w8a8 is supported here
3438
self.expert_weight_names = ["w13_weight", "w2_weight", "w13_weight_scale", "w13_weight_offset",
@@ -66,18 +70,16 @@ def get_init_expert_map(self, num_moe_layers):
6670
expert_map = self.model.get_all_expert_map(num_moe_layers)
6771
if dist.is_initialized():
6872
world_size = dist.get_world_size()
73+
rank = dist.get_rank()
6974

70-
rank = dist.get_rank()
71-
72-
tensor_list = [
73-
torch.zeros_like(expert_map) for _ in range(world_size)
74-
]
75+
gathered = torch.empty((world_size, *expert_map.shape), # [W, L, E]
76+
dtype=expert_map.dtype,
77+
device=expert_map.device)
7578

76-
dist.all_gather(tensor_list, expert_map)
77-
gathered = torch.stack(tensor_list, dim=0)
78-
all_maps = gathered.permute(1, 0, 2).contiguous()
79+
dist.all_gather_into_tensor(gathered, expert_map)
80+
all_maps = gathered.permute(1, 0, 2)
81+
all_expert_maps = all_maps.cpu()
7982

80-
all_expert_maps = all_maps.to(torch.device("cpu"))
8183
return all_expert_maps
8284

8385
def do_update_expert_map(self, layer_id, updated_expert_map):
@@ -105,6 +107,8 @@ def generate_index_dicts(self,tensor_2d):
105107
return dict_list
106108

107109
def generate_log2phy_map(self, expert_map):
110+
num_local_experts = expert_map.max() + 1
111+
expert_map = self.global2local(expert_map,num_local_experts)
108112
ranks_num, global_expert_num = expert_map.shape
109113
concatenated = torch.flatten(expert_map)
110114
rank_expert_to_global = self.generate_index_dicts(
@@ -116,7 +120,7 @@ def generate_log2phy_map(self, expert_map):
116120
result_dict[key] = []
117121
result_dict[key].append(idx)
118122

119-
log2phy_map = torch.full((ranks_num, global_expert_num),
123+
log2phy_map = torch.full((ranks_num, self.global_expert_num),
120124
-1,
121125
dtype=torch.int32)
122126
for rank in range(ranks_num):
@@ -130,7 +134,27 @@ def generate_log2phy_map(self, expert_map):
130134
return log2phy_map
131135

132136
def do_update_log2phy_map(self, layer_id, updated_log2phy_map):
133-
137+
rank_id = torch.distributed.get_rank()
134138
if self.log2phy_map_per_layer[layer_id] is not None:
135-
rank_id = torch.distributed.get_rank()
136-
self.log2phy_map_per_layer[layer_id].copy_(updated_log2phy_map[rank_id])
139+
self.log2phy_map_per_layer[layer_id].copy_(updated_log2phy_map[rank_id])
140+
141+
def global2local(self,
142+
placement: torch.Tensor,
143+
E_local: int
144+
) -> tuple[torch.Tensor, torch.Tensor]:
145+
146+
G, _ = placement.shape
147+
device = placement.device
148+
149+
pt_local = torch.full(( G, E_local),
150+
fill_value=-1,
151+
dtype=torch.long,
152+
device=device)
153+
154+
valid = placement >= 0
155+
g_idx, k_idx = valid.nonzero(as_tuple=True)
156+
slot_idx = placement[g_idx, k_idx]
157+
158+
pt_local[g_idx, slot_idx] = k_idx
159+
160+
return pt_local

0 commit comments

Comments
 (0)