Skip to content

Commit 6b36faf

Browse files
author
lt
committed
update format
1 parent 53e8949 commit 6b36faf

File tree

1 file changed

+24
-20
lines changed

1 file changed

+24
-20
lines changed

vllm_ascend/eplb/eplb_updator.py

Lines changed: 24 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#
12
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
23
#
34
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -22,6 +23,7 @@
2223
from vllm_ascend.eplb.core.worker.eplb_worker import EplbProcess
2324
from vllm_ascend.eplb.core.loader.device_transfer_loader import D2DExpertWeightLoader
2425

26+
2527
class EplbUpdator:
2628

2729
def __init__(self, expert_map_path):
@@ -42,7 +44,7 @@ def init_eplb(self, expert_map_path):
4244
if not envs.VLLM_ALLOW_EXPERT_LOAD_COLLECTING:
4345
self.num_expert_load_gather = self.num_iterations
4446
except Exception as e:
45-
self.num_expert_load_gather = self.num_iterations
47+
self.num_expert_load_gather = self.num_iterations
4648

4749
self.weight_update_counter = 0
4850
self.expert_map_initialized = False
@@ -72,19 +74,18 @@ def init_eplb(self, expert_map_path):
7274
})
7375

7476
self.eplb = EplbProcess(
75-
shared_dict = self.shared_dict,
76-
planner_q = self.planner_block_queue,
77-
block_update_q = self.block_update_queue,
78-
redundant_enable = self.redundant_enable,
79-
policy_type = 6,
80-
enable_d2d = True
77+
shared_dict=self.shared_dict,
78+
planner_q=self.planner_block_queue,
79+
block_update_q=self.block_update_queue,
80+
redundant_enable=self.redundant_enable,
81+
policy_type=6,
82+
enable_d2d=True
8183
)
8284

8385
self.eplb_process = self.eplb._launch_process()
8486

8587
logger.info(f"[ModelRunner] Launched EPLB process (pid={self.eplb_process.pid})")
8688

87-
8889
def get_update_iteration(self):
8990
self.cur_iterations = self.cur_iterations + 1
9091
load_gather_iteration = self.cur_iterations % self.num_expert_load_gather == 0 if not self.gate_eplb else self.cur_iterations == self.num_iterations
@@ -94,7 +95,8 @@ def get_update_iteration(self):
9495
def get_init_expert_map(self):
9596
try:
9697
if not self.expert_map_initialized:
97-
self.shared_dict["expert_maps"] = self.adaptor.get_init_expert_map_from_file(self.num_moe_layers, self.expert_map_path)
98+
self.shared_dict["expert_maps"] = self.adaptor.get_init_expert_map_from_file(self.num_moe_layers,
99+
self.expert_map_path)
98100
self.expert_map_initialized = True
99101
except Exception as e:
100102
logger.warning(f"[ModelRunner] Failed to wake EPLB process: {e}", exc_info=True)
@@ -103,6 +105,7 @@ def wakeup_eplb_worker(self):
103105
self.planner_block_queue.put(1)
104106

105107
def forward_before(self):
108+
106109
# Batch after eplb process being triggered, get update info provided by eplb process
107110
if self.update_in_flight and self.weight_update_counter == 0 and self.wait_worker_iterations == self.num_wait_worker_iterations:
108111
self.wait_worker_iterations = 0
@@ -111,14 +114,16 @@ def forward_before(self):
111114
self.weight_loading = True
112115

113116
if self.update_in_flight and self.weight_loading and self.weight_update_counter < self.num_moe_layers:
114-
(expert_send_info, expert_recv_info, updated_expert_map, log2phy_map, layer_id) = self.update_info_all.pop(0)
117+
(expert_send_info, expert_recv_info, updated_expert_map, log2phy_map, layer_id) = self.update_info_all.pop(
118+
0)
115119
rank_id = torch.distributed.get_rank()
116120
self.eplb_loader.set_log2phy_map(log2phy_map)
117121
expert_send_info_this_rank = expert_send_info[rank_id] if rank_id in expert_send_info else []
118122
expert_recv_info_this_rank = expert_recv_info[rank_id] if rank_id in expert_recv_info else []
119-
#logger.info(f"check update info, layer = {layer_id}, send = {expert_send_info_this_rank}, recv = {expert_recv_info_this_rank}")
123+
# logger.info(f"check update info, layer = {layer_id}, send = {expert_send_info_this_rank}, recv = {expert_recv_info_this_rank}")
120124
self.eplb_loader.generate_expert_d2d_transfer_task(expert_send_info_this_rank,
121-
expert_recv_info_this_rank, updated_expert_map, layer_id + 3)
125+
expert_recv_info_this_rank, updated_expert_map,
126+
layer_id + 3)
122127
self.weight_update_counter += 1
123128
if self.weight_update_counter == self.num_moe_layers:
124129
self.weight_update_counter = 0
@@ -129,8 +134,8 @@ def forward_before(self):
129134
self.reqs = []
130135
self.eplb_loader.asyn_expert_weight_transfer(self.reqs)
131136

132-
def forward_end(self,dummy_run=False):
133-
self.adaptor.get_rank_expert_workload(self.num_moe_layers,dummy_run)
137+
def forward_end(self, dummy_run=False):
138+
self.adaptor.get_rank_expert_workload(self.num_moe_layers, dummy_run)
134139
if not self.update_in_flight:
135140
load_gather_iteration, update_iteration = self.get_update_iteration()
136141
if load_gather_iteration:
@@ -146,8 +151,8 @@ def forward_end(self,dummy_run=False):
146151

147152
self.eplb_loader.update_expert_map_and_weight(self.reqs, self.redundant_enable)
148153

149-
def compute_and_set_moe_load(self,dummy_run=False):
150-
local_load = self.adaptor.get_rank_expert_workload(self.num_moe_layers,dummy_run)
154+
def compute_and_set_moe_load(self, dummy_run=False):
155+
local_load = self.adaptor.get_rank_expert_workload(self.num_moe_layers, dummy_run)
151156
self._gather_buffer = None
152157
if dist.is_initialized():
153158
self.world_size = dist.get_world_size()
@@ -192,7 +197,7 @@ def warm_up_eplb(self):
192197
continue
193198
comm_op_list.append(
194199
dist.P2POp(dist.irecv, src_tensor, src_rank)
195-
)
200+
)
196201
if comm_op_list:
197202
reqs = dist.batch_isend_irecv(comm_op_list)
198203

@@ -205,7 +210,7 @@ def unpack_update_batch(self, packed_update_info):
205210
"""
206211
send_all, recv_all, stacked_maps, stacked_log2phy, layer_id_tensor = packed_update_info
207212

208-
maps = stacked_maps.unbind(0)
213+
maps = stacked_maps.unbind(0)
209214
layer_ids = layer_id_tensor.tolist()
210215

211216
if self.redundant_enable:
@@ -217,7 +222,7 @@ def unpack_update_batch(self, packed_update_info):
217222
_send = send_all
218223
_recv = recv_all
219224
_maps = maps
220-
_l2p = log2phy_list
225+
_l2p = log2phy_list
221226
_lids = layer_ids
222227

223228
recovered = [
@@ -249,7 +254,6 @@ def update_expert_load_statistical_period(self, num_expert_load_gather: int, num
249254
self.num_iterations = num_iterations
250255
logger.info(f" update {self.num_expert_load_gather=}, {self.num_iterations} success...")
251256

252-
253257
def shutdown(self):
254258
"""
255259
Clean up the EPLB process.

0 commit comments

Comments
 (0)