Skip to content

Commit 06188d9

Browse files
authored
Add support for constant learning rate (ecmwf#1186)
* Added support for constant learning rate and minor clean-up in code * Fixed issues with overlap between lr phases * Changing default lr to constant
1 parent 5a1707c commit 06188d9

File tree

2 files changed

+16
-6
lines changed

2 files changed

+16
-6
lines changed

config/default_config.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ lr_final: 0.0
126126
lr_steps_warmup: 512
127127
lr_steps_cooldown: 512
128128
lr_policy_warmup: "cosine"
129-
lr_policy_decay: "linear"
129+
lr_policy_decay: "constant"
130130
lr_policy_cooldown: "linear"
131131

132132
grad_clip: 1.0

src/weathergen/train/lr_scheduler.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,10 @@ def __init__(
123123
self.decay_factor = self.lr_max_scaled * np.sqrt(n_steps_warmup)
124124
self.scheduler_decay = None
125125

126+
elif policy_decay == "constant":
127+
self.decay_factor = 0.0
128+
self.scheduler_decay = None
129+
126130
else:
127131
assert False, "Unsupported decay policy for learning rate scheduler"
128132

@@ -173,18 +177,24 @@ def step(self):
173177
if self.i_step >= (self.n_steps_warmup + self.n_steps_decay + self.n_steps_cooldown):
174178
return self.lr
175179

176-
if (
177-
self.policy_decay == "sqrt"
178-
and self.i_step > self.n_steps_warmup
179-
and self.i_step < self.n_steps_warmup + self.n_steps_decay
180-
):
180+
end_decay = self.n_steps_warmup + self.n_steps_decay
181+
phase_decay = (self.i_step > self.n_steps_warmup) and (self.i_step <= end_decay)
182+
183+
if self.policy_decay == "sqrt" and phase_decay:
181184
self.lr = (
182185
(self.decay_factor / np.sqrt(self.i_step))
183186
if self.i_step > 0
184187
else self.lr_max_scaled
185188
)
186189
for g in self.optimizer.param_groups:
187190
g["lr"] = self.lr
191+
elif self.policy_decay == "constant" and phase_decay:
192+
cur_lr = self.lr
193+
self.lr = self.lr_max_scaled
194+
# make sure lr_max_scaled rate is used if warm-up end is not lr_max_scaled
195+
if cur_lr < self.lr:
196+
for g in self.optimizer.param_groups:
197+
g["lr"] = self.lr
188198
else:
189199
self.cur_scheduler.step()
190200
self.lr = self.cur_scheduler.get_last_lr()[0]

0 commit comments

Comments
 (0)