Skip to content

Commit bcb0013

Browse files
Merge pull request #897 from Axelrod-Python/issue-894
Fix bug in plot.py and add missing test
2 parents f971c78 + 08734c8 commit bcb0013

File tree

2 files changed

+201
-72
lines changed

2 files changed

+201
-72
lines changed

axelrod/plot.py

Lines changed: 71 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import tqdm
44
import warnings
55

6+
from distutils.version import LooseVersion
67
from typing import List, Union
78

89
matplotlib_installed = True
@@ -24,13 +25,12 @@
2425
namesType = List[str]
2526
dataType = List[List[Union[int, float]]]
2627

27-
def default_cmap() -> str:
28+
29+
def default_cmap(version: str = "2.0") -> str:
2830
"""Sets a default matplotlib colormap based on the version."""
29-
s = matplotlib.__version__.split('.')
30-
if int(s[0]) >= 1 and int(s[1]) >= 5:
31-
return "viridis"
32-
else:
33-
return 'YlGnBu'
31+
if LooseVersion(version) >= "1.5":
32+
return 'viridis'
33+
return 'YlGnBu'
3434

3535

3636
class Plot(object):
@@ -40,7 +40,10 @@ def __init__(self, result_set: ResultSet) -> None:
4040
self.nplayers = self.result_set.nplayers
4141
self.players = self.result_set.players
4242

43-
def _violinplot(self, data: dataType, names: namesType, title: titleType =None, ax: matplotlib.axes.SubplotBase =None) -> matplotlib.figure.Figure:
43+
def _violinplot(
44+
self, data: dataType, names: namesType, title: titleType = None,
45+
ax: matplotlib.axes.SubplotBase = None
46+
) -> matplotlib.figure.Figure:
4447
"""For making violinplots."""
4548
if not self.matplotlib_installed: # pragma: no cover
4649
return None
@@ -56,13 +59,15 @@ def _violinplot(self, data: dataType, names: namesType, title: titleType =None,
5659
spacing = 4
5760
positions = spacing * arange(1, self.nplayers + 1, 1)
5861
figure.set_size_inches(width, height)
59-
plt.violinplot(data, positions=positions, widths=spacing / 2,
60-
showmedians=True, showextrema=False)
61-
plt.xticks(positions, names, rotation=90)
62-
plt.xlim(0, spacing * (self.nplayers + 1))
63-
plt.tick_params(axis='both', which='both', labelsize=8)
62+
ax.violinplot(data, positions=positions, widths=spacing / 2,
63+
showmedians=True, showextrema=False)
64+
ax.set_xticks(positions)
65+
ax.set_xticklabels(names, rotation=90)
66+
ax.set_xlim([0, spacing * (self.nplayers + 1)])
67+
ax.tick_params(axis='both', which='both', labelsize=8)
6468
if title:
65-
plt.title(title)
69+
ax.set_title(title)
70+
plt.tight_layout()
6671
return figure
6772

6873
# Box and Violin plots for mean score, score differences, wins, and match
@@ -81,11 +86,13 @@ def _boxplot_xticks_locations(self):
8186
def _boxplot_xticks_labels(self):
8287
return [str(n) for n in self.result_set.ranked_names]
8388

84-
def boxplot(self, title: titleType =None, ax: matplotlib.axes.SubplotBase =None) -> matplotlib.figure.Figure:
89+
def boxplot(
90+
self, title: titleType = None, ax: matplotlib.axes.SubplotBase = None
91+
) -> matplotlib.figure.Figure:
8592
"""For the specific mean score boxplot."""
8693
data = self._boxplot_dataset
8794
names = self._boxplot_xticks_labels
88-
figure = self._violinplot(data, names, title=title, ax=ax)
95+
figure = self._violinplot(data, names, title=title, ax=ax)
8996
return figure
9097

9198
@property
@@ -100,7 +107,9 @@ def _winplot_dataset(self):
100107
ranked_names = [str(self.players[x[-1]]) for x in medians]
101108
return wins, ranked_names
102109

103-
def winplot(self, title: titleType =None, ax: matplotlib.axes.SubplotBase =None) -> matplotlib.figure.Figure:
110+
def winplot(
111+
self, title: titleType = None, ax: matplotlib.axes.SubplotBase = None
112+
) -> matplotlib.figure.Figure:
104113
"""Plots the distributions for the number of wins for each strategy."""
105114
if not self.matplotlib_installed: # pragma: no cover
106115
return None
@@ -125,7 +134,9 @@ def _sdv_plot_dataset(self):
125134
ranked_names = [str(self.players[i]) for i in ordering]
126135
return diffs, ranked_names
127136

128-
def sdvplot(self, title: titleType =None, ax: matplotlib.axes.SubplotBase =None) -> matplotlib.figure.Figure:
137+
def sdvplot(
138+
self, title: titleType = None, ax: matplotlib.axes.SubplotBase = None
139+
) -> matplotlib.figure.Figure:
129140
"""Score difference violinplots to visualize the distributions of how
130141
players attain their payoffs."""
131142
diffs, ranked_names = self._sdv_plot_dataset
@@ -139,7 +150,9 @@ def _lengthplot_dataset(self):
139150
for length in rep[playeri]] for playeri in
140151
self.result_set.ranking]
141152

142-
def lengthplot(self, title: titleType =None, ax: matplotlib.axes.SubplotBase =None) -> matplotlib.figure.Figure:
153+
def lengthplot(
154+
self, title: titleType = None, ax: matplotlib.axes.SubplotBase = None
155+
) -> matplotlib.figure.Figure:
143156
"""For the specific match length boxplot."""
144157
data = self._lengthplot_dataset
145158
names = self._boxplot_xticks_labels
@@ -165,7 +178,10 @@ def _pdplot_dataset(self):
165178
ranked_names = [str(players[i]) for i in ordering]
166179
return matrix, ranked_names
167180

168-
def _payoff_heatmap(self, data: dataType, names: namesType, title: titleType =None, ax: matplotlib.axes.SubplotBase =None) -> matplotlib.figure.Figure:
181+
def _payoff_heatmap(
182+
self, data: dataType, names: namesType, title: titleType = None,
183+
ax: matplotlib.axes.SubplotBase = None
184+
) -> matplotlib.figure.Figure:
169185
"""Generic heatmap plot"""
170186
if not self.matplotlib_installed: # pragma: no cover
171187
return None
@@ -179,28 +195,31 @@ def _payoff_heatmap(self, data: dataType, names: namesType, title: titleType =No
179195
width = max(self.nplayers / 4, 12)
180196
height = width
181197
figure.set_size_inches(width, height)
182-
cmap = default_cmap()
198+
matplotlib_version = matplotlib.__version__
199+
cmap = default_cmap(matplotlib_version)
183200
mat = ax.matshow(data, cmap=cmap)
184-
plt.xticks(range(self.result_set.nplayers))
185-
plt.yticks(range(self.result_set.nplayers))
201+
ax.set_xticks(range(self.result_set.nplayers))
202+
ax.set_yticks(range(self.result_set.nplayers))
186203
ax.set_xticklabels(names, rotation=90)
187204
ax.set_yticklabels(names)
188-
plt.tick_params(axis='both', which='both', labelsize=16)
205+
ax.tick_params(axis='both', which='both', labelsize=16)
189206
if title:
190-
plt.xlabel(title)
191-
# Make the colorbar match up with the plot
192-
divider = make_axes_locatable(plt.gca())
193-
cax = divider.append_axes("right", "5%", pad="3%")
194-
plt.colorbar(mat, cax=cax)
207+
ax.set_xlabel(title)
208+
figure.colorbar(mat, ax=ax)
209+
plt.tight_layout()
195210
return figure
196211

197-
def pdplot(self, title: titleType =None, ax: matplotlib.axes.SubplotBase =None):
212+
def pdplot(
213+
self, title: titleType = None, ax: matplotlib.axes.SubplotBase = None
214+
) -> matplotlib.figure.Figure:
198215
"""Payoff difference heatmap to visualize the distributions of how
199216
players attain their payoffs."""
200217
matrix, names = self._pdplot_dataset
201218
return self._payoff_heatmap(matrix, names, title=title, ax=ax)
202219

203-
def payoff(self, title: titleType =None, ax: matplotlib.axes.SubplotBase =None):
220+
def payoff(
221+
self, title: titleType = None, ax: matplotlib.axes.SubplotBase = None
222+
) -> matplotlib.figure.Figure:
204223
"""Payoff heatmap to visualize the distributions of how
205224
players attain their payoffs."""
206225
data = self._payoff_dataset
@@ -209,7 +228,10 @@ def payoff(self, title: titleType =None, ax: matplotlib.axes.SubplotBase =None):
209228

210229
# Ecological Plot
211230

212-
def stackplot(self, eco, title: titleType =None, logscale: bool =True, ax: matplotlib.axes.SubplotBase =None) -> matplotlib.figure.Figure:
231+
def stackplot(
232+
self, eco, title: titleType = None, logscale: bool = True,
233+
ax: matplotlib.axes.SubplotBase =None
234+
) -> matplotlib.figure.Figure:
213235
if not self.matplotlib_installed: # pragma: no cover
214236
return None
215237

@@ -222,26 +244,31 @@ def stackplot(self, eco, title: titleType =None, logscale: bool =True, ax: matpl
222244

223245
figure = ax.get_figure()
224246
turns = range(len(populations))
225-
pops = [[populations[iturn][ir] for iturn in turns] for ir in self.result_set.ranking]
247+
pops = [
248+
[populations[iturn][ir] for iturn in turns]
249+
for ir in self.result_set.ranking
250+
]
226251
ax.stackplot(turns, *pops)
227252

228253
ax.yaxis.tick_left()
229254
ax.yaxis.set_label_position("right")
230255
ax.yaxis.labelpad = 25.0
231256

232-
plt.ylim([0.0, 1.0])
233-
plt.ylabel('Relative population size')
234-
plt.xlabel('Turn')
257+
ax.set_ylim([0.0, 1.0])
258+
ax.set_ylabel('Relative population size')
259+
ax.set_xlabel('Turn')
235260
if title is not None:
236-
plt.title(title)
261+
ax.set_title(title)
237262

238-
trans = transforms.blended_transform_factory(ax.transAxes, ax.transData)
263+
trans = transforms.blended_transform_factory(
264+
ax.transAxes, ax.transData)
239265
ticks = []
240266
for i, n in enumerate(self.result_set.ranked_names):
241267
x = -0.01
242268
y = (i + 0.5) * 1 / self.result_set.nplayers
243-
ax.annotate(n, xy=(x, y), xycoords=trans, clip_on=False,
244-
va='center', ha='right', fontsize=5)
269+
ax.annotate(
270+
n, xy=(x, y), xycoords=trans, clip_on=False, va='center',
271+
ha='right', fontsize=5)
245272
ticks.append(y)
246273
ax.set_yticks(ticks)
247274
ax.tick_params(direction='out')
@@ -250,10 +277,13 @@ def stackplot(self, eco, title: titleType =None, logscale: bool =True, ax: matpl
250277
if logscale:
251278
ax.set_xscale('log')
252279

280+
plt.tight_layout()
253281
return figure
254282

255-
def save_all_plots(self, prefix: str ="axelrod", title_prefix: str ="axelrod",
256-
filetype: str ="svg", progress_bar: bool =True):
283+
def save_all_plots(
284+
self, prefix: str ="axelrod", title_prefix: str ="axelrod",
285+
filetype: str ="svg", progress_bar: bool = True
286+
) -> None:
257287
"""
258288
A method to save all plots to file.
259289

0 commit comments

Comments
 (0)