Skip to content

Commit 5bcbf70

Browse files
Add typing to test_plot.py (#8889)
* Update pyproject.toml * Update test_plot.py * Update test_plot.py * Update test_plot.py * Update test_plot.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update test_plot.py * Update test_plot.py * Update test_plot.py * Update test_plot.py * Update test_plot.py * Update test_plot.py * Update test_plot.py * Update test_plot.py * Update test_plot.py * Update test_plot.py * raise ValueError if too many dims are requested --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 56182f7 commit 5bcbf70

File tree

2 files changed

+40
-33
lines changed

2 files changed

+40
-33
lines changed

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,6 @@ module = [
175175
"xarray.tests.test_merge",
176176
"xarray.tests.test_missing",
177177
"xarray.tests.test_parallelcompat",
178-
"xarray.tests.test_plot",
179178
"xarray.tests.test_sparse",
180179
"xarray.tests.test_ufuncs",
181180
"xarray.tests.test_units",

xarray/tests/test_plot.py

Lines changed: 40 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import contextlib
44
import inspect
55
import math
6-
from collections.abc import Hashable
6+
from collections.abc import Generator, Hashable
77
from copy import copy
88
from datetime import date, datetime, timedelta
99
from typing import Any, Callable, Literal
@@ -85,52 +85,54 @@ def test_all_figures_closed():
8585

8686
@pytest.mark.flaky
8787
@pytest.mark.skip(reason="maybe flaky")
88-
def text_in_fig():
88+
def text_in_fig() -> set[str]:
8989
"""
9090
Return the set of all text in the figure
9191
"""
92-
return {t.get_text() for t in plt.gcf().findobj(mpl.text.Text)}
92+
return {t.get_text() for t in plt.gcf().findobj(mpl.text.Text)} # type: ignore[attr-defined] # mpl error?
9393

9494

95-
def find_possible_colorbars():
95+
def find_possible_colorbars() -> list[mpl.collections.QuadMesh]:
9696
# nb. this function also matches meshes from pcolormesh
97-
return plt.gcf().findobj(mpl.collections.QuadMesh)
97+
return plt.gcf().findobj(mpl.collections.QuadMesh) # type: ignore[return-value] # mpl error?
9898

9999

100-
def substring_in_axes(substring, ax):
100+
def substring_in_axes(substring: str, ax: mpl.axes.Axes) -> bool:
101101
"""
102102
Return True if a substring is found anywhere in an axes
103103
"""
104-
alltxt = {t.get_text() for t in ax.findobj(mpl.text.Text)}
104+
alltxt: set[str] = {t.get_text() for t in ax.findobj(mpl.text.Text)} # type: ignore[attr-defined] # mpl error?
105105
for txt in alltxt:
106106
if substring in txt:
107107
return True
108108
return False
109109

110110

111-
def substring_not_in_axes(substring, ax):
111+
def substring_not_in_axes(substring: str, ax: mpl.axes.Axes) -> bool:
112112
"""
113113
Return True if a substring is not found anywhere in an axes
114114
"""
115-
alltxt = {t.get_text() for t in ax.findobj(mpl.text.Text)}
115+
alltxt: set[str] = {t.get_text() for t in ax.findobj(mpl.text.Text)} # type: ignore[attr-defined] # mpl error?
116116
check = [(substring not in txt) for txt in alltxt]
117117
return all(check)
118118

119119

120-
def property_in_axes_text(property, property_str, target_txt, ax):
120+
def property_in_axes_text(
121+
property, property_str, target_txt, ax: mpl.axes.Axes
122+
) -> bool:
121123
"""
122124
Return True if the specified text in an axes
123125
has the property assigned to property_str
124126
"""
125-
alltxt = ax.findobj(mpl.text.Text)
127+
alltxt: list[mpl.text.Text] = ax.findobj(mpl.text.Text) # type: ignore[assignment]
126128
check = []
127129
for t in alltxt:
128130
if t.get_text() == target_txt:
129131
check.append(plt.getp(t, property) == property_str)
130132
return all(check)
131133

132134

133-
def easy_array(shape, start=0, stop=1):
135+
def easy_array(shape: tuple[int, ...], start: float = 0, stop: float = 1) -> np.ndarray:
134136
"""
135137
Make an array with desired shape using np.linspace
136138
@@ -140,7 +142,7 @@ def easy_array(shape, start=0, stop=1):
140142
return a.reshape(shape)
141143

142144

143-
def get_colorbar_label(colorbar):
145+
def get_colorbar_label(colorbar) -> str:
144146
if colorbar.orientation == "vertical":
145147
return colorbar.ax.get_ylabel()
146148
else:
@@ -150,27 +152,27 @@ def get_colorbar_label(colorbar):
150152
@requires_matplotlib
151153
class PlotTestCase:
152154
@pytest.fixture(autouse=True)
153-
def setup(self):
155+
def setup(self) -> Generator:
154156
yield
155157
# Remove all matplotlib figures
156158
plt.close("all")
157159

158-
def pass_in_axis(self, plotmethod, subplot_kw=None):
160+
def pass_in_axis(self, plotmethod, subplot_kw=None) -> None:
159161
fig, axs = plt.subplots(ncols=2, subplot_kw=subplot_kw)
160162
plotmethod(ax=axs[0])
161163
assert axs[0].has_data()
162164

163165
@pytest.mark.slow
164-
def imshow_called(self, plotmethod):
166+
def imshow_called(self, plotmethod) -> bool:
165167
plotmethod()
166168
images = plt.gca().findobj(mpl.image.AxesImage)
167169
return len(images) > 0
168170

169-
def contourf_called(self, plotmethod):
171+
def contourf_called(self, plotmethod) -> bool:
170172
plotmethod()
171173

172174
# Compatible with mpl before (PathCollection) and after (QuadContourSet) 3.8
173-
def matchfunc(x):
175+
def matchfunc(x) -> bool:
174176
return isinstance(
175177
x, (mpl.collections.PathCollection, mpl.contour.QuadContourSet)
176178
)
@@ -1248,14 +1250,16 @@ def test_discrete_colormap_list_levels_and_vmin_or_vmax(self) -> None:
12481250
def test_discrete_colormap_provided_boundary_norm(self) -> None:
12491251
norm = mpl.colors.BoundaryNorm([0, 5, 10, 15], 4)
12501252
primitive = self.darray.plot.contourf(norm=norm)
1251-
np.testing.assert_allclose(primitive.levels, norm.boundaries)
1253+
np.testing.assert_allclose(list(primitive.levels), norm.boundaries)
12521254

12531255
def test_discrete_colormap_provided_boundary_norm_matching_cmap_levels(
12541256
self,
12551257
) -> None:
12561258
norm = mpl.colors.BoundaryNorm([0, 5, 10, 15], 4)
12571259
primitive = self.darray.plot.contourf(norm=norm)
1258-
assert primitive.colorbar.norm.Ncmap == primitive.colorbar.norm.N
1260+
cbar = primitive.colorbar
1261+
assert cbar is not None
1262+
assert cbar.norm.Ncmap == cbar.norm.N # type: ignore[attr-defined] # Exists, debatable if public though.
12591263

12601264

12611265
class Common2dMixin:
@@ -2532,7 +2536,7 @@ def test_default_labels(self) -> None:
25322536

25332537
# Leftmost column should have array name
25342538
for ax in g.axs[:, 0]:
2535-
assert substring_in_axes(self.darray.name, ax)
2539+
assert substring_in_axes(str(self.darray.name), ax)
25362540

25372541
def test_test_empty_cell(self) -> None:
25382542
g = (
@@ -2635,7 +2639,7 @@ def test_facetgrid(self) -> None:
26352639
(True, "continuous", False, True),
26362640
],
26372641
)
2638-
def test_add_guide(self, add_guide, hue_style, legend, colorbar):
2642+
def test_add_guide(self, add_guide, hue_style, legend, colorbar) -> None:
26392643
meta_data = _infer_meta_data(
26402644
self.ds,
26412645
x="x",
@@ -2811,7 +2815,7 @@ def test_bad_args(
28112815
add_legend: bool | None,
28122816
add_colorbar: bool | None,
28132817
error_type: type[Exception],
2814-
):
2818+
) -> None:
28152819
with pytest.raises(error_type):
28162820
self.ds.plot.scatter(
28172821
x=x, y=y, hue=hue, add_legend=add_legend, add_colorbar=add_colorbar
@@ -3011,20 +3015,22 @@ def test_ncaxis_notinstalled_line_plot(self) -> None:
30113015
@requires_matplotlib
30123016
class TestAxesKwargs:
30133017
@pytest.fixture(params=[1, 2, 3])
3014-
def data_array(self, request):
3018+
def data_array(self, request) -> DataArray:
30153019
"""
30163020
Return a simple DataArray
30173021
"""
30183022
dims = request.param
30193023
if dims == 1:
30203024
return DataArray(easy_array((10,)))
3021-
if dims == 2:
3025+
elif dims == 2:
30223026
return DataArray(easy_array((10, 3)))
3023-
if dims == 3:
3027+
elif dims == 3:
30243028
return DataArray(easy_array((10, 3, 2)))
3029+
else:
3030+
raise ValueError(f"No DataArray implemented for {dims=}.")
30253031

30263032
@pytest.fixture(params=[1, 2])
3027-
def data_array_logspaced(self, request):
3033+
def data_array_logspaced(self, request) -> DataArray:
30283034
"""
30293035
Return a simple DataArray with logspaced coordinates
30303036
"""
@@ -3033,12 +3039,14 @@ def data_array_logspaced(self, request):
30333039
return DataArray(
30343040
np.arange(7), dims=("x",), coords={"x": np.logspace(-3, 3, 7)}
30353041
)
3036-
if dims == 2:
3042+
elif dims == 2:
30373043
return DataArray(
30383044
np.arange(16).reshape(4, 4),
30393045
dims=("y", "x"),
30403046
coords={"x": np.logspace(-1, 2, 4), "y": np.logspace(-5, -1, 4)},
30413047
)
3048+
else:
3049+
raise ValueError(f"No DataArray implemented for {dims=}.")
30423050

30433051
@pytest.mark.parametrize("xincrease", [True, False])
30443052
def test_xincrease_kwarg(self, data_array, xincrease) -> None:
@@ -3146,16 +3154,16 @@ def test_facetgrid_single_contour() -> None:
31463154

31473155

31483156
@requires_matplotlib
3149-
def test_get_axis_raises():
3157+
def test_get_axis_raises() -> None:
31503158
# test get_axis raises an error if trying to do invalid things
31513159

31523160
# cannot provide both ax and figsize
31533161
with pytest.raises(ValueError, match="both `figsize` and `ax`"):
3154-
get_axis(figsize=[4, 4], size=None, aspect=None, ax="something")
3162+
get_axis(figsize=[4, 4], size=None, aspect=None, ax="something") # type: ignore[arg-type]
31553163

31563164
# cannot provide both ax and size
31573165
with pytest.raises(ValueError, match="both `size` and `ax`"):
3158-
get_axis(figsize=None, size=200, aspect=4 / 3, ax="something")
3166+
get_axis(figsize=None, size=200, aspect=4 / 3, ax="something") # type: ignore[arg-type]
31593167

31603168
# cannot provide both size and figsize
31613169
with pytest.raises(ValueError, match="both `figsize` and `size`"):
@@ -3167,7 +3175,7 @@ def test_get_axis_raises():
31673175

31683176
# cannot provide axis and subplot_kws
31693177
with pytest.raises(ValueError, match="cannot use subplot_kws with existing ax"):
3170-
get_axis(figsize=None, size=None, aspect=None, ax=1, something_else=5)
3178+
get_axis(figsize=None, size=None, aspect=None, ax=1, something_else=5) # type: ignore[arg-type]
31713179

31723180

31733181
@requires_matplotlib

0 commit comments

Comments
 (0)