Skip to content

Commit bfd2b94

Browse files
authored
Merge pull request #57 from ecmwf-lab/feature/opendata
Feature/opendata
2 parents 78c7a74 + 889598d commit bfd2b94

File tree

17 files changed

+832
-242
lines changed

17 files changed

+832
-242
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,3 +177,4 @@ bar
177177
dev/
178178
*.out
179179
_version.py
180+
*.tar

pyproject.toml

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -40,22 +40,26 @@ classifiers = [
4040
]
4141

4242
dependencies = [
43+
"cdsapi",
44+
"earthkit-data>=0.10.3",
45+
"earthkit-meteo",
46+
"earthkit-regrid",
47+
"eccodes>=2.37",
48+
"ecmwf-api-client",
49+
"ecmwf-opendata",
4350
"entrypoints",
44-
"requests",
45-
"climetlab>=0.23.0",
46-
"multiurl",
47-
"ecmwflibs>=0.6.1",
4851
"gputil",
49-
"earthkit-meteo",
52+
"multiurl",
5053
"pyyaml",
54+
"requests",
5155
"tqdm",
5256
]
5357

5458

5559
[project.urls]
56-
Homepage = "https://github.com/ecmwf/ai-models/"
57-
Repository = "https://github.com/ecmwf/ai-models/"
58-
Issues = "https://github.com/ecmwf/ai-models/issues"
60+
Homepage = "https://github.com/ecmwf-lab/ai-models/"
61+
Repository = "https://github.com/ecmwf-lab/ai-models/"
62+
Issues = "https://github.com/ecmwf-lab/ai-models/issues"
5963

6064
[project.scripts]
6165
ai-models = "ai_models.__main__:main"
@@ -64,10 +68,11 @@ ai-models = "ai_models.__main__:main"
6468
version_file = "src/ai_models/_version.py"
6569

6670
[project.entry-points."ai_models.input"]
67-
file = "ai_models.inputs:FileInput"
68-
mars = "ai_models.inputs:MarsInput"
69-
cds = "ai_models.inputs:CdsInput"
70-
opendata = "ai_models.inputs:OpenDataInput"
71+
file = "ai_models.inputs.file:FileInput"
72+
mars = "ai_models.inputs.mars:MarsInput"
73+
cds = "ai_models.inputs.cds:CdsInput"
74+
ecmwf-open-data = "ai_models.inputs.opendata:OpenDataInput"
75+
opendata = "ai_models.inputs.opendata:OpenDataInput"
7176

7277
[project.entry-points."ai_models.output"]
7378
file = "ai_models.outputs:FileOutput"

src/ai_models/__main__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,16 @@
1111
import shlex
1212
import sys
1313

14+
import earthkit.data as ekd
15+
1416
from .inputs import available_inputs
1517
from .model import Timer
1618
from .model import available_models
1719
from .model import load_model
1820
from .outputs import available_outputs
1921

22+
ekd.settings.set("cache-policy", "user")
23+
2024
LOG = logging.getLogger(__name__)
2125

2226

src/ai_models/inputs/__init__.py

Lines changed: 3 additions & 187 deletions
Original file line numberDiff line numberDiff line change
@@ -8,198 +8,14 @@
88
import logging
99
from functools import cached_property
1010

11-
import climetlab as cml
11+
import earthkit.data as ekd
12+
import earthkit.regrid as ekr
1213
import entrypoints
14+
from earthkit.data.indexing.fieldlist import FieldArray
1315

1416
LOG = logging.getLogger(__name__)
1517

1618

17-
class RequestBasedInput:
18-
def __init__(self, owner, **kwargs):
19-
self.owner = owner
20-
21-
def _patch(self, **kargs):
22-
r = dict(**kargs)
23-
self.owner.patch_retrieve_request(r)
24-
return r
25-
26-
@cached_property
27-
def fields_sfc(self):
28-
param = self.owner.param_sfc
29-
if not param:
30-
return cml.load_source("empty")
31-
32-
LOG.info(f"Loading surface fields from {self.WHERE}")
33-
return cml.load_source(
34-
"multi",
35-
[
36-
self.sfc_load_source(
37-
**self._patch(
38-
date=date,
39-
time=time,
40-
param=param,
41-
grid=self.owner.grid,
42-
area=self.owner.area,
43-
**self.owner.retrieve,
44-
)
45-
)
46-
for date, time in self.owner.datetimes()
47-
],
48-
)
49-
50-
@cached_property
51-
def fields_pl(self):
52-
param, level = self.owner.param_level_pl
53-
if not (param and level):
54-
return cml.load_source("empty")
55-
56-
LOG.info(f"Loading pressure fields from {self.WHERE}")
57-
return cml.load_source(
58-
"multi",
59-
[
60-
self.pl_load_source(
61-
**self._patch(
62-
date=date,
63-
time=time,
64-
param=param,
65-
level=level,
66-
grid=self.owner.grid,
67-
area=self.owner.area,
68-
)
69-
)
70-
for date, time in self.owner.datetimes()
71-
],
72-
)
73-
74-
@cached_property
75-
def fields_ml(self):
76-
param, level = self.owner.param_level_ml
77-
if not (param and level):
78-
return cml.load_source("empty")
79-
80-
LOG.info(f"Loading model fields from {self.WHERE}")
81-
return cml.load_source(
82-
"multi",
83-
[
84-
self.ml_load_source(
85-
**self._patch(
86-
date=date,
87-
time=time,
88-
param=param,
89-
level=level,
90-
grid=self.owner.grid,
91-
area=self.owner.area,
92-
)
93-
)
94-
for date, time in self.owner.datetimes()
95-
],
96-
)
97-
98-
@cached_property
99-
def all_fields(self):
100-
return self.fields_sfc + self.fields_pl + self.fields_ml
101-
102-
103-
class MarsInput(RequestBasedInput):
104-
WHERE = "MARS"
105-
106-
def __init__(self, owner, **kwargs):
107-
self.owner = owner
108-
109-
def pl_load_source(self, **kwargs):
110-
kwargs["levtype"] = "pl"
111-
logging.debug("load source mars %s", kwargs)
112-
return cml.load_source("mars", kwargs)
113-
114-
def sfc_load_source(self, **kwargs):
115-
kwargs["levtype"] = "sfc"
116-
logging.debug("load source mars %s", kwargs)
117-
return cml.load_source("mars", kwargs)
118-
119-
def ml_load_source(self, **kwargs):
120-
kwargs["levtype"] = "ml"
121-
logging.debug("load source mars %s", kwargs)
122-
return cml.load_source("mars", kwargs)
123-
124-
125-
class CdsInput(RequestBasedInput):
126-
WHERE = "CDS"
127-
128-
def pl_load_source(self, **kwargs):
129-
kwargs["product_type"] = "reanalysis"
130-
return cml.load_source("cds", "reanalysis-era5-pressure-levels", kwargs)
131-
132-
def sfc_load_source(self, **kwargs):
133-
kwargs["product_type"] = "reanalysis"
134-
return cml.load_source("cds", "reanalysis-era5-single-levels", kwargs)
135-
136-
def ml_load_source(self, **kwargs):
137-
raise NotImplementedError("CDS does not support model levels")
138-
139-
140-
class OpenDataInput(RequestBasedInput):
141-
WHERE = "OPENDATA"
142-
143-
RESOLS = {(0.25, 0.25): "0p25"}
144-
145-
def __init__(self, owner, **kwargs):
146-
self.owner = owner
147-
148-
def _adjust(self, kwargs):
149-
if "level" in kwargs:
150-
# OpenData uses levelist instead of level
151-
kwargs["levelist"] = kwargs.pop("level")
152-
153-
grid = kwargs.pop("grid")
154-
if isinstance(grid, list):
155-
grid = tuple(grid)
156-
157-
kwargs["resol"] = self.RESOLS[grid]
158-
r = dict(**kwargs)
159-
r.update(self.owner.retrieve)
160-
return r
161-
162-
def pl_load_source(self, **kwargs):
163-
self._adjust(kwargs)
164-
kwargs["levtype"] = "pl"
165-
logging.debug("load source ecmwf-open-data %s", kwargs)
166-
return cml.load_source("ecmwf-open-data", **kwargs)
167-
168-
def sfc_load_source(self, **kwargs):
169-
self._adjust(kwargs)
170-
kwargs["levtype"] = "sfc"
171-
logging.debug("load source ecmwf-open-data %s", kwargs)
172-
return cml.load_source("ecmwf-open-data", **kwargs)
173-
174-
def ml_load_source(self, **kwargs):
175-
self._adjust(kwargs)
176-
kwargs["levtype"] = "ml"
177-
logging.debug("load source ecmwf-open-data %s", kwargs)
178-
return cml.load_source("ecmwf-open-data", **kwargs)
179-
180-
181-
class FileInput:
182-
def __init__(self, owner, file, **kwargs):
183-
self.file = file
184-
self.owner = owner
185-
186-
@cached_property
187-
def fields_sfc(self):
188-
return self.all_fields.sel(levtype="sfc")
189-
190-
@cached_property
191-
def fields_pl(self):
192-
return self.all_fields.sel(levtype="pl")
193-
194-
@cached_property
195-
def fields_ml(self):
196-
return self.all_fields.sel(levtype="ml")
197-
198-
@cached_property
199-
def all_fields(self):
200-
return cml.load_source("file", self.file)
201-
202-
20319
def get_input(name, *args, **kwargs):
20420
return available_inputs()[name].load()(*args, **kwargs)
20521

src/ai_models/inputs/base.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
# (C) Copyright 2023 European Centre for Medium-Range Weather Forecasts.
2+
# This software is licensed under the terms of the Apache Licence Version 2.0
3+
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
4+
# In applying this licence, ECMWF does not waive the privileges and immunities
5+
# granted to it by virtue of its status as an intergovernmental organisation
6+
# nor does it submit to any jurisdiction.
7+
8+
import logging
9+
from functools import cached_property
10+
11+
import earthkit.data as ekd
12+
13+
LOG = logging.getLogger(__name__)
14+
15+
16+
class RequestBasedInput:
17+
def __init__(self, owner, **kwargs):
18+
self.owner = owner
19+
20+
def _patch(self, **kargs):
21+
r = dict(**kargs)
22+
self.owner.patch_retrieve_request(r)
23+
return r
24+
25+
@cached_property
26+
def fields_sfc(self):
27+
param = self.owner.param_sfc
28+
if not param:
29+
return ekd.from_source("empty")
30+
31+
LOG.info(f"Loading surface fields from {self.WHERE}")
32+
33+
return ekd.from_source(
34+
"multi",
35+
[
36+
self.sfc_load_source(
37+
**self._patch(
38+
date=date,
39+
time=time,
40+
param=param,
41+
grid=self.owner.grid,
42+
area=self.owner.area,
43+
**self.owner.retrieve,
44+
)
45+
)
46+
for date, time in self.owner.datetimes()
47+
],
48+
)
49+
50+
@cached_property
51+
def fields_pl(self):
52+
param, level = self.owner.param_level_pl
53+
if not (param and level):
54+
return ekd.from_source("empty")
55+
56+
LOG.info(f"Loading pressure fields from {self.WHERE}")
57+
return ekd.from_source(
58+
"multi",
59+
[
60+
self.pl_load_source(
61+
**self._patch(
62+
date=date,
63+
time=time,
64+
param=param,
65+
level=level,
66+
grid=self.owner.grid,
67+
area=self.owner.area,
68+
)
69+
)
70+
for date, time in self.owner.datetimes()
71+
],
72+
)
73+
74+
@cached_property
75+
def fields_ml(self):
76+
param, level = self.owner.param_level_ml
77+
if not (param and level):
78+
return ekd.from_source("empty")
79+
80+
LOG.info(f"Loading model fields from {self.WHERE}")
81+
return ekd.from_source(
82+
"multi",
83+
[
84+
self.ml_load_source(
85+
**self._patch(
86+
date=date,
87+
time=time,
88+
param=param,
89+
level=level,
90+
grid=self.owner.grid,
91+
area=self.owner.area,
92+
)
93+
)
94+
for date, time in self.owner.datetimes()
95+
],
96+
)
97+
98+
@cached_property
99+
def all_fields(self):
100+
return self.fields_sfc + self.fields_pl + self.fields_ml

0 commit comments

Comments
 (0)