1
1
import pathlib
2
- from typing import List , Union
2
+ from typing import Any , Callable , List , Optional , Union
3
3
4
4
import matplotlib
5
5
import matplotlib .pyplot as plt
10
10
from .load_data_ import axl_filename
11
11
from .result_set import ResultSet
12
12
13
- titleType = List [ str ]
13
+ titleType = str
14
14
namesType = List [str ]
15
15
dataType = List [List [Union [int , float ]]]
16
16
@@ -25,8 +25,11 @@ def _violinplot(
25
25
self ,
26
26
data : dataType ,
27
27
names : namesType ,
28
- title : titleType = None ,
29
- ax : matplotlib .axes .SubplotBase = None ,
28
+ title : Optional [titleType ] = None ,
29
+ ax : Optional [matplotlib .axes .Axes ] = None ,
30
+ get_figure : Callable [
31
+ [matplotlib .axes .Axes ], Union [matplotlib .figure .Figure , Any , None ]
32
+ ] = lambda ax : ax .get_figure (),
30
33
) -> matplotlib .figure .Figure :
31
34
"""For making violinplots."""
32
35
@@ -35,7 +38,11 @@ def _violinplot(
35
38
else :
36
39
ax = ax
37
40
38
- figure = ax .get_figure ()
41
+ figure = get_figure (ax )
42
+ if not isinstance (figure , matplotlib .figure .Figure ):
43
+ raise RuntimeError (
44
+ "get_figure unexpectedly returned a non-figure object"
45
+ )
39
46
width = max (self .num_players / 3 , 12 )
40
47
height = width / 2
41
48
spacing = 4
@@ -50,7 +57,7 @@ def _violinplot(
50
57
)
51
58
ax .set_xticks (positions )
52
59
ax .set_xticklabels (names , rotation = 90 )
53
- ax .set_xlim ([ 0 , spacing * (self .num_players + 1 )] )
60
+ ax .set_xlim (( 0 , spacing * (self .num_players + 1 )) )
54
61
ax .tick_params (axis = "both" , which = "both" , labelsize = 8 )
55
62
if title :
56
63
ax .set_title (title )
@@ -76,7 +83,9 @@ def _boxplot_xticks_labels(self):
76
83
return [str (n ) for n in self .result_set .ranked_names ]
77
84
78
85
def boxplot (
79
- self , title : titleType = None , ax : matplotlib .axes .SubplotBase = None
86
+ self ,
87
+ title : Optional [titleType ] = None ,
88
+ ax : Optional [matplotlib .axes .Axes ] = None ,
80
89
) -> matplotlib .figure .Figure :
81
90
"""For the specific mean score boxplot."""
82
91
data = self ._boxplot_dataset
@@ -98,7 +107,9 @@ def _winplot_dataset(self):
98
107
return wins , ranked_names
99
108
100
109
def winplot (
101
- self , title : titleType = None , ax : matplotlib .axes .SubplotBase = None
110
+ self ,
111
+ title : Optional [titleType ] = None ,
112
+ ax : Optional [matplotlib .axes .Axes ] = None ,
102
113
) -> matplotlib .figure .Figure :
103
114
"""Plots the distributions for the number of wins for each strategy."""
104
115
@@ -126,7 +137,9 @@ def _sdv_plot_dataset(self):
126
137
return diffs , ranked_names
127
138
128
139
def sdvplot (
129
- self , title : titleType = None , ax : matplotlib .axes .SubplotBase = None
140
+ self ,
141
+ title : Optional [titleType ] = None ,
142
+ ax : Optional [matplotlib .axes .Axes ] = None ,
130
143
) -> matplotlib .figure .Figure :
131
144
"""Score difference violin plots to visualize the distributions of how
132
145
players attain their payoffs."""
@@ -143,7 +156,9 @@ def _lengthplot_dataset(self):
143
156
]
144
157
145
158
def lengthplot (
146
- self , title : titleType = None , ax : matplotlib .axes .SubplotBase = None
159
+ self ,
160
+ title : Optional [titleType ] = None ,
161
+ ax : Optional [matplotlib .axes .Axes ] = None ,
147
162
) -> matplotlib .figure .Figure :
148
163
"""For the specific match length boxplot."""
149
164
data = self ._lengthplot_dataset
@@ -174,9 +189,12 @@ def _payoff_heatmap(
174
189
self ,
175
190
data : dataType ,
176
191
names : namesType ,
177
- title : titleType = None ,
178
- ax : matplotlib .axes .SubplotBase = None ,
192
+ title : Optional [ titleType ] = None ,
193
+ ax : Optional [ matplotlib .axes .Axes ] = None ,
179
194
cmap : str = "viridis" ,
195
+ get_figure : Callable [
196
+ [matplotlib .axes .Axes ], Union [matplotlib .figure .Figure , Any , None ]
197
+ ] = lambda ax : ax .get_figure (),
180
198
) -> matplotlib .figure .Figure :
181
199
"""Generic heatmap plot"""
182
200
@@ -185,7 +203,11 @@ def _payoff_heatmap(
185
203
else :
186
204
ax = ax
187
205
188
- figure = ax .get_figure ()
206
+ figure = get_figure (ax )
207
+ if not isinstance (figure , matplotlib .figure .Figure ):
208
+ raise RuntimeError (
209
+ "get_figure unexpectedly returned a non-figure object"
210
+ )
189
211
width = max (self .num_players / 4 , 12 )
190
212
height = width
191
213
figure .set_size_inches (width , height )
@@ -202,15 +224,19 @@ def _payoff_heatmap(
202
224
return figure
203
225
204
226
def pdplot (
205
- self , title : titleType = None , ax : matplotlib .axes .SubplotBase = None
227
+ self ,
228
+ title : Optional [titleType ] = None ,
229
+ ax : Optional [matplotlib .axes .Axes ] = None ,
206
230
) -> matplotlib .figure .Figure :
207
231
"""Payoff difference heatmap to visualize the distributions of how
208
232
players attain their payoffs."""
209
233
matrix , names = self ._pdplot_dataset
210
234
return self ._payoff_heatmap (matrix , names , title = title , ax = ax )
211
235
212
236
def payoff (
213
- self , title : titleType = None , ax : matplotlib .axes .SubplotBase = None
237
+ self ,
238
+ title : Optional [titleType ] = None ,
239
+ ax : Optional [matplotlib .axes .Axes ] = None ,
214
240
) -> matplotlib .figure .Figure :
215
241
"""Payoff heatmap to visualize the distributions of how
216
242
players attain their payoffs."""
@@ -223,9 +249,12 @@ def payoff(
223
249
def stackplot (
224
250
self ,
225
251
eco ,
226
- title : titleType = None ,
252
+ title : Optional [ titleType ] = None ,
227
253
logscale : bool = True ,
228
- ax : matplotlib .axes .SubplotBase = None ,
254
+ ax : Optional [matplotlib .axes .Axes ] = None ,
255
+ get_figure : Callable [
256
+ [matplotlib .axes .Axes ], Union [matplotlib .figure .Figure , Any , None ]
257
+ ] = lambda ax : ax .get_figure (),
229
258
) -> matplotlib .figure .Figure :
230
259
231
260
populations = eco .population_sizes
@@ -235,7 +264,11 @@ def stackplot(
235
264
else :
236
265
ax = ax
237
266
238
- figure = ax .get_figure ()
267
+ figure = get_figure (ax )
268
+ if not isinstance (figure , matplotlib .figure .Figure ):
269
+ raise RuntimeError (
270
+ "get_figure unexpectedly returned a non-figure object"
271
+ )
239
272
turns = range (len (populations ))
240
273
pops = [
241
274
[populations [iturn ][ir ] for iturn in turns ]
@@ -247,7 +280,7 @@ def stackplot(
247
280
ax .yaxis .set_label_position ("right" )
248
281
ax .yaxis .labelpad = 25.0
249
282
250
- ax .set_ylim ([ 0.0 , 1.0 ] )
283
+ ax .set_ylim (( 0.0 , 1.0 ) )
251
284
ax .set_ylabel ("Relative population size" )
252
285
ax .set_xlabel ("Turn" )
253
286
if title is not None :
0 commit comments