Skip to content
This repository was archived by the owner on Jan 10, 2025. It is now read-only.

Commit adc0363

Browse files
committed
refactor ensemble perturbation
1 parent 7b108e9 commit adc0363

File tree

3 files changed

+50
-67
lines changed

3 files changed

+50
-67
lines changed

ecml_tools/create/functions/ensemble_perturbations.py

Lines changed: 50 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -42,103 +42,86 @@ def normalise_number(number):
4242

4343

4444
def ensembles_perturbations(ensembles, center, mean, remapping={}, patches={}):
45-
n_ensembles = len(normalise_number(ensembles["number"]))
45+
number_list = normalise_number(ensembles["number"])
46+
n_numbers = len(number_list)
47+
48+
keys = ["param", "level", "valid_datetime", "date", "time", "step", "number"]
4649

4750
print(f"Retrieving ensemble data with {ensembles}")
48-
ensembles = load_source(**ensembles)
51+
ensembles = load_source(**ensembles).order_by(*keys)
4952
print(f"Retrieving center data with {center}")
50-
center = load_source(**center)
53+
center = load_source(**center).order_by(*keys)
5154
print(f"Retrieving mean data with {mean}")
52-
mean = load_source(**mean)
55+
mean = load_source(**mean).order_by(*keys)
5356

54-
assert len(mean) * n_ensembles == len(ensembles), (
57+
assert len(mean) * n_numbers == len(ensembles), (
5558
len(mean),
56-
n_ensembles,
59+
n_numbers,
5760
len(ensembles),
5861
)
59-
assert len(center) * n_ensembles == len(ensembles), (
62+
assert len(center) * n_numbers == len(ensembles), (
6063
len(center),
61-
n_ensembles,
64+
n_numbers,
6265
len(ensembles),
6366
)
6467

68+
# prepare output tmp file so we can read it back
6569
tmp = temp_file()
6670
path = tmp.path
6771
out = new_grib_output(path)
6872

69-
keys = ["param", "level", "valid_datetime", "number", "date", "time", "step"]
70-
71-
ensembles_coords = ensembles.unique_values(*keys)
72-
center_coords = center.unique_values(*keys)
73-
mean_coords = mean.unique_values(*keys)
74-
75-
for k in keys:
76-
if k == "number":
77-
assert len(mean_coords[k]) == 1
78-
assert len(center_coords[k]) == 1
79-
assert len(ensembles_coords[k]) == n_ensembles
80-
continue
81-
assert set(center_coords[k]) == set(ensembles_coords[k]), (
82-
k,
83-
center_coords[k],
84-
ensembles_coords[k],
85-
)
86-
assert set(center_coords[k]) == set(mean_coords[k]), (
87-
k,
88-
center_coords[k],
89-
mean_coords[k],
90-
)
91-
92-
for field in tqdm.tqdm(center):
73+
for i, field in tqdm.tqdm(enumerate(ensembles)):
9374
param = field.metadata("param")
94-
grid = field.metadata("grid")
95-
96-
selection = dict(
97-
valid_datetime=field.metadata("valid_datetime"),
98-
param=field.metadata("param"),
99-
level=field.metadata("level"),
100-
date=field.metadata("date"),
101-
time=field.metadata("time"),
102-
step=field.metadata("step"),
103-
)
104-
mean_field = get_unique_field(mean, selection)
105-
assert mean_field.metadata("grid") == grid, (mean_field.metadata("grid"), grid)
75+
number = field.metadata("number")
76+
ii = i // n_numbers
77+
78+
i_number = number_list.index(number)
79+
assert i == ii * n_numbers + i_number, (i, ii, n_numbers, i_number, number_list)
80+
81+
center_field = center[ii]
82+
mean_field = mean[ii]
83+
84+
for k in keys + ["grid", "shape"]:
85+
if k == "number":
86+
continue
87+
assert center_field.metadata(k) == field.metadata(k), (
88+
k,
89+
center_field.metadata(k),
90+
field.metadata(k),
91+
)
92+
assert mean_field.metadata(k) == field.metadata(k), (
93+
k,
94+
mean_field.metadata(k),
95+
field.metadata(k),
96+
)
10697

98+
e = field.to_numpy()
10799
m = mean_field.to_numpy()
108-
c = field.to_numpy()
100+
c = center_field.to_numpy()
109101
assert m.shape == c.shape, (m.shape, c.shape)
110102

111-
for number in ensembles_coords["number"]:
112-
ensembles_field = get_unique_field(ensembles.sel(number=number), selection)
113-
assert ensembles_field.metadata("grid") == grid, (
114-
ensembles_field.metadata("grid"),
115-
grid,
116-
)
117-
118-
e = ensembles_field.to_numpy()
119-
assert c.shape == e.shape, (c.shape, e.shape)
120-
121-
x = c + m - e
122-
if param == "q":
123-
warnings.warn("Clipping q")
124-
x = np.maximum(x, 0)
103+
#################################
104+
# Actual computation happens here
105+
x = c + m - e
106+
if param == "q":
107+
warnings.warn("Clipping q")
108+
x = np.maximum(x, 0)
109+
#################################
125110

126-
assert x.shape == c.shape, (x.shape, c.shape)
111+
assert x.shape == e.shape, (x.shape, e.shape)
127112

128-
check_data_values(x, name=param)
129-
out.write(x, template=ensembles_field)
113+
check_data_values(x, name=param)
114+
out.write(x, template=field)
130115

131116
out.close()
132117

133118
ds = load_source("file", path)
134-
assert len(ds) == len(ensembles), (len(ds), len(ensembles))
119+
# save a reference to the tmp file so it is deleted
120+
# only when the dataset is not used anymore
135121
ds._tmp = tmp
136122

137-
assert len(mean) * n_ensembles == len(ensembles)
138-
assert len(center) * n_ensembles == len(ensembles)
123+
assert len(ds) == len(ensembles), (len(ds), len(ensembles))
139124

140-
final_coords = ds.unique_values(*keys)
141-
assert len(final_coords["number"]) == n_ensembles, final_coords
142125
return ds
143126

144127

File renamed without changes.

0 commit comments

Comments
 (0)