Skip to content

Commit 839caa1

Browse files
committed
Refactor again
1 parent 5e0b5d1 commit 839caa1

File tree

2 files changed

+23
-47
lines changed

2 files changed

+23
-47
lines changed

pygmt/clib/session.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1829,23 +1829,25 @@ def virtualfile_in(
18291829
... print(fout.read().strip())
18301830
<vector memory>: N = 3 <7/9> <4/6> <1/3>
18311831
"""
1832+
# Check if the combination of data, and x/y/z is valid.
1833+
if data is not None and any(v is not None for v in (x, y, z)):
1834+
msg = "Too much data. Use either data or x/y/z."
1835+
raise GMTInvalidInput(msg)
1836+
1837+
# Determine the kind of data.
18321838
kind = data_kind(data, required=required_data)
1833-
_validate_data_input(
1834-
data=data,
1835-
x=x,
1836-
y=y,
1837-
z=z,
1838-
required_z=required_z,
1839-
required_data=required_data,
1840-
kind=kind,
1841-
)
18421839

1840+
# Check if the kind of data is valid.
18431841
if check_kind:
18441842
valid_kinds = ("file", "arg") if required_data is False else ("file",)
1845-
if check_kind == "raster":
1846-
valid_kinds += ("grid", "image")
1847-
elif check_kind == "vector":
1848-
valid_kinds += ("empty", "matrix", "vectors", "geojson")
1843+
match check_kind:
1844+
case "raster":
1845+
valid_kinds += ("grid", "image")
1846+
case "vector":
1847+
valid_kinds += ("empty", "matrix", "vectors", "geojson")
1848+
case _:
1849+
msg = f"Unrecognized check_kind: {check_kind}."
1850+
raise GMTInvalidInput(msg)
18491851
if kind not in valid_kinds:
18501852
msg = f"Unrecognized data type for {check_kind}: {type(data)}."
18511853
raise GMTInvalidInput(msg)
@@ -1898,6 +1900,8 @@ def virtualfile_in(
18981900
_virtualfile_from = self.virtualfile_from_vectors
18991901
_data = data.T
19001902

1903+
_validate_data_input(data=_data, required_z=required_z, kind=kind)
1904+
19011905
# Finally create the virtualfile from the data, to be passed into GMT
19021906
file_context = _virtualfile_from(_data)
19031907
return file_context

pygmt/helpers/utils.py

Lines changed: 6 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from pathlib import Path
1616
from typing import Any, Literal
1717

18-
import numpy as np
1918
import xarray as xr
2019
from pygmt.encodings import charset
2120
from pygmt.exceptions import GMTInvalidInput
@@ -46,15 +45,7 @@
4645
]
4746

4847

49-
def _validate_data_input(
50-
data=None,
51-
x=None,
52-
y=None,
53-
z=None,
54-
required_z: bool = False,
55-
required_data: bool = True,
56-
kind: Kind | None = None,
57-
) -> None:
48+
def _validate_data_input(data: Any, kind: Kind, required_z: bool = False) -> None:
5849
"""
5950
Check if the combination of data/x/y/z is valid.
6051
@@ -126,29 +117,18 @@ def _validate_data_input(
126117
GMTInvalidInput
127118
If the data input is not valid.
128119
"""
129-
# Check if too much data is provided.
130-
if data is not None and any(v is not None for v in (x, y, z)):
131-
msg = "Too much data. Use either data or x/y/z."
132-
raise GMTInvalidInput(msg)
133-
134-
# Determine the data kind if not provided.
135-
kind = kind or data_kind(data, required=required_data)
136-
137120
# Determine the required number of columns based on the required_z flag.
138121
required_cols = 3 if required_z else 1
139122

140123
# Check based on the data kind.
141124
match kind:
142-
case "empty": # data is given via a series vectors like x/y/z.
143-
if x is None and y is None:
144-
msg = "No input data provided."
125+
case "empty": # data = [x, y, z]
126+
if required_z and len(data) < 3:
127+
msg = "Must provide x, y, and z."
145128
raise GMTInvalidInput(msg)
146-
if x is None or y is None:
129+
if any(v is None for v in data):
147130
msg = "Must provide both x and y."
148131
raise GMTInvalidInput(msg)
149-
if required_z and z is None:
150-
msg = "Must provide x, y, and z."
151-
raise GMTInvalidInput(msg)
152132
case "matrix": # 2-D numpy.ndarray
153133
if (actual_cols := data.shape[1]) < required_cols:
154134
msg = (
@@ -157,16 +137,8 @@ def _validate_data_input(
157137
)
158138
raise GMTInvalidInput(msg)
159139
case "vectors":
160-
# The if-else block should match the codes in the virtualfile_in function.
161-
if hasattr(data, "items") and not hasattr(data, "to_frame"):
162-
# Dict, pandas.DataFrame, or xarray.Dataset, but not pd.Series.
163-
_data = [array for _, array in data.items()]
164-
else:
165-
# Python list, tuple, numpy.ndarray, and pandas.Series types
166-
_data = np.atleast_2d(np.asanyarray(data).T)
167-
168140
# Check if the number of columns is sufficient.
169-
if (actual_cols := len(_data)) < required_cols:
141+
if (actual_cols := len(data)) < required_cols:
170142
msg = (
171143
f"Need at least {required_cols} columns but {actual_cols} "
172144
"column(s) are given."

0 commit comments

Comments
 (0)