Skip to content

Commit defa84e

Browse files
committed
fix: bugfix device handling in thermo
1 parent 65415f7 commit defa84e

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

src/earthkit/meteo/thermo/array/es_comp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ def _es_mixed(t, xp):
148148
# svp = alpha * es_water + (1.0 - alpha) * es_ice
149149

150150
t = xp.asarray(t)
151-
svp = xp.zeros(t.shape, dtype=t.dtype)
151+
svp = xp.zeros(t.shape, dtype=t.dtype, device=xp.device(t))
152152

153153
# ice range
154154
i_mask = t <= TI
@@ -176,7 +176,7 @@ def _es_ice_slope(t, xp):
176176

177177
def _es_mixed_slope(t, xp):
178178
t = xp.asarray(t)
179-
d_svp = xp.zeros(t.shape, dtype=t.dtype)
179+
d_svp = xp.zeros(t.shape, dtype=t.dtype, device=xp.device(t))
180180

181181
# ice range
182182
i_mask = t <= TI

src/earthkit/meteo/thermo/array/thermo.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1058,7 +1058,7 @@ def compute_t_on_ma_stipanuk(self, ept, p):
10581058
p = xp.asarray(p)
10591059

10601060
size = xp.size(ept) if xp.size(ept) > xp.size(p) else xp.size(p)
1061-
t = xp.full(size, constants.T0 - 20, dtype=ept.dtype)
1061+
t = xp.full(size, constants.T0 - 20, dtype=ept.dtype, device=xp.device(ept))
10621062

10631063
# if isinstance(p, np.ndarray):
10641064
# t = np.full(p.shape, constants.T0 - 20)
@@ -1083,7 +1083,7 @@ def compute_t_on_ma_davies(self, ept, p):
10831083
ept = xp.asarray(ept)
10841084
p = xp.asarray(p)
10851085
if xp.size(ept) > xp.size(p):
1086-
p = xp.full(xp.size(ept), p, dtype=ept.dtype)
1086+
p = xp.full(xp.size(ept), p, dtype=ept.dtype, device=xp.device(ept))
10871087

10881088
# p = xp.full(ept.shape, p, dtype=ept.dtype)
10891089

0 commit comments

Comments
 (0)