Skip to content

Commit cf4b994

Browse files
committed
recenter fields
1 parent e5ee8d3 commit cf4b994

File tree

4 files changed

+233
-100
lines changed

4 files changed

+233
-100
lines changed

src/ai_models/inputs/compute.py

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,27 +7,33 @@
77

88
import logging
99

10+
import earthkit.data as ekd
11+
import tqdm
12+
from earthkit.data.core.temporary import temp_file
1013
from earthkit.data.indexing.fieldlist import FieldArray
1114

12-
from .transform import NewDataField
13-
from .transform import NewMetadataField
14-
1515
LOG = logging.getLogger(__name__)
1616

17+
G = 9.80665 # Same a pgen
18+
19+
20+
def make_z_from_gh(ds):
21+
22+
tmp = temp_file()
23+
24+
out = ekd.new_grib_output(tmp.path)
25+
other = []
1726

18-
def make_z_from_gh(previous):
19-
g = 9.80665 # Same a pgen
27+
for f in tqdm.tqdm(ds, delay=0.5, desc="GH to Z", leave=False):
2028

21-
def _proc(ds):
29+
if f.metadata("param") == "gh":
30+
out.write(f.to_numpy() * G, template=f, param="z")
31+
else:
32+
other.append(f)
2233

23-
ds = previous(ds)
34+
out.close()
2435

25-
result = []
26-
for f in ds:
27-
if f.metadata("param") == "gh":
28-
result.append(NewMetadataField(NewDataField(f, f.to_numpy() * g), param="z"))
29-
else:
30-
result.append(f)
31-
return FieldArray(result)
36+
result = FieldArray(other) + ekd.from_source("file", tmp.path)
37+
result._tmp = tmp
3238

33-
return _proc
39+
return result

src/ai_models/inputs/interpolate.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,25 +7,35 @@
77

88
import logging
99

10+
import earthkit.data as ekd
1011
import earthkit.regrid as ekr
1112
import tqdm
12-
from earthkit.data.indexing.fieldlist import FieldArray
13-
14-
from .transform import NewDataField
13+
from earthkit.data.core.temporary import temp_file
1514

1615
LOG = logging.getLogger(__name__)
1716

1817

1918
class Interpolate:
20-
def __init__(self, grid, source):
19+
def __init__(self, grid, source, metadata):
2120
self.grid = list(grid) if isinstance(grid, tuple) else grid
2221
self.source = list(source) if isinstance(source, tuple) else source
22+
self.metadata = metadata
2323

2424
def __call__(self, ds):
25+
tmp = temp_file()
26+
27+
out = ekd.new_grib_output(tmp.path)
28+
2529
result = []
2630
for f in tqdm.tqdm(ds, delay=0.5, desc="Interpolating", leave=False):
2731
data = ekr.interpolate(f.to_numpy(), dict(grid=self.source), dict(grid=self.grid))
28-
result.append(NewDataField(f, data))
32+
out.write(data, template=f, **self.metadata)
33+
34+
out.close()
35+
36+
result = ekd.from_source("file", tmp.path)
37+
result._tmp = tmp
38+
39+
print("Interpolated data", tmp.path)
2940

30-
LOG.info("Interpolated %d fields. Input shape %s, output shape %s.", len(result), ds[0].shape, result[0].shape)
31-
return FieldArray(result)
41+
return result

src/ai_models/inputs/opendata.py

Lines changed: 103 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -5,23 +5,23 @@
55
# granted to it by virtue of its status as an intergovernmental organisation
66
# nor does it submit to any jurisdiction.
77

8-
import datetime
98
import itertools
109
import logging
1110
import os
1211

1312
import earthkit.data as ekd
13+
from earthkit.data.core.temporary import temp_file
1414
from earthkit.data.indexing.fieldlist import FieldArray
1515
from multiurl import download
1616

1717
from .base import RequestBasedInput
1818
from .compute import make_z_from_gh
1919
from .interpolate import Interpolate
20+
from .recenter import recenter
2021
from .transform import NewMetadataField
2122

2223
LOG = logging.getLogger(__name__)
2324

24-
2525
CONSTANTS = (
2626
"z",
2727
"sdor",
@@ -30,17 +30,33 @@
3030

3131
CONSTANTS_URL = "https://get.ecmwf.int/repository/test-data/ai-models/opendata/constants-{resol}.grib2"
3232

33+
RESOLS = {
34+
(0.25, 0.25): ("0p25", (0.25, 0.25), False, False, {}),
35+
(0.1, 0.1): (
36+
"0p25",
37+
(0.25, 0.25),
38+
True,
39+
True,
40+
dict(
41+
longitudeOfLastGridPointInDegrees=359.9,
42+
iDirectionIncrementInDegrees=0.1,
43+
jDirectionIncrementInDegrees=0.1,
44+
Ni=3600,
45+
Nj=1801,
46+
),
47+
),
48+
# "N320": ("0p25", (0.25, 0.25), True, False, dict(gridType='reduced_gg')),
49+
# "O96": ("0p25", (0.25, 0.25), True, False, dict(gridType='reduced_gg', )),
50+
}
51+
52+
53+
def _identity(x):
54+
return x
55+
3356

3457
class OpenDataInput(RequestBasedInput):
3558
WHERE = "OPENDATA"
3659

37-
RESOLS = {
38-
(0.25, 0.25): ("0p25", (0.25, 0.25), False, False),
39-
"N320": ("0p25", (0.25, 0.25), True, False),
40-
"O96": ("0p25", (0.25, 0.25), True, False),
41-
(0.1, 0.1): ("0p25", (0.25, 0.25), True, True),
42-
}
43-
4460
def __init__(self, owner, **kwargs):
4561
self.owner = owner
4662

@@ -56,7 +72,7 @@ def _adjust(self, kwargs):
5672
if isinstance(grid, list):
5773
grid = tuple(grid)
5874

59-
kwargs["resol"], source, interp, oversampling = self.RESOLS[grid]
75+
kwargs["resol"], source, interp, oversampling, metadata = RESOLS[grid]
6076
r = dict(**kwargs)
6177
r.update(self.owner.retrieve)
6278

@@ -65,12 +81,15 @@ def _adjust(self, kwargs):
6581
logging.info("Interpolating input data from %s to %s.", source, grid)
6682
if oversampling:
6783
logging.warning("This will oversample the input data.")
68-
return Interpolate(grid, source)
84+
return Interpolate(grid, source, metadata)
6985
else:
70-
return lambda x: x
86+
return _identity
7187

7288
def pl_load_source(self, **kwargs):
73-
pproc = self._adjust(kwargs)
89+
90+
gh_to_z = _identity
91+
interpolate = self._adjust(kwargs)
92+
7493
kwargs["levtype"] = "pl"
7594
request = kwargs.copy()
7695

@@ -84,13 +103,68 @@ def pl_load_source(self, **kwargs):
84103
if "gh" not in param:
85104
param.append("gh")
86105
kwargs["param"] = param
87-
pproc = make_z_from_gh(pproc)
106+
gh_to_z = make_z_from_gh
88107

89108
logging.debug("load source ecmwf-open-data %s", kwargs)
90-
return self.check_pl(pproc(ekd.from_source("ecmwf-open-data", **kwargs)), request)
109+
110+
opendata = recenter(ekd.from_source("ecmwf-open-data", **kwargs))
111+
opendata = gh_to_z(opendata)
112+
opendata = interpolate(opendata)
113+
114+
return self.check_pl(opendata, request)
115+
116+
def constants(self, constant_params, request, kwargs):
117+
if len(constant_params) == 1:
118+
logging.warning(
119+
f"Single level parameter '{constant_params[0]}' is"
120+
" not available in ECMWF open data, using constants.grib2 instead"
121+
)
122+
else:
123+
logging.warning(
124+
f"Single level parameters {constant_params} are"
125+
" not available in ECMWF open data, using constants.grib2 instead"
126+
)
127+
128+
cachedir = os.path.expanduser("~/.cache/ai-models")
129+
constants_url = CONSTANTS_URL.format(resol=request["resol"])
130+
basename = os.path.basename(constants_url)
131+
132+
if not os.path.exists(cachedir):
133+
os.makedirs(cachedir)
134+
135+
path = os.path.join(cachedir, basename)
136+
137+
if not os.path.exists(path):
138+
logging.info("Downloading %s to %s", constants_url, path)
139+
download(constants_url, path + ".tmp")
140+
os.rename(path + ".tmp", path)
141+
142+
ds = ekd.from_source("file", path)
143+
ds = ds.sel(param=constant_params)
144+
145+
tmp = temp_file()
146+
147+
out = ekd.new_grib_output(tmp.path)
148+
149+
for f in ds:
150+
out.write(
151+
f.to_numpy(),
152+
template=f,
153+
date=kwargs["date"],
154+
time=kwargs["time"],
155+
step=kwargs.get("step", 0),
156+
)
157+
158+
out.close()
159+
160+
result = ekd.from_source("file", tmp.path)
161+
result._tmp = tmp
162+
163+
return result
91164

92165
def sfc_load_source(self, **kwargs):
93-
pproc = self._adjust(kwargs)
166+
interpolate = self._adjust(kwargs)
167+
94168
kwargs["levtype"] = "sfc"
95169
request = kwargs.copy()
96170

@@ -104,81 +178,32 @@ def sfc_load_source(self, **kwargs):
104178
param.remove(c)
105179
constant_params.append(c)
106180

107-
constants = ekd.from_source("empty")
108-
109181
if constant_params:
110-
if len(constant_params) == 1:
111-
logging.warning(
112-
f"Single level parameter '{constant_params[0]}' is"
113-
" not available in ECMWF open data, using constants.grib2 instead"
114-
)
115-
else:
116-
logging.warning(
117-
f"Single level parameters {constant_params} are"
118-
" not available in ECMWF open data, using constants.grib2 instead"
119-
)
120-
constants = []
121-
122-
cachedir = os.path.expanduser("~/.cache/ai-models")
123-
constants_url = CONSTANTS_URL.format(resol=request["resol"])
124-
basename = os.path.basename(constants_url)
125-
126-
if not os.path.exists(cachedir):
127-
os.makedirs(cachedir)
128-
129-
path = os.path.join(cachedir, basename)
130-
131-
if not os.path.exists(path):
132-
logging.info("Downloading %s to %s", constants_url, path)
133-
download(constants_url, path + ".tmp")
134-
os.rename(path + ".tmp", path)
135-
136-
ds = ekd.from_source("file", path)
137-
ds = ds.sel(param=constant_params)
138-
139-
date = int(kwargs["date"])
140-
time = int(kwargs["time"])
141-
if time < 100:
142-
time *= 100
143-
step = int(kwargs.get("step", 0))
144-
valid = datetime.datetime(
145-
date // 10000, date // 100 % 100, date % 100, time // 100, time % 100
146-
) + datetime.timedelta(hours=step)
147-
148-
for f in ds:
149-
150-
# assert False, (date, time, step)
151-
constants.append(
152-
NewMetadataField(
153-
f,
154-
valid_datetime=str(valid),
155-
date=date,
156-
time="%4d" % (time,),
157-
step=step,
158-
)
159-
)
160-
161-
constants = FieldArray(constants)
182+
constants = self.constants(constant_params, request, kwargs)
183+
else:
184+
constants = ekd.from_source("empty")
162185

163186
kwargs["param"] = param
164187

165-
logging.debug("load source ecmwf-open-data %s", kwargs)
166-
167-
fields = pproc(ekd.from_source("ecmwf-open-data", **kwargs) + constants)
188+
opendata = recenter(ekd.from_source("ecmwf-open-data", **kwargs))
189+
opendata = opendata + constants
190+
opendata = interpolate(opendata)
168191

169192
# Fix grib2/eccodes bug
170193

171-
fields = FieldArray([NewMetadataField(f, levelist=None) for f in fields])
194+
opendata = FieldArray([NewMetadataField(f, levelist=None) for f in opendata])
172195

173-
return self.check_sfc(fields, request)
196+
return self.check_sfc(opendata, request)
174197

175198
def ml_load_source(self, **kwargs):
176-
pproc = self._adjust(kwargs)
199+
interpolate = self._adjust(kwargs)
177200
kwargs["levtype"] = "ml"
178201
request = kwargs.copy()
179202

180-
logging.debug("load source ecmwf-open-data %s", kwargs)
181-
return self.check_ml(pproc(ekd.from_source("ecmwf-open-data", kwargs)), request)
203+
opendata = recenter(ekd.from_source("ecmwf-open-data", **kwargs))
204+
opendata = interpolate(opendata)
205+
206+
return self.check_ml(opendata, request)
182207

183208
def check_pl(self, ds, request):
184209
self._check(ds, "PL", request, "param", "levelist")

0 commit comments

Comments
 (0)