Skip to content

Commit 1f6cdcb

Browse files
committed
Fix the _validate_data_input function after refacting data_kind
1 parent 419a080 commit 1f6cdcb

File tree

1 file changed

+21
-13
lines changed

1 file changed

+21
-13
lines changed

pygmt/helpers/utils.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -96,23 +96,31 @@ def _validate_data_input(
9696
if data is None: # data is None
9797
if x is None and y is None: # both x and y are None
9898
if required_data: # data is not optional
99-
raise GMTInvalidInput("No input data provided.")
99+
msg = "No input data provided."
100+
raise GMTInvalidInput(msg)
100101
elif x is None or y is None: # either x or y is None
101-
raise GMTInvalidInput("Must provide both x and y.")
102+
msg = "Must provide both x and y."
103+
raise GMTInvalidInput(msg)
102104
if required_z and z is None: # both x and y are not None, now check z
103-
raise GMTInvalidInput("Must provide x, y, and z.")
105+
msg = "Must provide x, y, and z."
106+
raise GMTInvalidInput(msg)
104107
else: # data is not None
105108
if x is not None or y is not None or z is not None:
106-
raise GMTInvalidInput("Too much data. Use either data or x/y/z.")
107-
# For 'matrix' kind, check if data has the required z column
108-
if kind == "matrix" and required_z:
109-
if hasattr(data, "shape"): # np.ndarray or pd.DataFrame
110-
if len(data.shape) == 1 and data.shape[0] < 3:
111-
raise GMTInvalidInput("data must provide x, y, and z columns.")
112-
if len(data.shape) > 1 and data.shape[1] < 3:
113-
raise GMTInvalidInput("data must provide x, y, and z columns.")
114-
if hasattr(data, "data_vars") and len(data.data_vars) < 3: # xr.Dataset
115-
raise GMTInvalidInput("data must provide x, y, and z columns.")
109+
msg = "Too much data. Use either data or x/y/z."
110+
raise GMTInvalidInput(msg)
111+
# check if data has the required z column
112+
if required_z:
113+
msg = "data must provide x, y, and z columns."
114+
if kind == "matrix" and data.shape[1] < 3:
115+
raise GMTInvalidInput(msg)
116+
if kind == "vectors":
117+
if hasattr(data, "shape") and (
118+
(len(data.shape) == 1 and data.shape[0] < 3)
119+
or (len(data.shape) > 1 and data.shape[1] < 3)
120+
): # np.ndarray or pd.DataFrame
121+
raise GMTInvalidInput(msg)
122+
if hasattr(data, "data_vars") and len(data.data_vars) < 3: # xr.Dataset
123+
raise GMTInvalidInput(msg)
116124

117125

118126
def _check_encoding(

0 commit comments

Comments
 (0)