@@ -31,6 +31,8 @@ def __init__(
31
31
scheduler : str = "reduce_on_plateau" ,
32
32
scheduler_params : Dict [str , Any ] = None ,
33
33
log_freq : int = 100 ,
34
+ auto_lr_finder : bool = False ,
35
+ ** kwargs ,
34
36
) -> None :
35
37
"""Segmentation model training experiment.
36
38
@@ -64,6 +66,7 @@ def __init__(
64
66
optim paramas like learning rates, weight decays etc for diff parts of
65
67
the network.
66
68
E.g. {"encoder": {"weight_decay: 0.1, "lr": 0.1}, "sem": {"lr": 0.01}}
69
+ or {"learning_rate": 0.005, "weight_decay": 0.03}
67
70
lookahead : bool, default=False
68
71
Flag whether the optimizer uses lookahead.
69
72
scheduler : str, default="reduce_on_plateau"
@@ -75,6 +78,8 @@ def __init__(
75
78
for the possible scheduler arguments.
76
79
log_freq : int, default=100
77
80
Return logs every n batches in logging callbacks.
81
+ auto_lr_finder : bool, default=False
82
+ Flag, whether to use the lightning in-built auto-lr-finder.
78
83
79
84
Raises
80
85
------
@@ -83,6 +88,8 @@ def __init__(
83
88
ValueError if illegal metric names are given.
84
89
ValueError if illegal optimizer name is given.
85
90
ValueError if illegal scheduler name is given.
91
+ KeyError if `auto_lr_finder` is set to True and `optim_params` does not
92
+ contain `lr`-key.
86
93
"""
87
94
super ().__init__ ()
88
95
self .model = model
@@ -95,6 +102,16 @@ def __init__(
95
102
self .scheduler = scheduler
96
103
self .scheduler_params = scheduler_params
97
104
self .lookahead = lookahead
105
+ self .auto_lr_finder = auto_lr_finder
106
+
107
+ if auto_lr_finder :
108
+ try :
109
+ self .lr = optim_params ["lr" ]
110
+ except KeyError :
111
+ raise KeyError (
112
+ "To use lightning in-built auto_lr_finder, the `optim_params` "
113
+ "config variable has to contain 'lr'-key for learning-rate."
114
+ )
98
115
99
116
self .branch_losses = branch_losses
100
117
self .branch_metrics = branch_metrics
@@ -309,15 +326,20 @@ def configure_optimizers(self):
309
326
f"Illegal scheduler given. Got { self .scheduler } . Allowed: { allowed } ."
310
327
)
311
328
312
- # set sensible default if None.
313
- if self .optim_params is None :
314
- self .optim_params = {
315
- "encoder" : {"lr" : 0.00005 , "weight_decay" : 0.00003 },
316
- "decoder" : {"lr" : 0.0005 , "weight_decay" : 0.0003 },
317
- }
329
+ if not self .auto_lr_finder :
330
+ # set sensible default if None.
331
+ if self .optim_params is None :
332
+ self .optim_params = {
333
+ "encoder" : {"lr" : 0.00005 , "weight_decay" : 0.00005 },
334
+ "decoder" : {"lr" : 0.0005 , "weight_decay" : 0.0005 },
335
+ }
318
336
319
- params = adjust_optim_params (self .model , self .optim_params )
320
- optimizer = OPTIM_LOOKUP [self .optimizer ](params )
337
+ params = adjust_optim_params (self .model , self .optim_params )
338
+ optimizer = OPTIM_LOOKUP [self .optimizer ](params )
339
+ else :
340
+ optimizer = OPTIM_LOOKUP [self .optimizer ](
341
+ self .model .parameters (), lr = self .lr
342
+ )
321
343
322
344
if self .lookahead :
323
345
optimizer = OPTIM_LOOKUP ["lookahead" ](optimizer , k = 5 , alpha = 0.5 )
0 commit comments