Skip to content

Commit 234486b

Browse files
committed
Support wraparound indexing
1 parent fb36168 commit 234486b

File tree

3 files changed

+39
-3
lines changed

3 files changed

+39
-3
lines changed

src/rasterix/raster_index.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,8 @@ def isel( # type: ignore[override]
303303
# return PandasIndex(values, new_dim, coord_dtype=values.dtype)
304304

305305
def sel(self, labels, method=None, tolerance=None):
306+
# CoordinateTransformIndex only supports "nearest"
307+
method = method or "nearest"
306308
coord_name = self.axis_transform.coord_name
307309
label = labels[coord_name]
308310

@@ -535,8 +537,8 @@ def from_transform(
535537
affine = affine * Affine.translation(0.5, 0.5)
536538

537539
if affine.is_rectilinear and affine.b == affine.d == 0:
538-
x_transform = AxisAffineTransform(affine, width, "x", x_dim, is_xaxis=True)
539-
y_transform = AxisAffineTransform(affine, height, "y", y_dim, is_xaxis=False)
540+
x_transform = AxisAffineTransform(affine, width, x_dim, x_dim, is_xaxis=True)
541+
y_transform = AxisAffineTransform(affine, height, y_dim, y_dim, is_xaxis=False)
540542
index = (
541543
AxisAffineTransformIndex(x_transform),
542544
AxisAffineTransformIndex(y_transform),

src/rasterix/rioxarray_compat.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,8 @@ def guess_dims(obj: T_Xarray) -> tuple[str, str]:
204204
x_dim = "longitude"
205205
y_dim = "latitude"
206206
else:
207+
x_dim = None
208+
y_dim = None
207209
# look for coordinates with CF attributes
208210
for coord in obj.coords:
209211
# make sure to only look in 1D coordinates
@@ -213,12 +215,18 @@ def guess_dims(obj: T_Xarray) -> tuple[str, str]:
213215
if (obj.coords[coord].attrs.get("axis", "").upper() == "X") or (
214216
obj.coords[coord].attrs.get("standard_name", "").lower()
215217
in ("longitude", "projection_x_coordinate")
218+
or obj.coords[coord].attrs.get("units", "").lower() in ("degrees_east",)
216219
):
217220
x_dim = coord
218221
elif (obj.coords[coord].attrs.get("axis", "").upper() == "Y") or (
219222
obj.coords[coord].attrs.get("standard_name", "").lower()
220223
in ("latitude", "projection_y_coordinate")
224+
or obj.coords[coord].attrs.get("units", "").lower() in ("degrees_north",)
221225
):
222226
y_dim = coord
223227

228+
if not x_dim or not y_dim:
229+
raise ValueError(
230+
"Could not guess names of x, y coordinate variables. Please explicitly pass `x_dim` and `y_dim`."
231+
)
224232
return x_dim, y_dim

tests/test_raster_index.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import rioxarray # noqa
77
import xarray as xr
88
from affine import Affine
9-
from xarray.testing import assert_identical
9+
from xarray.testing import assert_equal, assert_identical
1010

1111
from rasterix import RasterIndex, assign_index
1212
from rasterix.utils import get_grid_mapping_var
@@ -389,3 +389,29 @@ def test_repr() -> None:
389389

390390
index3 = RasterIndex.from_transform(Affine.identity(), width=12, height=10, crs="epsg:31370")
391391
assert repr(index3).startswith("RasterIndex(crs=EPSG:31370)")
392+
393+
394+
def test_assign_index_cant_guess_error():
395+
ds = xr.Dataset(
396+
{"sst": (("time", "lat", "lon"), np.ones((1, 89, 180)))},
397+
coords={"lat": np.arange(88, -89, -2), "lon": np.arange(0, 360, 2)},
398+
)
399+
with pytest.raises(ValueError, match="guess"):
400+
assign_index(ds)
401+
402+
403+
def test_wraparound_indexing_longitude():
404+
ds = xr.Dataset(
405+
{"sst": (("time", "lat", "lon"), np.random.random((1, 89, 180)))},
406+
coords={
407+
"lat": ("lat", np.arange(88, -89, -2), {"axis": "Y"}),
408+
"lon": ("lon", np.arange(0, 360, 2), {"axis": "X"}),
409+
},
410+
)
411+
indexed = ds.pipe(assign_index)
412+
# We lose existing attrs when calling ``assign_index``.
413+
assert_equal(ds.sel(lon=[220, 240]), indexed.sel(lon=[-140, -120]))
414+
assert_equal(ds.sel(lon=220), indexed.sel(lon=-140))
415+
assert_equal(ds.sel(lon=slice(220, 240)), indexed.sel(lon=slice(-140, -120)))
416+
assert_equal(ds.sel(lon=slice(240, 220)), indexed.sel(lon=slice(-120, -140)))
417+
# assert_equal(ds.sel(lon=[220]), indexed.sel(lon=[-140])) # FIXME

0 commit comments

Comments
 (0)