Skip to content

Support wraparound indexing in longitude #47

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,8 @@
napoleon_type_aliases = {
# general terms
"sequence": ":term:`sequence`",
"iterable": ":term:`iterable`",
"Hashable": ":term:`sequence`",
"iterable": "~collections.abc.Hashable",
"callable": ":py:func:`callable`",
"dict_like": ":term:`dict-like <mapping>`",
"dict-like": ":term:`dict-like <mapping>`",
Expand Down
11 changes: 11 additions & 0 deletions docs/raster_index/design_choices.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,14 @@

In designing {py:class}`RasterIndex`, we faced a few thorny questions. Below we discuss these considerations, and the approach we've taken.
Ultimately, there are no easy answers and tradeoffs to be made.

## Handling the `GeoTransform` attribute

GDAL _chooses_ to track the affine transform using a `GeoTransform` attribute on a `spatial_ref` variable. The `"spatial_ref"` is a
"grid mapping" variable (as termed by the CF-conventions). It also records CRS information. Currently, our design is that
{py:class}`xproj.CRSIndex` controls the CRS information and handles the creation of the `"spatial_ref"` variable, or more generally,
the grid mapping variable. Thus, {py:class}`RasterIndex` _cannot_ keep the `"GeoTransform"` attribute on `"spatial_ref"` up-to-date
because it does not control it.

Thus, {py:func}`assign_index` will delete the `"GeoTransform"` attribute on the grid mapping variable if it is detected, after using it
to construct the affine matrix.
3 changes: 2 additions & 1 deletion docs/rasterize/exactextract.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,10 @@
"outputs": [],
"source": [
"import xarray as xr\n",
"import xproj # noqa\n",
"\n",
"ds = xr.tutorial.open_dataset(\"eraint_uvz\")\n",
"ds = ds.rio.write_crs(\"epsg:4326\")\n",
"ds = ds.proj.assign_crs(spatial_ref=\"epsg:4326\")\n",
"ds"
]
},
Expand Down
3 changes: 2 additions & 1 deletion docs/rasterize/geometry_mask.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,10 @@
"outputs": [],
"source": [
"import xarray as xr\n",
"import xproj # noqa\n",
"\n",
"ds = xr.tutorial.open_dataset(\"eraint_uvz\")[[\"u\"]]\n",
"ds = ds.rio.write_crs(\"epsg:4326\")\n",
"ds = ds.proj.assign_crs(spatial_ref=\"epsg:4326\")\n",
"ds"
]
},
Expand Down
3 changes: 2 additions & 1 deletion docs/rasterize/rasterio.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,10 @@
"outputs": [],
"source": [
"import xarray as xr\n",
"import xproj # noqa\n",
"\n",
"ds = xr.tutorial.open_dataset(\"eraint_uvz\")\n",
"ds = ds.rio.write_crs(\"epsg:4326\")\n",
"ds = ds.proj.assign_crs(spatial_ref=\"epsg:4326\")\n",
"ds"
]
},
Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ docs = [
"pooch",
"dask-geopandas",
"rasterio",
"rioxarray",
"exactextract",
"sparse",
"netCDF4",
Expand Down
36 changes: 29 additions & 7 deletions src/rasterix/raster_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from rasterix.odc_compat import BoundingBox, bbox_intersection, bbox_union, maybe_int, snap_grid
from rasterix.rioxarray_compat import guess_dims
from rasterix.utils import get_affine

T_Xarray = TypeVar("T_Xarray", "DataArray", "Dataset")

Expand All @@ -35,10 +36,14 @@
def assign_index(obj: T_Xarray, *, x_dim: str | None = None, y_dim: str | None = None) -> T_Xarray:
"""Assign a RasterIndex to an Xarray DataArray or Dataset.

By default, the affine transform is guessed by first looking for a ``GeoTransform`` attribute
on a CF "grid mapping" variable (commonly ``"spatial_ref"``). If not present, then the affine is determined from 1D coordinate
variables named ``x_dim`` and ``y_dim`` provided to this function.

Parameters
----------
obj : xarray.DataArray or xarray.Dataset
The object to assign the index to. Must have a rio accessor with a transform.
The object to assign the index to.
x_dim : str, optional
Name of the x dimension. If None, will be automatically detected.
y_dim : str, optional
Expand All @@ -49,22 +54,37 @@ def assign_index(obj: T_Xarray, *, x_dim: str | None = None, y_dim: str | None =
xarray.DataArray or xarray.Dataset
The input object with RasterIndex coordinates assigned.

Notes
-----
The "grid mapping" variable is determined following the CF conventions:

- If a DataArray is provided, we look for an attribute named ``"grid_mapping"``.
- For a Dataset, we pull the first detected ``"grid_mapping"`` attribute when iterating over data variables.

The value of this attribute is a variable name containing projection information (commonly ``"spatial_ref"``).
We then look for a ``"GeoTransform"`` attribute on this variable (following GDAL convention).

References
----------
- `CF conventions document <http://cfconventions.org/Data/cf-conventions/cf-conventions-1.11/cf-conventions.html#grid-mappings-and-projections>`_.
- `GDAL docs on GeoTransform <https://gdal.org/en/stable/tutorials/geotransforms_tut.html>`_.

Examples
--------
>>> import xarray as xr
>>> import rioxarray # Required for rio accessor
>>> import rioxarray # Required for reading TIFF
>>> da = xr.open_dataset("path/to/raster.tif", engine="rasterio")
>>> indexed_da = assign_index(da)
"""
import rioxarray # noqa

if x_dim is None or y_dim is None:
guessed_x, guessed_y = guess_dims(obj)
x_dim = x_dim or guessed_x
y_dim = y_dim or guessed_y

affine = get_affine(obj, x_dim=x_dim, y_dim=y_dim, clear_transform=True)

index = RasterIndex.from_transform(
obj.rio.transform(), width=obj.sizes[x_dim], height=obj.sizes[y_dim], x_dim=x_dim, y_dim=y_dim
affine, width=obj.sizes[x_dim], height=obj.sizes[y_dim], x_dim=x_dim, y_dim=y_dim, crs=obj.proj.crs
)
coords = Coordinates.from_xindex(index)
return obj.assign_coords(coords)
Expand Down Expand Up @@ -283,6 +303,8 @@ def isel( # type: ignore[override]
# return PandasIndex(values, new_dim, coord_dtype=values.dtype)

def sel(self, labels, method=None, tolerance=None):
# CoordinateTransformIndex only supports "nearest"
method = method or "nearest"
coord_name = self.axis_transform.coord_name
label = labels[coord_name]

Expand Down Expand Up @@ -515,8 +537,8 @@ def from_transform(
affine = affine * Affine.translation(0.5, 0.5)

if affine.is_rectilinear and affine.b == affine.d == 0:
x_transform = AxisAffineTransform(affine, width, "x", x_dim, is_xaxis=True)
y_transform = AxisAffineTransform(affine, height, "y", y_dim, is_xaxis=False)
x_transform = AxisAffineTransform(affine, width, x_dim, x_dim, is_xaxis=True)
y_transform = AxisAffineTransform(affine, height, y_dim, y_dim, is_xaxis=False)
index = (
AxisAffineTransformIndex(x_transform),
AxisAffineTransformIndex(y_transform),
Expand Down
7 changes: 4 additions & 3 deletions src/rasterix/rasterize/rasterio.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
from rasterio.features import geometry_mask as geometry_mask_rio
from rasterio.features import rasterize as rasterize_rio

from .utils import XAXIS, YAXIS, clip_to_bbox, get_affine, is_in_memory, prepare_for_dask
from ..utils import get_affine
from .utils import XAXIS, YAXIS, clip_to_bbox, is_in_memory, prepare_for_dask

F = TypeVar("F", bound=Callable[..., Any])

Expand Down Expand Up @@ -161,7 +162,7 @@ def rasterize(
obj = clip_to_bbox(obj, geometries, xdim=xdim, ydim=ydim)

rasterize_kwargs = dict(
all_touched=all_touched, merge_alg=merge_alg, affine=get_affine(obj, xdim=xdim, ydim=ydim), env=env
all_touched=all_touched, merge_alg=merge_alg, affine=get_affine(obj, x_dim=xdim, y_dim=ydim), env=env
)
# FIXME: box.crs == geometries.crs

Expand Down Expand Up @@ -325,7 +326,7 @@ def geometry_mask(
obj = clip_to_bbox(obj, geometries, xdim=xdim, ydim=ydim)

geometry_mask_kwargs = dict(
all_touched=all_touched, affine=get_affine(obj, xdim=xdim, ydim=ydim), env=env
all_touched=all_touched, affine=get_affine(obj, x_dim=xdim, y_dim=ydim), env=env
)

if is_in_memory(obj=obj, geometries=geometries):
Expand Down
8 changes: 8 additions & 0 deletions src/rasterix/rioxarray_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,8 @@ def guess_dims(obj: T_Xarray) -> tuple[str, str]:
x_dim = "longitude"
y_dim = "latitude"
else:
x_dim = None
y_dim = None
# look for coordinates with CF attributes
for coord in obj.coords:
# make sure to only look in 1D coordinates
Expand All @@ -213,12 +215,18 @@ def guess_dims(obj: T_Xarray) -> tuple[str, str]:
if (obj.coords[coord].attrs.get("axis", "").upper() == "X") or (
obj.coords[coord].attrs.get("standard_name", "").lower()
in ("longitude", "projection_x_coordinate")
or obj.coords[coord].attrs.get("units", "").lower() in ("degrees_east",)
):
x_dim = coord
elif (obj.coords[coord].attrs.get("axis", "").upper() == "Y") or (
obj.coords[coord].attrs.get("standard_name", "").lower()
in ("latitude", "projection_y_coordinate")
or obj.coords[coord].attrs.get("units", "").lower() in ("degrees_north",)
):
y_dim = coord

if not x_dim or not y_dim:
raise ValueError(
"Could not guess names of x, y coordinate variables. Please explicitly pass `x_dim` and `y_dim`."
)
return x_dim, y_dim
67 changes: 67 additions & 0 deletions src/rasterix/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import xarray as xr
from affine import Affine


def get_grid_mapping_var(obj: xr.Dataset | xr.DataArray) -> xr.DataArray | None:
grid_mapping_var = None
if isinstance(obj, xr.DataArray):
if maybe := obj.attrs.get("grid_mapping", None):
if maybe in obj.coords:
grid_mapping_var = maybe
else:
# for datasets, grab the first one for simplicity
for var in obj.data_vars.values():
if maybe := var.attrs.get("grid_mapping"):
if maybe in obj.coords:
# make sure it exists and is not an out-of-date attribute
grid_mapping_var = maybe
break
if grid_mapping_var is None and "spatial_ref" in obj.coords:
# hardcode this
grid_mapping_var = "spatial_ref"
if grid_mapping_var is not None:
return obj[grid_mapping_var]
return None


def get_affine(
obj: xr.Dataset | xr.DataArray, *, x_dim="x", y_dim="y", clear_transform: bool = False
) -> Affine:
"""
Grabs an affine transform from an Xarray object.

This method will first look for the ``"GeoTransform"`` attribute on a variable named
``"spatial_ref"``. If not, it will auto-guess the transform from the provided ``x_dim``,
and ``y_dim``.

Parameters
----------
obj: xr.DataArray or xr.Dataset
x_dim: str, optional
Name of the X dimension coordinate variable.
y_dim: str, optional
Name of the Y dimension coordinate variable.
clear_transform: bool
Whether to delete the ``GeoTransform`` attribute if detected.

Returns
-------
affine: Affine
"""
grid_mapping_var = get_grid_mapping_var(obj)
if grid_mapping_var is not None and (transform := grid_mapping_var.attrs.get("GeoTransform")):
if clear_transform:
del grid_mapping_var.attrs["GeoTransform"]
return Affine.from_gdal(*map(float, transform.split(" ")))
else:
x = obj.coords[x_dim]
y = obj.coords[y_dim]
if x.ndim != 1:
raise ValueError(f"Coordinate variable {x_dim=!r} must be 1D.")
if y.ndim != 1:
raise ValueError(f"Coordinate variable {y_dim=!r} must be 1D.")
dx = (x[1] - x[0]).item()
dy = (y[1] - y[0]).item()
return Affine.translation(
x[0].item() - dx / 2, (y[0] if dy < 0 else y[-1]).item() - dy / 2
) * Affine.scale(dx, dy)
Binary file modified tests/geometry_mask_snapshot.nc
Binary file not shown.
Binary file modified tests/rasterize_snapshot.nc
Binary file not shown.
3 changes: 2 additions & 1 deletion tests/test_exact.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,15 @@
import sparse
import xarray as xr
import xarray.testing as xrt
import xproj # noqa
from exactextract import exact_extract
from hypothesis import example, given, settings
from xarray.tests import raise_if_dask_computes

from rasterix.rasterize.exact import CoverageWeights, coverage, xy_to_raster_source

dataset = xr.tutorial.open_dataset("eraint_uvz").rename({"latitude": "y", "longitude": "x"})
dataset = dataset.rio.write_crs("epsg:4326")
dataset = dataset.proj.assign_crs(spatial_ref="epsg:4326")
world = gpd.read_file(geodatasets.get_path("naturalearth land"))
XSIZE = dataset.x.size
YSIZE = dataset.y.size
Expand Down
64 changes: 58 additions & 6 deletions tests/test_raster_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
import rioxarray # noqa
import xarray as xr
from affine import Affine
from xarray.testing import assert_identical
from xarray.testing import assert_equal, assert_identical

from rasterix import RasterIndex, assign_index
from rasterix.utils import get_grid_mapping_var

CRS_ATTRS = pyproj.CRS.from_epsg(4326).to_cf()

Expand All @@ -20,6 +21,32 @@ def dataset_from_transform(transform: str) -> xr.Dataset:
).pipe(assign_index)


def test_grid_mapping_var():
obj = xr.DataArray()
assert get_grid_mapping_var(obj) is None

obj = xr.Dataset()
assert get_grid_mapping_var(obj) is None

obj = xr.DataArray(attrs={"grid_mapping": "spatial_ref"})
assert get_grid_mapping_var(obj) is None

obj = xr.DataArray(attrs={"grid_mapping": "spatial_ref"}, coords={"spatial_ref": 0})
assert_identical(get_grid_mapping_var(obj), obj["spatial_ref"])

obj = xr.Dataset({"foo": ((), 0, {"grid_mapping": "spatial_ref"})})
assert get_grid_mapping_var(obj) is None

obj = xr.Dataset(
{
"foo": ((), 0, {"grid_mapping": "spatial_ref_0"}),
"zoo": ((), 0, {"grid_mapping": "spatial_ref_1"}),
},
coords={"spatial_ref_1": 0},
)
assert_identical(get_grid_mapping_var(obj), obj["spatial_ref_1"])


def test_set_xindex() -> None:
coords = xr.Coordinates(coords={"x": np.arange(0.5, 12.5), "y": np.arange(0.5, 10.5)}, indexes={})
ds = xr.Dataset(coords=coords)
Expand Down Expand Up @@ -56,18 +83,17 @@ def test_raster_index_properties():
def test_sel_slice():
ds = xr.Dataset({"foo": (("y", "x"), np.ones((10, 12)))})
transform = Affine.identity()
ds = ds.rio.write_transform(transform)
ds.coords["spatial_ref"] = ((), 0, {"GeoTransform": " ".join(map(str, transform.to_gdal()))})
ds = assign_index(ds)

assert "GeoTransform" not in ds.spatial_ref.attrs
assert ds.xindexes["x"].transform() == transform

actual = ds.sel(x=slice(4), y=slice(3, 5))
assert isinstance(actual.xindexes["x"], RasterIndex)
assert isinstance(actual.xindexes["y"], RasterIndex)
actual_transform = actual.xindexes["x"].transform()

assert actual_transform == actual.rio.transform()
assert actual_transform == (transform * Affine.translation(0, 3))
actual_transform = actual.xindexes["x"].transform()
assert actual_transform == transform * Affine.translation(0, 3)


def test_crs() -> None:
Expand Down Expand Up @@ -363,3 +389,29 @@ def test_repr() -> None:

index3 = RasterIndex.from_transform(Affine.identity(), width=12, height=10, crs="epsg:31370")
assert repr(index3).startswith("RasterIndex(crs=EPSG:31370)")


def test_assign_index_cant_guess_error():
ds = xr.Dataset(
{"sst": (("time", "lat", "lon"), np.ones((1, 89, 180)))},
coords={"lat": np.arange(88, -89, -2), "lon": np.arange(0, 360, 2)},
)
with pytest.raises(ValueError, match="guess"):
assign_index(ds)


def test_wraparound_indexing_longitude():
ds = xr.Dataset(
{"sst": (("time", "lat", "lon"), np.random.random((1, 89, 180)))},
coords={
"lat": ("lat", np.arange(88, -89, -2), {"axis": "Y"}),
"lon": ("lon", np.arange(0, 360, 2), {"axis": "X"}),
},
)
indexed = ds.pipe(assign_index)
# We lose existing attrs when calling ``assign_index``.
assert_equal(ds.sel(lon=[220, 240]), indexed.sel(lon=[-140, -120]))
assert_equal(ds.sel(lon=220), indexed.sel(lon=-140))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For example

assert_equal(ds.sel(lon=[0, 0]), indexed.sel(lon=[0, 359.9]))

should pass with some proper wraparound but raises

IndexError: index 180 is out of bounds for axis 2 with size 180

for indexed.sel(lon=359.9)

Copy link
Collaborator Author

@dcherian dcherian Jul 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmmm.. using method="nearest" by default is quite wrong, but CoordinateTransformIndex requires it

assert_equal(ds.sel(lon=slice(220, 240)), indexed.sel(lon=slice(-140, -120)))
assert_equal(ds.sel(lon=slice(240, 220)), indexed.sel(lon=slice(-120, -140)))
# assert_equal(ds.sel(lon=[220]), indexed.sel(lon=[-140])) # FIXME
Loading