Skip to content

Commit 2ead010

Browse files
committed
fix(optim): fix optim params adj util to mod any
1 parent 943a9cd commit 2ead010

File tree

3 files changed

+17
-47
lines changed

3 files changed

+17
-47
lines changed

cellseg_models_pytorch/optimizers/tests/test_optim_setup.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,11 @@
77
@pytest.mark.parametrize("optim", ["adam"])
88
def test_optim_setup(optim):
99
model = cellpose_base(type_classes=3)
10-
params = adjust_optim_params(
11-
model,
12-
encoder_lr=0.5,
13-
encoder_wd=0.4,
14-
decoder_lr=0.3,
15-
decoder_wd=0.2,
16-
remove_bias_wd=True,
17-
)
18-
10+
optim_params = {
11+
"encoder": {"lr": 0.004, "weight_decay": 0.3},
12+
"decoder": {"lr": 0.003, "weight_decay": 0.2},
13+
}
14+
params = adjust_optim_params(model, optim_params)
1915
optimizer = OPTIM_LOOKUP[optim](params)
2016

2117
assert all([p_g["lr"] == p["lr"] for p_g, p in zip(optimizer.param_groups, params)])
Lines changed: 9 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,67 +1,38 @@
1-
from typing import Dict, List
1+
from typing import Any, Dict, List
22

33
import torch.nn as nn
44

55

66
def adjust_optim_params(
7-
model: nn.Module,
8-
encoder_lr: float,
9-
encoder_wd: float,
10-
decoder_lr: float,
11-
decoder_wd: float,
12-
remove_bias_wd: bool = True,
7+
model: nn.Module, optim_params: Dict[str, Dict[str, Any]]
138
) -> List[Dict[str, Dict]]:
149
"""Adjust the learning parameters for optimizer.
1510
16-
1. Adjust learning rate and weight decay in the pre-trained
17-
encoder and decoders.
18-
2. Remove weight decay from bias terms to reduce overfitting.
19-
20-
"Bag of Tricks for Image Classification with Convolutional Neural Networks"
21-
- https://arxiv.org/pdf/1812.01187
22-
2311
Parameters
2412
----------
2513
model : nn.Module
2614
The encoder-decoder segmentation model.
27-
encoder_lr : float
28-
Learning rate of the model encoder.
29-
encoder_wd : float
30-
Weight decay for the model encoder.
31-
decoder_lr : float
32-
Learning rate of the model decoder.
33-
decoder_wd : float
34-
Weight decay for the model decoder.
35-
remove_bias_wd : bool, default=True
36-
If True, the weight decay from the bias terms is removed from the model
37-
params. Ignored if `remove_wd`=True.
15+
optim_params : Dict[str, Dict[str, Any]]
16+
optim paramas like learning rates, weight decays etc for diff parts of
17+
the network. E.g.
18+
{"encoder": {"weight_decay: 0.1, "lr":0.1}, "sem": {"lr": 0.1}}
3819
3920
Returns
4021
-------
4122
List[Dict[str, Dict]]:
4223
a list of kwargs (str, Dict pairs) containing the model params.
4324
"""
4425
params = list(model.named_parameters())
45-
encoder_params = {"encoder": {"lr": encoder_lr, "weight_decay": encoder_wd}}
46-
decoder_params = {"decoder": {"lr": decoder_lr, "weight_decay": decoder_wd}}
4726

4827
adjust_params = []
4928
for name, parameters in params:
5029
opts = {}
51-
for enc, enc_opts in encoder_params.items():
52-
if enc in name:
53-
for key, item in enc_opts.items():
54-
opts[key] = item
5530

56-
for dec, dec_opts in decoder_params.items():
57-
if dec in name:
58-
for key, item in dec_opts.items():
31+
for block, block_params in optim_params.items():
32+
if block in name:
33+
for key, item in block_params.items():
5934
opts[key] = item
6035

61-
if remove_bias_wd:
62-
if name.endswith("bias"):
63-
opts["weight_decay"] = 0.0
64-
6536
adjust_params.append({"params": parameters, **opts})
6637

6738
return adjust_params
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
## Fixes
2+
3+
- Modify the optimizer adjustment utility function to adjust any optim/weight params.

0 commit comments

Comments
 (0)