|
37 | 37 | from dataclasses import dataclass
|
38 | 38 | from enum import Enum
|
39 | 39 | from functools import lru_cache
|
40 |
| -from typing import Union |
| 40 | +from typing import TypeVar, Union |
41 | 41 |
|
42 | 42 | import einops
|
43 | 43 | import numpy as np
|
|
72 | 72 | __all__ = ["pad", "PixelOrder", "XY", "Compass", "Grid", "HEALPIX_PAD_XY", "conv2d", "reorder", "ang2pix"]
|
73 | 73 |
|
74 | 74 |
|
| 75 | +def _get_array_library(x): |
| 76 | + if isinstance(x, np.ndarray): |
| 77 | + return np |
| 78 | + else: |
| 79 | + return torch |
| 80 | + |
| 81 | + |
| 82 | +ArrayT = TypeVar("ArrayT", np.ndarray, torch.Tensor) |
| 83 | + |
| 84 | + |
75 | 85 | def pad(x: torch.Tensor, padding: int) -> torch.Tensor:
|
76 | 86 | """
|
77 | 87 | Pad each face consistently with its according neighbors in the HEALPix
|
@@ -412,20 +422,7 @@ def to_image(self, x: torch.Tensor, fill_value=torch.nan) -> torch.Tensor:
|
412 | 422 | """Use the 45 degree rotated grid pixelation
|
413 | 423 | i points to SE, j point to NE
|
414 | 424 | """
|
415 |
| - grid = [[6, 9, -1, -1, -1], [1, 5, 8, -1, -1], [-1, 0, 4, 11, -1], [-1, -1, 3, 7, 10], [-1, -1, -1, 2, 6]] |
416 |
| - pixel_order = XY(origin=Compass.W, clockwise=True) |
417 |
| - x = self.reorder(pixel_order, x) |
418 |
| - nside = self._nside() |
419 |
| - *shape, _ = x.shape |
420 |
| - x = x.reshape((*shape, 12, nside, nside)) |
421 |
| - output = torch.full((*shape, 5 * nside, 5 * nside), device=x.device, dtype=x.dtype, fill_value=fill_value) |
422 |
| - |
423 |
| - for j in range(len(grid)): |
424 |
| - for i in range(len(grid[0])): |
425 |
| - face = grid[j][i] |
426 |
| - if face != -1: |
427 |
| - output[j * nside : (j + 1) * nside, i * nside : (i + 1) * nside] = x[face] |
428 |
| - return output |
| 425 | + return to_rotated_pixelization(x, fill_value) |
429 | 426 |
|
430 | 427 |
|
431 | 428 | class HEALPixPadFunction(torch.autograd.Function):
|
@@ -602,3 +599,158 @@ def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
|
602 | 599 | weight = weight.unsqueeze(-3)
|
603 | 600 | out = torch.nn.functional.conv3d(input, weight, bias, stride, padding, dilation, groups)
|
604 | 601 | return einops.rearrange(out, "n c f x y -> n c () (f x y)")
|
| 602 | + |
| 603 | + |
| 604 | +def _pixels_to_rings(nside: int, p: ArrayT) -> ArrayT: |
| 605 | + """Get the ring number of a pixel ``i`` in RING order""" |
| 606 | + # See eq (2-5) of Gorski |
| 607 | + xp = _get_array_library(p) |
| 608 | + npix = 12 * nside * nside |
| 609 | + ncap = 2 * nside * (nside - 1) |
| 610 | + |
| 611 | + i_north = xp.floor(0.5 * (1 + np.sqrt(1 + 2 * p))) |
| 612 | + j_north = p - 2 * (i_north - 1) * i_north |
| 613 | + |
| 614 | + p_eq = p - ncap |
| 615 | + i_eq = xp.floor(p_eq / (4 * nside)) + nside - 1 |
| 616 | + j_eq = p_eq % (4 * nside) |
| 617 | + |
| 618 | + p_south = npix - p - 1 |
| 619 | + i_south = xp.floor(0.5 * (1 + np.sqrt(1 + 2 * p_south))) |
| 620 | + j_south = p_south - 2 * (i_south - 1) * i_south |
| 621 | + length_south = i_south * 4 |
| 622 | + |
| 623 | + i = i_north - 1 |
| 624 | + i = xp.where(p >= ncap, i_eq, i) |
| 625 | + i = xp.where(p >= (npix - ncap), 4 * nside - i_south - 1, i) |
| 626 | + |
| 627 | + j = j_north |
| 628 | + j = xp.where(p >= ncap, j_eq, j) |
| 629 | + j = xp.where(p >= (npix - ncap), length_south - 1 - j_south, j) |
| 630 | + |
| 631 | + return i.astype(int), j.astype(int) |
| 632 | + |
| 633 | + |
| 634 | +def ring_length(nside: int, i: ArrayT) -> ArrayT: |
| 635 | + """The number of pixels in ring 0 <= i < 4 * nside - 2""" |
| 636 | + xp = _get_array_library(i) |
| 637 | + |
| 638 | + length_north = 4 * (i + 1) |
| 639 | + length_eq = 4 * nside |
| 640 | + length_south = (4 * nside - i - 1) * 4 |
| 641 | + |
| 642 | + length = length_north |
| 643 | + # test i =1, nside = 1 |
| 644 | + length = xp.where(i >= nside, length_eq, length) |
| 645 | + # test: i = 2, nside=1, should have len 4 |
| 646 | + length = xp.where(i >= nside * 3 - 1, length_south, length) |
| 647 | + return length |
| 648 | + |
| 649 | + |
| 650 | +def ring2double(nside: int, p: ArrayT): |
| 651 | + """Compute the (i,j) index in the double pixelization scheme of Calabretta (2007) |
| 652 | +
|
| 653 | + This is a visually appealing way to visualize healpix data without any |
| 654 | + interpolation. |
| 655 | +
|
| 656 | + See Fig 5 |
| 657 | +
|
| 658 | + Calabretta, M. R., & Roukema, B. F. (2007). Mapping on the HEALPix grid. Monthly Notices of the Royal Astronomical Society, 381(2), 865–872. https://doi.org/10.1111/j.1365-2966.2007.12297.x |
| 659 | +
|
| 660 | + """ |
| 661 | + xp = _get_array_library(p) |
| 662 | + n = nside |
| 663 | + i, j = _pixels_to_rings(n, p) |
| 664 | + n_per_pyramid = ring_length(n, i) // 4 |
| 665 | + |
| 666 | + pyramid = j // n_per_pyramid |
| 667 | + left = n - i |
| 668 | + jp_north = 2 * pyramid * n + left + 2 * (j % n_per_pyramid) |
| 669 | + jp_eq = (n - i) % 2 + 2 * j |
| 670 | + |
| 671 | + left = i - 3 * n + 2 |
| 672 | + jp_south = 2 * pyramid * n + left + 2 * (j % n_per_pyramid) |
| 673 | + |
| 674 | + jp = xp.where(i >= n, jp_eq, jp_north) |
| 675 | + jp = xp.where(i >= 3 * n, jp_south, jp) |
| 676 | + |
| 677 | + return i, jp |
| 678 | + |
| 679 | + |
| 680 | +def to_rotated_pixelization(x, fill_value=math.nan): |
| 681 | + """Convert an array to a 2D-iamge w/ the rotated pixelization""" |
| 682 | + |
| 683 | + numpy_out = False |
| 684 | + if isinstance(x, np.ndarray): |
| 685 | + numpy_out = True |
| 686 | + x = torch.from_numpy(x) |
| 687 | + |
| 688 | + grid = [[6, 9, -1, -1, -1], [1, 5, 8, -1, -1], [-1, 0, 4, 11, -1], [-1, -1, 3, 7, 10], [-1, -1, -1, 2, 6]] |
| 689 | + pixel_order = XY(origin=Compass.W, clockwise=True) |
| 690 | + self = Grid(npix2level(x.shape[-1])) |
| 691 | + x = self.reorder(pixel_order, x) |
| 692 | + nside = self._nside() |
| 693 | + *shape, _ = x.shape |
| 694 | + x = x.reshape((*shape, 12, nside, nside)) |
| 695 | + output = torch.full((*shape, 5 * nside, 5 * nside), device=x.device, dtype=x.dtype, fill_value=fill_value) |
| 696 | + |
| 697 | + for j in range(len(grid)): |
| 698 | + for i in range(len(grid[0])): |
| 699 | + face = grid[j][i] |
| 700 | + if face != -1: |
| 701 | + output[j * nside : (j + 1) * nside, i * nside : (i + 1) * nside] = x[face] |
| 702 | + |
| 703 | + if numpy_out: |
| 704 | + return output.numpy() |
| 705 | + else: |
| 706 | + return output |
| 707 | + |
| 708 | + |
| 709 | +def to_double_pixelization(x: ArrayT, fill_value=0) -> ArrayT: |
| 710 | + """Convert the array x to 2D-image w/ the double pixelization |
| 711 | +
|
| 712 | + ``x`` must be in RING pixel order |
| 713 | +
|
| 714 | + """ |
| 715 | + xp = _get_array_library(x) |
| 716 | + |
| 717 | + n = npix2nside(x.shape[-1]) |
| 718 | + i, jp = ring2double(n, np.arange(12 * n * n)) |
| 719 | + out = xp.zeros_like(x, shape=x.shape[:-1] + (4 * n, 8 * n + 1), dtype=xp.float32) |
| 720 | + num = xp.zeros_like(out, dtype=xp.int32) |
| 721 | + out[i, jp] = x |
| 722 | + num[i, jp] += 1 |
| 723 | + |
| 724 | + out[i, jp + 1] = x |
| 725 | + num[i, jp + 1] += 1 |
| 726 | + |
| 727 | + out[i, jp - 1] += x |
| 728 | + num[i, jp - 1] += 1 |
| 729 | + out[num == 0] = fill_value |
| 730 | + num[num == 0] = 1 |
| 731 | + out /= num |
| 732 | + return out |
| 733 | + |
| 734 | + |
| 735 | +def zonal_average(x: ArrayT) -> ArrayT: |
| 736 | + """Compute the zonal average of a map in ring format""" |
| 737 | + xp = _get_array_library(x) |
| 738 | + if x.ndim != 2: |
| 739 | + raise ValueError() |
| 740 | + |
| 741 | + npix = x.shape[-1] |
| 742 | + nside = npix2nside(npix) |
| 743 | + |
| 744 | + iring, _ = _pixels_to_rings(nside, np.arange(npix)) |
| 745 | + nring = iring.max() + 1 |
| 746 | + if isinstance(x, np.ndarray): |
| 747 | + batch = np.arange(x.shape[0]) |
| 748 | + else: |
| 749 | + batch = torch.arange(x.shape[0], device=x.device) |
| 750 | + |
| 751 | + i_flat = batch[:, None] * nring + iring |
| 752 | + i_flat = i_flat.ravel() |
| 753 | + num = xp.bincount(i_flat, weights=x.ravel(), minlength=nring * x.shape[0]) |
| 754 | + denom = xp.bincount(i_flat, minlength=nring * x.shape[0]) |
| 755 | + average = num / denom |
| 756 | + return average.reshape(x.shape[0], nring) |
0 commit comments