Skip to content

Add .downscale(), .upscale(), and .rescale() methods #141

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 89 additions & 0 deletions xdggs/accessor.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
from collections.abc import Callable

import numpy as np
import numpy.typing as npt
import xarray as xr

from xdggs.grid import DGGSInfo
from xdggs.healpix import HealpixInfo
from xdggs.healpix import downscale as healpix_downscale
from xdggs.index import DGGSIndex
from xdggs.plotting import explore

Expand Down Expand Up @@ -209,3 +214,87 @@ def explore(self, *, cmap="viridis", center=None, alpha=None, coords=None):
alpha=alpha,
coords=coords,
)

def downscale(self, level: int, agg: Callable = None):
"""Aggregate data to a lower grid level.

Parameters
----------
level : int, optional
The target level of the grid you want to group towards. This is the level of the resulting data.
agg : callable, default: np.mean
The aggregation function to use. This function must accept a 1D array and return a scalar value.

Returns
-------
xarray.Dataset or xarray.DataArray
The downscaled data.
"""
if agg is None:
agg = np.mean

assert_valid_level(level)

if self.grid_info.level < level:
raise ValueError(
f"Can't downscale to level {level} from data on level {self.grid_info.level}. Did you mean upscale?"
)

offset = self.grid_info.level - level

if not isinstance(self.grid_info, HealpixInfo):
raise ValueError(
"Downscaling is currently only supported for Healpix grids."
)

return healpix_downscale(
self._obj, offset=offset, agg=agg, grid_info=self.grid_info
)

def upscale(self, level: int):
if not isinstance(level, int):
raise ValueError(
f"Expected level to be of type {{int}}. Got {type(level).__name__}"
)

if self.grid_info.level > level:
raise ValueError(
f"Can't upscale to level {level} from data on level {self.grid_info.level}. Did you mean downscale?"
)

offset = level - self.grid_info.level # noqa

if not isinstance(self.grid_info, HealpixInfo):
raise ValueError("Upscaling is currently only supported for Healpix grids.")

raise NotImplementedError()

def rescale(self, level: int, downscale_agg: Callable | None = None):
"""Rescale the data to a different grid level by either upscaling or downscaling.

Parameters
----------
level : int
The target level of the grid you want to group towards. This is the level of the resulting data.
downscale_agg : callable, default: np.mean
The aggregation function to use if downscaling. This function must accept a 1D array and return a scalar value.

Returns
-------
xarray.Dataset or xarray.DataArray
The rescaled data.
"""
assert_valid_level(level)

if self.grid_info.level < level:
return self.upscale(level)
else:
return self.downscale(level, agg=downscale_agg)


def assert_valid_level(level: int) -> None:
if not isinstance(level, int):
raise ValueError(f"level must be an integer, got {type(level).__name__}")

if level < 0:
raise ValueError(f"level must be a non-negative integer, got {level}")
27 changes: 25 additions & 2 deletions xdggs/healpix.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from __future__ import annotations

import json
from collections.abc import Mapping
from dataclasses import dataclass
from typing import Any, ClassVar, Literal, TypeVar

try:
from collections.abc import Callable
from typing import Self
except ImportError: # pragma: no cover
from typing_extensions import Self
Expand Down Expand Up @@ -88,6 +91,26 @@ def center_around_prime_meridian(lon, lat):
return result


def downscale(
obj: xr.DataArray | xr.Dataset,
*,
offset: int,
agg: Callable,
grid_info: HealpixInfo,
):
if not grid_info.nest:
raise NotImplementedError(
"Downscaling is only supported for nested Healpix grids."
)

if offset == 0:
return obj

upper_cell_membership = np.floor(obj.cell_ids / (4**offset))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, not sure how expensive this line is with high offsets?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there should be a constant cost (a binary right-shift), so I wouldn't worry too much about it. This implementation doesn't do that, so I'd recommend switching to healpix_geo.nested.zoom_to (nothing for ring, as hierarchy operations don't really make sense there)


return obj.groupby(upper_cell_membership).reduce(agg)


@dataclass(frozen=True)
class HealpixInfo(DGGSInfo):
"""
Expand Down Expand Up @@ -324,11 +347,11 @@ def __init__(

@classmethod
def from_variables(
cls: type["HealpixIndex"],
cls: type[HealpixIndex],
variables: Mapping[Any, xr.Variable],
*,
options: Mapping[str, Any],
) -> "HealpixIndex":
) -> HealpixIndex:
_, var, dim = _extract_cell_id_variable(variables)

grid_info = HealpixInfo.from_dict(var.attrs | options)
Expand Down