Skip to content

BREAKING: Raise GMTValueError exception for invalid values part 2 #3998

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions pygmt/clib/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import pandas as pd
import xarray as xr
from packaging.version import Version
from pygmt.exceptions import GMTInvalidInput
from pygmt.exceptions import GMTValueError


def dataarray_to_matrix(
Expand Down Expand Up @@ -49,7 +49,7 @@ def dataarray_to_matrix(

Raises
------
GMTInvalidInput
GMTValueError
If the grid has more than two dimensions or variable grid spacing.

Examples
Expand Down Expand Up @@ -92,8 +92,11 @@ def dataarray_to_matrix(
[2.0, 2.0]
"""
if len(grid.dims) != 2:
msg = f"Invalid number of grid dimensions 'len({grid.dims})'. Must be 2."
raise GMTInvalidInput(msg)
raise GMTValueError(
len(grid.dims),
description="number of grid dimensions",
reason="The grid must be 2-D.",
)

# Extract region and inc from the grid
region, inc = [], []
Expand All @@ -113,8 +116,11 @@ def dataarray_to_matrix(
)
warnings.warn(msg, category=RuntimeWarning, stacklevel=2)
if coord_inc == 0:
msg = f"Grid has a zero increment in the '{dim}' dimension."
raise GMTInvalidInput(msg)
raise GMTValueError(
coord_inc,
description="grid increment",
reason=f"Grid has a zero increment in the '{dim}' dimension.",
)
region.extend(
[
coord.min() - coord_inc / 2 * grid.gmt.registration,
Expand Down
12 changes: 8 additions & 4 deletions pygmt/clib/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -913,9 +913,10 @@ def _check_dtype_and_dim(self, array: np.ndarray, ndim: int) -> int:

Raises
------
GMTValueError
If the array has the wrong number of dimensions.
GMTInvalidInput
If the array has the wrong number of dimensions or is an unsupported data
type.
If the array is an unsupported data type.

Examples
--------
Expand All @@ -933,8 +934,11 @@ def _check_dtype_and_dim(self, array: np.ndarray, ndim: int) -> int:
"""
# Check that the array has the given number of dimensions.
if array.ndim != ndim:
msg = f"Expected a numpy {ndim}-D array, got {array.ndim}-D."
raise GMTInvalidInput(msg)
raise GMTValueError(
array.ndim,
description="array dimension",
reason=f"Expected a numpy {ndim}-D array, got {array.ndim}-D.",
)

# 1-D arrays can be numeric or text, 2-D arrays can only be numeric.
valid_dtypes = DTYPES if ndim == 1 else DTYPES_NUMERIC
Expand Down
5 changes: 2 additions & 3 deletions pygmt/helpers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import xarray as xr
from pygmt._typing import PathLike
from pygmt.encodings import charset
from pygmt.exceptions import GMTInvalidInput
from pygmt.exceptions import GMTInvalidInput, GMTValueError

# Type hints for the list of encodings supported by PyGMT.
Encoding = Literal[
Expand Down Expand Up @@ -597,8 +597,7 @@ def build_arg_list( # noqa: PLR0912
or os.fspath(outfile) in {"", ".", ".."}
or os.fspath(outfile).endswith(("/", "\\"))
):
msg = f"Invalid output file name '{outfile}'."
raise GMTInvalidInput(msg)
raise GMTValueError(outfile, description="output file name")
gmt_args.append(f"->{os.fspath(outfile)}")
return gmt_args

Expand Down
14 changes: 9 additions & 5 deletions pygmt/helpers/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import Literal

from pygmt._typing import PathLike
from pygmt.exceptions import GMTInvalidInput
from pygmt.exceptions import GMTInvalidInput, GMTValueError


def validate_output_table_type(
Expand Down Expand Up @@ -39,7 +39,7 @@ def validate_output_table_type(
>>> validate_output_table_type(output_type="invalid-type")
Traceback (most recent call last):
...
pygmt.exceptions.GMTInvalidInput: Must specify 'output_type' either as 'file', ...
pygmt....GMTValueError: ...: 'invalid-type'. Expected one of: ...
>>> validate_output_table_type("file", outfile=None)
Traceback (most recent call last):
...
Expand All @@ -49,9 +49,13 @@ def validate_output_table_type(
... assert len(w) == 1
'file'
"""
if output_type not in {"file", "numpy", "pandas"}:
msg = "Must specify 'output_type' either as 'file', 'numpy', or 'pandas'."
raise GMTInvalidInput(msg)
_valids = {"pandas", "numpy", "file"}
if output_type not in _valids:
raise GMTValueError(
output_type,
description="value for parameter 'output_type'",
choices=_valids,
)
if output_type == "file" and outfile is None:
msg = "Must specify 'outfile' for output_type='file'."
raise GMTInvalidInput(msg)
Expand Down
12 changes: 4 additions & 8 deletions pygmt/src/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from pathlib import Path
from typing import Any, ClassVar, Literal

from pygmt.exceptions import GMTInvalidInput, GMTValueError
from pygmt.exceptions import GMTValueError
from pygmt.src.which import which


Expand Down Expand Up @@ -125,7 +125,7 @@ class _FocalMechanismConvention:
>>> _FocalMechanismConvention.from_params(["strike", "dip", "rake"])
Traceback (most recent call last):
...
pygmt.exceptions.GMTInvalidInput: Fail to determine focal mechanism convention...
pygmt.exceptions.GMTValueError: Invalid focal mechanism parameters: ...
"""

# Mapping of focal mechanism conventions to their parameters.
Expand Down Expand Up @@ -236,18 +236,14 @@ def from_params(

Raises
------
GMTInvalidInput
GMTValueError
If the focal mechanism convention cannot be determined from the given
parameters.
"""
for convention, param_list in cls._params.items():
if set(param_list).issubset(set(params)):
return cls(convention, component=component)
msg = (
"Fail to determine focal mechanism convention from the given parameters: "
f"{', '.join(params)}."
)
raise GMTInvalidInput(msg)
raise GMTValueError(params, description="focal mechanism parameters")


def _parse_coastline_resolution(
Expand Down
9 changes: 5 additions & 4 deletions pygmt/src/grd2xyz.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import xarray as xr
from pygmt._typing import PathLike
from pygmt.clib import Session
from pygmt.exceptions import GMTInvalidInput
from pygmt.exceptions import GMTValueError
from pygmt.helpers import (
build_arg_list,
fmt_docstring,
Expand Down Expand Up @@ -145,10 +145,11 @@ def grd2xyz(
output_type = validate_output_table_type(output_type, outfile=outfile)

if kwargs.get("o") is not None and output_type == "pandas":
msg = (
"If 'outcols' is specified, 'output_type' must be either 'numpy' or 'file'."
raise GMTValueError(
output_type,
description="value for parameter 'output_type'",
reason="Expected one of: 'numpy', 'file' if 'outcols' is specified.",
)
raise GMTInvalidInput(msg)
# Set the default column names for the pandas DataFrame header.
column_names: list[str] = ["x", "y", "z"]
# Let output pandas column names match input DataArray dimension names
Expand Down
11 changes: 6 additions & 5 deletions pygmt/src/hlines.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from collections.abc import Sequence

import numpy as np
from pygmt.exceptions import GMTInvalidInput
from pygmt.exceptions import GMTValueError

__doctest_skip__ = ["hlines"]

Expand Down Expand Up @@ -99,11 +99,12 @@ def hlines(

# Check if xmin/xmax are scalars or have the expected length.
if _xmin.size not in {1, nlines} or _xmax.size not in {1, nlines}:
msg = (
f"'xmin' and 'xmax' are expected to be scalars or have lengths '{nlines}', "
f"but lengths '{_xmin.size}' and '{_xmax.size}' are given."
_value = f"{_xmin.size}, {_xmax.size}"
raise GMTValueError(
_value,
description="size for 'xmin'/'xmax'",
reason=f"'xmin'/'xmax' are expected to be scalars or have lengths {nlines!r}.",
)
raise GMTInvalidInput(msg)

# Repeat xmin/xmax to match the length of y if they are scalars.
if nlines != 1:
Expand Down
9 changes: 6 additions & 3 deletions pygmt/src/meca.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import pandas as pd
from pygmt._typing import PathLike, TableLike
from pygmt.clib import Session
from pygmt.exceptions import GMTInvalidInput
from pygmt.exceptions import GMTInvalidInput, GMTValueError
from pygmt.helpers import (
build_arg_list,
data_kind,
Expand Down Expand Up @@ -66,8 +66,11 @@ def _preprocess_spec(spec, colnames, override_cols):
}
ndiff = spec.shape[1] - len(colnames)
if ndiff not in extra_cols:
msg = f"Input array must have {len(colnames)} or two/three more columns."
raise GMTInvalidInput(msg)
raise GMTValueError(
spec.shape[1],
description="input array shape",
reason=f"Input array must have {len(colnames)} or two/three more columns.",
)
spec = dict(zip([*colnames, *extra_cols[ndiff]], spec.T, strict=False))

# Now, the input data is a dict or an ASCII file.
Expand Down
11 changes: 6 additions & 5 deletions pygmt/src/vlines.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from collections.abc import Sequence

import numpy as np
from pygmt.exceptions import GMTInvalidInput
from pygmt.exceptions import GMTValueError

__doctest_skip__ = ["vlines"]

Expand Down Expand Up @@ -99,11 +99,12 @@ def vlines(

# Check if ymin/ymax are scalars or have the expected length.
if _ymin.size not in {1, nlines} or _ymax.size not in {1, nlines}:
msg = (
f"'ymin' and 'ymax' are expected to be scalars or have lengths '{nlines}', "
f"but lengths '{_ymin.size}' and '{_ymax.size}' are given."
_value = f"{_ymin.size}, {_ymax.size}"
raise GMTValueError(
_value,
description="size for 'ymin'/'ymax'",
reason=f"'ymin'/'ymax' are expected to be scalars or have lengths {nlines!r}.",
)
raise GMTInvalidInput(msg)

# Repeat ymin/ymax to match the length of x if they are scalars.
if nlines != 1:
Expand Down
8 changes: 4 additions & 4 deletions pygmt/tests/test_clib_dataarray_to_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pytest
import xarray as xr
from pygmt.clib.conversion import dataarray_to_matrix
from pygmt.exceptions import GMTInvalidInput
from pygmt.exceptions import GMTValueError


@pytest.mark.benchmark
Expand Down Expand Up @@ -81,7 +81,7 @@ def test_dataarray_to_matrix_dims_fails():
y = np.arange(12)
z = np.arange(10)
grid = xr.DataArray(data, coords=[("z", z), ("y", y), ("x", x)])
with pytest.raises(GMTInvalidInput):
with pytest.raises(GMTValueError):
dataarray_to_matrix(grid)


Expand All @@ -107,11 +107,11 @@ def test_dataarray_to_matrix_zero_inc_fails():
x = np.linspace(0, 1, 5)
y = np.zeros_like(x)
grid = xr.DataArray(data, coords=[("y", y), ("x", x)])
with pytest.raises(GMTInvalidInput):
with pytest.raises(GMTValueError):
dataarray_to_matrix(grid)

y = np.linspace(0, 1, 5)
x = np.zeros_like(x)
grid = xr.DataArray(data, coords=[("y", y), ("x", x)])
with pytest.raises(GMTInvalidInput):
with pytest.raises(GMTValueError):
dataarray_to_matrix(grid)
4 changes: 2 additions & 2 deletions pygmt/tests/test_clib_put_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import pytest
from pygmt import clib
from pygmt.clib.session import DTYPES_NUMERIC
from pygmt.exceptions import GMTCLibError, GMTInvalidInput
from pygmt.exceptions import GMTCLibError, GMTInvalidInput, GMTValueError
from pygmt.helpers import GMTTempFile


Expand Down Expand Up @@ -257,5 +257,5 @@ def test_put_vector_2d_fails():
dim=[1, 6, 0, 0], # ncolumns, nrows, dtype, unused
)
data = np.array([[37, 12, 556], [37, 12, 556]], dtype=np.int32)
with pytest.raises(GMTInvalidInput):
with pytest.raises(GMTValueError):
lib.put_vector(dataset, column=0, vector=data)
6 changes: 3 additions & 3 deletions pygmt/tests/test_grd2cpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import pytest
from pygmt import Figure, grd2cpt
from pygmt.exceptions import GMTInvalidInput
from pygmt.exceptions import GMTInvalidInput, GMTValueError
from pygmt.helpers import GMTTempFile
from pygmt.helpers.testing import load_static_earth_relief

Expand Down Expand Up @@ -37,15 +37,15 @@ def test_grd2cpt_blank_output(grid):
"""
Use incorrect setting by passing in blank file name to output parameter.
"""
with pytest.raises(GMTInvalidInput):
with pytest.raises(GMTValueError):
grd2cpt(grid=grid, output="")


def test_grd2cpt_invalid_output(grid):
"""
Use incorrect setting by passing in invalid type to output parameter.
"""
with pytest.raises(GMTInvalidInput):
with pytest.raises(GMTValueError):
grd2cpt(grid=grid, output=["some.cpt"])


Expand Down
4 changes: 2 additions & 2 deletions pygmt/tests/test_grd2xyz.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pandas as pd
import pytest
from pygmt import grd2xyz
from pygmt.exceptions import GMTInvalidInput
from pygmt.exceptions import GMTValueError
from pygmt.helpers.testing import load_static_earth_relief


Expand Down Expand Up @@ -38,5 +38,5 @@ def test_grd2xyz_pandas_output_with_o(grid):
"""
Test that grd2xyz fails when outcols is set and output_type is set to 'pandas'.
"""
with pytest.raises(GMTInvalidInput):
with pytest.raises(GMTValueError):
grd2xyz(grid=grid, output_type="pandas", outcols="2")
4 changes: 2 additions & 2 deletions pygmt/tests/test_grdhisteq.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import xarray as xr
from pygmt import grdhisteq
from pygmt.enums import GridRegistration, GridType
from pygmt.exceptions import GMTInvalidInput
from pygmt.exceptions import GMTInvalidInput, GMTValueError
from pygmt.helpers import GMTTempFile
from pygmt.helpers.testing import load_static_earth_relief

Expand Down Expand Up @@ -136,7 +136,7 @@ def test_compute_bins_invalid_format(grid):
"""
Test that compute_bins fails with incorrect format.
"""
with pytest.raises(GMTInvalidInput):
with pytest.raises(GMTValueError):
grdhisteq.compute_bins(grid=grid, output_type=1)
with pytest.raises(GMTInvalidInput):
grdhisteq.compute_bins(grid=grid, output_type="pandas", header="o+c")
4 changes: 2 additions & 2 deletions pygmt/tests/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import pytest
import xarray as xr
from pygmt import Figure
from pygmt.exceptions import GMTInvalidInput, GMTValueError
from pygmt.exceptions import GMTValueError
from pygmt.helpers import (
GMTTempFile,
args_in_kwargs,
Expand Down Expand Up @@ -129,7 +129,7 @@ def test_build_arg_list_invalid_output(outfile):
"""
Test that build_arg_list raises an exception when output file name is invalid.
"""
with pytest.raises(GMTInvalidInput):
with pytest.raises(GMTValueError):
build_arg_list({}, outfile=outfile)


Expand Down
Loading
Loading