@@ -608,15 +608,15 @@ def _pixels_to_rings(nside: int, p: ArrayT) -> ArrayT:
608
608
npix = 12 * nside * nside
609
609
ncap = 2 * nside * (nside - 1 )
610
610
611
- i_north = xp .floor (0.5 * (1 + np .sqrt (1 + 2 * p )))
611
+ i_north = xp .floor (0.5 * (1 + xp .sqrt (1 + 2 * p )))
612
612
j_north = p - 2 * (i_north - 1 ) * i_north
613
613
614
614
p_eq = p - ncap
615
615
i_eq = xp .floor (p_eq / (4 * nside )) + nside - 1
616
616
j_eq = p_eq % (4 * nside )
617
617
618
618
p_south = npix - p - 1
619
- i_south = xp .floor (0.5 * (1 + np .sqrt (1 + 2 * p_south )))
619
+ i_south = xp .floor (0.5 * (1 + xp .sqrt (1 + 2 * p_south )))
620
620
j_south = p_south - 2 * (i_south - 1 ) * i_south
621
621
length_south = i_south * 4
622
622
@@ -628,7 +628,7 @@ def _pixels_to_rings(nside: int, p: ArrayT) -> ArrayT:
628
628
j = xp .where (p >= ncap , j_eq , j )
629
629
j = xp .where (p >= (npix - ncap ), length_south - 1 - j_south , j )
630
630
631
- return i . astype ( int ), j . astype ( int )
631
+ return _to_int ( i ), _to_int ( j )
632
632
633
633
634
634
def ring_length (nside : int , i : ArrayT ) -> ArrayT :
@@ -706,18 +706,45 @@ def to_rotated_pixelization(x, fill_value=math.nan):
706
706
return output
707
707
708
708
709
+ def _arange_like (n , like ):
710
+ if isinstance (like , np .ndarray ):
711
+ batch = np .arange (n )
712
+ else :
713
+ batch = torch .arange (n , device = like .device )
714
+ return batch
715
+
716
+
717
+ def _to_int (x ):
718
+ if isinstance (x , np .ndarray ):
719
+ return x .astype (int )
720
+ else :
721
+ return x .int ()
722
+
723
+
724
+ def _zeros_like (x , shape = None , dtype = None ):
725
+ if isinstance (x , np .ndarray ):
726
+ return np .zeros_like (x , shape = shape , dtype = dtype )
727
+ else :
728
+ return torch .zeros (shape or x .shape , dtype = dtype , device = x .device )
729
+
730
+
709
731
def to_double_pixelization (x : ArrayT , fill_value = 0 ) -> ArrayT :
710
732
"""Convert the array x to 2D-image w/ the double pixelization
711
733
712
734
``x`` must be in RING pixel order
713
735
714
736
"""
715
737
xp = _get_array_library (x )
738
+ dtype = xp .float32
716
739
717
740
n = npix2nside (x .shape [- 1 ])
718
- i , jp = ring2double (n , np .arange (12 * n * n ))
719
- out = xp .zeros_like (x , shape = x .shape [:- 1 ] + (4 * n , 8 * n + 1 ), dtype = xp .float32 )
720
- num = xp .zeros_like (out , dtype = xp .int32 )
741
+ i , jp = ring2double (n , _arange_like (12 * n * n , x ))
742
+ out = _zeros_like (x , shape = x .shape [:- 1 ] + (4 * n , 8 * n + 1 ), dtype = dtype )
743
+ num = _zeros_like (out , dtype = xp .int32 )
744
+
745
+ if torch .is_tensor (x ):
746
+ x = x .to (out )
747
+
721
748
out [i , jp ] = x
722
749
num [i , jp ] += 1
723
750
@@ -732,25 +759,26 @@ def to_double_pixelization(x: ArrayT, fill_value=0) -> ArrayT:
732
759
return out
733
760
734
761
735
- def zonal_average (x : ArrayT ) -> ArrayT :
762
+ def zonal_average (x : ArrayT , dim = - 1 ) -> ArrayT :
736
763
"""Compute the zonal average of a map in ring format"""
737
764
xp = _get_array_library (x )
738
- if x .ndim != 2 :
739
- raise ValueError ()
765
+
766
+ dim = dim % x .ndim
767
+ shape = [x .shape [i ] for i in range (x .ndim ) if i != dim ]
768
+ x = xp .moveaxis (x , dim , - 1 )
769
+ x = x .reshape ([- 1 , x .shape [- 1 ]])
740
770
741
771
npix = x .shape [- 1 ]
742
772
nside = npix2nside (npix )
743
773
744
- iring , _ = _pixels_to_rings (nside , np . arange (npix ))
774
+ iring , _ = _pixels_to_rings (nside , _arange_like (npix , like = x ))
745
775
nring = iring .max () + 1
746
- if isinstance (x , np .ndarray ):
747
- batch = np .arange (x .shape [0 ])
748
- else :
749
- batch = torch .arange (x .shape [0 ], device = x .device )
776
+ batch = _arange_like (x .shape [0 ], x )
750
777
751
778
i_flat = batch [:, None ] * nring + iring
752
779
i_flat = i_flat .ravel ()
753
780
num = xp .bincount (i_flat , weights = x .ravel (), minlength = nring * x .shape [0 ])
754
781
denom = xp .bincount (i_flat , minlength = nring * x .shape [0 ])
755
782
average = num / denom
756
- return average .reshape (x .shape [0 ], nring )
783
+ average = average .reshape ((* shape , nring )) # type: ignore
784
+ return xp .moveaxis (average , - 1 , dim )
0 commit comments