Skip to content

Commit 0d6864c

Browse files
authored
Fix indexing bug; add tests (#21)
1 parent 2516e3a commit 0d6864c

File tree

2 files changed

+50
-13
lines changed

2 files changed

+50
-13
lines changed

src/rasterix/raster_index.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -209,10 +209,13 @@ def sel(self, labels, method=None, tolerance=None):
209209
label = labels[coord_name]
210210

211211
if isinstance(label, slice):
212+
if label.start is None:
213+
label = slice(0, label.stop, label.step)
212214
if label.step is None:
213215
# continuous interval slice indexing (preserves the index)
214216
pos = self.transform.reverse({coord_name: np.array([label.start, label.stop])})
215-
pos = np.round(pos[self.dim]).astype("int")
217+
# np.round rounds to even, this way we round upwards
218+
pos = np.floor(pos[self.dim] + 0.5).astype("int")
216219
new_start = max(pos[0], 0)
217220
new_stop = min(pos[1], self.axis_transform.size)
218221
return IndexSelResult({self.dim: slice(new_start, new_stop)})
@@ -368,7 +371,7 @@ def sel(self, labels: dict[Any, Any], method=None, tolerance=None) -> IndexSelRe
368371
for coord_names, index in self._wrapped_indexes.items():
369372
if not isinstance(coord_names, tuple):
370373
coord_names = (coord_names,)
371-
index_labels = {k: v for k, v in labels if k in coord_names}
374+
index_labels = {k: v for k, v in labels.items() if k in coord_names}
372375
if index_labels:
373376
results.append(index.sel(index_labels, method=method, tolerance=tolerance))
374377

@@ -403,3 +406,14 @@ def __repr__(self):
403406
items += [repr(coord_names) + ":", textwrap.indent(repr(index), " ")]
404407

405408
return "RasterIndex\n" + "\n".join(items)
409+
410+
def transform(self) -> Affine:
411+
"""Returns Affine transform for top-left corners."""
412+
if len(self._wrapped_indexes) > 1:
413+
x = self._wrapped_indexes["x"].axis_transform.affine
414+
y = self._wrapped_indexes["y"].axis_transform.affine
415+
aff = Affine(x.a, x.b, x.c, y.d, y.e, y.f)
416+
else:
417+
index = next(iter(self._wrapped_indexes.values()))
418+
aff = index.affine
419+
return aff * Affine.translation(-0.5, -0.5)

tests/test_raster_index.py

Lines changed: 34 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,44 @@
1+
import numpy as np
2+
import rioxarray # noqa
13
import xarray as xr
4+
from affine import Affine
25

36
from rasterix import RasterIndex
47

58

6-
def test_rectilinear():
7-
source = "/vsicurl/https://noaadata.apps.nsidc.org/NOAA/G02135/south/daily/geotiff/2024/01_Jan/S_20240101_concentration_v3.0.tif"
8-
da_no_raster_index = xr.open_dataarray(source, engine="rasterio")
9-
x_dim = da_no_raster_index.rio.x_dim
10-
y_dim = da_no_raster_index.rio.y_dim
9+
def set_raster_index(obj):
10+
x_dim = obj.rio.x_dim
11+
y_dim = obj.rio.y_dim
1112

1213
index = RasterIndex.from_transform(
13-
da_no_raster_index.rio.transform(),
14-
da_no_raster_index.sizes[x_dim],
15-
da_no_raster_index.sizes[y_dim],
16-
x_dim=x_dim,
17-
y_dim=y_dim,
14+
obj.rio.transform(), obj.sizes[x_dim], obj.sizes[y_dim], x_dim=x_dim, y_dim=y_dim
1815
)
1916
coords = xr.Coordinates.from_xindex(index)
20-
da_raster_index = da_no_raster_index.assign_coords(coords)
17+
return obj.assign_coords(coords)
18+
19+
20+
def test_rectilinear():
21+
source = "/vsicurl/https://noaadata.apps.nsidc.org/NOAA/G02135/south/daily/geotiff/2024/01_Jan/S_20240101_concentration_v3.0.tif"
22+
da_no_raster_index = xr.open_dataarray(source, engine="rasterio")
23+
da_raster_index = set_raster_index(da_no_raster_index)
2124
assert da_raster_index.equals(da_no_raster_index)
25+
26+
27+
# TODO: parameterize over
28+
# 1. y points up;
29+
# 2. y points down
30+
def test_sel_slice():
31+
ds = xr.Dataset({"foo": (("y", "x"), np.ones((10, 12)))})
32+
transform = Affine.identity()
33+
ds = ds.rio.write_transform(transform)
34+
ds = set_raster_index(ds)
35+
36+
assert ds.xindexes["x"].transform() == transform
37+
38+
actual = ds.sel(x=slice(4), y=slice(3, 5))
39+
assert isinstance(actual.xindexes["x"], RasterIndex)
40+
assert isinstance(actual.xindexes["y"], RasterIndex)
41+
actual_transform = actual.xindexes["x"].transform()
42+
43+
assert actual_transform == actual.rio.transform()
44+
assert actual_transform == (transform * Affine.translation(0, 3))

0 commit comments

Comments
 (0)