Skip to content
This repository was archived by the owner on Jan 10, 2025. It is now read-only.

Commit 2f1c664

Browse files
committed
fix and more asserts
1 parent 8be7633 commit 2f1c664

File tree

1 file changed

+13
-2
lines changed

1 file changed

+13
-2
lines changed

ecml_tools/create/functions/ensemble_perturbations.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
# nor does it submit to any jurisdiction.
88
#
99

10+
import warnings
1011
import climetlab as cml
1112
import numpy as np
1213
import tqdm
@@ -42,8 +43,11 @@ def normalise_number(number):
4243
def ensembles_perturbations(ensembles, center, mean, remapping={}, patches={}):
4344
n_ensembles = len(normalise_number(ensembles["number"]))
4445

46+
print(f"Retrieving ensemble data with {ensembles}")
4547
ensembles = cml.load_source(**ensembles)
48+
print(f"Retrieving center data with {center}")
4649
center = cml.load_source(**center)
50+
print(f"Retrieving mean data with {mean}")
4751
mean = cml.load_source(**mean)
4852

4953
assert len(mean) * n_ensembles == len(ensembles), (
@@ -82,6 +86,8 @@ def ensembles_perturbations(ensembles, center, mean, remapping={}, patches={}):
8286

8387
for field in tqdm.tqdm(center):
8488
param = field.metadata("param")
89+
grid = field.metadata("grid")
90+
8591
selection = dict(
8692
valid_datetime=field.metadata("valid_datetime"),
8793
param=field.metadata("param"),
@@ -91,20 +97,25 @@ def ensembles_perturbations(ensembles, center, mean, remapping={}, patches={}):
9197
step=field.metadata("step"),
9298
)
9399
mean_field = get_unique_field(mean, selection)
100+
assert mean_field.metadata("grid") == grid, (mean_field.metadata("grid"), grid)
94101

95102
m = mean_field.to_numpy()
96103
c = field.to_numpy()
97104
assert m.shape == c.shape, (m.shape, c.shape)
98105

99106
for number in ensembles_coords["number"]:
100107
ensembles_field = get_unique_field(ensembles.sel(number=number), selection)
108+
assert ensembles_field.metadata("grid") == grid, (ensembles_field.metadata("grid"), grid)
109+
101110
e = ensembles_field.to_numpy()
102111
assert c.shape == e.shape, (c.shape, e.shape)
103112

104113
x = c + m - e
105114
if param == "q":
106-
print("Clipping q")
107-
x = np.max(x, 0)
115+
warnings.warn("Clipping q")
116+
x = np.maximum(x, 0)
117+
118+
assert x.shape == c.shape, (x.shape, c.shape)
108119

109120
check_data_values(x, name=param)
110121
out.write(x, template=ensembles_field)

0 commit comments

Comments
 (0)