Skip to content

Commit e7b7186

Browse files
author
lt
committed
update get_expert_load logic
2 parents 6b36faf + 9d9c93a commit e7b7186

File tree

8 files changed

+123
-114
lines changed

8 files changed

+123
-114
lines changed

vllm_ascend/eplb/adaptor/vllm_adaptor.py

Lines changed: 43 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ def __init__(self, model, **args):
6262
self.log2phy_map_per_layer[self.num_dense_layers + layer_idx] =\
6363
self.model.get_log2phy_map(self.num_dense_layers + layer_idx)
6464

65+
self.all_topk_ids = []
66+
6567
def init_buffer_tensor(self, num_buffer_tensor):
6668
for name in self.expert_weight_names:
6769
complete_name = "model.layers." + str(self.num_dense_layers) + ".mlp.experts." + name
@@ -82,39 +84,43 @@ def init_expert_param_per_layer(self):
8284
for name in self.expert_weight_names]
8385
)
8486

85-
def get_rank_expert_workload(
86-
self,
87-
num_moe_layers: int,
88-
dummy_run = False
89-
) -> torch.Tensor:
90-
91-
all_topk_ids = [self.model.get_topk_ids(i) for i in range(num_moe_layers)]
92-
stacked = torch.stack(all_topk_ids, dim=0)
93-
L, B, K = stacked.shape
94-
N = B * K
95-
device = stacked.device
96-
G = self.global_expert_num
97-
98-
if not hasattr(self, "cum_moe_load") or self.cum_moe_load is None:
99-
self.cum_moe_load = torch.zeros((L, G),
100-
dtype=torch.int64,
101-
device=device)
102-
87+
def collect_topk_ids(self, dummy_run=False):
10388
if dummy_run:
104-
return self.cum_moe_load
105-
106-
ids1d = stacked.view(-1).to(torch.int64)
107-
108-
row_idx = torch.arange(L, device=device).repeat_interleave(N)
109-
110-
combined = row_idx * G + ids1d
111-
112-
counts = torch.bincount(combined, minlength=L * G)
113-
workload = counts.view(L, G)
114-
115-
self.cum_moe_load.add_(workload)
116-
117-
return self.cum_moe_load
89+
return
90+
self.all_topk_ids.append(self.model.get_all_topk_ids(self.num_moe_layers))
91+
92+
def get_rank_expert_workload(self) -> torch.Tensor:
93+
device = self.all_topk_ids[0][0].device
94+
if not hasattr(self, "moe_load"):
95+
self.moe_load = torch.zeros(
96+
(self.num_moe_layers), self.global_expert_num,
97+
dtype=torch.int64,
98+
device=self.all_topk_ids[0][0].device,
99+
)
100+
else:
101+
self.moe_load.zero_()
102+
# pass
103+
flat_list_per_layer = [[] for _ in range(self.num_moe_layers)]
104+
105+
for period_data in self.all_topk_ids:
106+
for l in range(self.num_moe_layers):
107+
t = period_data[l]
108+
flat_list_per_layer[l].append(t.reshape(-1))
109+
110+
index_2d = torch.nn.utils.rnn.pad_sequence(
111+
[torch.cat(flat_list_per_layer[l]) for l in range(self.num_moe_layers)],
112+
batch_first=True, padding_value=-1
113+
).to(device)
114+
115+
mask = index_2d != -1
116+
index_2d = index_2d.masked_select(mask).reshape(self.num_moe_layers, -1)
117+
src_2d = torch.ones_like(index_2d, dtype=torch.int64)
118+
119+
self.moe_load.scatter_add_(dim=1, index=index_2d, src=src_2d)
120+
121+
if self.all_topk_ids:
122+
self.all_topk_ids[:] = self.all_topk_ids[-1:]
123+
return self.moe_load
118124

119125
def get_init_expert_map(self, num_moe_layers):
120126
expert_map = self.model.get_all_expert_map(num_moe_layers)
@@ -136,32 +142,6 @@ def get_init_expert_map(self, num_moe_layers):
136142

137143
return all_expert_maps
138144

139-
def local2global(self,
140-
placement_local: torch.Tensor
141-
) -> torch.Tensor:
142-
143-
L, G, E_local = placement_local.shape
144-
device = placement_local.device
145-
146-
max_id = torch.max(placement_local)
147-
E_global = (max_id + 1).item() if max_id >= 0 else 0
148-
149-
if E_global == 0:
150-
return torch.empty((L, G, 0), dtype=torch.long, device=device)
151-
152-
placement_global = torch.full((L, G, E_global),
153-
fill_value=-1,
154-
dtype=torch.long,
155-
device=device)
156-
157-
valid = placement_local >= 0
158-
l_idx, g_idx, slot_idx = valid.nonzero(as_tuple=True)
159-
gid_idx = placement_local[l_idx, g_idx, slot_idx]
160-
161-
placement_global[l_idx, g_idx, gid_idx] = slot_idx
162-
163-
return placement_global
164-
165145
def get_init_expert_map_from_file(self, num_moe_layers, expert_map_path):
166146

167147
try:
@@ -244,13 +224,13 @@ def determine_expert_map_all(self):
244224

245225
for r in range(self.world_size):
246226
if r < self.world_size - 1:
247-
start = r * local_num_experts
248-
end = (r + 1) * local_num_experts
249-
local_count = local_num_experts
227+
start = r * local_num_experts
228+
end = (r + 1) * local_num_experts
229+
local_count = local_num_experts
250230
else:
251-
start = r * local_num_experts
231+
start = r * local_num_experts
252232
end = self.global_expert_num
253-
local_count = self.global_expert_num - r * local_num_experts
233+
local_count = self.global_expert_num - r * local_num_experts
254234

255235
local_ids = torch.arange(local_count, dtype=torch.int32)
256236
expert_map_all[:, r, start:end] = local_ids.unsqueeze(0).expand(self.num_moe_layers, -1)

vllm_ascend/eplb/core/worker/eplb_worker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def do_update(self):
6868
new_placement = torch.tensor(new_placement)
6969
self.check_expert_placement(old_placement, new_placement)
7070
new_expert_maps = self.local2global(new_placement)
71-
71+
self.update_expert_map(new_expert_maps)
7272
logger.debug(f"[EPLB Process new_map differs, performing D2D")
7373

7474
update_info = self.compose_expert_update_info_bipartite(new_expert_maps, self.old_expert_maps)\

vllm_ascend/eplb/eplb_updator.py

Lines changed: 39 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,17 @@
1414
# limitations under the License.
1515
# This file is a part of the vllm-ascend project.
1616
#
17+
1718
import torch
19+
from typing import Dict, List
1820
import torch.distributed as dist
1921
import vllm.envs as envs
2022
from multiprocessing import Queue, Manager
2123

2224
from vllm.logger import logger
2325
from vllm_ascend.eplb.core.worker.eplb_worker import EplbProcess
2426
from vllm_ascend.eplb.core.loader.device_transfer_loader import D2DExpertWeightLoader
25-
27+
from vllm_ascend.eplb.tool.eplb_utils import ExpertMapUtils
2628

2729
class EplbUpdator:
2830

@@ -33,6 +35,7 @@ def set_adaptor(self, adaptor):
3335
self.adaptor = adaptor
3436
self.eplb_loader = D2DExpertWeightLoader(eplb_adaptor=self.adaptor)
3537
self.num_moe_layers = self.adaptor.num_moe_layers
38+
self.global_expert_num = self.adaptor.global_expert_num
3639

3740
def init_eplb(self, expert_map_path):
3841
self.num_expert_load_gather = 10
@@ -44,7 +47,7 @@ def init_eplb(self, expert_map_path):
4447
if not envs.VLLM_ALLOW_EXPERT_LOAD_COLLECTING:
4548
self.num_expert_load_gather = self.num_iterations
4649
except Exception as e:
47-
self.num_expert_load_gather = self.num_iterations
50+
self.num_expert_load_gather = self.num_iterations
4851

4952
self.weight_update_counter = 0
5053
self.expert_map_initialized = False
@@ -58,7 +61,7 @@ def init_eplb(self, expert_map_path):
5861
self.cur_iterations: torch.int64 = 0
5962

6063
self.wait_worker_iterations: torch.int64 = 0
61-
self.num_wait_worker_iterations: torch.int64 = 10
64+
self.num_wait_worker_iterations: torch.int64 = 20
6265

6366
self.planner_block_queue = Queue()
6467
self.block_update_queue = Queue(maxsize=1)
@@ -70,16 +73,16 @@ def init_eplb(self, expert_map_path):
7073
# 热度负载信息 [num_layers, world_size, num_experts]
7174
"moe_load": None,
7275
# 所有的专家表[num_layers, world_size, num_experts]
73-
"expert_maps": None
76+
"expert_maps": None,
7477
})
7578

7679
self.eplb = EplbProcess(
77-
shared_dict=self.shared_dict,
78-
planner_q=self.planner_block_queue,
79-
block_update_q=self.block_update_queue,
80-
redundant_enable=self.redundant_enable,
81-
policy_type=6,
82-
enable_d2d=True
80+
shared_dict = self.shared_dict,
81+
planner_q = self.planner_block_queue,
82+
block_update_q = self.block_update_queue,
83+
redundant_enable = self.redundant_enable,
84+
policy_type = 6,
85+
enable_d2d = True
8386
)
8487

8588
self.eplb_process = self.eplb._launch_process()
@@ -88,15 +91,14 @@ def init_eplb(self, expert_map_path):
8891

8992
def get_update_iteration(self):
9093
self.cur_iterations = self.cur_iterations + 1
91-
load_gather_iteration = self.cur_iterations % self.num_expert_load_gather == 0 if not self.gate_eplb else self.cur_iterations == self.num_iterations
92-
upate_iteration = self.cur_iterations % self.num_iterations == 0 if not self.gate_eplb else self.cur_iterations == self.num_iterations
94+
load_gather_iteration = self.cur_iterations % self.num_expert_load_gather == 0 if not self.gate_eplb else self.cur_iterations == self.num_iterations
95+
upate_iteration = self.cur_iterations % self.num_iterations == 0 if not self.gate_eplb else self.cur_iterations == self.num_iterations
9396
return load_gather_iteration, upate_iteration
9497

9598
def get_init_expert_map(self):
9699
try:
97100
if not self.expert_map_initialized:
98-
self.shared_dict["expert_maps"] = self.adaptor.get_init_expert_map_from_file(self.num_moe_layers,
99-
self.expert_map_path)
101+
self.shared_dict["expert_maps"] = self.adaptor.get_init_expert_map_from_file(self.num_moe_layers, self.expert_map_path)
100102
self.expert_map_initialized = True
101103
except Exception as e:
102104
logger.warning(f"[ModelRunner] Failed to wake EPLB process: {e}", exc_info=True)
@@ -114,32 +116,31 @@ def forward_before(self):
114116
self.weight_loading = True
115117

116118
if self.update_in_flight and self.weight_loading and self.weight_update_counter < self.num_moe_layers:
117-
(expert_send_info, expert_recv_info, updated_expert_map, log2phy_map, layer_id) = self.update_info_all.pop(
118-
0)
119+
(expert_send_info, expert_recv_info, updated_expert_map, log2phy_map, layer_id) = self.update_info_all.pop(0)
119120
rank_id = torch.distributed.get_rank()
120121
self.eplb_loader.set_log2phy_map(log2phy_map)
121122
expert_send_info_this_rank = expert_send_info[rank_id] if rank_id in expert_send_info else []
122123
expert_recv_info_this_rank = expert_recv_info[rank_id] if rank_id in expert_recv_info else []
123-
# logger.info(f"check update info, layer = {layer_id}, send = {expert_send_info_this_rank}, recv = {expert_recv_info_this_rank}")
124+
#logger.info(f"check update info, layer = {layer_id}, send = {expert_send_info_this_rank}, recv = {expert_recv_info_this_rank}")
124125
self.eplb_loader.generate_expert_d2d_transfer_task(expert_send_info_this_rank,
125-
expert_recv_info_this_rank, updated_expert_map,
126-
layer_id + 3)
126+
expert_recv_info_this_rank, updated_expert_map, layer_id + 3)
127127
self.weight_update_counter += 1
128128
if self.weight_update_counter == self.num_moe_layers:
129129
self.weight_update_counter = 0
130130
self.update_in_flight = False
131131
self.update_info_all = []
132-
133132
# set asynchronous stream for d2d expert weight update
134133
self.reqs = []
135134
self.eplb_loader.asyn_expert_weight_transfer(self.reqs)
136135

137-
def forward_end(self, dummy_run=False):
138-
self.adaptor.get_rank_expert_workload(self.num_moe_layers, dummy_run)
136+
137+
def forward_end(self,dummy_run=False):
138+
self.adaptor.collect_topk_ids(dummy_run)
139139
if not self.update_in_flight:
140140
load_gather_iteration, update_iteration = self.get_update_iteration()
141141
if load_gather_iteration:
142-
moe_load = self.compute_and_set_moe_load(dummy_run)
142+
moe_load = self.compute_and_set_moe_load()
143+
self.get_expert_load()
143144
if update_iteration:
144145
self.wakeup_eplb_worker()
145146
self.update_in_flight = True
@@ -151,8 +152,9 @@ def forward_end(self, dummy_run=False):
151152

152153
self.eplb_loader.update_expert_map_and_weight(self.reqs, self.redundant_enable)
153154

154-
def compute_and_set_moe_load(self, dummy_run=False):
155-
local_load = self.adaptor.get_rank_expert_workload(self.num_moe_layers, dummy_run)
155+
def compute_and_set_moe_load(self,dummy_run=False):
156+
local_load = self.adaptor.get_rank_expert_workload()
157+
156158
self._gather_buffer = None
157159
if dist.is_initialized():
158160
self.world_size = dist.get_world_size()
@@ -177,7 +179,7 @@ def compute_and_set_moe_load(self, dummy_run=False):
177179
def warm_up_eplb(self):
178180

179181
self.get_init_expert_map()
180-
182+
self.adaptor.collect_topk_ids(dummy_run=False)
181183
self.compute_and_set_moe_load()
182184

183185
src_tensor = torch.empty((1,), device=self.device)
@@ -197,7 +199,7 @@ def warm_up_eplb(self):
197199
continue
198200
comm_op_list.append(
199201
dist.P2POp(dist.irecv, src_tensor, src_rank)
200-
)
202+
)
201203
if comm_op_list:
202204
reqs = dist.batch_isend_irecv(comm_op_list)
203205

@@ -210,7 +212,7 @@ def unpack_update_batch(self, packed_update_info):
210212
"""
211213
send_all, recv_all, stacked_maps, stacked_log2phy, layer_id_tensor = packed_update_info
212214

213-
maps = stacked_maps.unbind(0)
215+
maps = stacked_maps.unbind(0)
214216
layer_ids = layer_id_tensor.tolist()
215217

216218
if self.redundant_enable:
@@ -222,7 +224,7 @@ def unpack_update_batch(self, packed_update_info):
222224
_send = send_all
223225
_recv = recv_all
224226
_maps = maps
225-
_l2p = log2phy_list
227+
_l2p = log2phy_list
226228
_lids = layer_ids
227229

228230
recovered = [
@@ -232,21 +234,14 @@ def unpack_update_batch(self, packed_update_info):
232234
]
233235
return recovered
234236

235-
def get_expert_load(self) -> str:
236-
237-
# todo wjh 给到返回值
238-
# return self.shared_dict['moe_load']
239-
# mock json_str
240-
experts_load = ('{\"expert_load\":['
241-
'{\"ip\":\"141.xxx.xxx.181\",'
242-
'\"node_0\":'
243-
'{\"card_0\":'
244-
'[{\"layer_4\":{\"expert_0\":3,\"expert_2\":1}},{\"layer_5\":{\"expert_0\":3,\"expert_2\":1}}],'
245-
'\"card_1\":[{\"layer_4\":{\"expert_1\":3,\"expert_3\":1},\"layer_5\":{\"expert_0\":3,\"'
246-
'expert_2\":1}}]}},{\"ip\":\"141.xxx.xxx.177\",\"node_0\":{\"card_0\":[{\"layer_4\":'
247-
'{\"expert_0\":3,\"expert_2\":1}},{\"layer_5\":{\"expert_0\":3,\"expert_2\":1}}],'
248-
'\"card_1\":[{\"layer_4\":{\"expert_1\":3,\"expert_3\":1}}]}}]}')
249-
return experts_load
237+
def get_expert_load(self) -> torch.Tensor:
238+
expert_maps = self.shared_dict["expert_maps"]
239+
moe_load = self.shared_dict["moe_load"] # Tensor [L, W, global_experts_num]
240+
if not moe_load:
241+
return None
242+
num_local_experts = expert_maps.max() + 1
243+
load_info, _ = ExpertMapUtils.global2local_load(moe_load, expert_maps, num_local_experts)
244+
return load_info
250245

251246
def update_expert_load_statistical_period(self, num_expert_load_gather: int, num_iterations: int):
252247
logger.info(f" start update {self.num_expert_load_gather=}, {self.num_iterations}...")

vllm_ascend/eplb/tool/eplb_utils.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,3 +83,33 @@ def global2local(cls,
8383
pt_local[g_idx, slot_idx] = k_idx
8484

8585
return pt_local
86+
87+
@classmethod
88+
def global2local_load(self,
89+
workload: torch.Tensor,
90+
placement: torch.Tensor,
91+
E_local: int
92+
) -> tuple[torch.Tensor, torch.Tensor]:
93+
94+
L, G, _ = placement.shape
95+
device = placement.device
96+
97+
wt_local = torch.full((L, G, E_local),
98+
fill_value=-1,
99+
dtype=workload.dtype,
100+
device=device)
101+
pt_local = torch.full((L, G, E_local),
102+
fill_value=-1,
103+
dtype=torch.long,
104+
device=device)
105+
106+
valid = placement >= 0
107+
l_idx, g_idx, k_idx = valid.nonzero(as_tuple=True)
108+
109+
slot_idx = placement[l_idx, g_idx, k_idx]
110+
values = workload[l_idx, g_idx, k_idx]
111+
112+
wt_local[l_idx, g_idx, slot_idx] = values
113+
pt_local[l_idx, g_idx, slot_idx] = k_idx
114+
115+
return wt_local, pt_local

0 commit comments

Comments
 (0)