@@ -40,28 +40,27 @@ def set_adaptor(self, adaptor):
40
40
41
41
def init_eplb (self , expert_map_path ):
42
42
self .num_expert_load_gather = 10
43
- self .redundant_enable = (expert_map_path != None )
44
- self .num_iterations : torch .int64 = 130
43
+ self .periodic_load_gather = True
44
+ self .redundant_enable = (expert_map_path is not None )
45
+ self .num_iterations_eplb_update : torch .int64 = 130
45
46
self .expert_map_path = expert_map_path
46
47
47
48
try :
48
49
if not envs .VLLM_ALLOW_EXPERT_LOAD_COLLECTING :
49
- self .num_expert_load_gather = self .num_iterations
50
+ self .num_expert_load_gather = self .num_iterations_eplb_update
51
+ self .periodic_load_gather = False
50
52
except Exception as e :
51
- self .num_expert_load_gather = self .num_iterations
53
+ self .num_expert_load_gather = self .num_iterations_eplb_update
54
+ self .periodic_load_gather = False
52
55
53
- self .weight_update_counter = 0
54
56
self .expert_map_initialized = False
55
- self .update_in_flight = False
56
-
57
57
self .gate_eplb = True
58
58
59
59
self .reqs = []
60
60
self .update_info_all = []
61
61
62
62
self .cur_iterations : torch .int64 = 0
63
63
64
- self .wait_worker_iterations : torch .int64 = 0
65
64
self .num_wait_worker_iterations : torch .int64 = 20
66
65
67
66
self .planner_block_queue = Queue ()
@@ -90,11 +89,22 @@ def init_eplb(self, expert_map_path):
90
89
91
90
logger .info (f"[ModelRunner] Launched EPLB process (pid={ self .eplb_process .pid } )" )
92
91
93
- def get_update_iteration (self ):
94
- self .cur_iterations = self .cur_iterations + 1
95
- load_gather_iteration = self .cur_iterations % self .num_expert_load_gather == 0 if not self .gate_eplb else self .cur_iterations == self .num_iterations
96
- upate_iteration = self .cur_iterations % self .num_iterations == 0 if not self .gate_eplb else self .cur_iterations == self .num_iterations
97
- return load_gather_iteration , upate_iteration
92
+ def update_iteration (self ):
93
+ self .cur_iterations += 1
94
+ if self .cur_iterations == (self .num_iterations_eplb_update + \
95
+ self .num_wait_worker_iterations + self .num_moe_layers ):
96
+ if not self .gate_eplb :
97
+ self .cur_iterations = 0
98
+
99
+ def get_update_info_flag (self ):
100
+ return self .cur_iterations == (self .num_iterations_eplb_update + self .num_wait_worker_iterations )
101
+
102
+ def wakeup_eplb_worker_flag (self ):
103
+ return self .cur_iterations == (self .num_iterations_eplb_update - 1 )
104
+
105
+ def update_expert_weight_flag (self ):
106
+ weight_update_counter = self .cur_iterations - (self .num_iterations_eplb_update + self .num_wait_worker_iterations )
107
+ return (weight_update_counter >= 0 and weight_update_counter < self .num_moe_layers )
98
108
99
109
def get_init_expert_map (self ):
100
110
try :
@@ -108,14 +118,11 @@ def wakeup_eplb_worker(self):
108
118
self .planner_block_queue .put (1 )
109
119
110
120
def forward_before (self ):
111
-
112
121
# Batch after eplb process being triggered, get update info provided by eplb process
113
- if self .update_in_flight and self .weight_update_counter == 0 and self .wait_worker_iterations == self .num_wait_worker_iterations :
114
- self .wait_worker_iterations = 0
122
+ if self .get_update_info_flag ():
115
123
self .update_info_all = self .block_update_queue .get ()
116
- self .weight_loading = True
117
124
118
- if self .update_in_flight and self . weight_loading and self . weight_update_counter < self . num_moe_layers :
125
+ if self .update_expert_weight_flag () :
119
126
(expert_send_info , expert_recv_info , updated_expert_map , log2phy_map , layer_id ) = self .update_info_all .pop (0 )
120
127
rank_id = torch .distributed .get_rank ()
121
128
if self .redundant_enable :
@@ -125,34 +132,22 @@ def forward_before(self):
125
132
#logger.info(f"check update info, layer = {layer_id}, send = {expert_send_info_this_rank}, recv = {expert_recv_info_this_rank}")
126
133
self .eplb_loader .generate_expert_d2d_transfer_task (expert_send_info , expert_recv_info ,
127
134
updated_expert_map_this_rank , layer_id + self .adaptor .num_dense_layers )
128
- self .weight_update_counter += 1
129
- if self .weight_update_counter == self .num_moe_layers :
130
- self .weight_update_counter = 0
131
- self .update_in_flight = False
132
- self .update_info_all = []
133
- # set asynchronous stream for d2d expert weight update
134
- self .reqs = []
135
- self .eplb_loader .asyn_expert_weight_transfer (self .reqs )
136
135
136
+ # set asynchronous stream for d2d expert weight update
137
+ self .reqs = []
138
+ self .eplb_loader .asyn_expert_weight_transfer (self .reqs )
137
139
138
- def forward_end (self ,dummy_run = False ):
139
- if not self .update_in_flight :
140
- load_gather_iteration , update_iteration = self .get_update_iteration ()
141
- if load_gather_iteration :
142
- moe_load = self .compute_and_set_moe_load ()
143
- self .get_expert_load ()
144
- if update_iteration :
145
- self .wakeup_eplb_worker ()
146
- self .update_in_flight = True
147
- self .wait_worker_iterations = 0
148
- self .weight_loading = False
140
+ def forward_end (self ):
141
+ if self .wakeup_eplb_worker_flag ():
142
+ moe_load = self .compute_and_set_moe_load (is_clear = True )
143
+ self .wakeup_eplb_worker ()
149
144
150
- if self .update_in_flight :
151
- self .wait_worker_iterations = self .wait_worker_iterations + 1
145
+ if self .update_expert_weight_flag () :
146
+ self .eplb_loader . update_expert_map_and_weight ( self .reqs , self . redundant_enable )
152
147
153
- self .eplb_loader . update_expert_map_and_weight ( self . reqs , self . redundant_enable )
148
+ self .update_iteration ( )
154
149
155
- def compute_and_set_moe_load (self ,dummy_run = False ):
150
+ def compute_and_set_moe_load (self , is_clear = False ):
156
151
local_load = self .adaptor .get_rank_expert_workload ()
157
152
158
153
self ._gather_buffer = None
@@ -241,10 +236,10 @@ def get_expert_load(self) -> tuple:
241
236
return moe_load , expert_maps , num_local_experts
242
237
243
238
def update_expert_load_statistical_period (self , num_expert_load_gather : int , num_iterations : int ):
244
- logger .info (f" start update { self .num_expert_load_gather = } , { self .num_iterations } ..." )
239
+ logger .info (f" start update { self .num_expert_load_gather = } , { self .num_iterations_eplb_update } ..." )
245
240
self .num_expert_load_gather = num_expert_load_gather
246
- self .num_iterations = num_iterations
247
- logger .info (f" update { self .num_expert_load_gather = } , { self .num_iterations } success..." )
241
+ self .num_iterations_eplb_update = num_iterations
242
+ logger .info (f" update { self .num_expert_load_gather = } , { self .num_iterations_eplb_update } success..." )
248
243
249
244
def shutdown (self ):
250
245
"""
0 commit comments