Skip to content

Add NDPointIndex examples #25

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jul 9, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
182 changes: 181 additions & 1 deletion docs/blocks/ndpoint.md
Original file line number Diff line number Diff line change
@@ -1 +1,181 @@
# Tree-based indexes with `NDPointIndex`
---
jupytext:
text_representation:
format_name: myst
kernelspec:
display_name: Python 3
name: python
---

# Nearest neighbors with `NDPointIndex`

## Highlights

1. {py:class}`xarray.indexes.NDPointIndex` is useful for dealing with
n-dimensional coordinate variables representing irregular data.
1. It enables point-wise (nearest-neighbors) data selection using Xarray's
[advanced indexing](https://docs.xarray.dev/en/latest/user-guide/indexing.html#more-advanced-indexing)
capabilities.
1. By default, a {py:class}`scipy.spatial.KDTree` is used under the hood for
fast lookup of point data. Although experimental, it is possible to plug in
alternative structures to `NDPointIndex` (see the advanced example below).

## Basic example

Let's create a dataset with random points.

```{code-cell} python
import numpy as np
import matplotlib.pyplot as plt
import xarray as xr
```

```{code-cell} python
---
tags: [remove-cell]
---
%xmode minimal
xr.set_options(
display_expand_indexes=True,
display_expand_data=False,
);
```

```{code-cell} python
shape = (5, 10)
xx = xr.DataArray(np.random.uniform(0, 10, size=shape), dims=("y", "x"))
yy = xr.DataArray(np.random.uniform(0, 5, size=shape), dims=("y", "x"))
data = (xx - 5)**2 + (yy - 2.5)**2

ds = xr.Dataset(data_vars={"data": data}, coords={"xx": xx, "yy": yy})
ds
```

```{code-cell} python
ds.plot.scatter(x="xx", y="yy", hue="data");
```

### Assigning

```{code-cell} python
ds_index = ds.set_xindex(("xx", "yy"), xr.indexes.NDPointIndex)
ds_index
```

### Point-wise indexing

Select one value.

```{code-cell} python
ds_index.sel(xx=3.4, yy=4.2, method="nearest")
```

Select multiple points in the `x`/`y` dimension space, using
{py:class}`xarray.DataArray` objects as input labels.

```{code-cell} python
# create a regular grid as query points
ds_grid = xr.Dataset(coords={"x": range(10), "y": range(5)})

# selection supports broadcasting of the input labels
ds_selection = ds_index.sel(
xx=ds_grid.x, yy=ds_grid.y, method="nearest"
)

# assign selection results to the grid
# -> nearest neighbor interpolation
ds_grid["data"] = ds_selection.data.variable

ds_grid
```

```{code-cell} python
ds_grid.data.plot(x="x", y="y")
ds.plot.scatter(x="xx", y="yy", c="k")
plt.show()
```

## Advanced example

This example is based on the Regional Ocean Modeling System (ROMS) [Xarray
example](https://docs.xarray.dev/en/stable/examples/ROMS_ocean_model.html).

```{note}
An alternative solution to this example is to use
{py:class}`xarray.indexes.CoordinateTransformIndex` (see {doc}`transform`) with the
horizontal coordinate transformations defined in ROMS.
```

```{code-cell} python
ds_roms = xr.tutorial.open_dataset("ROMS_example")
ds_roms
```

The dataset above is represented on a curvilinear grid with 2-dimensional
`lat_rho` and `lon_rho` coordinate variables (in degrees). The default kd-tree
structure used by {py:class}`~xarray.indexes.NDPointIndex` isn't best suited for
these latitude and longitude coordinates. Fortunately, there a way of using
alternative structures. Here let's use {py:class}`sklearn.neighbors.BallTree`
with the `haversine` distance metric.

```{code-cell} python
from sklearn.neighbors import BallTree
from xarray.indexes.nd_point_index import TreeAdapter


class SklearnGeoBallTreeAdapter(TreeAdapter):

def __init__(self, points: np.ndarray, options: dict):
options.update({'metric': 'haversine'})
self._balltree = BallTree(np.deg2rad(points), **options)

def query(self, points: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
return self._balltree.query(np.deg2rad(points))

def equals(self, other: "SklearnGeoBallTreeAdapter") -> bool:
return np.array_equal(self._balltree.data, other._kdtree.data)
```

```{note}
Using alternative structures via custom `TreeAdapter` subclasses is an
experimental feature!

The adapter above based on {py:class}`sklearn.neighbors.BallTree` will
eventually be available in the [xoak](https://github.com/xarray-contrib/xoak)
package along with other useful adapters.
```

### Assigning

```{code-cell} python
ds_roms_index = ds_roms.set_xindex(
("lat_rho", "lon_rho"),
xr.indexes.NDPointIndex,
tree_adapter_cls=SklearnGeoBallTreeAdapter,
)
ds_roms_index
```

### Indexing

```{code-cell} python
ds_trajectory = xr.Dataset(
coords={
"lat": ('trajectory', np.linspace(28, 30, 50)),
"lon": ('trajectory', np.linspace(-93, -88, 50)),
},
)

ds_roms_selection = ds_roms_index.sel(
lat_rho=ds_trajectory.lat,
lon_rho=ds_trajectory.lon,
method="nearest",
)
ds_roms_selection
```

```{code-cell} python

ds_roms_selection.plot.scatter(x="lat_rho", y="lat_rho", hue="zeta")
plt.show()
```
2 changes: 2 additions & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@
"python": ("https://docs.python.org/3/", None),
"pandas": ("https://pandas.pydata.org/pandas-docs/stable", None),
"numpy": ("https://numpy.org/doc/stable", None),
"scipy": ("https://docs.scipy.org/doc/scipy/", None),
"xarray": ("https://docs.xarray.dev/en/latest/", None),
"rasterix": ("https://rasterix.readthedocs.io/en/latest/", None),
"shapely": ("https://shapely.readthedocs.io/en/latest/", None),
Expand All @@ -124,4 +125,5 @@
"geopandas": ("https://geopandas.readthedocs.io/en/stable/", None),
"pint-xarray": ("https://pint-xarray.readthedocs.io/en/latest/", None),
"pint": ("https://pint.readthedocs.io/en/stable/", None),
"sklearn": ("https://scikit-learn.org/stable/", None),
}
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,4 @@ git+https://github.com/dcherian/rolodex
pint-xarray
cf_xarray
git+https://github.com/pydata/xarray
scikit-learn