Skip to content

Add NDPointIndex examples #25

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jul 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
191 changes: 190 additions & 1 deletion docs/blocks/ndpoint.md
Original file line number Diff line number Diff line change
@@ -1 +1,190 @@
# Tree-based indexes with `NDPointIndex`
---
jupytext:
text_representation:
format_name: myst
kernelspec:
display_name: Python 3
name: python
---

# Nearest neighbors with `NDPointIndex`

## Highlights

1. {py:class}`xarray.indexes.NDPointIndex` is useful for dealing with
n-dimensional coordinate variables representing irregular data.
1. It enables point-wise (nearest-neighbors) data selection using Xarray's
[advanced indexing](https://docs.xarray.dev/en/latest/user-guide/indexing.html#more-advanced-indexing)
capabilities.
1. By default, a {py:class}`scipy.spatial.KDTree` is used under the hood for
fast lookup of point data. Although experimental, it is possible to plug in
alternative structures to `NDPointIndex` (See {ref}`advanced`).

## Basic Example: Default KDTree

Let's create a dataset with random points.

```{code-cell} python
import numpy as np
import matplotlib.pyplot as plt
import xarray as xr
```

```{code-cell} python
---
tags: [remove-cell]
---
%xmode minimal
xr.set_options(
display_expand_indexes=True,
display_expand_data=False,
);
```

```{code-cell} python
shape = (5, 10)
xx = xr.DataArray(np.random.uniform(0, 10, size=shape), dims=("y", "x"))
yy = xr.DataArray(np.random.uniform(0, 5, size=shape), dims=("y", "x"))
data = (xx - 5)**2 + (yy - 2.5)**2

ds = xr.Dataset(data_vars={"data": data}, coords={"xx": xx, "yy": yy})
ds
```

```{code-cell} python
ds.plot.scatter(x="xx", y="yy", hue="data");
```

### Assigning

```{code-cell} python
ds_index = ds.set_xindex(("xx", "yy"), xr.indexes.NDPointIndex)
ds_index
```

### Point-wise indexing

Select one value.

```{code-cell} python
ds_index.sel(xx=3.4, yy=4.2, method="nearest")
```

Select multiple points in the `x`/`y` dimension space, using
{py:class}`xarray.DataArray` objects as input labels.

```{code-cell} python
# create a regular grid as query points
ds_grid = xr.Dataset(coords={"x": range(10), "y": range(5)})

# selection supports broadcasting of the input labels
ds_selection = ds_index.sel(
xx=ds_grid.x, yy=ds_grid.y, method="nearest"
)

# assign selection results to the grid
# -> nearest neighbor interpolation
ds_grid["data"] = ds_selection.data.variable

ds_grid
```

```{code-cell} python
ds_grid.data.plot(x="x", y="y")
ds.plot.scatter(x="xx", y="yy", c="k")
plt.show()
```

(advanced)=

## Advanced example

This example is based on the Regional Ocean Modeling System (ROMS) [Xarray
example](https://docs.xarray.dev/en/stable/examples/ROMS_ocean_model.html).

```{code-cell} python
ds_roms = xr.tutorial.open_dataset("ROMS_example")
ds_roms
```

The dataset above is represented on a curvilinear grid with 2-dimensional
`lat_rho` and `lon_rho` coordinate variables (in degrees). We will illustrate sampling a
straight line trajectory through this field.

```{code-cell} python
import matplotlib.pyplot as plt

ds_trajectory = xr.Dataset(
coords={
"lat": ('trajectory', np.linspace(28, 30, 50)),
"lon": ('trajectory', np.linspace(-93, -88, 50)),
},
)

ds_roms.salt.isel(s_rho=-1, ocean_time=0).plot(x="lon_rho", y="lat_rho")
plt.plot(
ds_trajectory.lon.data, ds_trajectory.lat.data, marker='.', color='k', ms=4, ls="none",
)
plt.show()
```

The default kd-tree structure used by {py:class}`~xarray.indexes.NDPointIndex`
isn't best suited for these latitude and longitude coordinates. Fortunately, there
is a way of using alternative structures. Here let's use {py:class}`sklearn.neighbors.BallTree`,
which supports providing distance metrics such as `haversine` that will better
work with latitude and longitude data.

```{code-cell} python
from sklearn.neighbors import BallTree
from xarray.indexes.nd_point_index import TreeAdapter


class SklearnGeoBallTreeAdapter(TreeAdapter):

def __init__(self, points: np.ndarray, options: dict):
options.update({'metric': 'haversine'})
self._balltree = BallTree(np.deg2rad(points), **options)

def query(self, points: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
return self._balltree.query(np.deg2rad(points))

def equals(self, other: "SklearnGeoBallTreeAdapter") -> bool:
return np.array_equal(self._balltree.data, other._kdtree.data)
```

```{note}
Using alternative structures via custom {py:class}`~xarray.indexes.TreeAdapter` subclasses is an
experimental feature!

The adapter above based on {py:class}`sklearn.neighbors.BallTree` will
eventually be available in the [xoak](https://github.com/xarray-contrib/xoak)
package along with other useful adapters.
```

### Assigning

```{code-cell} python
ds_roms_index = ds_roms.set_xindex(
("lat_rho", "lon_rho"),
xr.indexes.NDPointIndex,
tree_adapter_cls=SklearnGeoBallTreeAdapter,
)
ds_roms_index
```

### Indexing

```{code-cell} python
ds_roms_selection = ds_roms_index.sel(
lat_rho=ds_trajectory.lat,
lon_rho=ds_trajectory.lon,
method="nearest",
)
ds_roms_selection
```

```{code-cell} python
plt.figure()
ds_roms_selection.plot.scatter(x="lat_rho", y="lat_rho", hue="zeta")
plt.show()
```
2 changes: 2 additions & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@
"python": ("https://docs.python.org/3/", None),
"pandas": ("https://pandas.pydata.org/pandas-docs/stable", None),
"numpy": ("https://numpy.org/doc/stable", None),
"scipy": ("https://docs.scipy.org/doc/scipy/", None),
"xarray": ("https://docs.xarray.dev/en/latest/", None),
"rasterix": ("https://rasterix.readthedocs.io/en/latest/", None),
"shapely": ("https://shapely.readthedocs.io/en/latest/", None),
Expand All @@ -124,5 +125,6 @@
"geopandas": ("https://geopandas.org/en/stable/", None),
"pint-xarray": ("https://pint-xarray.readthedocs.io/en/latest/", None),
"pint": ("https://pint.readthedocs.io/en/stable/", None),
"sklearn": ("https://scikit-learn.org/stable/", None),
"astropy": ("https://docs.astropy.org/en/latest/", None),
}
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,4 @@ pint-xarray
cf_xarray
astropy
git+https://github.com/pydata/xarray
scikit-learn