Skip to content

Commit 3616f3a

Browse files
committed
typing
1 parent c657a3f commit 3616f3a

File tree

3 files changed

+99
-31
lines changed

3 files changed

+99
-31
lines changed

xdggs/accessor.py

Lines changed: 47 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,64 @@
1+
import numpy.typing as npt
12
import xarray as xr
23

34
from xdggs.index import DGGSIndex
45

56

67
@xr.register_dataset_accessor("dggs")
8+
@xr.register_dataarray_accessor("dggs")
79
class DGGSAccessor:
8-
def __init__(self, obj):
10+
_obj: xr.Dataset | xr.DataArray
11+
_index: DGGSIndex | None
12+
_name: str
13+
14+
def __init__(self, obj: xr.Dataset | xr.DataArray):
915
self._obj = obj
1016

11-
indexes = {k: idx for k, idx in obj.xindexes.items() if isinstance(idx, DGGSIndex)}
12-
if len(indexes) > 1:
13-
raise ValueError("Only one DGGSIndex per object is supported")
17+
index = None
18+
name = ""
19+
for k, idx in obj.xindexes.items():
20+
if isinstance(idx, DGGSIndex):
21+
if index is not None:
22+
raise ValueError("Only one DGGSIndex per object is supported")
23+
index = idx
24+
name = k
25+
self._name = name
26+
self._index = index
27+
28+
@property
29+
def index(self) -> DGGSIndex:
30+
if self._index is None:
31+
raise ValueError("no DGGSIndex found on this Dataset or DataArray")
32+
return self._index
33+
34+
def sel_latlon(
35+
self, latitude: npt.ArrayLike, longitude: npt.ArrayLike
36+
) -> xr.Dataset | xr.DataArray:
37+
"""Select grid cells from latitude/longitude data.
1438
15-
self._name, self._index = next(iter(indexes.items()))
39+
Parameters
40+
----------
41+
latitude : array-like
42+
Latitude coordinates (degrees).
43+
longitude : array-like
44+
Longitude coordinates (degrees).
1645
17-
def sel_latlon(self, lat, lon):
18-
"""Point-wise, nearest-neighbor selection from lat/lon data."""
46+
Returns
47+
-------
48+
subset
49+
A new :py:class:`xarray.Dataset` or :py:class:`xarray.DataArray`
50+
with all cells that contain the input latitude/longitude data points.
1951
20-
cell_indexers = {self._name: self._index._latlon2cellid(lat, lon)}
52+
"""
53+
cell_indexers = {self._name: self.index._latlon2cellid(latitude, longitude)}
2154
return self._obj.sel(cell_indexers)
2255

23-
def assign_latlon_coords(self):
24-
"""Return a new object with latitude and longitude coordinates
25-
of the cell centers."""
56+
def assign_latlon_coords(self) -> xr.Dataset | xr.DataArray:
57+
"""Return a new Dataset or DataArray with new "latitude" and "longitude"
58+
coordinates representing the grid cell centers."""
2659

27-
lat_data, lon_data = self._index.cell_centers
60+
lat_data, lon_data = self.index.cell_centers
2861
return self._obj.assign_coords(
29-
latitude=(self._index._dim, lat_data),
30-
longitude=(self._index._dim, lon_data),
62+
latitude=(self.index._dim, lat_data),
63+
longitude=(self.index._dim, lon_data),
3164
)

xdggs/healpix.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,57 @@
1+
from collections.abc import Mapping
2+
from typing import Any
3+
14
import healpy
5+
import numpy as np
6+
import xarray as xr
7+
from xarray.indexes import PandasIndex
28

39
from xdggs.index import DGGSIndex
410
from xdggs.utils import _extract_cell_id_variable, register_dggs
511

612

713
@register_dggs("healpix")
814
class HealpixIndex(DGGSIndex):
9-
def __init__(self, cell_ids, dim, nside, nest, rot_latlon):
15+
def __init__(
16+
self,
17+
cell_ids: Any | PandasIndex,
18+
dim: str,
19+
nside: int,
20+
nest: bool,
21+
rot_latlon: tuple[float, float],
22+
):
1023
super().__init__(cell_ids, dim)
1124

1225
self._nside = nside
1326
self._nest = nest
1427
self._rot_latlon = rot_latlon
1528

1629
@classmethod
17-
def from_variables(cls, variables, *, options):
18-
name, var, dim = _extract_cell_id_variable(variables)
30+
def from_variables(
31+
cls: type["HealpixIndex"],
32+
variables: Mapping[Any, xr.Variable],
33+
*,
34+
options: Mapping[str, Any],
35+
) -> "HealpixIndex":
36+
_, var, dim = _extract_cell_id_variable(variables)
1937

2038
nside = var.attrs.get("nside", options.get("nside"))
2139
nest = var.attrs.get("nest", options.get("nest", False))
2240
rot_latlon = var.attrs.get("rot_latlon", options.get("rot_latlon", (0.0, 0.0)))
2341

2442
return cls(var.data, dim, nside, nest, rot_latlon)
2543

26-
def _replace(self, new_pd_index):
44+
def _replace(self, new_pd_index: PandasIndex):
2745
return type(self)(new_pd_index, self._dim, self._nside, self._nest, self._rot_latlon)
2846

29-
def _latlon2cellid(self, lat, lon):
47+
def _latlon2cellid(self, lat: Any, lon: Any) -> np.ndarray:
3048
return healpy.ang2pix(self._nside, -lon, lat, lonlat=True, nest=self._nest)
3149

32-
def _cellid2latlon(self, cell_ids):
50+
def _cellid2latlon(self, cell_ids: Any) -> tuple[np.ndarray, np.ndarray]:
3351
lon, lat = healpy.pix2ang(self._nside, cell_ids, nest=self._nest, lonlat=True)
3452
return lat, -lon
3553

36-
def _repr_inline_(self, max_width):
54+
def _repr_inline_(self, max_width: int):
3755
return (
3856
f"HealpixIndex(nside={self._nside}, nest={self._nest}, rot_latlon={self._rot_latlon!r})"
3957
)

xdggs/index.py

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,18 @@
1+
from collections.abc import Hashable, Mapping
2+
from typing import Any, Union
3+
4+
import numpy as np
5+
import xarray as xr
16
from xarray.indexes import Index, PandasIndex
27

38
from .utils import GRID_REGISTRY, _extract_cell_id_variable
49

510

611
class DGGSIndex(Index):
7-
def __init__(self, cell_ids, dim):
12+
_dim: str
13+
_pd_index: PandasIndex
14+
15+
def __init__(self, cell_ids: Any | PandasIndex, dim: str):
816
self._dim = dim
917

1018
if isinstance(cell_ids, PandasIndex):
@@ -13,18 +21,27 @@ def __init__(self, cell_ids, dim):
1321
self._pd_index = PandasIndex(cell_ids, dim)
1422

1523
@classmethod
16-
def from_variables(cls, variables, *, options):
17-
name, var, dim = _extract_cell_id_variable(variables)
24+
def from_variables(
25+
cls: type["DGGSIndex"],
26+
variables: Mapping[Any, xr.Variable],
27+
*,
28+
options: Mapping[str, Any],
29+
) -> "DGGSIndex":
30+
_, var, _ = _extract_cell_id_variable(variables)
1831

1932
grid_name = var.attrs["grid_name"]
2033
cls = GRID_REGISTRY[grid_name]
2134

2235
return cls.from_variables(variables, options=options)
2336

24-
def create_variables(self, variables=None):
37+
def create_variables(
38+
self, variables: Mapping[Any, xr.Variable] | None = None
39+
) -> dict[Hashable, xr.Variable]:
2540
return self._pd_index.create_variables(variables)
2641

27-
def isel(self, indexers):
42+
def isel(
43+
self: "DGGSIndex", indexers: Mapping[Any, int | np.ndarray | xr.Variable]
44+
) -> Union["DGGSIndex", None]:
2845
new_pd_index = self._pd_index.isel(indexers)
2946
if new_pd_index is not None:
3047
return self._replace(new_pd_index)
@@ -36,17 +53,17 @@ def sel(self, labels, method=None, tolerance=None):
3653
raise ValueError("finding nearest grid cell has no meaning")
3754
return self._pd_index.sel(labels, method=method, tolerance=tolerance)
3855

39-
def _replace(self, new_pd_index):
56+
def _replace(self, new_pd_index: PandasIndex):
4057
raise NotImplementedError()
4158

42-
def _latlon2cellid(self, lat, lon):
59+
def _latlon2cellid(self, lat: Any, lon: Any) -> np.ndarray:
4360
"""convert latitude / longitude points to cell ids."""
4461
raise NotImplementedError()
4562

46-
def _cellid2latlon(self, cell_ids):
47-
"""convert cell centers to latitude / longitude."""
63+
def _cellid2latlon(self, cell_ids: Any) -> tuple[np.ndarray, np.ndarray]:
64+
"""convert cell ids to latitude / longitude (cell centers)."""
4865
raise NotImplementedError()
4966

5067
@property
51-
def cell_centers(self):
68+
def cell_centers(self) -> tuple[np.ndarray, np.ndarray]:
5269
return self._cellid2latlon(self._pd_index.index.values)

0 commit comments

Comments
 (0)