Skip to content

Commit 138b522

Browse files
committed
feat(training): add auto_lr_finder to lit exp
1 parent 9d215c3 commit 138b522

File tree

1 file changed

+30
-8
lines changed

1 file changed

+30
-8
lines changed

cellseg_models_pytorch/training/lit/lightning_experiment.py

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ def __init__(
3131
scheduler: str = "reduce_on_plateau",
3232
scheduler_params: Dict[str, Any] = None,
3333
log_freq: int = 100,
34+
auto_lr_finder: bool = False,
35+
**kwargs,
3436
) -> None:
3537
"""Segmentation model training experiment.
3638
@@ -64,6 +66,7 @@ def __init__(
6466
optim paramas like learning rates, weight decays etc for diff parts of
6567
the network.
6668
E.g. {"encoder": {"weight_decay: 0.1, "lr": 0.1}, "sem": {"lr": 0.01}}
69+
or {"learning_rate": 0.005, "weight_decay": 0.03}
6770
lookahead : bool, default=False
6871
Flag whether the optimizer uses lookahead.
6972
scheduler : str, default="reduce_on_plateau"
@@ -75,6 +78,8 @@ def __init__(
7578
for the possible scheduler arguments.
7679
log_freq : int, default=100
7780
Return logs every n batches in logging callbacks.
81+
auto_lr_finder : bool, default=False
82+
Flag, whether to use the lightning in-built auto-lr-finder.
7883
7984
Raises
8085
------
@@ -83,6 +88,8 @@ def __init__(
8388
ValueError if illegal metric names are given.
8489
ValueError if illegal optimizer name is given.
8590
ValueError if illegal scheduler name is given.
91+
KeyError if `auto_lr_finder` is set to True and `optim_params` does not
92+
contain `lr`-key.
8693
"""
8794
super().__init__()
8895
self.model = model
@@ -95,6 +102,16 @@ def __init__(
95102
self.scheduler = scheduler
96103
self.scheduler_params = scheduler_params
97104
self.lookahead = lookahead
105+
self.auto_lr_finder = auto_lr_finder
106+
107+
if auto_lr_finder:
108+
try:
109+
self.lr = optim_params["lr"]
110+
except KeyError:
111+
raise KeyError(
112+
"To use lightning in-built auto_lr_finder, the `optim_params` "
113+
"config variable has to contain 'lr'-key for learning-rate."
114+
)
98115

99116
self.branch_losses = branch_losses
100117
self.branch_metrics = branch_metrics
@@ -309,15 +326,20 @@ def configure_optimizers(self):
309326
f"Illegal scheduler given. Got {self.scheduler}. Allowed: {allowed}."
310327
)
311328

312-
# set sensible default if None.
313-
if self.optim_params is None:
314-
self.optim_params = {
315-
"encoder": {"lr": 0.00005, "weight_decay": 0.00003},
316-
"decoder": {"lr": 0.0005, "weight_decay": 0.0003},
317-
}
329+
if not self.auto_lr_finder:
330+
# set sensible default if None.
331+
if self.optim_params is None:
332+
self.optim_params = {
333+
"encoder": {"lr": 0.00005, "weight_decay": 0.00005},
334+
"decoder": {"lr": 0.0005, "weight_decay": 0.0005},
335+
}
318336

319-
params = adjust_optim_params(self.model, self.optim_params)
320-
optimizer = OPTIM_LOOKUP[self.optimizer](params)
337+
params = adjust_optim_params(self.model, self.optim_params)
338+
optimizer = OPTIM_LOOKUP[self.optimizer](params)
339+
else:
340+
optimizer = OPTIM_LOOKUP[self.optimizer](
341+
self.model.parameters(), lr=self.lr
342+
)
321343

322344
if self.lookahead:
323345
optimizer = OPTIM_LOOKUP["lookahead"](optimizer, k=5, alpha=0.5)

0 commit comments

Comments
 (0)