@@ -123,6 +123,10 @@ def __init__(
123123 self .decay_factor = self .lr_max_scaled * np .sqrt (n_steps_warmup )
124124 self .scheduler_decay = None
125125
126+ elif policy_decay == "constant" :
127+ self .decay_factor = 0.0
128+ self .scheduler_decay = None
129+
126130 else :
127131 assert False , "Unsupported decay policy for learning rate scheduler"
128132
@@ -173,18 +177,24 @@ def step(self):
173177 if self .i_step >= (self .n_steps_warmup + self .n_steps_decay + self .n_steps_cooldown ):
174178 return self .lr
175179
176- if (
177- self .policy_decay == "sqrt"
178- and self .i_step > self .n_steps_warmup
179- and self .i_step < self .n_steps_warmup + self .n_steps_decay
180- ):
180+ end_decay = self .n_steps_warmup + self .n_steps_decay
181+ phase_decay = (self .i_step > self .n_steps_warmup ) and (self .i_step <= end_decay )
182+
183+ if self .policy_decay == "sqrt" and phase_decay :
181184 self .lr = (
182185 (self .decay_factor / np .sqrt (self .i_step ))
183186 if self .i_step > 0
184187 else self .lr_max_scaled
185188 )
186189 for g in self .optimizer .param_groups :
187190 g ["lr" ] = self .lr
191+ elif self .policy_decay == "constant" and phase_decay :
192+ cur_lr = self .lr
193+ self .lr = self .lr_max_scaled
194+ # make sure lr_max_scaled rate is used if warm-up end is not lr_max_scaled
195+ if cur_lr < self .lr :
196+ for g in self .optimizer .param_groups :
197+ g ["lr" ] = self .lr
188198 else :
189199 self .cur_scheduler .step ()
190200 self .lr = self .cur_scheduler .get_last_lr ()[0 ]
0 commit comments