Skip to content

Commit 747d3d2

Browse files
benbovydcherian
andauthored
Add NDPointIndex examples (#25)
Co-authored-by: Deepak Cherian <deepak@cherian.net>
1 parent 6045101 commit 747d3d2

File tree

3 files changed

+193
-1
lines changed

3 files changed

+193
-1
lines changed

docs/blocks/ndpoint.md

Lines changed: 190 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,190 @@
1-
# Tree-based indexes with `NDPointIndex`
1+
---
2+
jupytext:
3+
text_representation:
4+
format_name: myst
5+
kernelspec:
6+
display_name: Python 3
7+
name: python
8+
---
9+
10+
# Nearest neighbors with `NDPointIndex`
11+
12+
## Highlights
13+
14+
1. {py:class}`xarray.indexes.NDPointIndex` is useful for dealing with
15+
n-dimensional coordinate variables representing irregular data.
16+
1. It enables point-wise (nearest-neighbors) data selection using Xarray's
17+
[advanced indexing](https://docs.xarray.dev/en/latest/user-guide/indexing.html#more-advanced-indexing)
18+
capabilities.
19+
1. By default, a {py:class}`scipy.spatial.KDTree` is used under the hood for
20+
fast lookup of point data. Although experimental, it is possible to plug in
21+
alternative structures to `NDPointIndex` (See {ref}`advanced`).
22+
23+
## Basic Example: Default KDTree
24+
25+
Let's create a dataset with random points.
26+
27+
```{code-cell} python
28+
import numpy as np
29+
import matplotlib.pyplot as plt
30+
import xarray as xr
31+
```
32+
33+
```{code-cell} python
34+
---
35+
tags: [remove-cell]
36+
---
37+
%xmode minimal
38+
xr.set_options(
39+
display_expand_indexes=True,
40+
display_expand_data=False,
41+
);
42+
```
43+
44+
```{code-cell} python
45+
shape = (5, 10)
46+
xx = xr.DataArray(np.random.uniform(0, 10, size=shape), dims=("y", "x"))
47+
yy = xr.DataArray(np.random.uniform(0, 5, size=shape), dims=("y", "x"))
48+
data = (xx - 5)**2 + (yy - 2.5)**2
49+
50+
ds = xr.Dataset(data_vars={"data": data}, coords={"xx": xx, "yy": yy})
51+
ds
52+
```
53+
54+
```{code-cell} python
55+
ds.plot.scatter(x="xx", y="yy", hue="data");
56+
```
57+
58+
### Assigning
59+
60+
```{code-cell} python
61+
ds_index = ds.set_xindex(("xx", "yy"), xr.indexes.NDPointIndex)
62+
ds_index
63+
```
64+
65+
### Point-wise indexing
66+
67+
Select one value.
68+
69+
```{code-cell} python
70+
ds_index.sel(xx=3.4, yy=4.2, method="nearest")
71+
```
72+
73+
Select multiple points in the `x`/`y` dimension space, using
74+
{py:class}`xarray.DataArray` objects as input labels.
75+
76+
```{code-cell} python
77+
# create a regular grid as query points
78+
ds_grid = xr.Dataset(coords={"x": range(10), "y": range(5)})
79+
80+
# selection supports broadcasting of the input labels
81+
ds_selection = ds_index.sel(
82+
xx=ds_grid.x, yy=ds_grid.y, method="nearest"
83+
)
84+
85+
# assign selection results to the grid
86+
# -> nearest neighbor interpolation
87+
ds_grid["data"] = ds_selection.data.variable
88+
89+
ds_grid
90+
```
91+
92+
```{code-cell} python
93+
ds_grid.data.plot(x="x", y="y")
94+
ds.plot.scatter(x="xx", y="yy", c="k")
95+
plt.show()
96+
```
97+
98+
(advanced)=
99+
100+
## Advanced example
101+
102+
This example is based on the Regional Ocean Modeling System (ROMS) [Xarray
103+
example](https://docs.xarray.dev/en/stable/examples/ROMS_ocean_model.html).
104+
105+
```{code-cell} python
106+
ds_roms = xr.tutorial.open_dataset("ROMS_example")
107+
ds_roms
108+
```
109+
110+
The dataset above is represented on a curvilinear grid with 2-dimensional
111+
`lat_rho` and `lon_rho` coordinate variables (in degrees). We will illustrate sampling a
112+
straight line trajectory through this field.
113+
114+
```{code-cell} python
115+
import matplotlib.pyplot as plt
116+
117+
ds_trajectory = xr.Dataset(
118+
coords={
119+
"lat": ('trajectory', np.linspace(28, 30, 50)),
120+
"lon": ('trajectory', np.linspace(-93, -88, 50)),
121+
},
122+
)
123+
124+
ds_roms.salt.isel(s_rho=-1, ocean_time=0).plot(x="lon_rho", y="lat_rho")
125+
plt.plot(
126+
ds_trajectory.lon.data, ds_trajectory.lat.data, marker='.', color='k', ms=4, ls="none",
127+
)
128+
plt.show()
129+
```
130+
131+
The default kd-tree structure used by {py:class}`~xarray.indexes.NDPointIndex`
132+
isn't best suited for these latitude and longitude coordinates. Fortunately, there
133+
is a way of using alternative structures. Here let's use {py:class}`sklearn.neighbors.BallTree`,
134+
which supports providing distance metrics such as `haversine` that will better
135+
work with latitude and longitude data.
136+
137+
```{code-cell} python
138+
from sklearn.neighbors import BallTree
139+
from xarray.indexes.nd_point_index import TreeAdapter
140+
141+
142+
class SklearnGeoBallTreeAdapter(TreeAdapter):
143+
144+
def __init__(self, points: np.ndarray, options: dict):
145+
options.update({'metric': 'haversine'})
146+
self._balltree = BallTree(np.deg2rad(points), **options)
147+
148+
def query(self, points: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
149+
return self._balltree.query(np.deg2rad(points))
150+
151+
def equals(self, other: "SklearnGeoBallTreeAdapter") -> bool:
152+
return np.array_equal(self._balltree.data, other._kdtree.data)
153+
```
154+
155+
```{note}
156+
Using alternative structures via custom {py:class}`~xarray.indexes.TreeAdapter` subclasses is an
157+
experimental feature!
158+
159+
The adapter above based on {py:class}`sklearn.neighbors.BallTree` will
160+
eventually be available in the [xoak](https://github.com/xarray-contrib/xoak)
161+
package along with other useful adapters.
162+
```
163+
164+
### Assigning
165+
166+
```{code-cell} python
167+
ds_roms_index = ds_roms.set_xindex(
168+
("lat_rho", "lon_rho"),
169+
xr.indexes.NDPointIndex,
170+
tree_adapter_cls=SklearnGeoBallTreeAdapter,
171+
)
172+
ds_roms_index
173+
```
174+
175+
### Indexing
176+
177+
```{code-cell} python
178+
ds_roms_selection = ds_roms_index.sel(
179+
lat_rho=ds_trajectory.lat,
180+
lon_rho=ds_trajectory.lon,
181+
method="nearest",
182+
)
183+
ds_roms_selection
184+
```
185+
186+
```{code-cell} python
187+
plt.figure()
188+
ds_roms_selection.plot.scatter(x="lat_rho", y="lat_rho", hue="zeta")
189+
plt.show()
190+
```

docs/conf.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@
116116
"python": ("https://docs.python.org/3/", None),
117117
"pandas": ("https://pandas.pydata.org/pandas-docs/stable", None),
118118
"numpy": ("https://numpy.org/doc/stable", None),
119+
"scipy": ("https://docs.scipy.org/doc/scipy/", None),
119120
"xarray": ("https://docs.xarray.dev/en/latest/", None),
120121
"rasterix": ("https://rasterix.readthedocs.io/en/latest/", None),
121122
"shapely": ("https://shapely.readthedocs.io/en/latest/", None),
@@ -124,5 +125,6 @@
124125
"geopandas": ("https://geopandas.org/en/stable/", None),
125126
"pint-xarray": ("https://pint-xarray.readthedocs.io/en/latest/", None),
126127
"pint": ("https://pint.readthedocs.io/en/stable/", None),
128+
"sklearn": ("https://scikit-learn.org/stable/", None),
127129
"astropy": ("https://docs.astropy.org/en/latest/", None),
128130
}

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,3 +23,4 @@ pint-xarray
2323
cf_xarray
2424
astropy
2525
git+https://github.com/pydata/xarray
26+
scikit-learn

0 commit comments

Comments
 (0)