Skip to content

Commit 19422b2

Browse files
bdvllrspre-commit-ci[bot]gmertes
authored
feat: Resolve config with omegaconf (ecmwf#252)
## Description Uses omegaconf to resolve the configuration. This allows to use interpolations in the configuration, similar to the anemoi-core packages. ## What problem does this change solve? <!-- Describe if it's a bugfix, new feature, doc update, or breaking change --> ## What issue or task does this change relate to? Fixes ecmwf#243. ## Additional notes ## <!-- Include any additional information, caveats, or considerations that the reviewer should be aware of. --> ***As a contributor to the Anemoi framework, please ensure that your changes include unit tests, updates to any affected dependencies and documentation, and have been tested in a parallel setting (i.e., with multiple GPUs). As a reviewer, you are also responsible for verifying these aspects and requesting changes if they are not adequately addressed. For guidelines about those please refer to https://anemoi.readthedocs.io/en/latest/*** <!-- readthedocs-preview anemoi-inference start --> ---- 📚 Documentation preview 📚: https://anemoi-inference--252.org.readthedocs.build/en/252/ <!-- readthedocs-preview anemoi-inference end --> --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Gert Mertes <13658335+gmertes@users.noreply.github.com>
1 parent eb7afd1 commit 19422b2

File tree

12 files changed

+201
-72
lines changed

12 files changed

+201
-72
lines changed

docs/inference/apis/level3.rst

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,11 @@ You can also override values by providing them on the command line:
2626
.. literalinclude:: code/level3_2.sh
2727
:language: bash
2828

29+
Overrides are parsed as an `OmegaConf
30+
<https://omegaconf.readthedocs.io/en/2.2_branch/usage.html#from-a-dot-list>`_
31+
dotlist, so list items can be accessed with ``list.index`` or
32+
``list[index]``.
33+
2934
You can also run entirely from the command line without a config file,
3035
by passing all required options as an override:
3136

@@ -39,9 +44,9 @@ that was used to train the model, by setting ``dataset`` entry to
3944
.. literalinclude:: code/level3_2.yaml
4045
:language: yaml
4146

42-
It is also possible to override list entries and append to lists on the
43-
command line by using the list indices as key. Running inference with
44-
following command:
47+
Below is an example of how to override list entries and append to lists
48+
on the command line by using the dotlist notation. Running inference
49+
with following command:
4550

4651
.. literalinclude:: code/level3_4.sh
4752
:language: bash

docs/inference/configs/introduction.rst

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,13 @@ This document provides an overview of the configuration to provide to
88
the :ref:`anemoi-inference run <run_command>` command line tool.
99

1010
The configuration file is a YAML file that specifies various options. It
11-
is composed of :ref:`top level <top-level>` options which are usually
12-
simple values such as strings, number or booleans. The configuration
13-
also provide ways to specify which internal classes to use for the
14-
:ref:`inputs <inputs>` and :ref:`outputs <outputs>`, and how to
11+
is extended by `OmegaConf <https://github.com/omry/omegaconf>`_ such
12+
that `interpolations
13+
<https://omegaconf.readthedocs.io/en/2.2_branch/usage.html#variable-interpolation>`_
14+
can be used. It is composed of :ref:`top level <top-level>` options
15+
which are usually simple values such as strings, number or booleans. The
16+
configuration also provide ways to specify which internal classes to use
17+
for the :ref:`inputs <inputs>` and :ref:`outputs <outputs>`, and how to
1518
configure them.
1619

1720
In that case, the general format is shown below. The first entry

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ dependencies = [
5151
"earthkit-data>=0.12.4",
5252
"eccodes>=2.38.3",
5353
"numpy",
54+
"omegaconf>=2.2,<2.4",
5455
"packaging",
5556
"pydantic",
5657
"pyyaml",

src/anemoi/inference/config/__init__.py

Lines changed: 81 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
import logging
1313
import os
14-
from copy import deepcopy
14+
from datetime import datetime
1515
from typing import Any
1616
from typing import Dict
1717
from typing import List
@@ -20,9 +20,13 @@
2020
from typing import TypeVar
2121
from typing import Union
2222

23-
import yaml
23+
from earthkit.data.utils.dates import to_datetime
24+
from omegaconf import DictConfig
25+
from omegaconf import ListConfig
26+
from omegaconf import OmegaConf
2427
from pydantic import BaseModel
2528
from pydantic import ConfigDict
29+
from pydantic import field_validator
2630

2731
LOG = logging.getLogger(__name__)
2832

@@ -34,6 +38,15 @@ class Configuration(BaseModel):
3438

3539
model_config = ConfigDict(extra="forbid")
3640

41+
date: Union[datetime, None] = None
42+
"""The starting date for the forecast. If not provided, the date will depend on the selected Input object. If a string, it is parsed by :func:`earthkit.data.utils.dates`."""
43+
44+
@field_validator("date", mode="before")
45+
@classmethod
46+
def to_datetime(cls, date: Union[str, int, datetime, None]) -> Optional[datetime]:
47+
if date is not None:
48+
return to_datetime(date)
49+
3750
@classmethod
3851
def load(
3952
cls: Type[T],
@@ -58,56 +71,47 @@ def load(
5871
The loaded configuration.
5972
"""
6073

61-
config = {}
74+
configs: List[Union[DictConfig, ListConfig]] = []
6275

6376
# Set default values
6477
if defaults is not None:
6578
if not isinstance(defaults, list):
6679
defaults = [defaults]
6780
for d in defaults:
6881
if isinstance(d, str):
69-
with open(d) as f:
70-
d = yaml.safe_load(f)
71-
config.update(d)
82+
configs.append(OmegaConf.load(d))
83+
continue
84+
configs.append(OmegaConf.create(d))
7285

7386
# Load the user configuration
7487
if isinstance(path, dict):
75-
user_config = deepcopy(path)
88+
configs.append(OmegaConf.create(path))
7689
else:
77-
with open(path) as f:
78-
user_config = yaml.safe_load(f)
79-
80-
cls._merge_configs(config, user_config)
90+
configs.append(OmegaConf.load(path))
8191

82-
# Apply overrides
8392
if not isinstance(overrides, list):
8493
overrides = [overrides]
8594

95+
# unsafe merge should be fine as we don't re-use the original configs
96+
oc_config = OmegaConf.unsafe_merge(*configs)
97+
8698
for override in overrides:
8799
if isinstance(override, dict):
88-
cls._merge_configs(config, override)
100+
oc_config = OmegaConf.unsafe_merge(oc_config, OmegaConf.create(override))
89101
else:
90-
path = config
91-
key, value = override.split("=")
92-
keys = key.split(".")
93-
for key in keys[:-1]:
94-
if key.isdigit() and isinstance(path, list):
95-
index = int(key)
96-
if index < len(path):
97-
LOG.debug(f"key {key} is used as list index in list{path}")
98-
path = path[index]
99-
elif index == len(path):
100-
LOG.debug(f"key {key} is used to append to list {path}")
101-
path.append({})
102-
path = path[index]
103-
else:
104-
raise IndexError(f"Index {index} out of range for list {path} of length {len(path)}")
105-
else:
106-
path = path.setdefault(key, {})
107-
path[keys[-1]] = value
108-
109-
# Validate the configuration
110-
config = cls(**config)
102+
# use from_dotlist to use OmegaConf split
103+
# which allows for "param.val" or "param[val]".
104+
override_conf = OmegaConf.from_dotlist([override])
105+
# We can't directly merge reconstructed with the config because
106+
# omegaconf prefers parsing digits (like 0 in key.0) into dict keys
107+
# rather than lists.
108+
# Instead, we provide a reference config and we try to merge the override
109+
# into the reference and keep types provided by the reference.
110+
oc_config = OmegaConf.unsafe_merge(_merge_configs(oc_config, override_conf))
111+
112+
resolved_config = OmegaConf.to_container(oc_config, resolve=True)
113+
114+
config = cls.model_validate(resolved_config)
111115

112116
# Set environment variables found in the configuration
113117
# as soon as possible
@@ -116,19 +120,47 @@ def load(
116120

117121
return config
118122

119-
@classmethod
120-
def _merge_configs(cls, a: Dict[Any, Any], b: Dict[Any, Any]) -> None:
121-
"""Merge two configurations.
122123

123-
Parameters
124-
----------
125-
a : Dict[Any, Any]
126-
The first configuration.
127-
b : Dict[Any, Any]
128-
The second configuration.
129-
"""
130-
for key, value in b.items():
131-
if key in a and isinstance(a[key], dict) and isinstance(value, dict):
132-
cls._merge_configs(a[key], value)
133-
else:
134-
a[key] = value
124+
def _merge_configs(ref_conf: Any, new_conf: Any) -> Any:
125+
"""Recursively merges a new OmegaConf object into a reference OmegaConf object
126+
127+
Parameters
128+
----------
129+
ref_conf : Any
130+
reference OmegaConf object. Should be a DictConfig or ListConfig.
131+
new_conf : Any
132+
new OmegaConf object.
133+
134+
Returns
135+
-------
136+
Any
137+
The merged OmegaConf config
138+
"""
139+
if isinstance(new_conf, DictConfig) and len(new_conf):
140+
key, rest = next(iter(new_conf.items()))
141+
key = str(key)
142+
elif isinstance(new_conf, ListConfig) and len(new_conf):
143+
key, rest = 0, new_conf[0]
144+
else:
145+
return new_conf
146+
if isinstance(ref_conf, ListConfig):
147+
if isinstance(key, str) and not key.isdigit():
148+
raise ValueError(f"Expected int key, got {key}")
149+
index = int(key)
150+
if index < len(ref_conf):
151+
LOG.debug(f"key {key} is used as list key in list{ref_conf}")
152+
ref_conf[index] = _merge_configs(ref_conf[index], rest)
153+
elif index == len(ref_conf):
154+
LOG.debug(f"key {key} is used to append to list {ref_conf}")
155+
ref_conf.append(rest)
156+
else:
157+
raise IndexError(f"key {key} out of range for list {ref_conf} of length {len(ref_conf)}")
158+
return ref_conf
159+
elif isinstance(ref_conf, DictConfig) and key in ref_conf:
160+
ref_conf[key] = _merge_configs(ref_conf[key], rest)
161+
return ref_conf
162+
elif isinstance(ref_conf, DictConfig):
163+
ref_conf[key] = rest
164+
return ref_conf
165+
else:
166+
raise ValueError(f"ref is of unexpected type {type(ref_conf)}. Should be ListConfig or DictConfig")

src/anemoi/inference/config/couple.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from typing import Tuple
1919

2020
from anemoi.inference.config import Configuration
21-
from anemoi.inference.types import Date
2221

2322
LOG = logging.getLogger(__name__)
2423

@@ -28,9 +27,6 @@ class CoupleConfiguration(Configuration):
2827

2928
description: Optional[str] = None
3029

31-
date: Optional[Date] = None
32-
"""The starting date for the forecast. If not provided, the date will depend on the selected Input object. If a string, it is parsed by :func:`anemoi.utils.dates.as_datetime`."""
33-
3430
lead_time: Optional[Tuple[str, int, datetime.timedelta]] = None
3531
"""The lead time for the forecast. This can be a string, an integer or a timedelta object.
3632
If an integer, it represents a number of hours. Otherwise, it is parsed by :func:`anemoi.utils.dates.as_timedelta`.

src/anemoi/inference/config/run.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,6 @@ class RunConfiguration(Configuration):
3434
runner: Union[str, Dict[str, Any]] = "default"
3535
"""The runner to use."""
3636

37-
date: Union[str, int, datetime.datetime, None] = None
38-
"""The starting date for the forecast. If not provided, the date will depend on the selected Input object. If a string, it is parsed by :func:`anemoi.utils.dates.as_datetime`."""
39-
4037
lead_time: Union[str, int, datetime.timedelta] = "10d"
4138
"""The lead time for the forecast. This can be a string, an integer or a timedelta object.
4239
If an integer, it represents a number of hours. Otherwise, it is parsed by :func:`anemoi.utils.dates.as_timedelta`.

src/anemoi/inference/inputs/cds.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,8 +151,6 @@ def create_input_state(self, *, date: Optional[Date]) -> State:
151151
date = to_datetime(-1)
152152
LOG.warning("CDSInput: `date` parameter not provided, using yesterday's date: %s", date)
153153

154-
date = to_datetime(date)
155-
156154
return self._create_input_state(
157155
self.retrieve(
158156
self.variables,

src/anemoi/inference/inputs/dataset.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from typing import Tuple
1919

2020
import numpy as np
21-
from earthkit.data.utils.dates import to_datetime
2221

2322
from anemoi.inference.context import Context
2423
from anemoi.inference.types import Date
@@ -103,7 +102,6 @@ def create_input_state(self, *, date: Optional[Date] = None) -> State:
103102
if date is None:
104103
raise ValueError("`date` must be provided")
105104

106-
date = to_datetime(date)
107105
latitudes = self.ds.latitudes
108106
longitudes = self.ds.longitudes
109107

src/anemoi/inference/inputs/ekd.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,6 @@ def _create_state(
280280

281281
n_points = fields[0].to_numpy(dtype=dtype, flatten=flatten).size
282282
for field in fields:
283-
284283
name, valid_datetime = field.metadata("name"), field.metadata("valid_datetime")
285284
if name not in state_fields:
286285
state_fields[name] = np.full(
@@ -368,8 +367,6 @@ def _create_input_state(
368367
"%s: `date` not provided, using the most recent date: %s", self.__class__.__name__, date.isoformat()
369368
)
370369

371-
# TODO: where we do this might change in the future
372-
date = to_datetime(date)
373370
dates = [date + h for h in self.checkpoint.lagged]
374371

375372
return self._create_state(

src/anemoi/inference/inputs/mars.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -247,8 +247,6 @@ def create_input_state(self, *, date: Optional[Date]) -> State:
247247
date = to_datetime(-1)
248248
LOG.warning("MarsInput: `date` parameter not provided, using yesterday's date: %s", date)
249249

250-
date = to_datetime(date)
251-
252250
return self._create_input_state(
253251
self.retrieve(
254252
self.variables,

0 commit comments

Comments
 (0)