Skip to content

Commit 1a8d238

Browse files
authored
Merge pull request #85 from raindaywhu/lt_dev
add mock experts_load data
2 parents 976eb9f + d537fb2 commit 1a8d238

File tree

3 files changed

+48
-15
lines changed

3 files changed

+48
-15
lines changed

vllm_ascend/eplb/eplb_updator.py

Lines changed: 42 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#
1717
import torch
1818
import torch.distributed as dist
19+
import vllm.envs as envs
1920
from multiprocessing import Queue, Manager
2021

2122
from vllm.logger import logger
@@ -33,11 +34,17 @@ def set_adaptor(self, adaptor):
3334
self.num_moe_layers = self.adaptor.num_moe_layers
3435

3536
def init_eplb(self, expert_map_path):
36-
37+
self.num_expert_load_gather = 10
3738
self.redundant_enable = (expert_map_path != None)
3839
self.num_iterations: torch.int64 = 130
3940
self.expert_map_path = expert_map_path
4041

42+
try:
43+
if not envs.VLLM_ALLOW_EXPERT_LOAD_COLLECTING:
44+
self.num_expert_load_gather = self.num_iterations
45+
except Exception as e:
46+
self.num_expert_load_gather = self.num_iterations
47+
4148
self.weight_update_counter = 0
4249
self.expert_map_initialized = False
4350
self.update_in_flight = False
@@ -70,7 +77,7 @@ def init_eplb(self, expert_map_path):
7077
planner_q = self.planner_block_queue,
7178
block_update_q = self.block_update_queue,
7279
redundant_enable = self.redundant_enable,
73-
policy_type = 2,
80+
policy_type = 6,
7481
enable_d2d = True
7582
)
7683

@@ -80,10 +87,9 @@ def init_eplb(self, expert_map_path):
8087

8188
def get_update_iteration(self):
8289
self.cur_iterations = self.cur_iterations + 1
83-
if not self.gate_eplb:
84-
return self.cur_iterations % self.num_iterations == 0
85-
else:
86-
return self.cur_iterations == self.num_iterations
90+
load_gather_iteration = self.cur_iterations % self.num_expert_load_gather == 0 if not self.gate_eplb else self.cur_iterations == self.num_iterations
91+
upate_iteration = self.cur_iterations % self.num_iterations == 0 if not self.gate_eplb else self.cur_iterations == self.num_iterations
92+
return load_gather_iteration, upate_iteration
8793

8894
def get_init_expert_map(self):
8995
try:
@@ -125,12 +131,15 @@ def forward_before(self):
125131
self.eplb_loader.asyn_expert_weight_transfer(self.reqs)
126132

127133
def forward_end(self):
128-
if not self.update_in_flight and self.get_update_iteration():
129-
moe_load = self.compute_and_set_moe_load()
130-
self.wakeup_eplb_worker()
131-
self.update_in_flight = True
132-
self.wait_worker_iterations = 0
133-
self.weight_loading = False
134+
if not self.update_in_flight:
135+
load_gather_iteration, update_iteration = self.get_update_iteration()
136+
if load_gather_iteration:
137+
self.moe_load = self.compute_and_set_moe_load()
138+
if update_iteration:
139+
self.wakeup_eplb_worker()
140+
self.update_in_flight = True
141+
self.wait_worker_iterations = 0
142+
self.weight_loading = False
134143

135144
if self.update_in_flight:
136145
self.wait_worker_iterations = self.wait_worker_iterations + 1
@@ -220,9 +229,27 @@ def unpack_update_batch(self, packed_update_info):
220229
return recovered
221230

222231
def get_expert_load(self) -> str:
223-
"""todo 确认moe_load的值是什么类型"""
224-
# return '{"a":"b"}' # mock
225-
return self.shared_dict['moe_load']
232+
233+
# todo wjh 给到返回值
234+
# return self.shared_dict['moe_load']
235+
# mock json_str
236+
experts_load = ('{\"expert_load\":['
237+
'{\"ip\":\"141.xxx.xxx.181\",'
238+
'\"node_0\":'
239+
'{\"card_0\":'
240+
'[{\"layer_4\":{\"expert_0\":3,\"expert_2\":1}},{\"layer_5\":{\"expert_0\":3,\"expert_2\":1}}],'
241+
'\"card_1\":[{\"layer_4\":{\"expert_1\":3,\"expert_3\":1},\"layer_5\":{\"expert_0\":3,\"'
242+
'expert_2\":1}}]}},{\"ip\":\"141.xxx.xxx.177\",\"node_0\":{\"card_0\":[{\"layer_4\":'
243+
'{\"expert_0\":3,\"expert_2\":1}},{\"layer_5\":{\"expert_0\":3,\"expert_2\":1}}],'
244+
'\"card_1\":[{\"layer_4\":{\"expert_1\":3,\"expert_3\":1}}]}}]}')
245+
return experts_load
246+
247+
def update_expert_load_statistical_period(self, num_expert_load_gather: int, num_iterations: int):
248+
logger.info(f" start update {self.num_expert_load_gather=}, {self.num_iterations}...")
249+
self.num_expert_load_gather = num_expert_load_gather
250+
self.num_iterations = num_iterations
251+
logger.info(f" update {self.num_expert_load_gather=}, {self.num_iterations} success...")
252+
226253

227254
def shutdown(self):
228255
"""

vllm_ascend/worker/model_runner_v1.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1589,6 +1589,9 @@ def profile_run(self) -> None:
15891589
def do_get_expert_load(self) -> str:
15901590
return self.eplb_updator.get_expert_load()
15911591

1592+
def do_update_expert_load_statistical_period(self, num_expert_load_gather: int, num_iterations: int):
1593+
return self.eplb_updator.update_expert_load_statistical_period(num_expert_load_gather, num_iterations)
1594+
15921595
def eplb_warmup(self):
15931596
#EPLB
15941597
if self.dynamic_eplb and not self.is_eplb_warmuped:

vllm_ascend/worker/worker_v1.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,9 @@ def get_expert_load(self) -> str:
214214
moe_load = self.model_runner.do_get_expert_load()
215215
return moe_load
216216

217+
def update_expert_load_statistical_period(self, num_expert_load_gather: int, num_iterations: int):
218+
self.model_runner.do_update_expert_load_statistical_period(num_expert_load_gather, num_iterations)
219+
217220
def get_model(self) -> nn.Module:
218221
return self.model_runner.get_model()
219222

0 commit comments

Comments
 (0)