3
3
import contextlib
4
4
import inspect
5
5
import math
6
- from collections .abc import Hashable
6
+ from collections .abc import Generator , Hashable
7
7
from copy import copy
8
8
from datetime import date , datetime , timedelta
9
9
from typing import Any , Callable , Literal
@@ -85,52 +85,54 @@ def test_all_figures_closed():
85
85
86
86
@pytest .mark .flaky
87
87
@pytest .mark .skip (reason = "maybe flaky" )
88
- def text_in_fig ():
88
+ def text_in_fig () -> set [ str ] :
89
89
"""
90
90
Return the set of all text in the figure
91
91
"""
92
- return {t .get_text () for t in plt .gcf ().findobj (mpl .text .Text )}
92
+ return {t .get_text () for t in plt .gcf ().findobj (mpl .text .Text )} # type: ignore[attr-defined] # mpl error?
93
93
94
94
95
- def find_possible_colorbars ():
95
+ def find_possible_colorbars () -> list [ mpl . collections . QuadMesh ] :
96
96
# nb. this function also matches meshes from pcolormesh
97
- return plt .gcf ().findobj (mpl .collections .QuadMesh )
97
+ return plt .gcf ().findobj (mpl .collections .QuadMesh ) # type: ignore[return-value] # mpl error?
98
98
99
99
100
- def substring_in_axes (substring , ax ) :
100
+ def substring_in_axes (substring : str , ax : mpl . axes . Axes ) -> bool :
101
101
"""
102
102
Return True if a substring is found anywhere in an axes
103
103
"""
104
- alltxt = {t .get_text () for t in ax .findobj (mpl .text .Text )}
104
+ alltxt : set [ str ] = {t .get_text () for t in ax .findobj (mpl .text .Text )} # type: ignore[attr-defined] # mpl error?
105
105
for txt in alltxt :
106
106
if substring in txt :
107
107
return True
108
108
return False
109
109
110
110
111
- def substring_not_in_axes (substring , ax ) :
111
+ def substring_not_in_axes (substring : str , ax : mpl . axes . Axes ) -> bool :
112
112
"""
113
113
Return True if a substring is not found anywhere in an axes
114
114
"""
115
- alltxt = {t .get_text () for t in ax .findobj (mpl .text .Text )}
115
+ alltxt : set [ str ] = {t .get_text () for t in ax .findobj (mpl .text .Text )} # type: ignore[attr-defined] # mpl error?
116
116
check = [(substring not in txt ) for txt in alltxt ]
117
117
return all (check )
118
118
119
119
120
- def property_in_axes_text (property , property_str , target_txt , ax ):
120
+ def property_in_axes_text (
121
+ property , property_str , target_txt , ax : mpl .axes .Axes
122
+ ) -> bool :
121
123
"""
122
124
Return True if the specified text in an axes
123
125
has the property assigned to property_str
124
126
"""
125
- alltxt = ax .findobj (mpl .text .Text )
127
+ alltxt : list [ mpl . text . Text ] = ax .findobj (mpl .text .Text ) # type: ignore[assignment]
126
128
check = []
127
129
for t in alltxt :
128
130
if t .get_text () == target_txt :
129
131
check .append (plt .getp (t , property ) == property_str )
130
132
return all (check )
131
133
132
134
133
- def easy_array (shape , start = 0 , stop = 1 ) :
135
+ def easy_array (shape : tuple [ int , ...], start : float = 0 , stop : float = 1 ) -> np . ndarray :
134
136
"""
135
137
Make an array with desired shape using np.linspace
136
138
@@ -140,7 +142,7 @@ def easy_array(shape, start=0, stop=1):
140
142
return a .reshape (shape )
141
143
142
144
143
- def get_colorbar_label (colorbar ):
145
+ def get_colorbar_label (colorbar ) -> str :
144
146
if colorbar .orientation == "vertical" :
145
147
return colorbar .ax .get_ylabel ()
146
148
else :
@@ -150,27 +152,27 @@ def get_colorbar_label(colorbar):
150
152
@requires_matplotlib
151
153
class PlotTestCase :
152
154
@pytest .fixture (autouse = True )
153
- def setup (self ):
155
+ def setup (self ) -> Generator :
154
156
yield
155
157
# Remove all matplotlib figures
156
158
plt .close ("all" )
157
159
158
- def pass_in_axis (self , plotmethod , subplot_kw = None ):
160
+ def pass_in_axis (self , plotmethod , subplot_kw = None ) -> None :
159
161
fig , axs = plt .subplots (ncols = 2 , subplot_kw = subplot_kw )
160
162
plotmethod (ax = axs [0 ])
161
163
assert axs [0 ].has_data ()
162
164
163
165
@pytest .mark .slow
164
- def imshow_called (self , plotmethod ):
166
+ def imshow_called (self , plotmethod ) -> bool :
165
167
plotmethod ()
166
168
images = plt .gca ().findobj (mpl .image .AxesImage )
167
169
return len (images ) > 0
168
170
169
- def contourf_called (self , plotmethod ):
171
+ def contourf_called (self , plotmethod ) -> bool :
170
172
plotmethod ()
171
173
172
174
# Compatible with mpl before (PathCollection) and after (QuadContourSet) 3.8
173
- def matchfunc (x ):
175
+ def matchfunc (x ) -> bool :
174
176
return isinstance (
175
177
x , (mpl .collections .PathCollection , mpl .contour .QuadContourSet )
176
178
)
@@ -1248,14 +1250,16 @@ def test_discrete_colormap_list_levels_and_vmin_or_vmax(self) -> None:
1248
1250
def test_discrete_colormap_provided_boundary_norm (self ) -> None :
1249
1251
norm = mpl .colors .BoundaryNorm ([0 , 5 , 10 , 15 ], 4 )
1250
1252
primitive = self .darray .plot .contourf (norm = norm )
1251
- np .testing .assert_allclose (primitive .levels , norm .boundaries )
1253
+ np .testing .assert_allclose (list ( primitive .levels ) , norm .boundaries )
1252
1254
1253
1255
def test_discrete_colormap_provided_boundary_norm_matching_cmap_levels (
1254
1256
self ,
1255
1257
) -> None :
1256
1258
norm = mpl .colors .BoundaryNorm ([0 , 5 , 10 , 15 ], 4 )
1257
1259
primitive = self .darray .plot .contourf (norm = norm )
1258
- assert primitive .colorbar .norm .Ncmap == primitive .colorbar .norm .N
1260
+ cbar = primitive .colorbar
1261
+ assert cbar is not None
1262
+ assert cbar .norm .Ncmap == cbar .norm .N # type: ignore[attr-defined] # Exists, debatable if public though.
1259
1263
1260
1264
1261
1265
class Common2dMixin :
@@ -2532,7 +2536,7 @@ def test_default_labels(self) -> None:
2532
2536
2533
2537
# Leftmost column should have array name
2534
2538
for ax in g .axs [:, 0 ]:
2535
- assert substring_in_axes (self .darray .name , ax )
2539
+ assert substring_in_axes (str ( self .darray .name ) , ax )
2536
2540
2537
2541
def test_test_empty_cell (self ) -> None :
2538
2542
g = (
@@ -2635,7 +2639,7 @@ def test_facetgrid(self) -> None:
2635
2639
(True , "continuous" , False , True ),
2636
2640
],
2637
2641
)
2638
- def test_add_guide (self , add_guide , hue_style , legend , colorbar ):
2642
+ def test_add_guide (self , add_guide , hue_style , legend , colorbar ) -> None :
2639
2643
meta_data = _infer_meta_data (
2640
2644
self .ds ,
2641
2645
x = "x" ,
@@ -2811,7 +2815,7 @@ def test_bad_args(
2811
2815
add_legend : bool | None ,
2812
2816
add_colorbar : bool | None ,
2813
2817
error_type : type [Exception ],
2814
- ):
2818
+ ) -> None :
2815
2819
with pytest .raises (error_type ):
2816
2820
self .ds .plot .scatter (
2817
2821
x = x , y = y , hue = hue , add_legend = add_legend , add_colorbar = add_colorbar
@@ -3011,20 +3015,22 @@ def test_ncaxis_notinstalled_line_plot(self) -> None:
3011
3015
@requires_matplotlib
3012
3016
class TestAxesKwargs :
3013
3017
@pytest .fixture (params = [1 , 2 , 3 ])
3014
- def data_array (self , request ):
3018
+ def data_array (self , request ) -> DataArray :
3015
3019
"""
3016
3020
Return a simple DataArray
3017
3021
"""
3018
3022
dims = request .param
3019
3023
if dims == 1 :
3020
3024
return DataArray (easy_array ((10 ,)))
3021
- if dims == 2 :
3025
+ elif dims == 2 :
3022
3026
return DataArray (easy_array ((10 , 3 )))
3023
- if dims == 3 :
3027
+ elif dims == 3 :
3024
3028
return DataArray (easy_array ((10 , 3 , 2 )))
3029
+ else :
3030
+ raise ValueError (f"No DataArray implemented for { dims = } ." )
3025
3031
3026
3032
@pytest .fixture (params = [1 , 2 ])
3027
- def data_array_logspaced (self , request ):
3033
+ def data_array_logspaced (self , request ) -> DataArray :
3028
3034
"""
3029
3035
Return a simple DataArray with logspaced coordinates
3030
3036
"""
@@ -3033,12 +3039,14 @@ def data_array_logspaced(self, request):
3033
3039
return DataArray (
3034
3040
np .arange (7 ), dims = ("x" ,), coords = {"x" : np .logspace (- 3 , 3 , 7 )}
3035
3041
)
3036
- if dims == 2 :
3042
+ elif dims == 2 :
3037
3043
return DataArray (
3038
3044
np .arange (16 ).reshape (4 , 4 ),
3039
3045
dims = ("y" , "x" ),
3040
3046
coords = {"x" : np .logspace (- 1 , 2 , 4 ), "y" : np .logspace (- 5 , - 1 , 4 )},
3041
3047
)
3048
+ else :
3049
+ raise ValueError (f"No DataArray implemented for { dims = } ." )
3042
3050
3043
3051
@pytest .mark .parametrize ("xincrease" , [True , False ])
3044
3052
def test_xincrease_kwarg (self , data_array , xincrease ) -> None :
@@ -3146,16 +3154,16 @@ def test_facetgrid_single_contour() -> None:
3146
3154
3147
3155
3148
3156
@requires_matplotlib
3149
- def test_get_axis_raises ():
3157
+ def test_get_axis_raises () -> None :
3150
3158
# test get_axis raises an error if trying to do invalid things
3151
3159
3152
3160
# cannot provide both ax and figsize
3153
3161
with pytest .raises (ValueError , match = "both `figsize` and `ax`" ):
3154
- get_axis (figsize = [4 , 4 ], size = None , aspect = None , ax = "something" )
3162
+ get_axis (figsize = [4 , 4 ], size = None , aspect = None , ax = "something" ) # type: ignore[arg-type]
3155
3163
3156
3164
# cannot provide both ax and size
3157
3165
with pytest .raises (ValueError , match = "both `size` and `ax`" ):
3158
- get_axis (figsize = None , size = 200 , aspect = 4 / 3 , ax = "something" )
3166
+ get_axis (figsize = None , size = 200 , aspect = 4 / 3 , ax = "something" ) # type: ignore[arg-type]
3159
3167
3160
3168
# cannot provide both size and figsize
3161
3169
with pytest .raises (ValueError , match = "both `figsize` and `size`" ):
@@ -3167,7 +3175,7 @@ def test_get_axis_raises():
3167
3175
3168
3176
# cannot provide axis and subplot_kws
3169
3177
with pytest .raises (ValueError , match = "cannot use subplot_kws with existing ax" ):
3170
- get_axis (figsize = None , size = None , aspect = None , ax = 1 , something_else = 5 )
3178
+ get_axis (figsize = None , size = None , aspect = None , ax = 1 , something_else = 5 ) # type: ignore[arg-type]
3171
3179
3172
3180
3173
3181
@requires_matplotlib
0 commit comments