Skip to content

Commit 74c93c8

Browse files
committed
add PolyScheduler
1 parent 93952cd commit 74c93c8

File tree

2 files changed

+168
-21
lines changed

2 files changed

+168
-21
lines changed

change_detection_pytorch/utils/lr_scheduler.py

Lines changed: 167 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
1-
from torch.optim.lr_scheduler import _LRScheduler
2-
from torch.optim.lr_scheduler import ReduceLROnPlateau
1+
import warnings
32

4-
__all__ = ['GradualWarmupScheduler']
3+
from torch.optim.lr_scheduler import ReduceLROnPlateau, _LRScheduler
4+
from torch.optim.optimizer import Optimizer
5+
6+
__all__ = ['GradualWarmupScheduler', 'PolyScheduler']
57

68

79
class GradualWarmupScheduler(_LRScheduler):
8-
""" Gradually warm-up(increasing) learning rate in optimizer.
10+
"""https://github.com/ildoonet/pytorch-gradual-warmup-lr
11+
Gradually warm-up(increasing) learning rate in optimizer.
912
Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'.
1013
Args:
1114
optimizer (Optimizer): Wrapped optimizer.
@@ -64,27 +67,171 @@ def step(self, epoch=None, metrics=None):
6467
else:
6568
self.step_ReduceLROnPlateau(metrics, epoch)
6669

67-
# TODO: Poly
70+
71+
class PolyScheduler(_LRScheduler):
72+
r"""Decays the learning rate of each parameter group using a polynomial LR scheduler.
73+
When last_epoch=-1, sets initial lr as lr.
74+
75+
Args:
76+
optimizer (Optimizer): Wrapped optimizer.
77+
power (float): Polynomial factor of learning rate decay.
78+
total_steps (int): The total number of steps in the cycle. Note that
79+
if a value is not provided here, then it must be inferred by providing
80+
a value for epochs and steps_per_epoch.
81+
Default: None
82+
epochs (int): The number of epochs to train for. This is used along
83+
with steps_per_epoch in order to infer the total number of steps in the cycle
84+
if a value for total_steps is not provided.
85+
Default: None
86+
steps_per_epoch (int): The number of steps per epoch to train for. This is
87+
used along with epochs in order to infer the total number of steps in the
88+
cycle if a value for total_steps is not provided.
89+
Default: None
90+
by_epoch (bool): If ``True``, the learning rate will be updated with the epoch
91+
and `steps_per_epoch` and `total_steps` will be ignored. If ``False``,
92+
the learning rate will be updated with the batch, you must define either
93+
`total_steps` or (`epochs` and `steps_per_epoch`).
94+
Default: ``False``.
95+
min_lr (float or list): A scalar or a list of scalars. A
96+
lower bound on the learning rate of all param groups
97+
or each group respectively. Default: 0.
98+
last_epoch (int): The index of the last batch. This parameter is used when
99+
resuming a training job. Since `step()` should be invoked after each
100+
batch instead of after each epoch, this number represents the total
101+
number of *batches* computed, not the total number of epochs computed.
102+
When last_epoch=-1, the schedule is started from the beginning.
103+
Default: -1
104+
verbose (bool): If ``True``, prints a message to stdout for
105+
each update. Default: ``False``.
106+
107+
Example:
108+
>>> data_loader = torch.utils.data.DataLoader(...)
109+
>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
110+
>>> scheduler = torch.optim.lr_scheduler.PolyScheduler(optimizer, power=0.9, steps_per_epoch=len(data_loader), epochs=10)
111+
>>> for epoch in range(10):
112+
>>> for batch in data_loader:
113+
>>> train_batch(...)
114+
>>> scheduler.step()
115+
116+
OR
117+
118+
>>> data_loader = torch.utils.data.DataLoader(...)
119+
>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
120+
>>> scheduler = torch.optim.lr_scheduler.PolyScheduler(optimizer, power=0.9, epochs=10, by_epoch=True)
121+
>>> for epoch in range(10):
122+
>>> train_epoch(...)
123+
>>> scheduler.step()
124+
125+
126+
https://github.com/likyoo/change_detection.pytorch/blob/main/change_detection_pytorch/utils/lr_scheduler.py
127+
"""
128+
129+
def __init__(self,
130+
optimizer,
131+
power=1.0,
132+
total_steps=None,
133+
epochs=None,
134+
steps_per_epoch=None,
135+
by_epoch=False,
136+
min_lr=0,
137+
last_epoch=-1,
138+
verbose=False):
139+
140+
# Validate optimizer
141+
if not isinstance(optimizer, Optimizer):
142+
raise TypeError('{} is not an Optimizer'.format(
143+
type(optimizer).__name__))
144+
self.optimizer = optimizer
145+
self.by_epoch = by_epoch
146+
self.epochs = epochs
147+
self.min_lr = min_lr
148+
self.power = power
149+
150+
# Validate total_steps
151+
if by_epoch:
152+
if epochs <= 0 or not isinstance(epochs, int):
153+
raise ValueError("Expected positive integer epochs, but got {}".format(epochs))
154+
if steps_per_epoch is not None or total_steps is not None:
155+
warnings.warn("`steps_per_epoch` and `total_steps` will be ignored if `by_epoch` is True, "
156+
"please use `epochs`.", UserWarning)
157+
self.total_steps = epochs
158+
elif total_steps is None and epochs is None and steps_per_epoch is None:
159+
raise ValueError("You must define either total_steps OR (epochs AND steps_per_epoch)")
160+
elif total_steps is not None:
161+
if total_steps <= 0 or not isinstance(total_steps, int):
162+
raise ValueError("Expected positive integer total_steps, but got {}".format(total_steps))
163+
self.total_steps = total_steps
164+
else:
165+
if epochs <= 0 or not isinstance(epochs, int):
166+
raise ValueError("Expected positive integer epochs, but got {}".format(epochs))
167+
if steps_per_epoch <= 0 or not isinstance(steps_per_epoch, int):
168+
raise ValueError("Expected positive integer steps_per_epoch, but got {}".format(steps_per_epoch))
169+
self.total_steps = epochs * steps_per_epoch
170+
171+
super(PolyScheduler, self).__init__(optimizer, last_epoch, verbose)
172+
173+
def get_lr(self):
174+
if not self._get_lr_called_within_step:
175+
warnings.warn("To get the last learning rate computed by the scheduler, "
176+
"please use `get_last_lr()`.", UserWarning)
177+
178+
step_num = self.last_epoch
179+
180+
if step_num > self.total_steps:
181+
raise ValueError("Tried to step {} times. The specified number of total steps is {}"
182+
.format(step_num + 1, self.total_steps))
183+
184+
if step_num == 0:
185+
return self.base_lrs
186+
187+
coeff = (1 - step_num / self.total_steps) ** self.power
188+
189+
return [(base_lr - self.min_lr) * coeff + self.min_lr
190+
for base_lr in self.base_lrs]
191+
192+
68193

69194
if __name__ == '__main__':
70195
# https://github.com/ildoonet/pytorch-gradual-warmup-lr
71-
import torch
72-
from torch.optim.lr_scheduler import StepLR, ExponentialLR
73-
from torch.optim.sgd import SGD
196+
# import torch
197+
# from torch.optim.lr_scheduler import StepLR
198+
# from torch.optim.sgd import SGD
199+
#
200+
# model = [torch.nn.Parameter(torch.randn(2, 2, requires_grad=True))]
201+
# optim = SGD(model, 0.1)
202+
#
203+
# # scheduler_warmup is chained with schduler_steplr
204+
# scheduler_steplr = StepLR(optim, step_size=10, gamma=0.1)
205+
# scheduler_warmup = GradualWarmupScheduler(optim, multiplier=1, total_epoch=5, after_scheduler=scheduler_steplr)
206+
#
207+
# # this zero gradient update is needed to avoid a warning message, issue #8.
208+
# optim.zero_grad()
209+
# optim.step()
210+
#
211+
# for epoch in range(1, 20):
212+
# scheduler_warmup.step(epoch)
213+
# print(epoch, optim.param_groups[0]['lr'])
214+
#
215+
# optim.step() # backward pass (update network)
74216

75-
model = [torch.nn.Parameter(torch.randn(2, 2, requires_grad=True))]
76-
optim = SGD(model, 0.1)
217+
import matplotlib.pyplot as plt
218+
import torch
77219

78-
# scheduler_warmup is chained with schduler_steplr
79-
scheduler_steplr = StepLR(optim, step_size=10, gamma=0.1)
80-
scheduler_warmup = GradualWarmupScheduler(optim, multiplier=1, total_epoch=5, after_scheduler=scheduler_steplr)
220+
EPOCH = 10
221+
LEN_DATA = 10
81222

82-
# this zero gradient update is needed to avoid a warning message, issue #8.
83-
optim.zero_grad()
84-
optim.step()
223+
model = [torch.nn.Parameter(torch.randn(2, 2, requires_grad=True))]
224+
optimizer = torch.optim.SGD(params=model, lr=0.1)
225+
scheduler = PolyScheduler(optimizer, power=0.9, min_lr=1e-4, epochs=EPOCH, steps_per_epoch=LEN_DATA, by_epoch=False)
226+
plt.figure()
85227

86-
for epoch in range(1, 20):
87-
scheduler_warmup.step(epoch)
88-
print(epoch, optim.param_groups[0]['lr'])
228+
x = list(range(EPOCH*LEN_DATA))
229+
y = []
230+
for epoch in range(EPOCH):
231+
for batch in range(LEN_DATA):
232+
print(epoch, 'lr={:.6f}'.format(scheduler.get_last_lr()[0]))
233+
y.append(scheduler.get_last_lr()[0])
234+
scheduler.step()
89235

90-
optim.step() # backward pass (update network)
236+
plt.plot(x, y)
237+
plt.show()

local_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@
7171
max_score = 0
7272
MAX_EPOCH = 60
7373

74-
for i in range(1, MAX_EPOCH + 1):
74+
for i in range(MAX_EPOCH):
7575

7676
print('\nEpoch: {}'.format(i))
7777
train_logs = train_epoch.run(train_loader)

0 commit comments

Comments
 (0)