From c18e6012dc65866e5230b64fc3ee058e71e1f63f Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Wed, 9 Jul 2025 15:59:19 +0200 Subject: [PATCH 1/4] add NDPointIndex examples (wip advanced example) --- docs/blocks/ndpoint.md | 158 ++++++++++++++++++++++++++++++++++++++++- docs/conf.py | 2 + requirements.txt | 1 + 3 files changed, 160 insertions(+), 1 deletion(-) diff --git a/docs/blocks/ndpoint.md b/docs/blocks/ndpoint.md index 275a8a4..12840f9 100644 --- a/docs/blocks/ndpoint.md +++ b/docs/blocks/ndpoint.md @@ -1 +1,157 @@ -# Tree-based indexes with `NDPointIndex` +--- +jupytext: + text_representation: + format_name: myst +kernelspec: + display_name: Python 3 + name: python +--- + +# Nearest neighbors with `NDPointIndex` + +## Highlights + +1. {py:class}`xarray.indexes.NDPointIndex` is useful for dealing with + n-dimensional coordinate variables representing irregular data. +1. It enables point-wise (nearest-neighbors) data selection using Xarray's + [advanced indexing](https://docs.xarray.dev/en/latest/user-guide/indexing.html#more-advanced-indexing) + capabilities. +1. By default, a {py:class}`scipy.spatial.KDTree` is used under the hood for + fast lookup of point data. Although experimental, it is possible to plug in + alternative structures to `NDPointIndex` (see the advanced example below). + +## Basic example + +Let's create a dataset with random points. + +```{code-cell} python +import numpy as np +import matplotlib.pyplot as plt +import xarray as xr +``` + +```{code-cell} python +--- +tags: [remove-cell] +--- +%xmode minimal +xr.set_options( + display_expand_indexes=True, + display_expand_data=False, +); +``` + +```{code-cell} python +shape = (5, 10) +xx = xr.DataArray(np.random.uniform(0, 10, size=shape), dims=("y", "x")) +yy = xr.DataArray(np.random.uniform(0, 5, size=shape), dims=("y", "x")) +data = (xx - 5)**2 + (yy - 2.5)**2 + +ds = xr.Dataset(data_vars={"data": data}, coords={"xx": xx, "yy": yy}) +ds +``` + +```{code-cell} python +ds.plot.scatter(x="xx", y="yy", hue="data"); +``` + +### Assigning + +```{code-cell} python +ds_index = ds.set_xindex(("xx", "yy"), xr.indexes.NDPointIndex) +ds_index +``` + +### Point-wise indexing + +Select one value. + +```{code-cell} python +ds_index.sel(xx=3.4, yy=4.2, method="nearest") +``` + +Select multiple points in the `x`/`y` dimension space, using +{py:class}`xarray.DataArray` objects as input labels. + +```{code-cell} python +# create a regular grid as query points +ds_grid = xr.Dataset(coords={"x": range(10), "y": range(5)}) + +# selection supports broadcasting of the input labels +ds_selection = ds_index.sel( + xx=ds_grid.x, yy=ds_grid.y, method="nearest" +) + +# assign selection results to the grid +# -> nearest neighbor interpolation +ds_grid["data"] = ds_selection.data.variable + +ds_grid +``` + +```{code-cell} python +ds_grid.data.plot(x="x", y="y") +ds.plot.scatter(x="xx", y="yy", c="k") +plt.show() +``` + +## Advanced example + +This example is based on the Regional Ocean Modeling System (ROMS) [Xarray +example](https://docs.xarray.dev/en/stable/examples/ROMS_ocean_model.html). + +```{note} +An alternative solution to this example is to use +{py:class}`xarray.indexes.CoordinateTransformIndex` (see {doc}`transform`) with the +horizontal coordinate transformations defined in ROMS. +``` + +```{code-cell} python +ds_roms = xr.tutorial.open_dataset("ROMS_example") +ds_roms +``` + +The dataset above is represented on a curvilinear grid with 2-dimensional +`lat_rho` and `lon_rho` coordinate variables (in degrees). The default kd-tree +structure used by {py:class}`~xarray.indexes.NDPointIndex` isn't best suited for +these latitude and longitude coordinates. Fortunately, there a way of using +alternative structures. Here let's use {py:class}`sklearn.neighbors.BallTree` +with the `haversine` distance metric. + +```{code-cell} python +from sklearn.neighbors import BallTree +from xarray.indexes.nd_point_index import TreeAdapter + + +class SklearnGeoBallTreeAdapter(TreeAdapter): + + def __init__(self, points: np.ndarray, options: dict): + options.update({'metric': 'haversine'}) + self._balltree = BallTree(np.deg2rad(points), **options) + + def query(self, points: np.ndarray) -> tuple[np.ndarray, np.ndarray]: + return self._balltree.query(np.deg2rad(points)) + + def equals(self, other: "SklearnGeoBallTreeAdapter") -> bool: + return np.array_equal(self._balltree.data, other._kdtree.data) +``` + +```{note} +Using alternative structures via custom `TreeAdapter` subclasses is an +experimental feature! + +The adapter above based on {py:class}`sklearn.neighbors.BallTree` will +eventually be available in the [xoak](https://github.com/xarray-contrib/xoak) +package along with other useful adapters. +``` + +### Assigning + +```{code-cell} python +ds_roms_index = ds_roms.set_xindex( + ("lat_rho", "lon_rho"), + xr.indexes.NDPointIndex, + tree_adapter_cls=SklearnGeoBallTreeAdapter, +) +ds_roms_index +``` diff --git a/docs/conf.py b/docs/conf.py index b109771..8c267ae 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -116,6 +116,7 @@ "python": ("https://docs.python.org/3/", None), "pandas": ("https://pandas.pydata.org/pandas-docs/stable", None), "numpy": ("https://numpy.org/doc/stable", None), + "scipy": ("https://docs.scipy.org/doc/scipy/", None), "xarray": ("https://docs.xarray.dev/en/latest/", None), "rasterix": ("https://rasterix.readthedocs.io/en/latest/", None), "shapely": ("https://shapely.readthedocs.io/en/latest/", None), @@ -124,4 +125,5 @@ "geopandas": ("https://geopandas.readthedocs.io/en/stable/", None), "pint-xarray": ("https://pint-xarray.readthedocs.io/en/latest/", None), "pint": ("https://pint.readthedocs.io/en/stable/", None), + "sklearn": ("https://scikit-learn.org/stable/", None), } diff --git a/requirements.txt b/requirements.txt index 5688e1a..676c6c6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -22,3 +22,4 @@ git+https://github.com/dcherian/rolodex pint-xarray cf_xarray git+https://github.com/pydata/xarray +scikit-learn From 260df50e685e605c7bdca6013b6909381cad10a1 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Wed, 9 Jul 2025 16:21:44 +0200 Subject: [PATCH 2/4] advanced example: add indexing --- docs/blocks/ndpoint.md | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/docs/blocks/ndpoint.md b/docs/blocks/ndpoint.md index 12840f9..687c352 100644 --- a/docs/blocks/ndpoint.md +++ b/docs/blocks/ndpoint.md @@ -155,3 +155,27 @@ ds_roms_index = ds_roms.set_xindex( ) ds_roms_index ``` + +### Indexing + +```{code-cell} python +ds_trajectory = xr.Dataset( + coords={ + "lat": ('trajectory', np.linspace(28, 30, 50)), + "lon": ('trajectory', np.linspace(-93, -88, 50)), + }, +) + +ds_roms_selection = ds_roms_index.sel( + lat_rho=ds_trajectory.lat, + lon_rho=ds_trajectory.lon, + method="nearest", +) +ds_roms_selection +``` + +```{code-cell} python + +ds_roms_selection.plot.scatter(x="lat_rho", y="lat_rho", hue="zeta") +plt.show() +``` From 79508c1334d808bea07c6a5509a0c56236da5afa Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Wed, 9 Jul 2025 18:43:31 +0200 Subject: [PATCH 3/4] explain with using sklearn.BallTree --- docs/blocks/ndpoint.md | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/docs/blocks/ndpoint.md b/docs/blocks/ndpoint.md index 687c352..6aac9b9 100644 --- a/docs/blocks/ndpoint.md +++ b/docs/blocks/ndpoint.md @@ -115,8 +115,9 @@ The dataset above is represented on a curvilinear grid with 2-dimensional `lat_rho` and `lon_rho` coordinate variables (in degrees). The default kd-tree structure used by {py:class}`~xarray.indexes.NDPointIndex` isn't best suited for these latitude and longitude coordinates. Fortunately, there a way of using -alternative structures. Here let's use {py:class}`sklearn.neighbors.BallTree` -with the `haversine` distance metric. +alternative structures. Here let's use {py:class}`sklearn.neighbors.BallTree`, +which supports providing distance metrics such as `haversine` that will better +work with latitude and longitude data. ```{code-cell} python from sklearn.neighbors import BallTree From 94bdafe488c7705ff90cb0af58074c2547975d22 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Wed, 9 Jul 2025 13:17:38 -0700 Subject: [PATCH 4/4] edits --- docs/blocks/ndpoint.md | 50 ++++++++++++++++++++++++------------------ 1 file changed, 29 insertions(+), 21 deletions(-) diff --git a/docs/blocks/ndpoint.md b/docs/blocks/ndpoint.md index 6aac9b9..87d8054 100644 --- a/docs/blocks/ndpoint.md +++ b/docs/blocks/ndpoint.md @@ -18,9 +18,9 @@ kernelspec: capabilities. 1. By default, a {py:class}`scipy.spatial.KDTree` is used under the hood for fast lookup of point data. Although experimental, it is possible to plug in - alternative structures to `NDPointIndex` (see the advanced example below). + alternative structures to `NDPointIndex` (See {ref}`advanced`). -## Basic example +## Basic Example: Default KDTree Let's create a dataset with random points. @@ -95,27 +95,42 @@ ds.plot.scatter(x="xx", y="yy", c="k") plt.show() ``` +(advanced)= + ## Advanced example This example is based on the Regional Ocean Modeling System (ROMS) [Xarray example](https://docs.xarray.dev/en/stable/examples/ROMS_ocean_model.html). -```{note} -An alternative solution to this example is to use -{py:class}`xarray.indexes.CoordinateTransformIndex` (see {doc}`transform`) with the -horizontal coordinate transformations defined in ROMS. -``` - ```{code-cell} python ds_roms = xr.tutorial.open_dataset("ROMS_example") ds_roms ``` The dataset above is represented on a curvilinear grid with 2-dimensional -`lat_rho` and `lon_rho` coordinate variables (in degrees). The default kd-tree -structure used by {py:class}`~xarray.indexes.NDPointIndex` isn't best suited for -these latitude and longitude coordinates. Fortunately, there a way of using -alternative structures. Here let's use {py:class}`sklearn.neighbors.BallTree`, +`lat_rho` and `lon_rho` coordinate variables (in degrees). We will illustrate sampling a +straight line trajectory through this field. + +```{code-cell} python +import matplotlib.pyplot as plt + +ds_trajectory = xr.Dataset( + coords={ + "lat": ('trajectory', np.linspace(28, 30, 50)), + "lon": ('trajectory', np.linspace(-93, -88, 50)), + }, +) + +ds_roms.salt.isel(s_rho=-1, ocean_time=0).plot(x="lon_rho", y="lat_rho") +plt.plot( + ds_trajectory.lon.data, ds_trajectory.lat.data, marker='.', color='k', ms=4, ls="none", +) +plt.show() +``` + +The default kd-tree structure used by {py:class}`~xarray.indexes.NDPointIndex` +isn't best suited for these latitude and longitude coordinates. Fortunately, there +is a way of using alternative structures. Here let's use {py:class}`sklearn.neighbors.BallTree`, which supports providing distance metrics such as `haversine` that will better work with latitude and longitude data. @@ -138,7 +153,7 @@ class SklearnGeoBallTreeAdapter(TreeAdapter): ``` ```{note} -Using alternative structures via custom `TreeAdapter` subclasses is an +Using alternative structures via custom {py:class}`~xarray.indexes.TreeAdapter` subclasses is an experimental feature! The adapter above based on {py:class}`sklearn.neighbors.BallTree` will @@ -160,13 +175,6 @@ ds_roms_index ### Indexing ```{code-cell} python -ds_trajectory = xr.Dataset( - coords={ - "lat": ('trajectory', np.linspace(28, 30, 50)), - "lon": ('trajectory', np.linspace(-93, -88, 50)), - }, -) - ds_roms_selection = ds_roms_index.sel( lat_rho=ds_trajectory.lat, lon_rho=ds_trajectory.lon, @@ -176,7 +184,7 @@ ds_roms_selection ``` ```{code-cell} python - +plt.figure() ds_roms_selection.plot.scatter(x="lat_rho", y="lat_rho", hue="zeta") plt.show() ```