@@ -130,22 +130,22 @@ def forward_before(self):
130
130
self .reqs = []
131
131
self .eplb_loader .asyn_expert_weight_transfer (self .reqs )
132
132
133
- def forward_end (self ,dummy_run = False ):
134
- self .adaptor .get_rank_expert_workload (self .num_moe_layers ,dummy_run )
135
- if not self .update_in_flight :
136
- load_gather_iteration , update_iteration = self .get_update_iteration ()
137
- if load_gather_iteration :
138
- moe_load = self .compute_and_set_moe_load (dummy_run )
139
- if update_iteration :
140
- self .wakeup_eplb_worker ()
141
- self .update_in_flight = True
142
- self .wait_worker_iterations = 0
143
- self .weight_loading = False
144
-
145
- if self .update_in_flight :
146
- self .wait_worker_iterations = self .wait_worker_iterations + 1
147
-
148
- self .eplb_loader .update_expert_map_and_weight (self .reqs , self .redundant_enable )
133
+ def forward_end (self ,dummy_run = False ):
134
+ self .adaptor .get_rank_expert_workload (self .num_moe_layers ,dummy_run )
135
+ if not self .update_in_flight :
136
+ load_gather_iteration , update_iteration = self .get_update_iteration ()
137
+ if load_gather_iteration :
138
+ moe_load = self .compute_and_set_moe_load (dummy_run )
139
+ if update_iteration :
140
+ self .wakeup_eplb_worker ()
141
+ self .update_in_flight = True
142
+ self .wait_worker_iterations = 0
143
+ self .weight_loading = False
144
+
145
+ if self .update_in_flight :
146
+ self .wait_worker_iterations = self .wait_worker_iterations + 1
147
+
148
+ self .eplb_loader .update_expert_map_and_weight (self .reqs , self .redundant_enable )
149
149
150
150
def compute_and_set_moe_load (self ,dummy_run = False ):
151
151
local_load = self .adaptor .get_rank_expert_workload (self .num_moe_layers ,dummy_run )
0 commit comments