Skip to content

Commit c51739d

Browse files
committed
Fix bugs in new healpix routines
- zonal average, to_double_pixelization didn't work with torch - add zonal average(dim) argument
1 parent aefb107 commit c51739d

File tree

2 files changed

+67
-18
lines changed

2 files changed

+67
-18
lines changed

earth2grid/healpix.py

Lines changed: 43 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -608,15 +608,15 @@ def _pixels_to_rings(nside: int, p: ArrayT) -> ArrayT:
608608
npix = 12 * nside * nside
609609
ncap = 2 * nside * (nside - 1)
610610

611-
i_north = xp.floor(0.5 * (1 + np.sqrt(1 + 2 * p)))
611+
i_north = xp.floor(0.5 * (1 + xp.sqrt(1 + 2 * p)))
612612
j_north = p - 2 * (i_north - 1) * i_north
613613

614614
p_eq = p - ncap
615615
i_eq = xp.floor(p_eq / (4 * nside)) + nside - 1
616616
j_eq = p_eq % (4 * nside)
617617

618618
p_south = npix - p - 1
619-
i_south = xp.floor(0.5 * (1 + np.sqrt(1 + 2 * p_south)))
619+
i_south = xp.floor(0.5 * (1 + xp.sqrt(1 + 2 * p_south)))
620620
j_south = p_south - 2 * (i_south - 1) * i_south
621621
length_south = i_south * 4
622622

@@ -628,7 +628,7 @@ def _pixels_to_rings(nside: int, p: ArrayT) -> ArrayT:
628628
j = xp.where(p >= ncap, j_eq, j)
629629
j = xp.where(p >= (npix - ncap), length_south - 1 - j_south, j)
630630

631-
return i.astype(int), j.astype(int)
631+
return _to_int(i), _to_int(j)
632632

633633

634634
def ring_length(nside: int, i: ArrayT) -> ArrayT:
@@ -706,18 +706,45 @@ def to_rotated_pixelization(x, fill_value=math.nan):
706706
return output
707707

708708

709+
def _arange_like(n, like):
710+
if isinstance(like, np.ndarray):
711+
batch = np.arange(n)
712+
else:
713+
batch = torch.arange(n, device=like.device)
714+
return batch
715+
716+
717+
def _to_int(x):
718+
if isinstance(x, np.ndarray):
719+
return x.astype(int)
720+
else:
721+
return x.int()
722+
723+
724+
def _zeros_like(x, shape=None, dtype=None):
725+
if isinstance(x, np.ndarray):
726+
return np.zeros_like(x, shape=shape, dtype=dtype)
727+
else:
728+
return torch.zeros(shape or x.shape, dtype=dtype, device=x.device)
729+
730+
709731
def to_double_pixelization(x: ArrayT, fill_value=0) -> ArrayT:
710732
"""Convert the array x to 2D-image w/ the double pixelization
711733
712734
``x`` must be in RING pixel order
713735
714736
"""
715737
xp = _get_array_library(x)
738+
dtype = xp.float32
716739

717740
n = npix2nside(x.shape[-1])
718-
i, jp = ring2double(n, np.arange(12 * n * n))
719-
out = xp.zeros_like(x, shape=x.shape[:-1] + (4 * n, 8 * n + 1), dtype=xp.float32)
720-
num = xp.zeros_like(out, dtype=xp.int32)
741+
i, jp = ring2double(n, _arange_like(12 * n * n, x))
742+
out = _zeros_like(x, shape=x.shape[:-1] + (4 * n, 8 * n + 1), dtype=dtype)
743+
num = _zeros_like(out, dtype=xp.int32)
744+
745+
if torch.is_tensor(x):
746+
x = x.to(out)
747+
721748
out[i, jp] = x
722749
num[i, jp] += 1
723750

@@ -732,25 +759,26 @@ def to_double_pixelization(x: ArrayT, fill_value=0) -> ArrayT:
732759
return out
733760

734761

735-
def zonal_average(x: ArrayT) -> ArrayT:
762+
def zonal_average(x: ArrayT, dim=-1) -> ArrayT:
736763
"""Compute the zonal average of a map in ring format"""
737764
xp = _get_array_library(x)
738-
if x.ndim != 2:
739-
raise ValueError()
765+
766+
dim = dim % x.ndim
767+
shape = [x.shape[i] for i in range(x.ndim) if i != dim]
768+
x = xp.moveaxis(x, dim, -1)
769+
x = x.reshape([-1, x.shape[-1]])
740770

741771
npix = x.shape[-1]
742772
nside = npix2nside(npix)
743773

744-
iring, _ = _pixels_to_rings(nside, np.arange(npix))
774+
iring, _ = _pixels_to_rings(nside, _arange_like(npix, like=x))
745775
nring = iring.max() + 1
746-
if isinstance(x, np.ndarray):
747-
batch = np.arange(x.shape[0])
748-
else:
749-
batch = torch.arange(x.shape[0], device=x.device)
776+
batch = _arange_like(x.shape[0], x)
750777

751778
i_flat = batch[:, None] * nring + iring
752779
i_flat = i_flat.ravel()
753780
num = xp.bincount(i_flat, weights=x.ravel(), minlength=nring * x.shape[0])
754781
denom = xp.bincount(i_flat, minlength=nring * x.shape[0])
755782
average = num / denom
756-
return average.reshape(x.shape[0], nring)
783+
average = average.reshape((*shape, nring)) # type: ignore
784+
return xp.moveaxis(average, -1, dim)

tests/test_healpix.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ def test_grid_visualize():
3030

3131
@pytest.mark.parametrize("origin", list(healpix.Compass))
3232
def test_grid_healpix_orientations(tmp_path, origin):
33-
3433
nest_grid = healpix.Grid(level=4, pixel_order=healpix.PixelOrder.NEST)
3534
grid = healpix.Grid(level=4, pixel_order=healpix.XY(origin=origin))
3635

@@ -108,7 +107,6 @@ def grad_abs(z):
108107
sigma = grad_abs(z)
109108

110109
if sigma_padded > sigma * 1.1:
111-
112110
fig, axs = plt.subplots(3, 4)
113111
axs = axs.ravel()
114112
for i in range(12):
@@ -175,7 +173,12 @@ def test_latlon_cuda_set_device_regression():
175173
torch.set_default_device(default)
176174

177175

178-
def test_zonal_average():
176+
@pytest.mark.parametrize("device,do_torch", [("cpu", True), ("cuda", True), ("cpu", False)])
177+
def test_zonal_average(device, do_torch):
178+
179+
if device == "cuda" and torch.cuda.device_count() == 0:
180+
pytest.skip("no cuda devices available")
181+
179182
# hpx 2 in ring order
180183
x = np.array(
181184
[
@@ -230,7 +233,11 @@ def test_zonal_average():
230233
]
231234
)
232235
x = x[None]
236+
if do_torch:
237+
x = torch.from_numpy(x).to(device)
233238
zonal = healpix.zonal_average(x)
239+
if do_torch:
240+
zonal = zonal.cpu().numpy()
234241
assert zonal.shape == (1, 7)
235242
assert np.all(zonal == np.arange(7))
236243

@@ -241,3 +248,17 @@ def test_to_double_pixelization(regtest):
241248
x = healpix.to_double_pixelization(x)
242249
assert x.dtype == x.dtype
243250
np.savetxt(regtest, x, fmt="%d")
251+
252+
253+
def test_to_double_pixelization_cuda(device="cuda"):
254+
if not torch.cuda.is_available():
255+
pytest.skip()
256+
257+
n = 2
258+
x = np.arange(12 * n * n)
259+
xnp = healpix.to_double_pixelization(x)
260+
261+
x = torch.arange(12 * n * n, device=device)
262+
x = healpix.to_double_pixelization(x)
263+
264+
np.testing.assert_array_equal(xnp, x.cpu().numpy())

0 commit comments

Comments
 (0)