Skip to content

Commit 5dd88f8

Browse files
committed
feat: bugfix mismatched device arrays
1 parent 2fa888e commit 5dd88f8

File tree

5 files changed

+14
-19
lines changed

5 files changed

+14
-19
lines changed

src/earthkit/meteo/extreme/array/cpf.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,14 @@ def _cpf(clim, ens, epsilon=None, from_zero=False):
1414
xp = array_namespace(clim, ens)
1515
clim = xp.asarray(clim)
1616
ens = xp.asarray(ens)
17+
device = xp.device(clim)
1718

1819
nclim, npoints = clim.shape
1920
nens, _ = ens.shape
2021

21-
cpf = xp.zeros(npoints, dtype=xp.float32)
22-
mask = xp.zeros(npoints, dtype=xp.bool)
23-
prim = xp.zeros(npoints, dtype=xp.bool)
22+
cpf = xp.zeros(npoints, dtype=xp.float32, device=device)
23+
mask = xp.zeros(npoints, dtype=xp.bool, device=device)
24+
prim = xp.zeros(npoints, dtype=xp.bool, device=device)
2425

2526
# start scanning ensemble from iq_start
2627
iq_start = 0 if from_zero else nens // 2

src/earthkit/meteo/extreme/array/efi.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ def efi(clim, ens, eps=-0.1):
3434
xp = array_namespace(clim, ens)
3535
clim = xp.asarray(clim)
3636
ens = xp.asarray(ens)
37+
device = xp.device(clim)
3738

3839
# locate missing values
3940
missing_mask = xp.logical_or(xp.sum(xp.isnan(clim), axis=0), xp.sum(xp.isnan(ens), axis=0))
@@ -60,10 +61,10 @@ def efi(clim, ens, eps=-0.1):
6061
acoef = (1.0 - 2.0 * p[:-1]) * acosdiff + proddiff
6162

6263
# compute EFI from coefficients
63-
efi = xp.zeros(npoints)
64+
efi = xp.zeros(npoints, device=device)
6465
##################################
6566
if eps > 0:
66-
efimax = xp.zeros(npoints)
67+
efimax = xp.zeros(npoints, device=device)
6768
for icl in range(nclim - 1):
6869
mask = clim[icl + 1, :] > eps
6970
dEFI = xp.where(

src/earthkit/meteo/extreme/array/sot.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -39,17 +39,9 @@ def sot_func(qc_tail, qc, qf, eps=-1e-4, lower_bound=-10, upper_bound=10):
3939
qc = xp.asarray(qc)
4040
qf = xp.asarray(qf)
4141

42-
# TODO: check if this is necessary
43-
# NOTE: work for numpy but not for other backends
44-
# avoid divided by zero warning
45-
err = xp.seterr(divide="ignore", invalid="ignore")
46-
4742
min_den = xp.fmax(xp.asarray(eps), xp.asarray(0))
4843
sot = xp.where(xp.abs(qc_tail - qc) > min_den, (qf - qc_tail) / (qc_tail - qc), xp.nan)
4944

50-
# revert to original error state
51-
xp.seterr(**err)
52-
5345
mask_missing = xp.isnan(sot)
5446

5547
# upper and lower bounds

src/earthkit/meteo/stats/array/quantiles.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,12 +49,13 @@ def iter_quantiles(
4949

5050
xp = array_namespace(arr)
5151
arr = xp.asarray(arr)
52+
device = xp.device(arr)
5253

5354
if isinstance(which, int):
5455
n = which
55-
qs = xp.linspace(0.0, 1.0, n + 1)
56+
qs = xp.linspace(0.0, 1.0, n + 1, device=device)
5657
else:
57-
qs = xp.asarray(which)
58+
qs = xp.asarray(which, device=device)
5859

5960
if method == "numpy_bulk":
6061
quantiles = xp.quantile(arr, qs, axis=axis)

src/earthkit/meteo/wind/array/wind.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -303,12 +303,13 @@ def windrose(speed, direction, sectors=16, speed_bins=None, percent=True):
303303
# TODO: atleast_1d is not part of the array API standard
304304
speed = xp.atleast_1d(speed)
305305
direction = xp.atleast_1d(direction)
306+
device = xp.device(speed)
306307

307308
dir_step = 360.0 / sectors
308-
dir_bins = xp.asarray(
309-
xp.linspace(int(-dir_step / 2), int(360 + dir_step / 2), int(360 / dir_step) + 2), dtype=speed.dtype
309+
dir_bins = xp.linspace(
310+
int(-dir_step / 2), int(360 + dir_step / 2), int(360 / dir_step) + 2, dtype=speed.dtype, device=device
310311
)
311-
speed_bins = xp.asarray(speed_bins, dtype=speed.dtype)
312+
speed_bins = xp.asarray(speed_bins, dtype=speed.dtype, device=device)
312313

313314
# NOTE: np.histogram2d is only available in numpy. For other namespaces we use a fallback implementation
314315
# based on histogramdd. (See utils.compute.histogram2d). However, neither histogram2d nor
@@ -317,7 +318,6 @@ def windrose(speed, direction, sectors=16, speed_bins=None, percent=True):
317318
speed,
318319
direction,
319320
bins=[speed_bins, dir_bins],
320-
density=False,
321321
)[0]
322322

323323
# unify the north bins

0 commit comments

Comments
 (0)