diff --git a/docs/conf.py b/docs/conf.py index 366d125..9ecbaae 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -154,7 +154,8 @@ napoleon_type_aliases = { # general terms "sequence": ":term:`sequence`", - "iterable": ":term:`iterable`", + "Hashable": ":term:`sequence`", + "iterable": "~collections.abc.Hashable", "callable": ":py:func:`callable`", "dict_like": ":term:`dict-like `", "dict-like": ":term:`dict-like `", diff --git a/docs/raster_index/design_choices.md b/docs/raster_index/design_choices.md index 8dbb63a..b09c9ba 100644 --- a/docs/raster_index/design_choices.md +++ b/docs/raster_index/design_choices.md @@ -6,3 +6,14 @@ In designing {py:class}`RasterIndex`, we faced a few thorny questions. Below we discuss these considerations, and the approach we've taken. Ultimately, there are no easy answers and tradeoffs to be made. + +## Handling the `GeoTransform` attribute + +GDAL _chooses_ to track the affine transform using a `GeoTransform` attribute on a `spatial_ref` variable. The `"spatial_ref"` is a +"grid mapping" variable (as termed by the CF-conventions). It also records CRS information. Currently, our design is that +{py:class}`xproj.CRSIndex` controls the CRS information and handles the creation of the `"spatial_ref"` variable, or more generally, +the grid mapping variable. Thus, {py:class}`RasterIndex` _cannot_ keep the `"GeoTransform"` attribute on `"spatial_ref"` up-to-date +because it does not control it. + +Thus, {py:func}`assign_index` will delete the `"GeoTransform"` attribute on the grid mapping variable if it is detected, after using it +to construct the affine matrix. diff --git a/docs/rasterize/exactextract.ipynb b/docs/rasterize/exactextract.ipynb index 5b89fb5..468d895 100644 --- a/docs/rasterize/exactextract.ipynb +++ b/docs/rasterize/exactextract.ipynb @@ -28,9 +28,10 @@ "outputs": [], "source": [ "import xarray as xr\n", + "import xproj # noqa\n", "\n", "ds = xr.tutorial.open_dataset(\"eraint_uvz\")\n", - "ds = ds.rio.write_crs(\"epsg:4326\")\n", + "ds = ds.proj.assign_crs(spatial_ref=\"epsg:4326\")\n", "ds" ] }, diff --git a/docs/rasterize/geometry_mask.ipynb b/docs/rasterize/geometry_mask.ipynb index 685f383..718634d 100644 --- a/docs/rasterize/geometry_mask.ipynb +++ b/docs/rasterize/geometry_mask.ipynb @@ -28,9 +28,10 @@ "outputs": [], "source": [ "import xarray as xr\n", + "import xproj # noqa\n", "\n", "ds = xr.tutorial.open_dataset(\"eraint_uvz\")[[\"u\"]]\n", - "ds = ds.rio.write_crs(\"epsg:4326\")\n", + "ds = ds.proj.assign_crs(spatial_ref=\"epsg:4326\")\n", "ds" ] }, diff --git a/docs/rasterize/rasterio.ipynb b/docs/rasterize/rasterio.ipynb index b51928c..c6ac31f 100644 --- a/docs/rasterize/rasterio.ipynb +++ b/docs/rasterize/rasterio.ipynb @@ -28,9 +28,10 @@ "outputs": [], "source": [ "import xarray as xr\n", + "import xproj # noqa\n", "\n", "ds = xr.tutorial.open_dataset(\"eraint_uvz\")\n", - "ds = ds.rio.write_crs(\"epsg:4326\")\n", + "ds = ds.proj.assign_crs(spatial_ref=\"epsg:4326\")\n", "ds" ] }, diff --git a/pyproject.toml b/pyproject.toml index ea65179..a0d7cae 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,7 +54,6 @@ docs = [ "pooch", "dask-geopandas", "rasterio", - "rioxarray", "exactextract", "sparse", "netCDF4", diff --git a/src/rasterix/raster_index.py b/src/rasterix/raster_index.py index 6731567..35cb771 100644 --- a/src/rasterix/raster_index.py +++ b/src/rasterix/raster_index.py @@ -21,6 +21,7 @@ from rasterix.odc_compat import BoundingBox, bbox_intersection, bbox_union, maybe_int, snap_grid from rasterix.rioxarray_compat import guess_dims +from rasterix.utils import get_affine T_Xarray = TypeVar("T_Xarray", "DataArray", "Dataset") @@ -35,10 +36,14 @@ def assign_index(obj: T_Xarray, *, x_dim: str | None = None, y_dim: str | None = None) -> T_Xarray: """Assign a RasterIndex to an Xarray DataArray or Dataset. + By default, the affine transform is guessed by first looking for a ``GeoTransform`` attribute + on a CF "grid mapping" variable (commonly ``"spatial_ref"``). If not present, then the affine is determined from 1D coordinate + variables named ``x_dim`` and ``y_dim`` provided to this function. + Parameters ---------- obj : xarray.DataArray or xarray.Dataset - The object to assign the index to. Must have a rio accessor with a transform. + The object to assign the index to. x_dim : str, optional Name of the x dimension. If None, will be automatically detected. y_dim : str, optional @@ -49,22 +54,37 @@ def assign_index(obj: T_Xarray, *, x_dim: str | None = None, y_dim: str | None = xarray.DataArray or xarray.Dataset The input object with RasterIndex coordinates assigned. + Notes + ----- + The "grid mapping" variable is determined following the CF conventions: + + - If a DataArray is provided, we look for an attribute named ``"grid_mapping"``. + - For a Dataset, we pull the first detected ``"grid_mapping"`` attribute when iterating over data variables. + + The value of this attribute is a variable name containing projection information (commonly ``"spatial_ref"``). + We then look for a ``"GeoTransform"`` attribute on this variable (following GDAL convention). + + References + ---------- + - `CF conventions document `_. + - `GDAL docs on GeoTransform `_. + Examples -------- >>> import xarray as xr - >>> import rioxarray # Required for rio accessor + >>> import rioxarray # Required for reading TIFF >>> da = xr.open_dataset("path/to/raster.tif", engine="rasterio") >>> indexed_da = assign_index(da) """ - import rioxarray # noqa - if x_dim is None or y_dim is None: guessed_x, guessed_y = guess_dims(obj) x_dim = x_dim or guessed_x y_dim = y_dim or guessed_y + affine = get_affine(obj, x_dim=x_dim, y_dim=y_dim, clear_transform=True) + index = RasterIndex.from_transform( - obj.rio.transform(), width=obj.sizes[x_dim], height=obj.sizes[y_dim], x_dim=x_dim, y_dim=y_dim + affine, width=obj.sizes[x_dim], height=obj.sizes[y_dim], x_dim=x_dim, y_dim=y_dim, crs=obj.proj.crs ) coords = Coordinates.from_xindex(index) return obj.assign_coords(coords) @@ -283,6 +303,8 @@ def isel( # type: ignore[override] # return PandasIndex(values, new_dim, coord_dtype=values.dtype) def sel(self, labels, method=None, tolerance=None): + # CoordinateTransformIndex only supports "nearest" + method = method or "nearest" coord_name = self.axis_transform.coord_name label = labels[coord_name] @@ -515,8 +537,8 @@ def from_transform( affine = affine * Affine.translation(0.5, 0.5) if affine.is_rectilinear and affine.b == affine.d == 0: - x_transform = AxisAffineTransform(affine, width, "x", x_dim, is_xaxis=True) - y_transform = AxisAffineTransform(affine, height, "y", y_dim, is_xaxis=False) + x_transform = AxisAffineTransform(affine, width, x_dim, x_dim, is_xaxis=True) + y_transform = AxisAffineTransform(affine, height, y_dim, y_dim, is_xaxis=False) index = ( AxisAffineTransformIndex(x_transform), AxisAffineTransformIndex(y_transform), diff --git a/src/rasterix/rasterize/rasterio.py b/src/rasterix/rasterize/rasterio.py index b69f002..9c2459a 100644 --- a/src/rasterix/rasterize/rasterio.py +++ b/src/rasterix/rasterize/rasterio.py @@ -15,7 +15,8 @@ from rasterio.features import geometry_mask as geometry_mask_rio from rasterio.features import rasterize as rasterize_rio -from .utils import XAXIS, YAXIS, clip_to_bbox, get_affine, is_in_memory, prepare_for_dask +from ..utils import get_affine +from .utils import XAXIS, YAXIS, clip_to_bbox, is_in_memory, prepare_for_dask F = TypeVar("F", bound=Callable[..., Any]) @@ -161,7 +162,7 @@ def rasterize( obj = clip_to_bbox(obj, geometries, xdim=xdim, ydim=ydim) rasterize_kwargs = dict( - all_touched=all_touched, merge_alg=merge_alg, affine=get_affine(obj, xdim=xdim, ydim=ydim), env=env + all_touched=all_touched, merge_alg=merge_alg, affine=get_affine(obj, x_dim=xdim, y_dim=ydim), env=env ) # FIXME: box.crs == geometries.crs @@ -325,7 +326,7 @@ def geometry_mask( obj = clip_to_bbox(obj, geometries, xdim=xdim, ydim=ydim) geometry_mask_kwargs = dict( - all_touched=all_touched, affine=get_affine(obj, xdim=xdim, ydim=ydim), env=env + all_touched=all_touched, affine=get_affine(obj, x_dim=xdim, y_dim=ydim), env=env ) if is_in_memory(obj=obj, geometries=geometries): diff --git a/src/rasterix/rioxarray_compat.py b/src/rasterix/rioxarray_compat.py index 933fa38..ac56884 100644 --- a/src/rasterix/rioxarray_compat.py +++ b/src/rasterix/rioxarray_compat.py @@ -204,6 +204,8 @@ def guess_dims(obj: T_Xarray) -> tuple[str, str]: x_dim = "longitude" y_dim = "latitude" else: + x_dim = None + y_dim = None # look for coordinates with CF attributes for coord in obj.coords: # make sure to only look in 1D coordinates @@ -213,12 +215,18 @@ def guess_dims(obj: T_Xarray) -> tuple[str, str]: if (obj.coords[coord].attrs.get("axis", "").upper() == "X") or ( obj.coords[coord].attrs.get("standard_name", "").lower() in ("longitude", "projection_x_coordinate") + or obj.coords[coord].attrs.get("units", "").lower() in ("degrees_east",) ): x_dim = coord elif (obj.coords[coord].attrs.get("axis", "").upper() == "Y") or ( obj.coords[coord].attrs.get("standard_name", "").lower() in ("latitude", "projection_y_coordinate") + or obj.coords[coord].attrs.get("units", "").lower() in ("degrees_north",) ): y_dim = coord + if not x_dim or not y_dim: + raise ValueError( + "Could not guess names of x, y coordinate variables. Please explicitly pass `x_dim` and `y_dim`." + ) return x_dim, y_dim diff --git a/src/rasterix/utils.py b/src/rasterix/utils.py new file mode 100644 index 0000000..c0ec361 --- /dev/null +++ b/src/rasterix/utils.py @@ -0,0 +1,67 @@ +import xarray as xr +from affine import Affine + + +def get_grid_mapping_var(obj: xr.Dataset | xr.DataArray) -> xr.DataArray | None: + grid_mapping_var = None + if isinstance(obj, xr.DataArray): + if maybe := obj.attrs.get("grid_mapping", None): + if maybe in obj.coords: + grid_mapping_var = maybe + else: + # for datasets, grab the first one for simplicity + for var in obj.data_vars.values(): + if maybe := var.attrs.get("grid_mapping"): + if maybe in obj.coords: + # make sure it exists and is not an out-of-date attribute + grid_mapping_var = maybe + break + if grid_mapping_var is None and "spatial_ref" in obj.coords: + # hardcode this + grid_mapping_var = "spatial_ref" + if grid_mapping_var is not None: + return obj[grid_mapping_var] + return None + + +def get_affine( + obj: xr.Dataset | xr.DataArray, *, x_dim="x", y_dim="y", clear_transform: bool = False +) -> Affine: + """ + Grabs an affine transform from an Xarray object. + + This method will first look for the ``"GeoTransform"`` attribute on a variable named + ``"spatial_ref"``. If not, it will auto-guess the transform from the provided ``x_dim``, + and ``y_dim``. + + Parameters + ---------- + obj: xr.DataArray or xr.Dataset + x_dim: str, optional + Name of the X dimension coordinate variable. + y_dim: str, optional + Name of the Y dimension coordinate variable. + clear_transform: bool + Whether to delete the ``GeoTransform`` attribute if detected. + + Returns + ------- + affine: Affine + """ + grid_mapping_var = get_grid_mapping_var(obj) + if grid_mapping_var is not None and (transform := grid_mapping_var.attrs.get("GeoTransform")): + if clear_transform: + del grid_mapping_var.attrs["GeoTransform"] + return Affine.from_gdal(*map(float, transform.split(" "))) + else: + x = obj.coords[x_dim] + y = obj.coords[y_dim] + if x.ndim != 1: + raise ValueError(f"Coordinate variable {x_dim=!r} must be 1D.") + if y.ndim != 1: + raise ValueError(f"Coordinate variable {y_dim=!r} must be 1D.") + dx = (x[1] - x[0]).item() + dy = (y[1] - y[0]).item() + return Affine.translation( + x[0].item() - dx / 2, (y[0] if dy < 0 else y[-1]).item() - dy / 2 + ) * Affine.scale(dx, dy) diff --git a/tests/geometry_mask_snapshot.nc b/tests/geometry_mask_snapshot.nc index 8b9f14d..f817bd4 100644 Binary files a/tests/geometry_mask_snapshot.nc and b/tests/geometry_mask_snapshot.nc differ diff --git a/tests/rasterize_snapshot.nc b/tests/rasterize_snapshot.nc index 17d649a..9f95128 100644 Binary files a/tests/rasterize_snapshot.nc and b/tests/rasterize_snapshot.nc differ diff --git a/tests/test_exact.py b/tests/test_exact.py index f615a2c..17c20a2 100644 --- a/tests/test_exact.py +++ b/tests/test_exact.py @@ -9,6 +9,7 @@ import sparse import xarray as xr import xarray.testing as xrt +import xproj # noqa from exactextract import exact_extract from hypothesis import example, given, settings from xarray.tests import raise_if_dask_computes @@ -16,7 +17,7 @@ from rasterix.rasterize.exact import CoverageWeights, coverage, xy_to_raster_source dataset = xr.tutorial.open_dataset("eraint_uvz").rename({"latitude": "y", "longitude": "x"}) -dataset = dataset.rio.write_crs("epsg:4326") +dataset = dataset.proj.assign_crs(spatial_ref="epsg:4326") world = gpd.read_file(geodatasets.get_path("naturalearth land")) XSIZE = dataset.x.size YSIZE = dataset.y.size diff --git a/tests/test_raster_index.py b/tests/test_raster_index.py index 4c42826..8b5873d 100644 --- a/tests/test_raster_index.py +++ b/tests/test_raster_index.py @@ -6,9 +6,10 @@ import rioxarray # noqa import xarray as xr from affine import Affine -from xarray.testing import assert_identical +from xarray.testing import assert_equal, assert_identical from rasterix import RasterIndex, assign_index +from rasterix.utils import get_grid_mapping_var CRS_ATTRS = pyproj.CRS.from_epsg(4326).to_cf() @@ -20,6 +21,32 @@ def dataset_from_transform(transform: str) -> xr.Dataset: ).pipe(assign_index) +def test_grid_mapping_var(): + obj = xr.DataArray() + assert get_grid_mapping_var(obj) is None + + obj = xr.Dataset() + assert get_grid_mapping_var(obj) is None + + obj = xr.DataArray(attrs={"grid_mapping": "spatial_ref"}) + assert get_grid_mapping_var(obj) is None + + obj = xr.DataArray(attrs={"grid_mapping": "spatial_ref"}, coords={"spatial_ref": 0}) + assert_identical(get_grid_mapping_var(obj), obj["spatial_ref"]) + + obj = xr.Dataset({"foo": ((), 0, {"grid_mapping": "spatial_ref"})}) + assert get_grid_mapping_var(obj) is None + + obj = xr.Dataset( + { + "foo": ((), 0, {"grid_mapping": "spatial_ref_0"}), + "zoo": ((), 0, {"grid_mapping": "spatial_ref_1"}), + }, + coords={"spatial_ref_1": 0}, + ) + assert_identical(get_grid_mapping_var(obj), obj["spatial_ref_1"]) + + def test_set_xindex() -> None: coords = xr.Coordinates(coords={"x": np.arange(0.5, 12.5), "y": np.arange(0.5, 10.5)}, indexes={}) ds = xr.Dataset(coords=coords) @@ -56,18 +83,17 @@ def test_raster_index_properties(): def test_sel_slice(): ds = xr.Dataset({"foo": (("y", "x"), np.ones((10, 12)))}) transform = Affine.identity() - ds = ds.rio.write_transform(transform) + ds.coords["spatial_ref"] = ((), 0, {"GeoTransform": " ".join(map(str, transform.to_gdal()))}) ds = assign_index(ds) - + assert "GeoTransform" not in ds.spatial_ref.attrs assert ds.xindexes["x"].transform() == transform actual = ds.sel(x=slice(4), y=slice(3, 5)) assert isinstance(actual.xindexes["x"], RasterIndex) assert isinstance(actual.xindexes["y"], RasterIndex) - actual_transform = actual.xindexes["x"].transform() - assert actual_transform == actual.rio.transform() - assert actual_transform == (transform * Affine.translation(0, 3)) + actual_transform = actual.xindexes["x"].transform() + assert actual_transform == transform * Affine.translation(0, 3) def test_crs() -> None: @@ -363,3 +389,29 @@ def test_repr() -> None: index3 = RasterIndex.from_transform(Affine.identity(), width=12, height=10, crs="epsg:31370") assert repr(index3).startswith("RasterIndex(crs=EPSG:31370)") + + +def test_assign_index_cant_guess_error(): + ds = xr.Dataset( + {"sst": (("time", "lat", "lon"), np.ones((1, 89, 180)))}, + coords={"lat": np.arange(88, -89, -2), "lon": np.arange(0, 360, 2)}, + ) + with pytest.raises(ValueError, match="guess"): + assign_index(ds) + + +def test_wraparound_indexing_longitude(): + ds = xr.Dataset( + {"sst": (("time", "lat", "lon"), np.random.random((1, 89, 180)))}, + coords={ + "lat": ("lat", np.arange(88, -89, -2), {"axis": "Y"}), + "lon": ("lon", np.arange(0, 360, 2), {"axis": "X"}), + }, + ) + indexed = ds.pipe(assign_index) + # We lose existing attrs when calling ``assign_index``. + assert_equal(ds.sel(lon=[220, 240]), indexed.sel(lon=[-140, -120])) + assert_equal(ds.sel(lon=220), indexed.sel(lon=-140)) + assert_equal(ds.sel(lon=slice(220, 240)), indexed.sel(lon=slice(-140, -120))) + assert_equal(ds.sel(lon=slice(240, 220)), indexed.sel(lon=slice(-120, -140))) + # assert_equal(ds.sel(lon=[220]), indexed.sel(lon=[-140])) # FIXME diff --git a/tests/test_rasterize.py b/tests/test_rasterize.py index 2ba5f95..f2d4c84 100644 --- a/tests/test_rasterize.py +++ b/tests/test_rasterize.py @@ -4,30 +4,37 @@ import pytest import rioxarray # noqa import xarray as xr +import xproj # noqa from xarray.tests import raise_if_dask_computes from rasterix.rasterize.rasterio import geometry_mask, rasterize -@pytest.mark.parametrize("clip", [True, False]) -def test_rasterize(clip): +@pytest.fixture +def dataset(): + ds = xr.tutorial.open_dataset("eraint_uvz") + ds = ds.proj.assign_crs(spatial_ref="epsg:4326") + ds["spatial_ref"].attrs = ds.proj.crs.to_cf() + return ds + + +@pytest.mark.parametrize("clip", [False, True]) +def test_rasterize(clip, dataset): fname = "rasterize_snapshot.nc" try: snapshot = xr.load_dataarray(fname) except FileNotFoundError: - snapshot = xr.load_dataarray(f"./tests/{fname}") + fname = f"./tests/{fname}" + snapshot = xr.load_dataarray(fname) if clip: snapshot = snapshot.sel(latitude=slice(83.25, None)) - ds = xr.tutorial.open_dataset("eraint_uvz") - ds = ds.rio.write_crs("epsg:4326") world = gpd.read_file(geodatasets.get_path("naturalearth land")) - kwargs = dict(xdim="longitude", ydim="latitude", clip=clip) - rasterized = rasterize(ds, world[["geometry"]], **kwargs) + rasterized = rasterize(dataset, world[["geometry"]], **kwargs) xr.testing.assert_identical(rasterized, snapshot) - chunked = ds.chunk(latitude=119, longitude=-1) + chunked = dataset.chunk(latitude=119, longitude=-1) with raise_if_dask_computes(): drasterized = rasterize(chunked, world[["geometry"]], **kwargs) xr.testing.assert_identical(rasterized, drasterized) @@ -42,7 +49,7 @@ def test_rasterize(clip): @pytest.mark.parametrize("invert", [False, True]) @pytest.mark.parametrize("clip", [False, True]) -def test_geometry_mask(clip, invert): +def test_geometry_mask(clip, invert, dataset): fname = "geometry_mask_snapshot.nc" try: snapshot = xr.load_dataarray(fname) @@ -53,15 +60,13 @@ def test_geometry_mask(clip, invert): if invert: snapshot = ~snapshot - ds = xr.tutorial.open_dataset("eraint_uvz") - ds = ds.rio.write_crs("epsg:4326") world = gpd.read_file(geodatasets.get_path("naturalearth land")) kwargs = dict(xdim="longitude", ydim="latitude", clip=clip, invert=invert) - rasterized = geometry_mask(ds, world[["geometry"]], **kwargs) + rasterized = geometry_mask(dataset, world[["geometry"]], **kwargs) xr.testing.assert_identical(rasterized, snapshot) - chunked = ds.chunk(latitude=119, longitude=-1) + chunked = dataset.chunk(latitude=119, longitude=-1) with raise_if_dask_computes(): drasterized = geometry_mask(chunked, world[["geometry"]], **kwargs) xr.testing.assert_identical(drasterized, snapshot)