Skip to content

Commit 1c4a966

Browse files
jakob-schloerpre-commit-ci[bot]anaprietonem
authored
feat: Log variable scaling in mlflow (ecmwf#327)
## Description The variable scaling of the loss function is not trivial anymore since many scalers can be multiplied. To keep track of that, the final variable scaling is logged to mlfow (or other loggers) hyperparameters. ![image](https://github.com/user-attachments/assets/1450c624-9038-441c-9923-5c61bd770120) ## Type of Change - [x] New feature (non-breaking change which adds functionality) ## Code Compatibility - [x] I have performed a self-review of my code - [ ] Test with wandb and other loggers ### Code Performance and Testing - [x] I ran the [complete Pytest test](https://anemoi.readthedocs.io/projects/training/en/latest/dev/testing.html) suite locally, and they pass - [x] I have tested the changes on a single GPU - [x] I have tested the changes on multiple GPUs / multi-node setups - [ ] I have run the Benchmark Profiler against the old version of the code ### Documentation - [x] My code follows the style guidelines of this project - [x] I have added comments to my code, particularly in hard-to-understand areas <!-- readthedocs-preview anemoi-training start --> ---- 📚 Documentation preview 📚: https://anemoi-training--327.org.readthedocs.build/en/327/ <!-- readthedocs-preview anemoi-training end --> <!-- readthedocs-preview anemoi-graphs start --> ---- 📚 Documentation preview 📚: https://anemoi-graphs--327.org.readthedocs.build/en/327/ <!-- readthedocs-preview anemoi-graphs end --> <!-- readthedocs-preview anemoi-models start --> ---- 📚 Documentation preview 📚: https://anemoi-models--327.org.readthedocs.build/en/327/ <!-- readthedocs-preview anemoi-models end --> --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Ana Prieto Nemesio <91897203+anaprietonem@users.noreply.github.com>
1 parent 42c0194 commit 1c4a966

File tree

3 files changed

+50
-12
lines changed

3 files changed

+50
-12
lines changed

training/src/anemoi/training/diagnostics/logger.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from omegaconf import OmegaConf
1818

1919
from anemoi.training.schemas.base_schema import BaseSchema
20-
from anemoi.training.schemas.base_schema import convert_to_omegaconf
2120

2221
LOGGER = logging.getLogger(__name__)
2322

@@ -91,11 +90,6 @@ def get_mlflow_logger(config: BaseSchema) -> None:
9190
on_resume_create_child=config.diagnostics.log.mlflow.on_resume_create_child,
9291
max_params_length=max_params_length,
9392
)
94-
config_params = OmegaConf.to_container(convert_to_omegaconf(config), resolve=True)
95-
logger.log_hyperparams(
96-
config_params,
97-
expand_keys=config.diagnostics.log.mlflow.expand_hyperparams,
98-
)
9993

10094
if config.diagnostics.log.mlflow.terminal:
10195
logger.log_terminal_output(artifact_save_dir=config.hardware.paths.plots)

training/src/anemoi/training/losses/utils.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,27 +8,48 @@
88
# nor does it submit to any jurisdiction.
99

1010

11+
from __future__ import annotations
12+
1113
import logging
14+
from typing import TYPE_CHECKING
1215

13-
from anemoi.models.data_indices.collection import IndexCollection
14-
from anemoi.training.losses.base import BaseLoss
1516
from anemoi.training.utils.enums import TensorDim
1617

18+
if TYPE_CHECKING:
19+
from anemoi.models.data_indices.collection import IndexCollection
20+
from anemoi.training.losses.base import BaseLoss
21+
1722
LOGGER = logging.getLogger(__name__)
1823

1924

20-
def print_variable_scaling(loss: BaseLoss, data_indices: IndexCollection) -> None:
21-
"""Log the final variable scaling for each variable in the model.
25+
def print_variable_scaling(loss: BaseLoss, data_indices: IndexCollection) -> dict[str, float]:
26+
"""Log the final variable scaling for each variable in the model and return the scaling values.
2227
2328
Parameters
2429
----------
2530
loss : BaseLoss
2631
Loss function to get the variable scaling from.
2732
data_indices : IndexCollection
2833
Index collection to get the variable names from.
34+
35+
Returns
36+
-------
37+
Dict[str, float]
38+
Dictionary mapping variable names to their scaling values. If max_variables is specified,
39+
only the top N variables plus 'total_sum' will be included.
2940
"""
3041
variable_scaling = loss.scaler.subset_by_dim(TensorDim.VARIABLE.value).get_scaler(len(TensorDim)).reshape(-1)
3142
log_text = "Final Variable Scaling: "
43+
scaling_values, scaling_sum = {}, 0.0
44+
3245
for idx, name in enumerate(data_indices.model.output.name_to_index.keys()):
33-
log_text += f"{name}: {variable_scaling[idx]:.4g}, "
46+
value = float(variable_scaling[idx])
47+
log_text += f"{name}: {value:.4g}, "
48+
scaling_values[name] = value
49+
scaling_sum += value
50+
51+
log_text += f"Total scaling sum: {scaling_sum:.4g}, "
52+
scaling_values["total_sum"] = scaling_sum
3453
LOGGER.debug(log_text)
54+
55+
return scaling_values

training/src/anemoi/training/train/tasks/base.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import pytorch_lightning as pl
1919
import torch
2020
from hydra.utils import instantiate
21+
from omegaconf import OmegaConf
2122
from timm.scheduler import CosineLRScheduler
2223
from torch.distributed.optim import ZeroRedundancyOptimizer
2324

@@ -212,7 +213,10 @@ def __init__(
212213
scalers=self.scalers,
213214
data_indices=self.data_indices,
214215
)
215-
print_variable_scaling(self.loss, data_indices)
216+
self._scaling_values_log = print_variable_scaling(
217+
self.loss,
218+
data_indices,
219+
)
216220

217221
self.metrics = torch.nn.ModuleDict(
218222
{
@@ -761,3 +765,22 @@ def configure_optimizers(self) -> tuple[list[torch.optim.Optimizer], list[dict]]
761765
)
762766

763767
return [optimizer], [{"scheduler": scheduler, "interval": "step"}]
768+
769+
def setup(self, stage: str) -> None:
770+
"""Lightning hook that is called after model is initialized but before training starts."""
771+
# The conditions should be separate, but are combined due to pre-commit hook
772+
if stage == "fit" and self.trainer.is_global_zero and self.logger is not None:
773+
# Log hyperparameters on rank 0
774+
hyper_params = OmegaConf.to_container(convert_to_omegaconf(self.config), resolve=True)
775+
hyper_params.update({"variable_loss_scaling": self._scaling_values_log})
776+
# Expand keys for better visibility
777+
expand_keys = OmegaConf.select(
778+
convert_to_omegaconf(self.config),
779+
"diagnostics.log.mlflow.expand_hyperparams",
780+
default=["config"],
781+
)
782+
# Log hyperparameters
783+
self.logger.log_hyperparams(
784+
hyper_params,
785+
expand_keys=expand_keys,
786+
)

0 commit comments

Comments
 (0)