Skip to content

Commit 2b62a47

Browse files
improve d2d expert weight update impl in eplb_updator.py
1 parent cfbe8b1 commit 2b62a47

File tree

2 files changed

+40
-46
lines changed

2 files changed

+40
-46
lines changed

vllm_ascend/eplb/eplb_updator.py

Lines changed: 39 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -40,28 +40,27 @@ def set_adaptor(self, adaptor):
4040

4141
def init_eplb(self, expert_map_path):
4242
self.num_expert_load_gather = 10
43-
self.redundant_enable = (expert_map_path != None)
44-
self.num_iterations: torch.int64 = 130
43+
self.periodic_load_gather = True
44+
self.redundant_enable = (expert_map_path is not None)
45+
self.num_iterations_eplb_update: torch.int64 = 130
4546
self.expert_map_path = expert_map_path
4647

4748
try:
4849
if not envs.VLLM_ALLOW_EXPERT_LOAD_COLLECTING:
49-
self.num_expert_load_gather = self.num_iterations
50+
self.num_expert_load_gather = self.num_iterations_eplb_update
51+
self.periodic_load_gather = False
5052
except Exception as e:
51-
self.num_expert_load_gather = self.num_iterations
53+
self.num_expert_load_gather = self.num_iterations_eplb_update
54+
self.periodic_load_gather = False
5255

53-
self.weight_update_counter = 0
5456
self.expert_map_initialized = False
55-
self.update_in_flight = False
56-
5757
self.gate_eplb = True
5858

5959
self.reqs = []
6060
self.update_info_all = []
6161

6262
self.cur_iterations: torch.int64 = 0
6363

64-
self.wait_worker_iterations: torch.int64 = 0
6564
self.num_wait_worker_iterations: torch.int64 = 20
6665

6766
self.planner_block_queue = Queue()
@@ -90,11 +89,22 @@ def init_eplb(self, expert_map_path):
9089

9190
logger.info(f"[ModelRunner] Launched EPLB process (pid={self.eplb_process.pid})")
9291

93-
def get_update_iteration(self):
94-
self.cur_iterations = self.cur_iterations + 1
95-
load_gather_iteration = self.cur_iterations % self.num_expert_load_gather == 0 if not self.gate_eplb else self.cur_iterations == self.num_iterations
96-
upate_iteration = self.cur_iterations % self.num_iterations == 0 if not self.gate_eplb else self.cur_iterations == self.num_iterations
97-
return load_gather_iteration, upate_iteration
92+
def update_iteration(self):
93+
self.cur_iterations += 1
94+
if self.cur_iterations == (self.num_iterations_eplb_update +\
95+
self.num_wait_worker_iterations + self.num_moe_layers):
96+
if not self.gate_eplb:
97+
self.cur_iterations = 0
98+
99+
def get_update_info_flag(self):
100+
return self.cur_iterations == (self.num_iterations_eplb_update + self.num_wait_worker_iterations)
101+
102+
def wakeup_eplb_worker_flag(self):
103+
return self.cur_iterations == (self.num_iterations_eplb_update - 1)
104+
105+
def update_expert_weight_flag(self):
106+
weight_update_counter = self.cur_iterations - (self.num_iterations_eplb_update + self.num_wait_worker_iterations)
107+
return (weight_update_counter >= 0 and weight_update_counter < self.num_moe_layers)
98108

99109
def get_init_expert_map(self):
100110
try:
@@ -108,14 +118,11 @@ def wakeup_eplb_worker(self):
108118
self.planner_block_queue.put(1)
109119

110120
def forward_before(self):
111-
112121
# Batch after eplb process being triggered, get update info provided by eplb process
113-
if self.update_in_flight and self.weight_update_counter == 0 and self.wait_worker_iterations == self.num_wait_worker_iterations:
114-
self.wait_worker_iterations = 0
122+
if self.get_update_info_flag():
115123
self.update_info_all = self.block_update_queue.get()
116-
self.weight_loading = True
117124

118-
if self.update_in_flight and self.weight_loading and self.weight_update_counter < self.num_moe_layers:
125+
if self.update_expert_weight_flag():
119126
(expert_send_info, expert_recv_info, updated_expert_map, log2phy_map, layer_id) = self.update_info_all.pop(0)
120127
rank_id = torch.distributed.get_rank()
121128
if self.redundant_enable:
@@ -125,34 +132,22 @@ def forward_before(self):
125132
#logger.info(f"check update info, layer = {layer_id}, send = {expert_send_info_this_rank}, recv = {expert_recv_info_this_rank}")
126133
self.eplb_loader.generate_expert_d2d_transfer_task(expert_send_info, expert_recv_info,
127134
updated_expert_map_this_rank, layer_id + self.adaptor.num_dense_layers)
128-
self.weight_update_counter += 1
129-
if self.weight_update_counter == self.num_moe_layers:
130-
self.weight_update_counter = 0
131-
self.update_in_flight = False
132-
self.update_info_all = []
133-
# set asynchronous stream for d2d expert weight update
134-
self.reqs = []
135-
self.eplb_loader.asyn_expert_weight_transfer(self.reqs)
136135

136+
# set asynchronous stream for d2d expert weight update
137+
self.reqs = []
138+
self.eplb_loader.asyn_expert_weight_transfer(self.reqs)
137139

138-
def forward_end(self,dummy_run=False):
139-
if not self.update_in_flight:
140-
load_gather_iteration, update_iteration = self.get_update_iteration()
141-
if load_gather_iteration:
142-
moe_load = self.compute_and_set_moe_load()
143-
self.get_expert_load()
144-
if update_iteration:
145-
self.wakeup_eplb_worker()
146-
self.update_in_flight = True
147-
self.wait_worker_iterations = 0
148-
self.weight_loading = False
140+
def forward_end(self):
141+
if self.wakeup_eplb_worker_flag():
142+
moe_load = self.compute_and_set_moe_load(is_clear=True)
143+
self.wakeup_eplb_worker()
149144

150-
if self.update_in_flight:
151-
self.wait_worker_iterations = self.wait_worker_iterations + 1
145+
if self.update_expert_weight_flag():
146+
self.eplb_loader.update_expert_map_and_weight(self.reqs, self.redundant_enable)
152147

153-
self.eplb_loader.update_expert_map_and_weight(self.reqs, self.redundant_enable)
148+
self.update_iteration()
154149

155-
def compute_and_set_moe_load(self,dummy_run=False):
150+
def compute_and_set_moe_load(self, is_clear=False):
156151
local_load = self.adaptor.get_rank_expert_workload()
157152

158153
self._gather_buffer = None
@@ -241,10 +236,10 @@ def get_expert_load(self) -> tuple:
241236
return moe_load, expert_maps, num_local_experts
242237

243238
def update_expert_load_statistical_period(self, num_expert_load_gather: int, num_iterations: int):
244-
logger.info(f" start update {self.num_expert_load_gather=}, {self.num_iterations}...")
239+
logger.info(f" start update {self.num_expert_load_gather=}, {self.num_iterations_eplb_update}...")
245240
self.num_expert_load_gather = num_expert_load_gather
246-
self.num_iterations = num_iterations
247-
logger.info(f" update {self.num_expert_load_gather=}, {self.num_iterations} success...")
241+
self.num_iterations_eplb_update = num_iterations
242+
logger.info(f" update {self.num_expert_load_gather=}, {self.num_iterations_eplb_update} success...")
248243

249244
def shutdown(self):
250245
"""

vllm_ascend/worker/model_runner_v1.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1547,8 +1547,7 @@ def _dummy_run(
15471547
if is_profile_run and self.dynamic_eplb:
15481548
self.model.clear_all_moe_loads()
15491549
if not is_compile and not is_profile_run and self.dynamic_eplb:
1550-
dummy_run = True
1551-
self.eplb_updator.forward_end(dummy_run)
1550+
self.eplb_updator.forward_end()
15521551
return hidden_states
15531552

15541553
def profile_run(self) -> None:

0 commit comments

Comments
 (0)