1212
1313import numpy as np
1414import pytest
15+ from earthkit .utils .array .testing import NAMESPACE_DEVICES
1516
1617from earthkit .meteo import extreme
17- from earthkit .meteo .utils .testing import ARRAY_BACKENDS
1818
1919here = os .path .dirname (__file__ )
2020sys .path .insert (0 , here )
2121import _cpf # noqa
2222import _data # noqa
2323
2424
25- @pytest .mark .parametrize ("array_backend " , ARRAY_BACKENDS )
25+ @pytest .mark .parametrize ("xp, device " , NAMESPACE_DEVICES )
2626@pytest .mark .parametrize ("clim,ens,v_ref" , [(_data .clim , _data .ens , [- 0.1838425040642013 ])])
27- def test_highlevel_efi (clim , ens , v_ref , array_backend ):
28- clim , ens , v_ref = array_backend .asarray (clim , ens , v_ref )
27+ def test_highlevel_efi (xp , device , clim , ens , v_ref ):
28+ clim = xp .asarray (clim , device = device )
29+ ens = xp .asarray (ens , device = device )
30+ v_ref = xp .asarray (v_ref , device = device )
31+ # clim, ens, v_ref = array_backend.asarray(clim, ens, v_ref)
2932 efi = extreme .efi (clim , ens )
30- assert array_backend .isclose (efi [0 ], v_ref [0 ])
33+ assert xp .isclose (efi [0 ], v_ref [0 ])
3134
3235
33- @pytest .mark .parametrize ("array_backend " , ARRAY_BACKENDS )
36+ @pytest .mark .parametrize ("xp, device " , NAMESPACE_DEVICES )
3437@pytest .mark .parametrize (
3538 "clim,ens,kwargs,v_ref" ,
3639 [
@@ -45,32 +48,34 @@ def test_highlevel_efi(clim, ens, v_ref, array_backend):
4548 (_data .clim_eps2 , _data .ens_eps2 , dict (eps = 1e-4 ), 0.6330071575726789 ),
4649 ],
4750)
48- def test_efi_core (clim , ens , kwargs , v_ref , array_backend ):
49- clim , ens , v_ref = array_backend .asarray (clim , ens , v_ref )
51+ def test_efi_core (xp , device , clim , ens , kwargs , v_ref ):
52+ clim = xp .asarray (clim , device = device )
53+ ens = xp .asarray (ens , device = device )
54+ v_ref = xp .asarray (v_ref , device = device )
5055 efi = extreme .array .efi (clim , ens , ** kwargs )
51- assert array_backend .allclose (efi [0 ], v_ref , rtol = 1e-4 )
56+ assert xp .allclose (efi [0 ], v_ref , rtol = 1e-4 )
5257
5358
54- @pytest .mark .parametrize ("array_backend " , ARRAY_BACKENDS )
59+ @pytest .mark .parametrize ("xp, device " , NAMESPACE_DEVICES )
5560@pytest .mark .parametrize ("clim,ens,v_ref" , [(_data .clim , _data .ens , - 0.18384250406420133 )])
56- def test_efi_sorted (clim , ens , v_ref , array_backend ):
57- clim , ens , v_ref = array_backend .asarray (clim , ens , v_ref )
61+ def test_efi_sorted (xp , device , clim , ens , v_ref ):
62+ clim = xp .asarray (clim , device = device )
63+ ens = xp .asarray (ens , device = device )
64+ v_ref = xp .asarray (v_ref , device = device )
5865
5966 # ensures the algorithm is the same if we sort the data or not
60- ens_perc = array_backend . namespace .sort (ens )
67+ ens_perc = xp .sort (ens )
6168
6269 efi = extreme .array .efi (clim , ens_perc )
6370
64- assert array_backend .isclose (efi [0 ], v_ref )
71+ assert xp .isclose (efi [0 ], v_ref )
6572
6673
67- @pytest .mark .parametrize ("array_backend" , ARRAY_BACKENDS )
68- def test_efi_nan (array_backend ):
69- xp = array_backend .namespace
70-
71- clim_nan = xp .empty ((101 , 1 ))
74+ @pytest .mark .parametrize ("xp, device" , NAMESPACE_DEVICES )
75+ def test_efi_nan (xp , device ):
76+ clim_nan = xp .empty ((101 , 1 ), device = device )
7277 clim_nan [:] = xp .nan
73- ens_nan = xp .empty ((51 , 1 ))
78+ ens_nan = xp .empty ((51 , 1 ), device = device )
7479 ens_nan [:] = xp .nan
7580 # print(clim_nan)
7681 # print(ens_nan)
@@ -81,57 +86,59 @@ def test_efi_nan(array_backend):
8186 assert xp .isnan (efi [0 ])
8287
8388
84- @pytest .mark .parametrize ("array_backend " , ARRAY_BACKENDS )
89+ @pytest .mark .parametrize ("xp, device " , NAMESPACE_DEVICES )
8590@pytest .mark .parametrize ("clim,ens,v_ref" , [(_data .clim , _data .ens , [- 2.14617638 , - 1.3086723 ])])
86- def test_sot_highlevel (clim , ens , v_ref , array_backend ):
87- clim , ens , v_ref = array_backend .asarray (clim , ens , v_ref )
91+ def test_sot_highlevel (xp , device , clim , ens , v_ref ):
92+ clim = xp .asarray (clim , device = device )
93+ ens = xp .asarray (ens , device = device )
94+ v_ref = xp .asarray (v_ref , device = device )
8895
8996 sot_upper = extreme .sot (clim , ens , 90 )
9097 sot_lower = extreme .sot (clim , ens , 10 )
9198
92- v_ref = array_backend .asarray (v_ref , dtype = sot_upper .dtype )
99+ v_ref = xp .asarray (v_ref , dtype = sot_upper .dtype )
93100
94- assert array_backend .allclose (sot_upper [0 ], v_ref [0 ])
95- assert array_backend .allclose (sot_lower [0 ], v_ref [1 ])
101+ assert xp .allclose (sot_upper [0 ], v_ref [0 ])
102+ assert xp .allclose (sot_lower [0 ], v_ref [1 ])
96103
97104
98- @pytest .mark .parametrize ("array_backend" , ARRAY_BACKENDS )
99- # @pytest.mark.parametrize("array_backend", get_array_backend(["numpy"]))
105+ @pytest .mark .parametrize ("xp, device" , NAMESPACE_DEVICES )
100106@pytest .mark .parametrize (
101107 "clim,ens,v_ref" ,
102108 [
103109 (_data .clim , _data .ens , [- 2.14617638 , - 1.3086723 ]),
104110 ],
105111)
106- def test_sot_core (clim , ens , v_ref , array_backend ):
107- clim , ens , v_ref = array_backend .asarray (clim , ens , v_ref )
112+ def test_sot_core (xp , device , clim , ens , v_ref ):
113+ clim = xp .asarray (clim , device = device )
114+ ens = xp .asarray (ens , device = device )
115+ v_ref = xp .asarray (v_ref , device = device )
108116
109117 sot_upper = extreme .array .sot (clim , ens , 90 )
110118 sot_lower = extreme .array .sot (clim , ens , 10 )
111119
112- v_ref = array_backend .asarray (v_ref , dtype = sot_upper .dtype )
113-
114- # print(sot_upper)
115- # print(sot_lower)
120+ v_ref = xp .asarray (v_ref , dtype = sot_upper .dtype )
116121
117- assert array_backend .allclose (sot_upper [0 ], v_ref [0 ])
118- assert array_backend .allclose (sot_lower [0 ], v_ref [1 ])
122+ assert xp .allclose (sot_upper [0 ], v_ref [0 ])
123+ assert xp .allclose (sot_lower [0 ], v_ref [1 ])
119124
120125
121- @pytest .mark .parametrize ("array_backend " , ARRAY_BACKENDS )
126+ @pytest .mark .parametrize ("xp, device " , NAMESPACE_DEVICES )
122127# @pytest.mark.parametrize("array_backend", get_array_backend(["numpy"]))
123128@pytest .mark .parametrize ("clim,ens,v_ref" , [(_data .clim_eps2 , _data .ens_eps2 , [np .nan ])])
124- def test_sot_perc (clim , ens , v_ref , array_backend ):
125- clim , ens , v_ref = array_backend .asarray (clim , ens , v_ref )
129+ def test_sot_perc (xp , device , clim , ens , v_ref ):
130+ clim = xp .asarray (clim , device = device )
131+ ens = xp .asarray (ens , device = device )
132+ v_ref = xp .asarray (v_ref , device = device )
126133
127134 sot = extreme .array .sot (clim , ens , 90 , eps = 1e4 )
128135
129- v_ref = array_backend .asarray (v_ref , dtype = sot .dtype )
136+ v_ref = xp .asarray (v_ref , dtype = sot .dtype )
130137
131- assert array_backend .allclose (sot [0 ], v_ref [0 ], equal_nan = True )
138+ assert xp .allclose (sot [0 ], v_ref [0 ], equal_nan = True )
132139
133140
134- @pytest .mark .parametrize ("array_backend " , ARRAY_BACKENDS )
141+ @pytest .mark .parametrize ("xp, device " , NAMESPACE_DEVICES )
135142# @pytest.mark.parametrize("array_backend", get_array_backend(["numpy"]))
136143@pytest .mark .parametrize (
137144 "qc_tail,qc,qf,kwargs,v_ref" ,
@@ -163,27 +170,32 @@ def test_sot_perc(clim, ens, v_ref, array_backend):
163170 ([np .nan ], [0.1 ], [0.2 ], {}, [np .nan ]), # nan
164171 ],
165172)
166- def test_sot_func (qc_tail , qc , qf , kwargs , v_ref , array_backend ):
167- qc_tail , qc , qf , v_ref = array_backend .asarray (qc_tail , qc , qf , v_ref )
173+ def test_sot_func (xp , device , qc_tail , qc , qf , kwargs , v_ref ):
174+ qc_tail = xp .asarray (qc_tail , device = device )
175+ qc = xp .asarray (qc , device = device )
176+ qf = xp .asarray (qf , device = device )
177+ v_ref = xp .asarray (v_ref , device = device )
168178
169179 sot = extreme .array .sot_func (qc_tail , qc , qf , ** kwargs )
170180
171- v_ref = array_backend .asarray (v_ref , dtype = sot .dtype )
181+ v_ref = xp .asarray (v_ref , dtype = sot .dtype )
172182
173- assert array_backend .allclose (sot , v_ref , equal_nan = True )
183+ assert xp .allclose (sot , v_ref , equal_nan = True )
174184
175185
176- @pytest .mark .parametrize ("array_backend " , ARRAY_BACKENDS )
186+ @pytest .mark .parametrize ("xp, device " , NAMESPACE_DEVICES )
177187@pytest .mark .parametrize ("clim,ens,v_ref" , [(_cpf .cpf_clim , _cpf .cpf_ens , _cpf .cpf_val )])
178- def test_cpf_highlevel (clim , ens , v_ref , array_backend ):
179- clim , ens , v_ref = array_backend .asarray (clim , ens , v_ref )
188+ def test_cpf_highlevel (xp , device , clim , ens , v_ref ):
189+ clim = xp .asarray (clim , device = device )
190+ ens = xp .asarray (ens , device = device )
191+ v_ref = xp .asarray (v_ref , device = device )
180192
181193 cpf = extreme .cpf (clim , ens , sort_clim = True )
182194
183- assert array_backend .allclose (cpf , v_ref )
195+ assert xp .allclose (cpf , v_ref )
184196
185197
186- @pytest .mark .parametrize ("array_backend " , ARRAY_BACKENDS )
198+ @pytest .mark .parametrize ("xp, device " , NAMESPACE_DEVICES )
187199@pytest .mark .parametrize (
188200 "clim,ens,kwargs,v_ref" ,
189201 [
@@ -193,8 +205,10 @@ def test_cpf_highlevel(clim, ens, v_ref, array_backend):
193205 (_cpf .cpf_clim , _cpf .cpf_ens , dict (sort_clim = True , from_zero = True ), _cpf .cpf_val_fromzero ),
194206 ],
195207)
196- def test_cpf_core (clim , ens , v_ref , kwargs , array_backend ):
197- clim , ens , v_ref = array_backend .asarray (clim , ens , v_ref )
208+ def test_cpf_core (xp , device , clim , ens , v_ref , kwargs ):
209+ clim = xp .asarray (clim , device = device )
210+ ens = xp .asarray (ens , device = device )
211+ v_ref = xp .asarray (v_ref , device = device )
198212
199213 cpf = extreme .array .cpf (clim , ens , ** kwargs )
200- assert array_backend .allclose (cpf , v_ref )
214+ assert xp .allclose (cpf , v_ref )
0 commit comments