diff --git a/xdggs/accessor.py b/xdggs/accessor.py index 1852c981..65a8c2a1 100644 --- a/xdggs/accessor.py +++ b/xdggs/accessor.py @@ -1,7 +1,12 @@ +from collections.abc import Callable + +import numpy as np import numpy.typing as npt import xarray as xr from xdggs.grid import DGGSInfo +from xdggs.healpix import HealpixInfo +from xdggs.healpix import downscale as healpix_downscale from xdggs.index import DGGSIndex from xdggs.plotting import explore @@ -209,3 +214,87 @@ def explore(self, *, cmap="viridis", center=None, alpha=None, coords=None): alpha=alpha, coords=coords, ) + + def downscale(self, level: int, agg: Callable = None): + """Aggregate data to a lower grid level. + + Parameters + ---------- + level : int, optional + The target level of the grid you want to group towards. This is the level of the resulting data. + agg : callable, default: np.mean + The aggregation function to use. This function must accept a 1D array and return a scalar value. + + Returns + ------- + xarray.Dataset or xarray.DataArray + The downscaled data. + """ + if agg is None: + agg = np.mean + + assert_valid_level(level) + + if self.grid_info.level < level: + raise ValueError( + f"Can't downscale to level {level} from data on level {self.grid_info.level}. Did you mean upscale?" + ) + + offset = self.grid_info.level - level + + if not isinstance(self.grid_info, HealpixInfo): + raise ValueError( + "Downscaling is currently only supported for Healpix grids." + ) + + return healpix_downscale( + self._obj, offset=offset, agg=agg, grid_info=self.grid_info + ) + + def upscale(self, level: int): + if not isinstance(level, int): + raise ValueError( + f"Expected level to be of type {{int}}. Got {type(level).__name__}" + ) + + if self.grid_info.level > level: + raise ValueError( + f"Can't upscale to level {level} from data on level {self.grid_info.level}. Did you mean downscale?" + ) + + offset = level - self.grid_info.level # noqa + + if not isinstance(self.grid_info, HealpixInfo): + raise ValueError("Upscaling is currently only supported for Healpix grids.") + + raise NotImplementedError() + + def rescale(self, level: int, downscale_agg: Callable | None = None): + """Rescale the data to a different grid level by either upscaling or downscaling. + + Parameters + ---------- + level : int + The target level of the grid you want to group towards. This is the level of the resulting data. + downscale_agg : callable, default: np.mean + The aggregation function to use if downscaling. This function must accept a 1D array and return a scalar value. + + Returns + ------- + xarray.Dataset or xarray.DataArray + The rescaled data. + """ + assert_valid_level(level) + + if self.grid_info.level < level: + return self.upscale(level) + else: + return self.downscale(level, agg=downscale_agg) + + +def assert_valid_level(level: int) -> None: + if not isinstance(level, int): + raise ValueError(f"level must be an integer, got {type(level).__name__}") + + if level < 0: + raise ValueError(f"level must be a non-negative integer, got {level}") diff --git a/xdggs/healpix.py b/xdggs/healpix.py index 2da42bf5..9874c7be 100644 --- a/xdggs/healpix.py +++ b/xdggs/healpix.py @@ -1,9 +1,12 @@ +from __future__ import annotations + import json from collections.abc import Mapping from dataclasses import dataclass from typing import Any, ClassVar, Literal, TypeVar try: + from collections.abc import Callable from typing import Self except ImportError: # pragma: no cover from typing_extensions import Self @@ -88,6 +91,26 @@ def center_around_prime_meridian(lon, lat): return result +def downscale( + obj: xr.DataArray | xr.Dataset, + *, + offset: int, + agg: Callable, + grid_info: HealpixInfo, +): + if not grid_info.nest: + raise NotImplementedError( + "Downscaling is only supported for nested Healpix grids." + ) + + if offset == 0: + return obj + + upper_cell_membership = np.floor(obj.cell_ids / (4**offset)) + + return obj.groupby(upper_cell_membership).reduce(agg) + + @dataclass(frozen=True) class HealpixInfo(DGGSInfo): """ @@ -324,11 +347,11 @@ def __init__( @classmethod def from_variables( - cls: type["HealpixIndex"], + cls: type[HealpixIndex], variables: Mapping[Any, xr.Variable], *, options: Mapping[str, Any], - ) -> "HealpixIndex": + ) -> HealpixIndex: _, var, dim = _extract_cell_id_variable(variables) grid_info = HealpixInfo.from_dict(var.attrs | options)