Skip to content
This repository was archived by the owner on Jan 10, 2025. It is now read-only.

Commit 183e059

Browse files
committed
ensemble perturbations and clean up
1 parent d0bbee9 commit 183e059

File tree

7 files changed

+390
-72
lines changed

7 files changed

+390
-72
lines changed

ecml_tools/create/functions/__init__.py

Whitespace-only changes.
Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
# (C) Copyright 2020 ECMWF.
2+
#
3+
# This software is licensed under the terms of the Apache Licence Version 2.0
4+
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
5+
# In applying this licence, ECMWF does not waive the privileges and immunities
6+
# granted to it by virtue of its status as an intergovernmental organisation
7+
# nor does it submit to any jurisdiction.
8+
#
9+
10+
import warnings
11+
12+
import numpy as np
13+
import tqdm
14+
from climetlab import load_source
15+
from climetlab.core.temporary import temp_file
16+
from climetlab.readers.grib.output import new_grib_output
17+
18+
from ecml_tools.create.check import check_data_values
19+
20+
21+
def get_unique_field(ds, selection):
22+
ds = ds.sel(**selection)
23+
assert len(ds) == 1, (ds, selection)
24+
return ds[0]
25+
26+
27+
def normalise_number(number):
28+
if isinstance(number, (tuple, list, int)):
29+
return number
30+
31+
assert isinstance(number, str), (type(number), number)
32+
33+
number = number.split("/")
34+
if len(number) > 4 and (number[1] == "to" and number[3] == "by"):
35+
return list(range(int(number[0]), int(number[2]) + 1, int(number[4])))
36+
37+
if len(number) > 2 and number[1] == "to":
38+
return list(range(int(number[0]), int(number[2]) + 1))
39+
40+
assert isinstance(number, list), (type(number), number)
41+
return number
42+
43+
44+
def ensembles_perturbations(ensembles, center, mean, remapping={}, patches={}):
45+
n_ensembles = len(normalise_number(ensembles["number"]))
46+
47+
print(f"Retrieving ensemble data with {ensembles}")
48+
ensembles = load_source(**ensembles)
49+
print(f"Retrieving center data with {center}")
50+
center = load_source(**center)
51+
print(f"Retrieving mean data with {mean}")
52+
mean = load_source(**mean)
53+
54+
assert len(mean) * n_ensembles == len(ensembles), (
55+
len(mean),
56+
n_ensembles,
57+
len(ensembles),
58+
)
59+
assert len(center) * n_ensembles == len(ensembles), (
60+
len(center),
61+
n_ensembles,
62+
len(ensembles),
63+
)
64+
65+
tmp = temp_file()
66+
path = tmp.path
67+
out = new_grib_output(path)
68+
69+
keys = ["param", "level", "valid_datetime", "number", "date", "time", "step"]
70+
71+
ensembles_coords = ensembles.unique_values(*keys)
72+
center_coords = center.unique_values(*keys)
73+
mean_coords = mean.unique_values(*keys)
74+
75+
for k in keys:
76+
if k == "number":
77+
assert len(mean_coords[k]) == 1
78+
assert len(center_coords[k]) == 1
79+
assert len(ensembles_coords[k]) == n_ensembles
80+
continue
81+
assert set(center_coords[k]) == set(ensembles_coords[k]), (
82+
k,
83+
center_coords[k],
84+
ensembles_coords[k],
85+
)
86+
assert set(center_coords[k]) == set(mean_coords[k]), (
87+
k,
88+
center_coords[k],
89+
mean_coords[k],
90+
)
91+
92+
for field in tqdm.tqdm(center):
93+
param = field.metadata("param")
94+
grid = field.metadata("grid")
95+
96+
selection = dict(
97+
valid_datetime=field.metadata("valid_datetime"),
98+
param=field.metadata("param"),
99+
level=field.metadata("level"),
100+
date=field.metadata("date"),
101+
time=field.metadata("time"),
102+
step=field.metadata("step"),
103+
)
104+
mean_field = get_unique_field(mean, selection)
105+
assert mean_field.metadata("grid") == grid, (mean_field.metadata("grid"), grid)
106+
107+
m = mean_field.to_numpy()
108+
c = field.to_numpy()
109+
assert m.shape == c.shape, (m.shape, c.shape)
110+
111+
for number in ensembles_coords["number"]:
112+
ensembles_field = get_unique_field(ensembles.sel(number=number), selection)
113+
assert ensembles_field.metadata("grid") == grid, (
114+
ensembles_field.metadata("grid"),
115+
grid,
116+
)
117+
118+
e = ensembles_field.to_numpy()
119+
assert c.shape == e.shape, (c.shape, e.shape)
120+
121+
x = c + m - e
122+
if param == "q":
123+
warnings.warn("Clipping q")
124+
x = np.maximum(x, 0)
125+
126+
assert x.shape == c.shape, (x.shape, c.shape)
127+
128+
check_data_values(x, name=param)
129+
out.write(x, template=ensembles_field)
130+
131+
out.close()
132+
133+
ds = load_source("file", path)
134+
assert len(ds) == len(ensembles), (len(ds), len(ensembles))
135+
ds._tmp = tmp
136+
137+
assert len(mean) * n_ensembles == len(ensembles)
138+
assert len(center) * n_ensembles == len(ensembles)
139+
140+
final_coords = ds.unique_values(*keys)
141+
assert len(final_coords["number"]) == n_ensembles, final_coords
142+
return ds
143+
144+
145+
execute = ensembles_perturbations
146+
147+
if __name__ == "__main__":
148+
import yaml
149+
150+
config = yaml.safe_load(
151+
"""
152+
153+
common: &common
154+
name: mars
155+
# marser is the MARS containing ERA5 reanalysis dataset, avoid hitting the FDB server for nothing
156+
database: marser
157+
class: ea
158+
# date: $datetime_format($dates,%Y%m%d)
159+
# time: $datetime_format($dates,%H%M)
160+
date: 20221230/to/20230103
161+
time: '0000/1200'
162+
expver: '0001'
163+
grid: 20.0/20.0
164+
levtype: sfc
165+
param: [2t]
166+
# levtype: pl
167+
# param: [10u, 10v, 2d, 2t, lsm, msl, sdor, skt, slor, sp, tcw, z]
168+
169+
config:
170+
ensembles: # the ensemble data has one additional dimension
171+
<<: *common
172+
stream: enda
173+
type: an
174+
number: [0, 1]
175+
# number: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
176+
177+
center: # the new center of the data
178+
<<: *common
179+
stream: oper
180+
type: an
181+
182+
mean: # the previous center of the data
183+
<<: *common
184+
stream: enda
185+
type: em
186+
187+
"""
188+
)["config"]
189+
for k, v in config.items():
190+
print(k, v)
191+
192+
for f in ensembles_perturbations(**config):
193+
print(f, f.to_numpy().mean())

ecml_tools/create/input.py

Lines changed: 40 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@
77
# nor does it submit to any jurisdiction.
88
#
99
import datetime
10+
import importlib
1011
import logging
12+
import os
1113
import time
1214
from collections import defaultdict
1315
from copy import deepcopy
@@ -334,9 +336,10 @@ def __getitem__(self, key):
334336
raise KeyError(key)
335337

336338

337-
class SourceResult(Result):
339+
class FunctionResult(Result):
338340
def __init__(self, context, dates, action, previous_sibling=None):
339341
super().__init__(context, dates)
342+
assert isinstance(action, Action), type(action)
340343
self.action = action
341344

342345
_args = self.action.args
@@ -349,17 +352,19 @@ def __init__(self, context, dates, action, previous_sibling=None):
349352

350353
@cached_property
351354
def datasource(self):
352-
from climetlab import load_source
353-
354355
print(f"loading source with {self.args} {self.kwargs}")
355-
return load_source(*self.args, **self.kwargs)
356+
return self.action.function(*self.args, **self.kwargs)
356357

357358
def __repr__(self):
358359
content = " ".join([f"{v}" for v in self.args])
359360
content += " ".join([f"{k}={v}" for k, v in self.kwargs.items()])
360361

361362
return super().__repr__(content)
362363

364+
@property
365+
def function(self):
366+
raise NotImplementedError(f"Not implemented in {self.__class__.__name__}")
367+
363368

364369
class JoinResult(Result):
365370
def __init__(self, context, dates, results, **kwargs):
@@ -396,7 +401,7 @@ def __repr__(self):
396401
return super().__repr__(_inline_=self.name, _indent_=" ")
397402

398403

399-
class SourceAction(Action):
404+
class BaseFunctionAction(Action):
400405
def __repr__(self):
401406
content = ""
402407
content += ",".join([self._short_str(a) for a in self.args])
@@ -407,7 +412,32 @@ def __repr__(self):
407412
return super().__repr__(_inline_=content, _indent_=" ")
408413

409414
def select(self, dates):
410-
return SourceResult(self.context, dates, action=self)
415+
return FunctionResult(self.context, dates, action=self)
416+
417+
418+
class SourceAction(BaseFunctionAction):
419+
@property
420+
def function(self):
421+
from climetlab import load_source
422+
423+
return load_source
424+
425+
426+
class FunctionAction(BaseFunctionAction):
427+
def __init__(self, context, name, **kwargs):
428+
super().__init__(context, **kwargs)
429+
self.name = name
430+
431+
@property
432+
def function(self):
433+
here = os.path.dirname(__file__)
434+
path = os.path.join(here, "functions", f"{self.name}.py")
435+
spec = importlib.util.spec_from_file_location(self.name, path)
436+
module = spec.loader.load_module()
437+
# TODO: this fails here, fix this.
438+
# getattr(module, self.name)
439+
# self.action.kwargs
440+
return module.execute
411441

412442

413443
class ConcatResult(Result):
@@ -610,6 +640,7 @@ def action_factory(config, context):
610640
label=LabelAction,
611641
pipe=PipeAction,
612642
source=SourceAction,
643+
function=FunctionAction,
613644
dates=DateAction,
614645
)[key]
615646

@@ -688,5 +719,8 @@ def select(self, dates):
688719
"""This changes the context."""
689720
return self._action.select(dates)
690721

722+
def __repr__(self):
723+
return repr(self._action)
724+
691725

692726
build_input = InputBuilder

ecml_tools/create/statistics.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from collections import defaultdict
1717

1818
import numpy as np
19-
from prepml.utils.text import table
2019

2120
from ecml_tools.provenance import gather_provenance_info
2221

@@ -274,19 +273,12 @@ def check(self):
274273
raise
275274

276275
def __str__(self):
277-
stats = [self[name] for name in self.STATS_NAMES]
278-
279-
rows = []
276+
header = ["Variables"] + [self[name] for name in self.STATS_NAMES]
277+
out = " ".join(header)
280278

281279
for i, v in enumerate(self["variables_names"]):
282-
rows.append([i, v] + [x[i] for x in stats])
283-
284-
return table(
285-
rows,
286-
header=["Index", "Variable", "Min", "Max", "Mean", "Stdev"],
287-
align=[">", "<", ">", ">", ">", ">"],
288-
margin=3,
289-
)
280+
out += " ".join([v] + [f"{x[i]:.2f}" for x in self.values()])
281+
return out
290282

291283
def save(self, filename, provenance=None):
292284
assert filename.endswith(".json"), filename

0 commit comments

Comments
 (0)