Skip to content

Commit a917dad

Browse files
weiji14seisman
andauthored
Allow passing region to GMTBackendEntrypoint.open_dataset (#3932)
Support passing in a region as a Sequence [xmin, xmax, ymin, ymax] or ISO country code to `xarray.open_dataset` when using `engine="gmt"`. * Refactor _load_remote_dataset internals to use xr.load_dataarray Remove duplicated code calling GMT read, since `xr.load_dataarray(engine="gmt")` now works with region argument. * Update TypeError regex for test_xarray_backend_gmt_read_invalid_kind * Don't need to re-load GMTDataArray accessor info in GMTBackendEntrypoint GMTDataArrayAccessor info should already be loaded by calling `virtualfile_to_raster` which calls `self.read_virtualfile(vfname, kind=kind).contents.to_xarray()` that sets registration and gtype from the header. * Add doctest for load_dataarray with region argument * Remove @kwargs_to_strings from _load_remote_dataset * Sort list of source files alphabetically --------- Co-authored-by: Dongdong Tian <seisman.info@gmail.com>
1 parent 12f9776 commit a917dad

File tree

3 files changed

+82
-34
lines changed

3 files changed

+82
-34
lines changed

pygmt/datasets/load_remote_dataset.py

Lines changed: 3 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,7 @@
77
from typing import Any, Literal, NamedTuple
88

99
import xarray as xr
10-
from pygmt.clib import Session
1110
from pygmt.exceptions import GMTInvalidInput
12-
from pygmt.helpers import build_arg_list, kwargs_to_strings
13-
from pygmt.src import which
1411

1512
with contextlib.suppress(ImportError):
1613
# rioxarray is needed to register the rio accessor
@@ -502,7 +499,6 @@ class GMTRemoteDataset(NamedTuple):
502499
}
503500

504501

505-
@kwargs_to_strings(region="sequence")
506502
def _load_remote_dataset(
507503
name: str,
508504
prefix: str,
@@ -581,23 +577,9 @@ def _load_remote_dataset(
581577
raise GMTInvalidInput(msg)
582578

583579
fname = f"@{prefix}_{resolution}_{reg}"
584-
kwdict = {"R": region, "T": {"grid": "g", "image": "i"}[dataset.kind]}
585-
with Session() as lib:
586-
with lib.virtualfile_out(kind=dataset.kind) as voutgrd:
587-
lib.call_module(
588-
module="read",
589-
args=[fname, voutgrd, *build_arg_list(kwdict)],
590-
)
591-
grid = lib.virtualfile_to_raster(
592-
kind=dataset.kind, outgrid=None, vfname=voutgrd
593-
)
594-
595-
# Full path to the grid
596-
source: str | list = which(fname, verbose="q")
597-
if resinfo.tiled:
598-
source = sorted(source)[0] # get first grid for tiled grids
599-
# Manually add source to xarray.DataArray encoding to make the GMT accessors work.
600-
grid.encoding["source"] = source
580+
grid = xr.load_dataarray(
581+
fname, engine="gmt", raster_kind=dataset.kind, region=region
582+
)
601583

602584
# Add some metadata to the grid
603585
grid.attrs["description"] = dataset.description

pygmt/tests/test_xarray_backend.py

Lines changed: 40 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@ def test_xarray_backend_load_dataarray():
4040

4141
def test_xarray_backend_gmt_open_nc_grid():
4242
"""
43-
Ensure that passing engine='gmt' to xarray.open_dataarray works for opening NetCDF
44-
grids.
43+
Ensure that passing engine='gmt' to xarray.open_dataarray works to open a netCDF
44+
grid.
4545
"""
4646
with xr.open_dataarray(
4747
"@static_earth_relief.nc", engine="gmt", raster_kind="grid"
@@ -52,10 +52,29 @@ def test_xarray_backend_gmt_open_nc_grid():
5252
assert da.gmt.registration is GridRegistration.PIXEL
5353

5454

55+
def test_xarray_backend_gmt_open_nc_grid_with_region_bbox():
56+
"""
57+
Ensure that passing engine='gmt' with a `region` argument to xarray.open_dataarray
58+
works to open a netCDF grid over a specific bounding box.
59+
"""
60+
with xr.open_dataarray(
61+
"@static_earth_relief.nc",
62+
engine="gmt",
63+
raster_kind="grid",
64+
region=[-52, -48, -18, -12],
65+
) as da:
66+
assert da.sizes == {"lat": 6, "lon": 4}
67+
npt.assert_allclose(da.lat, [-17.5, -16.5, -15.5, -14.5, -13.5, -12.5])
68+
npt.assert_allclose(da.lon, [-51.5, -50.5, -49.5, -48.5])
69+
assert da.dtype == "float32"
70+
assert da.gmt.gtype is GridType.GEOGRAPHIC
71+
assert da.gmt.registration is GridRegistration.PIXEL
72+
73+
5574
def test_xarray_backend_gmt_open_tif_image():
5675
"""
57-
Ensure that passing engine='gmt' to xarray.open_dataarray works for opening GeoTIFF
58-
images.
76+
Ensure that passing engine='gmt' to xarray.open_dataarray works to open a GeoTIFF
77+
image.
5978
"""
6079
with xr.open_dataarray("@earth_day_01d", engine="gmt", raster_kind="image") as da:
6180
assert da.sizes == {"band": 3, "y": 180, "x": 360}
@@ -64,6 +83,22 @@ def test_xarray_backend_gmt_open_tif_image():
6483
assert da.gmt.registration is GridRegistration.PIXEL
6584

6685

86+
def test_xarray_backend_gmt_open_tif_image_with_region_iso():
87+
"""
88+
Ensure that passing engine='gmt' with a `region` argument to xarray.open_dataarray
89+
works to open a GeoTIFF image over a specific ISO country code border.
90+
"""
91+
with xr.open_dataarray(
92+
"@earth_day_01d", engine="gmt", raster_kind="image", region="BN"
93+
) as da:
94+
assert da.sizes == {"band": 3, "lat": 2, "lon": 2}
95+
npt.assert_allclose(da.lat, [5.5, 4.5])
96+
npt.assert_allclose(da.lon, [114.5, 115.5])
97+
assert da.dtype == "uint8"
98+
assert da.gmt.gtype is GridType.GEOGRAPHIC
99+
assert da.gmt.registration is GridRegistration.PIXEL
100+
101+
67102
def test_xarray_backend_gmt_load_grd_grid():
68103
"""
69104
Ensure that passing engine='gmt' to xarray.load_dataarray works for loading GRD
@@ -88,9 +123,7 @@ def test_xarray_backend_gmt_read_invalid_kind():
88123
"""
89124
with pytest.raises(
90125
TypeError,
91-
match=re.escape(
92-
"GMTBackendEntrypoint.open_dataset() missing 1 required keyword-only argument: 'raster_kind'"
93-
),
126+
match=re.escape("missing a required argument: 'raster_kind'"),
94127
):
95128
xr.open_dataarray("nokind.nc", engine="gmt")
96129

pygmt/xarray/backend.py

Lines changed: 39 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,14 @@
22
An xarray backend for reading raster grid/image files using the 'gmt' engine.
33
"""
44

5+
from collections.abc import Sequence
56
from typing import Literal
67

78
import xarray as xr
89
from pygmt._typing import PathLike
910
from pygmt.clib import Session
1011
from pygmt.exceptions import GMTInvalidInput
11-
from pygmt.helpers import build_arg_list
12+
from pygmt.helpers import build_arg_list, kwargs_to_strings
1213
from pygmt.src.which import which
1314
from xarray.backends import BackendEntrypoint
1415

@@ -30,6 +31,9 @@ class GMTBackendEntrypoint(BackendEntrypoint):
3031
- ``"grid"``: for reading single-band raster grids
3132
- ``"image"``: for reading multi-band raster images
3233
34+
Optionally, you can pass in a ``region`` in the form of a sequence [*xmin*, *xmax*,
35+
*ymin*, *ymax*] or an ISO country code.
36+
3337
Examples
3438
--------
3539
Read a single-band netCDF file using ``raster_kind="grid"``
@@ -68,18 +72,45 @@ class GMTBackendEntrypoint(BackendEntrypoint):
6872
* band (band) uint8... 1 2 3
6973
Attributes:...
7074
long_name: z
75+
76+
Load a single-band netCDF file using ``raster_kind="grid"`` over a bounding box
77+
region.
78+
79+
>>> da_grid = xr.load_dataarray(
80+
... "@tut_bathy.nc", engine="gmt", raster_kind="grid", region=[-64, -62, 32, 33]
81+
... )
82+
>>> da_grid
83+
<xarray.DataArray 'z' (lat: 13, lon: 25)>...
84+
array([[-4369., -4587., -4469., -4409., -4587., -4505., -4403., -4405.,
85+
-4466., -4595., -4609., -4608., -4606., -4607., -4607., -4597.,
86+
...
87+
-4667., -4642., -4677., -4795., -4797., -4800., -4803., -4818.,
88+
-4820.]], dtype=float32)
89+
Coordinates:
90+
* lat (lat) float64... 32.0 32.08 32.17 32.25 ... 32.83 32.92 33.0
91+
* lon (lon) float64... -64.0 -63.92 -63.83 ... -62.17 -62.08 -62.0
92+
Attributes:...
93+
Conventions: CF-1.7
94+
title: ETOPO5 global topography
95+
history: grdreformat -fg bermuda.grd bermuda.nc=ns
96+
description: /home/elepaio5/data/grids/etopo5.i2
97+
actual_range: [-4968. -4315.]
98+
long_name: Topography
99+
units: m
71100
"""
72101

73102
description = "Open raster (.grd, .nc or .tif) files in Xarray via GMT."
74-
open_dataset_parameters = ("filename_or_obj", "raster_kind")
103+
open_dataset_parameters = ("filename_or_obj", "raster_kind", "region")
75104
url = "https://pygmt.org/dev/api/generated/pygmt.GMTBackendEntrypoint.html"
76105

106+
@kwargs_to_strings(region="sequence")
77107
def open_dataset( # type: ignore[override]
78108
self,
79109
filename_or_obj: PathLike,
80110
*,
81111
drop_variables=None, # noqa: ARG002
82112
raster_kind: Literal["grid", "image"],
113+
region: Sequence[float] | str | None = None,
83114
# other backend specific keyword arguments
84115
# `chunks` and `cache` DO NOT go here, they are handled by xarray
85116
) -> xr.Dataset:
@@ -94,14 +125,17 @@ def open_dataset( # type: ignore[override]
94125
:gmt-docs:`reference/features.html#grid-file-format`.
95126
raster_kind
96127
Whether to read the file as a "grid" (single-band) or "image" (multi-band).
128+
region
129+
The subregion of the grid or image to load, in the form of a sequence
130+
[*xmin*, *xmax*, *ymin*, *ymax*] or an ISO country code.
97131
"""
98132
if raster_kind not in {"grid", "image"}:
99133
msg = f"Invalid raster kind: '{raster_kind}'. Valid values are 'grid' or 'image'."
100134
raise GMTInvalidInput(msg)
101135

102136
with Session() as lib:
103137
with lib.virtualfile_out(kind=raster_kind) as voutfile:
104-
kwdict = {"T": {"grid": "g", "image": "i"}[raster_kind]}
138+
kwdict = {"R": region, "T": {"grid": "g", "image": "i"}[raster_kind]}
105139
lib.call_module(
106140
module="read",
107141
args=[filename_or_obj, voutfile, *build_arg_list(kwdict)],
@@ -111,9 +145,8 @@ def open_dataset( # type: ignore[override]
111145
vfname=voutfile, kind=raster_kind
112146
)
113147
# Add "source" encoding
114-
source = which(fname=filename_or_obj)
148+
source: str | list = which(fname=filename_or_obj, verbose="q")
115149
raster.encoding["source"] = (
116-
source[0] if isinstance(source, list) else source
150+
sorted(source)[0] if isinstance(source, list) else source
117151
)
118-
_ = raster.gmt # Load GMTDataArray accessor information
119152
return raster.to_dataset()

0 commit comments

Comments
 (0)