Skip to content

Commit 96fe998

Browse files
authored
Merge pull request #102 from raindaywhu/br_main_into_eplb_wjh
fix bug in moe load & add expert load to josn
2 parents 0897ccc + 4980f2c commit 96fe998

File tree

3 files changed

+41
-49
lines changed

3 files changed

+41
-49
lines changed

vllm_ascend/eplb/adaptor/vllm_adaptor.py

Lines changed: 13 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,16 @@ def collect_topk_ids(self, dummy_run=False):
9090
self.all_topk_ids.append(self.model.get_all_topk_ids(self.num_moe_layers))
9191

9292
def get_rank_expert_workload(self) -> torch.Tensor:
93-
9493
device = self.all_topk_ids[0][0].device
94+
if not hasattr(self, "moe_load"):
95+
self.moe_load = torch.zeros(
96+
(self.num_moe_layers), self.global_expert_num,
97+
dtype=torch.int64,
98+
device=self.all_topk_ids[0][0].device,
99+
)
100+
else:
101+
self.moe_load.zero_()
102+
# pass
95103
flat_list_per_layer = [[] for _ in range(self.num_moe_layers)]
96104

97105
for period_data in self.all_topk_ids:
@@ -108,12 +116,11 @@ def get_rank_expert_workload(self) -> torch.Tensor:
108116
index_2d = index_2d.masked_select(mask).reshape(self.num_moe_layers, -1)
109117
src_2d = torch.ones_like(index_2d, dtype=torch.int64)
110118

111-
moe_load = torch.zeros((self.num_moe_layers), self.global_expert_num,
112-
dtype=torch.int64, device=device)
113-
moe_load.scatter_add_(dim=1, index=index_2d, src=src_2d)
119+
self.moe_load.scatter_add_(dim=1, index=index_2d, src=src_2d)
114120

115-
self.all_topk_ids = []
116-
return moe_load
121+
if self.all_topk_ids:
122+
self.all_topk_ids[:] = self.all_topk_ids[-1:]
123+
return self.moe_load
117124

118125
def get_init_expert_map(self, num_moe_layers):
119126
expert_map = self.model.get_all_expert_map(num_moe_layers)
@@ -135,32 +142,6 @@ def get_init_expert_map(self, num_moe_layers):
135142

136143
return all_expert_maps
137144

138-
def local2global(self,
139-
placement_local: torch.Tensor
140-
) -> torch.Tensor:
141-
142-
L, G, E_local = placement_local.shape
143-
device = placement_local.device
144-
145-
max_id = torch.max(placement_local)
146-
E_global = (max_id + 1).item() if max_id >= 0 else 0
147-
148-
if E_global == 0:
149-
return torch.empty((L, G, 0), dtype=torch.long, device=device)
150-
151-
placement_global = torch.full((L, G, E_global),
152-
fill_value=-1,
153-
dtype=torch.long,
154-
device=device)
155-
156-
valid = placement_local >= 0
157-
l_idx, g_idx, slot_idx = valid.nonzero(as_tuple=True)
158-
gid_idx = placement_local[l_idx, g_idx, slot_idx]
159-
160-
placement_global[l_idx, g_idx, gid_idx] = slot_idx
161-
162-
return placement_global
163-
164145
def get_init_expert_map_from_file(self, num_moe_layers, expert_map_path):
165146

166147
try:

vllm_ascend/eplb/core/worker/eplb_worker.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ def do_update(self):
6262

6363
#根据负载信息,获取更新后的专家表
6464
load_info, old_placement = self.global2local(load_info, self.old_expert_maps, self.num_local_experts)
65+
self.shared_dict["load_info"] = load_info
6566
changed, priority, new_placement = self.calculate_rebalance_experts(load_info, old_placement)
6667

6768
if not torch.is_tensor(new_placement):

vllm_ascend/eplb/eplb_updator.py

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,17 @@
1414
# limitations under the License.
1515
# This file is a part of the vllm-ascend project.
1616
#
17+
1718
import torch
19+
from typing import Dict, List
1820
import torch.distributed as dist
1921
import vllm.envs as envs
2022
from multiprocessing import Queue, Manager
2123

2224
from vllm.logger import logger
2325
from vllm_ascend.eplb.core.worker.eplb_worker import EplbProcess
2426
from vllm_ascend.eplb.core.loader.device_transfer_loader import D2DExpertWeightLoader
27+
from vllm_ascend.eplb.tool.eplb_utils import ExpertMapUtils
2528

2629
class EplbUpdator:
2730

@@ -32,6 +35,7 @@ def set_adaptor(self, adaptor):
3235
self.adaptor = adaptor
3336
self.eplb_loader = D2DExpertWeightLoader(eplb_adaptor=self.adaptor)
3437
self.num_moe_layers = self.adaptor.num_moe_layers
38+
self.global_expert_num = self.adaptor.global_expert_num
3539

3640
def init_eplb(self, expert_map_path):
3741
self.num_expert_load_gather = 10
@@ -69,7 +73,9 @@ def init_eplb(self, expert_map_path):
6973
# 热度负载信息 [num_layers, world_size, num_experts]
7074
"moe_load": None,
7175
# 所有的专家表[num_layers, world_size, num_experts]
72-
"expert_maps": None
76+
"expert_maps": None,
77+
# 热度负载信息 [num_layers, world_size, local_num_experts]
78+
"load_info": None,
7379
})
7480

7581
self.eplb = EplbProcess(
@@ -125,11 +131,11 @@ def forward_before(self):
125131
self.weight_update_counter = 0
126132
self.update_in_flight = False
127133
self.update_info_all = []
128-
129134
# set asynchronous stream for d2d expert weight update
130135
self.reqs = []
131136
self.eplb_loader.asyn_expert_weight_transfer(self.reqs)
132137

138+
133139
def forward_end(self,dummy_run=False):
134140
self.adaptor.collect_topk_ids(dummy_run)
135141
if not self.update_in_flight:
@@ -149,6 +155,7 @@ def forward_end(self,dummy_run=False):
149155

150156
def compute_and_set_moe_load(self,dummy_run=False):
151157
local_load = self.adaptor.get_rank_expert_workload()
158+
152159
self._gather_buffer = None
153160
if dist.is_initialized():
154161
self.world_size = dist.get_world_size()
@@ -229,28 +236,31 @@ def unpack_update_batch(self, packed_update_info):
229236
return recovered
230237

231238
def get_expert_load(self) -> str:
232-
233-
# todo wjh 给到返回值
234-
# return self.shared_dict['moe_load']
235-
# mock json_str
236-
experts_load = ('{\"expert_load\":['
237-
'{\"ip\":\"141.xxx.xxx.181\",'
238-
'\"node_0\":'
239-
'{\"card_0\":'
240-
'[{\"layer_4\":{\"expert_0\":3,\"expert_2\":1}},{\"layer_5\":{\"expert_0\":3,\"expert_2\":1}}],'
241-
'\"card_1\":[{\"layer_4\":{\"expert_1\":3,\"expert_3\":1},\"layer_5\":{\"expert_0\":3,\"'
242-
'expert_2\":1}}]}},{\"ip\":\"141.xxx.xxx.177\",\"node_0\":{\"card_0\":[{\"layer_4\":'
243-
'{\"expert_0\":3,\"expert_2\":1}},{\"layer_5\":{\"expert_0\":3,\"expert_2\":1}}],'
244-
'\"card_1\":[{\"layer_4\":{\"expert_1\":3,\"expert_3\":1}}]}}]}')
245-
return experts_load
239+
240+
load_info = self.shared_dict["load_info"] # Tensor [L, W, local_experts_num]
241+
L, W, _ = load_info.shape
242+
243+
expert_load: Dict[str, List[dict]] = {}
244+
for c in range(W):
245+
layers: List[dict] = []
246+
for l in range(L):
247+
counts_1d = load_info[l, c]
248+
249+
layer_val = {
250+
f"expert_{e}": int(v)
251+
for e, v in enumerate(counts_1d.tolist())
252+
}
253+
layers.append({f"layer_{l}": layer_val})
254+
expert_load[f"card_{c}"] = layers
255+
256+
return {"expert_load": expert_load}
246257

247258
def update_expert_load_statistical_period(self, num_expert_load_gather: int, num_iterations: int):
248259
logger.info(f" start update {self.num_expert_load_gather=}, {self.num_iterations}...")
249260
self.num_expert_load_gather = num_expert_load_gather
250261
self.num_iterations = num_iterations
251262
logger.info(f" update {self.num_expert_load_gather=}, {self.num_iterations} success...")
252263

253-
254264
def shutdown(self):
255265
"""
256266
Clean up the EPLB process.

0 commit comments

Comments
 (0)