Skip to content

Commit aefb107

Browse files
authored
Merge pull request #29 from NVlabs/nb/hpx-utils
More healpix utilities
2 parents baa00d2 + 41a9e95 commit aefb107

File tree

5 files changed

+305
-15
lines changed

5 files changed

+305
-15
lines changed

CHANGELOG.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,14 @@
11
# Changelog
22

3+
## Latest
4+
5+
New APIs
6+
- earth2grid.healpix
7+
- zonal_average
8+
- ring2double
9+
- to_rotated_pixelization
10+
- to_double_pixelization
11+
312
## 2025.4.1
413

514
Breaking changes:

earth2grid/healpix.py

Lines changed: 167 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
from dataclasses import dataclass
3838
from enum import Enum
3939
from functools import lru_cache
40-
from typing import Union
40+
from typing import TypeVar, Union
4141

4242
import einops
4343
import numpy as np
@@ -72,6 +72,16 @@
7272
__all__ = ["pad", "PixelOrder", "XY", "Compass", "Grid", "HEALPIX_PAD_XY", "conv2d", "reorder", "ang2pix"]
7373

7474

75+
def _get_array_library(x):
76+
if isinstance(x, np.ndarray):
77+
return np
78+
else:
79+
return torch
80+
81+
82+
ArrayT = TypeVar("ArrayT", np.ndarray, torch.Tensor)
83+
84+
7585
def pad(x: torch.Tensor, padding: int) -> torch.Tensor:
7686
"""
7787
Pad each face consistently with its according neighbors in the HEALPix
@@ -412,20 +422,7 @@ def to_image(self, x: torch.Tensor, fill_value=torch.nan) -> torch.Tensor:
412422
"""Use the 45 degree rotated grid pixelation
413423
i points to SE, j point to NE
414424
"""
415-
grid = [[6, 9, -1, -1, -1], [1, 5, 8, -1, -1], [-1, 0, 4, 11, -1], [-1, -1, 3, 7, 10], [-1, -1, -1, 2, 6]]
416-
pixel_order = XY(origin=Compass.W, clockwise=True)
417-
x = self.reorder(pixel_order, x)
418-
nside = self._nside()
419-
*shape, _ = x.shape
420-
x = x.reshape((*shape, 12, nside, nside))
421-
output = torch.full((*shape, 5 * nside, 5 * nside), device=x.device, dtype=x.dtype, fill_value=fill_value)
422-
423-
for j in range(len(grid)):
424-
for i in range(len(grid[0])):
425-
face = grid[j][i]
426-
if face != -1:
427-
output[j * nside : (j + 1) * nside, i * nside : (i + 1) * nside] = x[face]
428-
return output
425+
return to_rotated_pixelization(x, fill_value)
429426

430427

431428
class HEALPixPadFunction(torch.autograd.Function):
@@ -602,3 +599,158 @@ def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
602599
weight = weight.unsqueeze(-3)
603600
out = torch.nn.functional.conv3d(input, weight, bias, stride, padding, dilation, groups)
604601
return einops.rearrange(out, "n c f x y -> n c () (f x y)")
602+
603+
604+
def _pixels_to_rings(nside: int, p: ArrayT) -> ArrayT:
605+
"""Get the ring number of a pixel ``i`` in RING order"""
606+
# See eq (2-5) of Gorski
607+
xp = _get_array_library(p)
608+
npix = 12 * nside * nside
609+
ncap = 2 * nside * (nside - 1)
610+
611+
i_north = xp.floor(0.5 * (1 + np.sqrt(1 + 2 * p)))
612+
j_north = p - 2 * (i_north - 1) * i_north
613+
614+
p_eq = p - ncap
615+
i_eq = xp.floor(p_eq / (4 * nside)) + nside - 1
616+
j_eq = p_eq % (4 * nside)
617+
618+
p_south = npix - p - 1
619+
i_south = xp.floor(0.5 * (1 + np.sqrt(1 + 2 * p_south)))
620+
j_south = p_south - 2 * (i_south - 1) * i_south
621+
length_south = i_south * 4
622+
623+
i = i_north - 1
624+
i = xp.where(p >= ncap, i_eq, i)
625+
i = xp.where(p >= (npix - ncap), 4 * nside - i_south - 1, i)
626+
627+
j = j_north
628+
j = xp.where(p >= ncap, j_eq, j)
629+
j = xp.where(p >= (npix - ncap), length_south - 1 - j_south, j)
630+
631+
return i.astype(int), j.astype(int)
632+
633+
634+
def ring_length(nside: int, i: ArrayT) -> ArrayT:
635+
"""The number of pixels in ring 0 <= i < 4 * nside - 2"""
636+
xp = _get_array_library(i)
637+
638+
length_north = 4 * (i + 1)
639+
length_eq = 4 * nside
640+
length_south = (4 * nside - i - 1) * 4
641+
642+
length = length_north
643+
# test i =1, nside = 1
644+
length = xp.where(i >= nside, length_eq, length)
645+
# test: i = 2, nside=1, should have len 4
646+
length = xp.where(i >= nside * 3 - 1, length_south, length)
647+
return length
648+
649+
650+
def ring2double(nside: int, p: ArrayT):
651+
"""Compute the (i,j) index in the double pixelization scheme of Calabretta (2007)
652+
653+
This is a visually appealing way to visualize healpix data without any
654+
interpolation.
655+
656+
See Fig 5
657+
658+
Calabretta, M. R., & Roukema, B. F. (2007). Mapping on the HEALPix grid. Monthly Notices of the Royal Astronomical Society, 381(2), 865–872. https://doi.org/10.1111/j.1365-2966.2007.12297.x
659+
660+
"""
661+
xp = _get_array_library(p)
662+
n = nside
663+
i, j = _pixels_to_rings(n, p)
664+
n_per_pyramid = ring_length(n, i) // 4
665+
666+
pyramid = j // n_per_pyramid
667+
left = n - i
668+
jp_north = 2 * pyramid * n + left + 2 * (j % n_per_pyramid)
669+
jp_eq = (n - i) % 2 + 2 * j
670+
671+
left = i - 3 * n + 2
672+
jp_south = 2 * pyramid * n + left + 2 * (j % n_per_pyramid)
673+
674+
jp = xp.where(i >= n, jp_eq, jp_north)
675+
jp = xp.where(i >= 3 * n, jp_south, jp)
676+
677+
return i, jp
678+
679+
680+
def to_rotated_pixelization(x, fill_value=math.nan):
681+
"""Convert an array to a 2D-iamge w/ the rotated pixelization"""
682+
683+
numpy_out = False
684+
if isinstance(x, np.ndarray):
685+
numpy_out = True
686+
x = torch.from_numpy(x)
687+
688+
grid = [[6, 9, -1, -1, -1], [1, 5, 8, -1, -1], [-1, 0, 4, 11, -1], [-1, -1, 3, 7, 10], [-1, -1, -1, 2, 6]]
689+
pixel_order = XY(origin=Compass.W, clockwise=True)
690+
self = Grid(npix2level(x.shape[-1]))
691+
x = self.reorder(pixel_order, x)
692+
nside = self._nside()
693+
*shape, _ = x.shape
694+
x = x.reshape((*shape, 12, nside, nside))
695+
output = torch.full((*shape, 5 * nside, 5 * nside), device=x.device, dtype=x.dtype, fill_value=fill_value)
696+
697+
for j in range(len(grid)):
698+
for i in range(len(grid[0])):
699+
face = grid[j][i]
700+
if face != -1:
701+
output[j * nside : (j + 1) * nside, i * nside : (i + 1) * nside] = x[face]
702+
703+
if numpy_out:
704+
return output.numpy()
705+
else:
706+
return output
707+
708+
709+
def to_double_pixelization(x: ArrayT, fill_value=0) -> ArrayT:
710+
"""Convert the array x to 2D-image w/ the double pixelization
711+
712+
``x`` must be in RING pixel order
713+
714+
"""
715+
xp = _get_array_library(x)
716+
717+
n = npix2nside(x.shape[-1])
718+
i, jp = ring2double(n, np.arange(12 * n * n))
719+
out = xp.zeros_like(x, shape=x.shape[:-1] + (4 * n, 8 * n + 1), dtype=xp.float32)
720+
num = xp.zeros_like(out, dtype=xp.int32)
721+
out[i, jp] = x
722+
num[i, jp] += 1
723+
724+
out[i, jp + 1] = x
725+
num[i, jp + 1] += 1
726+
727+
out[i, jp - 1] += x
728+
num[i, jp - 1] += 1
729+
out[num == 0] = fill_value
730+
num[num == 0] = 1
731+
out /= num
732+
return out
733+
734+
735+
def zonal_average(x: ArrayT) -> ArrayT:
736+
"""Compute the zonal average of a map in ring format"""
737+
xp = _get_array_library(x)
738+
if x.ndim != 2:
739+
raise ValueError()
740+
741+
npix = x.shape[-1]
742+
nside = npix2nside(npix)
743+
744+
iring, _ = _pixels_to_rings(nside, np.arange(npix))
745+
nring = iring.max() + 1
746+
if isinstance(x, np.ndarray):
747+
batch = np.arange(x.shape[0])
748+
else:
749+
batch = torch.arange(x.shape[0], device=x.device)
750+
751+
i_flat = batch[:, None] * nring + iring
752+
i_flat = i_flat.ravel()
753+
num = xp.bincount(i_flat, weights=x.ravel(), minlength=nring * x.shape[0])
754+
denom = xp.bincount(i_flat, minlength=nring * x.shape[0])
755+
average = num / denom
756+
return average.reshape(x.shape[0], nring)
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""
16+
HealPIX Pixelization
17+
--------------------
18+
19+
HealPIX maps can be viewed as a 2D image rotated by 45 deg or alternatively with
20+
double pixelization that is not rotated. This is useful for quick visualization
21+
with image viewers without distorting the native pixels of the image.
22+
"""
23+
# %%
24+
import matplotlib.pyplot as plt
25+
import numpy as np
26+
27+
from earth2grid import healpix
28+
29+
n = 8
30+
npix = 12 * n * n
31+
ncap = 2 * n * (n - 1)
32+
p = np.arange(npix)
33+
34+
grid = healpix.Grid(healpix.nside2level(n))
35+
i, jp = healpix.ring2double(n, p)
36+
plt.figure(figsize=(10, 3))
37+
# plt.scatter(jp, i, c=grid.lon[p])
38+
plt.scatter(jp + 1, i, c=grid.lon[p])
39+
plt.grid()
40+
41+
# %%
42+
n = 4
43+
npix = 12 * n * n
44+
ncap = 2 * n * (n - 1)
45+
p = np.arange(npix)
46+
47+
grid = healpix.Grid(healpix.nside2level(n))
48+
out = healpix.to_double_pixelization(grid.lon)
49+
plt.imshow(out)
50+
51+
# %%
52+
out = healpix.to_rotated_pixelization(grid.lon)
53+
plt.imshow(out)
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
0 0 0 0 0 1 1 1 0 2 2 2 0 3 3 3 0
2+
4 4 4 5 5 6 6 7 7 8 8 9 9 10 10 11 11
3+
12 12 13 13 14 14 15 15 16 16 17 17 18 18 19 19 12
4+
20 20 20 21 21 22 22 23 23 24 24 25 25 26 26 27 27
5+
28 28 29 29 30 30 31 31 32 32 33 33 34 34 35 35 28
6+
36 36 36 37 37 38 38 39 39 40 40 41 41 42 42 43 43
7+
0 44 44 44 0 45 45 45 0 46 46 46 0 47 47 47 0
8+
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0

tests/test_healpix.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,3 +173,71 @@ def test_latlon_cuda_set_device_regression():
173173
grid.lat
174174
finally:
175175
torch.set_default_device(default)
176+
177+
178+
def test_zonal_average():
179+
# hpx 2 in ring order
180+
x = np.array(
181+
[
182+
0,
183+
0,
184+
0,
185+
0,
186+
1,
187+
1,
188+
1,
189+
1,
190+
1,
191+
1,
192+
1,
193+
1,
194+
2,
195+
2,
196+
2,
197+
2,
198+
2,
199+
2,
200+
2,
201+
2,
202+
3,
203+
3,
204+
3,
205+
3,
206+
3,
207+
3,
208+
3,
209+
3,
210+
4,
211+
4,
212+
4,
213+
4,
214+
4,
215+
4,
216+
4,
217+
4,
218+
5,
219+
5,
220+
5,
221+
5,
222+
5,
223+
5,
224+
5,
225+
5,
226+
6,
227+
6,
228+
6,
229+
6,
230+
]
231+
)
232+
x = x[None]
233+
zonal = healpix.zonal_average(x)
234+
assert zonal.shape == (1, 7)
235+
assert np.all(zonal == np.arange(7))
236+
237+
238+
def test_to_double_pixelization(regtest):
239+
n = 2
240+
x = np.arange(12 * n * n)
241+
x = healpix.to_double_pixelization(x)
242+
assert x.dtype == x.dtype
243+
np.savetxt(regtest, x, fmt="%d")

0 commit comments

Comments
 (0)