Skip to content

Commit 781ea4e

Browse files
committed
Merge commit '463bec447aaa6daaf38e3b9c6dc041e168ee7ea1' into develop
2 parents f8e5b39 + 463bec4 commit 781ea4e

File tree

5 files changed

+106
-9
lines changed

5 files changed

+106
-9
lines changed

training/CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,15 @@ Keep it human-readable, your future self will thank you!
2727
- Introduce variable to configure: transfer_learning -> bool, True if loading checkpoint in a transfer learning setting.
2828
-
2929
<b> TRANSFER LEARNING</b>: enabled new functionality. You can now load checkpoints from different models and different training runs.
30+
- Introduce (optional) variable to configure: transfer_learning -> bool, True if loading checkpoint in a transfer learning setting.
31+
- <b> TRANSFER LEARNING</b>: enabled new functionality. You can now load checkpoints from different models and different training runs.
3032
- Effective batch size: `(config.dataloader.batch_size["training"] * config.hardware.num_gpus_per_node * config.hardware.num_nodes) // config.hardware.num_gpus_per_model`.
3133
Used for experiment reproducibility across different computing configurations.
3234
- Added a check for the variable sorting on pre-trained/finetuned models [#120](https://github.com/ecmwf/anemoi-training/pull/120)
3335
- Added default configuration files for stretched grid and limited area model experiments [173](https://github.com/ecmwf/anemoi-training/pull/173)
3436
- Added new metrics for stretched grid models to track losses inside/outside the regional domain [#199](https://github.com/ecmwf/anemoi-training/pull/199)
37+
- <b> Model Freezing ❄️</b>: enabled new functionality. You can now Freeze parts of your model by specifying a list of submodules to freeze with the new config parameter: submodules_to_freeze.
38+
- Introduce (optional) variable to configure: submodules_to_freeze -> List[str], list of submodules to freeze.
3539
- Add supporting arrrays (numpy) to checkpoint
3640
- Support for masking out unconnected nodes in LAM [#171](https://github.com/ecmwf/anemoi-training/pull/171)
3741
- Improved validation metrics, allow 'all' to be scaled [#202](https://github.com/ecmwf/anemoi-training/pull/202)

training/docs/user-guide/training.rst

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,3 +280,65 @@ finished training. However if the user wants to restart the model from a
280280
specific point they can do this by setting
281281
``config.hardware.files.warm_start`` to be the checkpoint they want to
282282
restart from..
283+
284+
*******************
285+
Transfer Learning
286+
*******************
287+
288+
Transfer learning allows the model to reuse knowledge from a previously
289+
trained checkpoint. This is particularly useful when the new task is
290+
related to the old one, enabling faster convergence and often improving
291+
model performance.
292+
293+
To enable transfer learning, set the config.training.transfer_learning
294+
flag to True in the configuration file.
295+
296+
.. code:: yaml
297+
298+
training:
299+
# start the training from a checkpoint of a previous run
300+
fork_run_id: ...
301+
load_weights_only: True
302+
transfer_learning: True
303+
304+
When this flag is active and a checkpoint path is specified in
305+
config.hardware.files.warm_start or self.last_checkpoint, the system
306+
loads the pre-trained weights using the `transfer_learning_loading`
307+
function. This approach ensures only compatible weights are loaded and
308+
mismatched layers are handled appropriately.
309+
310+
For example, transfer learning might be used to adapt a weather
311+
forecasting model trained on one geographic region to another region
312+
with similar characteristics.
313+
314+
****************
315+
Model Freezing
316+
****************
317+
318+
Model freezing is a technique where specific parts (submodules) of a
319+
model are excluded from training. This is useful when certain parts of
320+
the model have been sufficiently trained or should remain unchanged for
321+
the current task.
322+
323+
To specify which submodules to freeze, use the
324+
config.training.submodules_to_freeze field in the configuration. List
325+
the names of submodules to be frozen. During model initialization, these
326+
submodules will have their parameters frozen, ensuring they are not
327+
updated during training.
328+
329+
For example with the following configuration, the processor will be
330+
frozen and only the encoder and decoder will be trained:
331+
332+
.. code:: yaml
333+
334+
training:
335+
# start the training from a checkpoint of a previous run
336+
fork_run_id: ...
337+
load_weights_only: True
338+
339+
submodules_to_freeze:
340+
- processor
341+
342+
Freezing can be particularly beneficial in scenarios such as fine-tuning
343+
when only specific components (e.g., the encoder, the decoder) need to
344+
adapt to a new task while keeping others (e.g., the processor) fixed.

training/src/anemoi/training/config/training/default.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,3 +140,5 @@ node_loss_weights:
140140
_target_: anemoi.training.losses.nodeweights.GraphNodeAttribute
141141
target_nodes: ${graph.data}
142142
node_attribute: area_weight
143+
144+
submodules_to_freeze: []

training/src/anemoi/training/train/train.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from anemoi.training.diagnostics.logger import get_wandb_logger
3333
from anemoi.training.distributed.strategy import DDPGroupStrategy
3434
from anemoi.training.train.forecaster import GraphForecaster
35+
from anemoi.training.utils.checkpoint import freeze_submodule_by_name
3536
from anemoi.training.utils.checkpoint import transfer_learning_loading
3637
from anemoi.training.utils.jsonify import map_config_to_primitives
3738
from anemoi.training.utils.seeding import get_base_seed
@@ -155,17 +156,24 @@ def model(self) -> GraphForecaster:
155156

156157
model = GraphForecaster(**kwargs)
157158

159+
# Load the model weights
158160
if self.load_weights_only:
159-
# Sanify the checkpoint for transfer learning
160-
if self.config.training.transfer_learning:
161-
LOGGER.info("Loading weights with Transfer Learning from %s", self.last_checkpoint)
162-
return transfer_learning_loading(model, self.last_checkpoint)
161+
if hasattr(self.config.training, "transfer_learning"):
162+
# Sanify the checkpoint for transfer learning
163+
if self.config.training.transfer_learning:
164+
LOGGER.info("Loading weights with Transfer Learning from %s", self.last_checkpoint)
165+
model = transfer_learning_loading(model, self.last_checkpoint)
166+
else:
167+
LOGGER.info("Restoring only model weights from %s", self.last_checkpoint)
168+
model = GraphForecaster.load_from_checkpoint(self.last_checkpoint, **kwargs, strict=False)
169+
170+
if hasattr(self.config.training, "submodules_to_freeze"):
171+
# Freeze the chosen model weights
172+
LOGGER.info("The following submodules will NOT be trained: %s", self.config.training.submodules_to_freeze)
173+
for submodule_name in self.config.training.submodules_to_freeze:
174+
freeze_submodule_by_name(model, submodule_name)
175+
LOGGER.info("%s frozen successfully.", submodule_name.upper())
163176

164-
LOGGER.info("Restoring only model weights from %s", self.last_checkpoint)
165-
166-
return GraphForecaster.load_from_checkpoint(self.last_checkpoint, **kwargs, strict=False)
167-
168-
LOGGER.info("Model initialised from scratch.")
169177
return model
170178

171179
@rank_zero_only

training/src/anemoi/training/utils/checkpoint.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,3 +91,24 @@ def transfer_learning_loading(model: torch.nn.Module, ckpt_path: Path | str) ->
9191
# Load the filtered st-ate_dict into the model
9292
model.load_state_dict(state_dict, strict=False)
9393
return model
94+
95+
96+
def freeze_submodule_by_name(module: nn.Module, target_name: str) -> None:
97+
"""
98+
Recursively freezes the parameters of a submodule with the specified name.
99+
100+
Parameters
101+
----------
102+
module : torch.nn.Module
103+
Pytorch model
104+
target_name : str
105+
The name of the submodule to freeze.
106+
"""
107+
for name, child in module.named_children():
108+
# If this is the target submodule, freeze its parameters
109+
if name == target_name:
110+
for param in child.parameters():
111+
param.requires_grad = False
112+
else:
113+
# Recursively search within children
114+
freeze_submodule_by_name(child, target_name)

0 commit comments

Comments
 (0)