|
1 |
| -from typing import Dict, List |
| 1 | +from typing import Any, Dict, List |
2 | 2 |
|
3 | 3 | import torch.nn as nn
|
4 | 4 |
|
5 | 5 |
|
6 | 6 | def adjust_optim_params(
|
7 |
| - model: nn.Module, |
8 |
| - encoder_lr: float, |
9 |
| - encoder_wd: float, |
10 |
| - decoder_lr: float, |
11 |
| - decoder_wd: float, |
12 |
| - remove_bias_wd: bool = True, |
| 7 | + model: nn.Module, optim_params: Dict[str, Dict[str, Any]] |
13 | 8 | ) -> List[Dict[str, Dict]]:
|
14 | 9 | """Adjust the learning parameters for optimizer.
|
15 | 10 |
|
16 |
| - 1. Adjust learning rate and weight decay in the pre-trained |
17 |
| - encoder and decoders. |
18 |
| - 2. Remove weight decay from bias terms to reduce overfitting. |
19 |
| -
|
20 |
| - "Bag of Tricks for Image Classification with Convolutional Neural Networks" |
21 |
| - - https://arxiv.org/pdf/1812.01187 |
22 |
| -
|
23 | 11 | Parameters
|
24 | 12 | ----------
|
25 | 13 | model : nn.Module
|
26 | 14 | The encoder-decoder segmentation model.
|
27 |
| - encoder_lr : float |
28 |
| - Learning rate of the model encoder. |
29 |
| - encoder_wd : float |
30 |
| - Weight decay for the model encoder. |
31 |
| - decoder_lr : float |
32 |
| - Learning rate of the model decoder. |
33 |
| - decoder_wd : float |
34 |
| - Weight decay for the model decoder. |
35 |
| - remove_bias_wd : bool, default=True |
36 |
| - If True, the weight decay from the bias terms is removed from the model |
37 |
| - params. Ignored if `remove_wd`=True. |
| 15 | + optim_params : Dict[str, Dict[str, Any]] |
| 16 | + optim paramas like learning rates, weight decays etc for diff parts of |
| 17 | + the network. E.g. |
| 18 | + {"encoder": {"weight_decay: 0.1, "lr":0.1}, "sem": {"lr": 0.1}} |
38 | 19 |
|
39 | 20 | Returns
|
40 | 21 | -------
|
41 | 22 | List[Dict[str, Dict]]:
|
42 | 23 | a list of kwargs (str, Dict pairs) containing the model params.
|
43 | 24 | """
|
44 | 25 | params = list(model.named_parameters())
|
45 |
| - encoder_params = {"encoder": {"lr": encoder_lr, "weight_decay": encoder_wd}} |
46 |
| - decoder_params = {"decoder": {"lr": decoder_lr, "weight_decay": decoder_wd}} |
47 | 26 |
|
48 | 27 | adjust_params = []
|
49 | 28 | for name, parameters in params:
|
50 | 29 | opts = {}
|
51 |
| - for enc, enc_opts in encoder_params.items(): |
52 |
| - if enc in name: |
53 |
| - for key, item in enc_opts.items(): |
54 |
| - opts[key] = item |
55 | 30 |
|
56 |
| - for dec, dec_opts in decoder_params.items(): |
57 |
| - if dec in name: |
58 |
| - for key, item in dec_opts.items(): |
| 31 | + for block, block_params in optim_params.items(): |
| 32 | + if block in name: |
| 33 | + for key, item in block_params.items(): |
59 | 34 | opts[key] = item
|
60 | 35 |
|
61 |
| - if remove_bias_wd: |
62 |
| - if name.endswith("bias"): |
63 |
| - opts["weight_decay"] = 0.0 |
64 |
| - |
65 | 36 | adjust_params.append({"params": parameters, **opts})
|
66 | 37 |
|
67 | 38 | return adjust_params
|
0 commit comments