|
8 | 8 | # nor does it submit to any jurisdiction. |
9 | 9 |
|
10 | 10 |
|
| 11 | +from __future__ import annotations |
| 12 | + |
11 | 13 | import logging |
| 14 | +from typing import TYPE_CHECKING |
12 | 15 |
|
13 | | -from anemoi.models.data_indices.collection import IndexCollection |
14 | | -from anemoi.training.losses.base import BaseLoss |
15 | 16 | from anemoi.training.utils.enums import TensorDim |
16 | 17 |
|
| 18 | +if TYPE_CHECKING: |
| 19 | + from anemoi.models.data_indices.collection import IndexCollection |
| 20 | + from anemoi.training.losses.base import BaseLoss |
| 21 | + |
17 | 22 | LOGGER = logging.getLogger(__name__) |
18 | 23 |
|
19 | 24 |
|
20 | | -def print_variable_scaling(loss: BaseLoss, data_indices: IndexCollection) -> None: |
21 | | - """Log the final variable scaling for each variable in the model. |
| 25 | +def print_variable_scaling(loss: BaseLoss, data_indices: IndexCollection) -> dict[str, float]: |
| 26 | + """Log the final variable scaling for each variable in the model and return the scaling values. |
22 | 27 |
|
23 | 28 | Parameters |
24 | 29 | ---------- |
25 | 30 | loss : BaseLoss |
26 | 31 | Loss function to get the variable scaling from. |
27 | 32 | data_indices : IndexCollection |
28 | 33 | Index collection to get the variable names from. |
| 34 | +
|
| 35 | + Returns |
| 36 | + ------- |
| 37 | + Dict[str, float] |
| 38 | + Dictionary mapping variable names to their scaling values. If max_variables is specified, |
| 39 | + only the top N variables plus 'total_sum' will be included. |
29 | 40 | """ |
30 | 41 | variable_scaling = loss.scaler.subset_by_dim(TensorDim.VARIABLE.value).get_scaler(len(TensorDim)).reshape(-1) |
31 | 42 | log_text = "Final Variable Scaling: " |
| 43 | + scaling_values, scaling_sum = {}, 0.0 |
| 44 | + |
32 | 45 | for idx, name in enumerate(data_indices.model.output.name_to_index.keys()): |
33 | | - log_text += f"{name}: {variable_scaling[idx]:.4g}, " |
| 46 | + value = float(variable_scaling[idx]) |
| 47 | + log_text += f"{name}: {value:.4g}, " |
| 48 | + scaling_values[name] = value |
| 49 | + scaling_sum += value |
| 50 | + |
| 51 | + log_text += f"Total scaling sum: {scaling_sum:.4g}, " |
| 52 | + scaling_values["total_sum"] = scaling_sum |
34 | 53 | LOGGER.debug(log_text) |
| 54 | + |
| 55 | + return scaling_values |
0 commit comments