Skip to content

Commit 2fa888e

Browse files
committed
feat: refactor tests to test device and use new eku format
1 parent ffa6fcc commit 2fa888e

File tree

9 files changed

+489
-402
lines changed

9 files changed

+489
-402
lines changed

tests/extreme/_data.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@
5555
273.621521,
5656
]
5757
)
58-
ens = ens[:, np.newaxis]
58+
ens = ens[:, np.newaxis].astype(np.float32)
5959

6060
ens_eps = np.array(
6161
[
@@ -112,7 +112,7 @@
112112
0.005218505859375,
113113
]
114114
)
115-
ens_eps = ens_eps[:, np.newaxis]
115+
ens_eps = ens_eps[:, np.newaxis].astype(np.float32)
116116

117117
clim = np.array(
118118
[
@@ -219,7 +219,7 @@
219219
274.55859375,
220220
]
221221
)
222-
clim = clim[:, np.newaxis]
222+
clim = clim[:, np.newaxis].astype(np.float32)
223223

224224
clim_eps = np.array(
225225
[
@@ -326,7 +326,7 @@
326326
0.02294921875,
327327
]
328328
)
329-
clim_eps = clim_eps[:, np.newaxis]
329+
clim_eps = clim_eps[:, np.newaxis].astype(np.float32)
330330

331331
ens_eps2 = np.array(
332332
[
@@ -383,7 +383,7 @@
383383
1.5535354614257812,
384384
]
385385
)
386-
ens_eps2 = ens_eps2[:, np.newaxis]
386+
ens_eps2 = ens_eps2[:, np.newaxis].astype(np.float32)
387387

388388
clim_eps2 = np.array(
389389
[
@@ -490,4 +490,4 @@
490490
0.48828125,
491491
]
492492
)
493-
clim_eps2 = clim_eps2[:, np.newaxis]
493+
clim_eps2 = clim_eps2[:, np.newaxis].astype(np.float32)

tests/extreme/test_extreme.py

Lines changed: 68 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -12,25 +12,28 @@
1212

1313
import numpy as np
1414
import pytest
15+
from earthkit.utils.array.testing import NAMESPACE_DEVICES
1516

1617
from earthkit.meteo import extreme
17-
from earthkit.meteo.utils.testing import ARRAY_BACKENDS
1818

1919
here = os.path.dirname(__file__)
2020
sys.path.insert(0, here)
2121
import _cpf # noqa
2222
import _data # noqa
2323

2424

25-
@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS)
25+
@pytest.mark.parametrize("xp, device", NAMESPACE_DEVICES)
2626
@pytest.mark.parametrize("clim,ens,v_ref", [(_data.clim, _data.ens, [-0.1838425040642013])])
27-
def test_highlevel_efi(clim, ens, v_ref, array_backend):
28-
clim, ens, v_ref = array_backend.asarray(clim, ens, v_ref)
27+
def test_highlevel_efi(xp, device, clim, ens, v_ref):
28+
clim = xp.asarray(clim, device=device)
29+
ens = xp.asarray(ens, device=device)
30+
v_ref = xp.asarray(v_ref, device=device)
31+
# clim, ens, v_ref = array_backend.asarray(clim, ens, v_ref)
2932
efi = extreme.efi(clim, ens)
30-
assert array_backend.isclose(efi[0], v_ref[0])
33+
assert xp.isclose(efi[0], v_ref[0])
3134

3235

33-
@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS)
36+
@pytest.mark.parametrize("xp, device", NAMESPACE_DEVICES)
3437
@pytest.mark.parametrize(
3538
"clim,ens,kwargs,v_ref",
3639
[
@@ -45,32 +48,34 @@ def test_highlevel_efi(clim, ens, v_ref, array_backend):
4548
(_data.clim_eps2, _data.ens_eps2, dict(eps=1e-4), 0.6330071575726789),
4649
],
4750
)
48-
def test_efi_core(clim, ens, kwargs, v_ref, array_backend):
49-
clim, ens, v_ref = array_backend.asarray(clim, ens, v_ref)
51+
def test_efi_core(xp, device, clim, ens, kwargs, v_ref):
52+
clim = xp.asarray(clim, device=device)
53+
ens = xp.asarray(ens, device=device)
54+
v_ref = xp.asarray(v_ref, device=device)
5055
efi = extreme.array.efi(clim, ens, **kwargs)
51-
assert array_backend.allclose(efi[0], v_ref, rtol=1e-4)
56+
assert xp.allclose(efi[0], v_ref, rtol=1e-4)
5257

5358

54-
@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS)
59+
@pytest.mark.parametrize("xp, device", NAMESPACE_DEVICES)
5560
@pytest.mark.parametrize("clim,ens,v_ref", [(_data.clim, _data.ens, -0.18384250406420133)])
56-
def test_efi_sorted(clim, ens, v_ref, array_backend):
57-
clim, ens, v_ref = array_backend.asarray(clim, ens, v_ref)
61+
def test_efi_sorted(xp, device, clim, ens, v_ref):
62+
clim = xp.asarray(clim, device=device)
63+
ens = xp.asarray(ens, device=device)
64+
v_ref = xp.asarray(v_ref, device=device)
5865

5966
# ensures the algorithm is the same if we sort the data or not
60-
ens_perc = array_backend.namespace.sort(ens)
67+
ens_perc = xp.sort(ens)
6168

6269
efi = extreme.array.efi(clim, ens_perc)
6370

64-
assert array_backend.isclose(efi[0], v_ref)
71+
assert xp.isclose(efi[0], v_ref)
6572

6673

67-
@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS)
68-
def test_efi_nan(array_backend):
69-
xp = array_backend.namespace
70-
71-
clim_nan = xp.empty((101, 1))
74+
@pytest.mark.parametrize("xp, device", NAMESPACE_DEVICES)
75+
def test_efi_nan(xp, device):
76+
clim_nan = xp.empty((101, 1), device=device)
7277
clim_nan[:] = xp.nan
73-
ens_nan = xp.empty((51, 1))
78+
ens_nan = xp.empty((51, 1), device=device)
7479
ens_nan[:] = xp.nan
7580
# print(clim_nan)
7681
# print(ens_nan)
@@ -81,57 +86,59 @@ def test_efi_nan(array_backend):
8186
assert xp.isnan(efi[0])
8287

8388

84-
@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS)
89+
@pytest.mark.parametrize("xp, device", NAMESPACE_DEVICES)
8590
@pytest.mark.parametrize("clim,ens,v_ref", [(_data.clim, _data.ens, [-2.14617638, -1.3086723])])
86-
def test_sot_highlevel(clim, ens, v_ref, array_backend):
87-
clim, ens, v_ref = array_backend.asarray(clim, ens, v_ref)
91+
def test_sot_highlevel(xp, device, clim, ens, v_ref):
92+
clim = xp.asarray(clim, device=device)
93+
ens = xp.asarray(ens, device=device)
94+
v_ref = xp.asarray(v_ref, device=device)
8895

8996
sot_upper = extreme.sot(clim, ens, 90)
9097
sot_lower = extreme.sot(clim, ens, 10)
9198

92-
v_ref = array_backend.asarray(v_ref, dtype=sot_upper.dtype)
99+
v_ref = xp.asarray(v_ref, dtype=sot_upper.dtype)
93100

94-
assert array_backend.allclose(sot_upper[0], v_ref[0])
95-
assert array_backend.allclose(sot_lower[0], v_ref[1])
101+
assert xp.allclose(sot_upper[0], v_ref[0])
102+
assert xp.allclose(sot_lower[0], v_ref[1])
96103

97104

98-
@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS)
99-
# @pytest.mark.parametrize("array_backend", get_array_backend(["numpy"]))
105+
@pytest.mark.parametrize("xp, device", NAMESPACE_DEVICES)
100106
@pytest.mark.parametrize(
101107
"clim,ens,v_ref",
102108
[
103109
(_data.clim, _data.ens, [-2.14617638, -1.3086723]),
104110
],
105111
)
106-
def test_sot_core(clim, ens, v_ref, array_backend):
107-
clim, ens, v_ref = array_backend.asarray(clim, ens, v_ref)
112+
def test_sot_core(xp, device, clim, ens, v_ref):
113+
clim = xp.asarray(clim, device=device)
114+
ens = xp.asarray(ens, device=device)
115+
v_ref = xp.asarray(v_ref, device=device)
108116

109117
sot_upper = extreme.array.sot(clim, ens, 90)
110118
sot_lower = extreme.array.sot(clim, ens, 10)
111119

112-
v_ref = array_backend.asarray(v_ref, dtype=sot_upper.dtype)
113-
114-
# print(sot_upper)
115-
# print(sot_lower)
120+
v_ref = xp.asarray(v_ref, dtype=sot_upper.dtype)
116121

117-
assert array_backend.allclose(sot_upper[0], v_ref[0])
118-
assert array_backend.allclose(sot_lower[0], v_ref[1])
122+
assert xp.allclose(sot_upper[0], v_ref[0])
123+
assert xp.allclose(sot_lower[0], v_ref[1])
119124

120125

121-
@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS)
126+
@pytest.mark.parametrize("xp, device", NAMESPACE_DEVICES)
122127
# @pytest.mark.parametrize("array_backend", get_array_backend(["numpy"]))
123128
@pytest.mark.parametrize("clim,ens,v_ref", [(_data.clim_eps2, _data.ens_eps2, [np.nan])])
124-
def test_sot_perc(clim, ens, v_ref, array_backend):
125-
clim, ens, v_ref = array_backend.asarray(clim, ens, v_ref)
129+
def test_sot_perc(xp, device, clim, ens, v_ref):
130+
clim = xp.asarray(clim, device=device)
131+
ens = xp.asarray(ens, device=device)
132+
v_ref = xp.asarray(v_ref, device=device)
126133

127134
sot = extreme.array.sot(clim, ens, 90, eps=1e4)
128135

129-
v_ref = array_backend.asarray(v_ref, dtype=sot.dtype)
136+
v_ref = xp.asarray(v_ref, dtype=sot.dtype)
130137

131-
assert array_backend.allclose(sot[0], v_ref[0], equal_nan=True)
138+
assert xp.allclose(sot[0], v_ref[0], equal_nan=True)
132139

133140

134-
@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS)
141+
@pytest.mark.parametrize("xp, device", NAMESPACE_DEVICES)
135142
# @pytest.mark.parametrize("array_backend", get_array_backend(["numpy"]))
136143
@pytest.mark.parametrize(
137144
"qc_tail,qc,qf,kwargs,v_ref",
@@ -163,27 +170,32 @@ def test_sot_perc(clim, ens, v_ref, array_backend):
163170
([np.nan], [0.1], [0.2], {}, [np.nan]), # nan
164171
],
165172
)
166-
def test_sot_func(qc_tail, qc, qf, kwargs, v_ref, array_backend):
167-
qc_tail, qc, qf, v_ref = array_backend.asarray(qc_tail, qc, qf, v_ref)
173+
def test_sot_func(xp, device, qc_tail, qc, qf, kwargs, v_ref):
174+
qc_tail = xp.asarray(qc_tail, device=device)
175+
qc = xp.asarray(qc, device=device)
176+
qf = xp.asarray(qf, device=device)
177+
v_ref = xp.asarray(v_ref, device=device)
168178

169179
sot = extreme.array.sot_func(qc_tail, qc, qf, **kwargs)
170180

171-
v_ref = array_backend.asarray(v_ref, dtype=sot.dtype)
181+
v_ref = xp.asarray(v_ref, dtype=sot.dtype)
172182

173-
assert array_backend.allclose(sot, v_ref, equal_nan=True)
183+
assert xp.allclose(sot, v_ref, equal_nan=True)
174184

175185

176-
@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS)
186+
@pytest.mark.parametrize("xp, device", NAMESPACE_DEVICES)
177187
@pytest.mark.parametrize("clim,ens,v_ref", [(_cpf.cpf_clim, _cpf.cpf_ens, _cpf.cpf_val)])
178-
def test_cpf_highlevel(clim, ens, v_ref, array_backend):
179-
clim, ens, v_ref = array_backend.asarray(clim, ens, v_ref)
188+
def test_cpf_highlevel(xp, device, clim, ens, v_ref):
189+
clim = xp.asarray(clim, device=device)
190+
ens = xp.asarray(ens, device=device)
191+
v_ref = xp.asarray(v_ref, device=device)
180192

181193
cpf = extreme.cpf(clim, ens, sort_clim=True)
182194

183-
assert array_backend.allclose(cpf, v_ref)
195+
assert xp.allclose(cpf, v_ref)
184196

185197

186-
@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS)
198+
@pytest.mark.parametrize("xp, device", NAMESPACE_DEVICES)
187199
@pytest.mark.parametrize(
188200
"clim,ens,kwargs,v_ref",
189201
[
@@ -193,8 +205,10 @@ def test_cpf_highlevel(clim, ens, v_ref, array_backend):
193205
(_cpf.cpf_clim, _cpf.cpf_ens, dict(sort_clim=True, from_zero=True), _cpf.cpf_val_fromzero),
194206
],
195207
)
196-
def test_cpf_core(clim, ens, v_ref, kwargs, array_backend):
197-
clim, ens, v_ref = array_backend.asarray(clim, ens, v_ref)
208+
def test_cpf_core(xp, device, clim, ens, v_ref, kwargs):
209+
clim = xp.asarray(clim, device=device)
210+
ens = xp.asarray(ens, device=device)
211+
v_ref = xp.asarray(v_ref, device=device)
198212

199213
cpf = extreme.array.cpf(clim, ens, **kwargs)
200-
assert array_backend.allclose(cpf, v_ref)
214+
assert xp.allclose(cpf, v_ref)

tests/score/test_score.py

Lines changed: 30 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,10 @@
1212

1313
import numpy as np
1414
import pytest
15+
from earthkit.utils.array.namespace import _NUMPY_NAMESPACE
16+
from earthkit.utils.array.testing import NAMESPACE_DEVICES
1517

1618
from earthkit.meteo import score
17-
from earthkit.meteo.utils.testing import ARRAY_BACKENDS
18-
from earthkit.meteo.utils.testing import get_array_backend
1919

2020

2121
def crps_quaver2(x, y):
@@ -81,27 +81,29 @@ def _get_crps_data():
8181
return obs, ens, v_ref
8282

8383

84-
@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS)
84+
@pytest.mark.parametrize("xp, device", NAMESPACE_DEVICES)
8585
@pytest.mark.parametrize("nan_policy", ["raise", "propagate", "omit"])
8686
@pytest.mark.parametrize("obs,ens,v_ref", [_get_crps_data()])
87-
def test_crps_meteo(obs, ens, v_ref, array_backend, nan_policy):
88-
obs, ens, v_ref = array_backend.asarray(obs, ens, v_ref)
89-
xp = array_backend.namespace
87+
def test_crps_meteo(xp, device, obs, ens, v_ref, nan_policy):
88+
obs = xp.asarray(obs, device=device)
89+
ens = xp.asarray(ens, device=device)
90+
v_ref = xp.asarray(v_ref, device=device)
9091

9192
c = score.crps(ens.T, obs[0], nan_policy)
9293

9394
for i in range(ens.shape[0]):
94-
assert array_backend.isclose(c[i], v_ref[i]), f"i={i}"
95+
assert xp.isclose(c[i], v_ref[i]), f"i={i}"
9596

96-
assert array_backend.isclose(xp.mean(c), xp.mean(v_ref))
97+
assert xp.isclose(xp.mean(c), xp.mean(v_ref))
9798

9899

99-
@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS)
100+
@pytest.mark.parametrize("xp, device", NAMESPACE_DEVICES)
100101
@pytest.mark.parametrize("nan_policy", ["raise", "propagate", "omit"])
101102
@pytest.mark.parametrize("obs,ens,v_ref", [_get_crps_data()])
102-
def test_crps_meteo_missing(obs, ens, v_ref, array_backend, nan_policy):
103-
obs, ens, v_ref = array_backend.asarray(obs, ens, v_ref)
104-
xp = array_backend.namespace
103+
def test_crps_meteo_missing(xp, device, obs, ens, v_ref, nan_policy):
104+
obs = xp.asarray(obs, device=device)
105+
ens = xp.asarray(ens, device=device)
106+
v_ref = xp.asarray(v_ref, device=device)
105107

106108
ens = ens.T
107109
obs = obs[0]
@@ -119,32 +121,33 @@ def test_crps_meteo_missing(obs, ens, v_ref, array_backend, nan_policy):
119121

120122
if nan_policy == "omit":
121123
for i in range(c_all.shape[0]):
122-
assert array_backend.isclose(c_all[i], c_non_missing[i])
124+
assert xp.isclose(c_all[i], c_non_missing[i])
123125
elif nan_policy == "propagate":
124126
j = 0
125127
for i in range(c_all.shape[0]):
126128
if nan_mask[i]:
127129
assert xp.isnan(c_all[i])
128130
else:
129-
assert array_backend.isclose(c_all[i], c_non_missing[j])
131+
assert xp.isclose(c_all[i], c_non_missing[j])
130132
j += 1
131133

132134
non_missing_crps = c_all[~xp.isnan(c_all)]
133-
assert array_backend.isclose(xp.mean(non_missing_crps), xp.mean(c_non_missing))
135+
assert xp.isclose(xp.mean(non_missing_crps), xp.mean(c_non_missing))
134136

135137

136-
@pytest.mark.parametrize("array_backend", get_array_backend(["numpy"]))
138+
@pytest.mark.parametrize("xp", [_NUMPY_NAMESPACE])
137139
@pytest.mark.parametrize("obs,ens,v_ref", [_get_crps_data()])
138-
def test_crps_quaver2(obs, ens, v_ref, array_backend):
139-
obs, ens, v_ref = array_backend.asarray(obs, ens, v_ref)
140-
xp = array_backend.namespace
140+
def test_crps_quaver2(xp, obs, ens, v_ref):
141+
obs = xp.asarray(obs)
142+
ens = xp.asarray(ens)
143+
v_ref = xp.asarray(v_ref)
141144

142145
c = crps_quaver2(ens.T, obs[0])
143146

144147
for i in range(ens.shape[0]):
145-
assert array_backend.isclose(c[i], v_ref[i]), f"i={i}"
148+
assert xp.isclose(c[i], v_ref[i]), f"i={i}"
146149

147-
assert array_backend.isclose(xp.mean(c), xp.mean(v_ref))
150+
assert xp.isclose(xp.mean(c), xp.mean(v_ref))
148151

149152

150153
def _get_pearson_data():
@@ -175,10 +178,12 @@ def _get_pearson_data():
175178
return x.tolist(), y.tolist(), rs.tolist()
176179

177180

178-
@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS)
181+
@pytest.mark.parametrize("xp, device", NAMESPACE_DEVICES)
179182
@pytest.mark.parametrize("x, y, v_ref", [_get_pearson_data()])
180-
def test_pearson(x, y, v_ref, array_backend):
181-
x, y, v_ref = array_backend.asarray(x, y, v_ref)
183+
def test_pearson(xp, device, x, y, v_ref):
184+
x = xp.asarray(x, device=device)
185+
y = xp.asarray(y, device=device)
186+
v_ref = xp.asarray(v_ref, device=device)
182187

183188
r = score.pearson(x, y, axis=1)
184-
np.testing.assert_allclose(r, v_ref, atol=1e-7)
189+
assert xp.allclose(r, v_ref, atol=1e-7, equal_nan=True)

0 commit comments

Comments
 (0)