Skip to content

Commit 6d2d465

Browse files
authored
Get mypy working again (#1457)
* Get mypy working again * Add test for new assert * Test error string * Inject get_data for testing * More typing stuff for mypy to pass * Ran black * Added more mypy stuff to plot
1 parent 6a6c8a9 commit 6d2d465

22 files changed

+162
-84
lines changed

axelrod/deterministic_cache.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import pickle
1616
from collections import UserDict
17-
from typing import List, Tuple
17+
from typing import List, Optional, Tuple
1818

1919
from axelrod import Classifiers
2020

@@ -104,7 +104,7 @@ class DeterministicCache(UserDict):
104104
methods to save/load the cache to/from a file.
105105
"""
106106

107-
def __init__(self, file_name: str = None) -> None:
107+
def __init__(self, file_name: Optional[str] = None) -> None:
108108
"""Initialize a new cache.
109109
110110
Parameters

axelrod/ecosystem.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
"""
1313

1414
import random
15-
from typing import Callable, List
15+
from typing import Callable, List, Optional
1616

1717
from axelrod.result_set import ResultSet
1818

@@ -29,8 +29,8 @@ class Ecosystem(object):
2929
def __init__(
3030
self,
3131
results: ResultSet,
32-
fitness: Callable[[float], float] = None,
33-
population: List[int] = None,
32+
fitness: Optional[Callable[[float], float]] = None,
33+
population: Optional[List[int]] = None,
3434
) -> None:
3535
"""Create a new ecosystem.
3636

axelrod/fingerprint.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import os
22
from collections import namedtuple
33
from tempfile import mkstemp
4-
from typing import Any, List, Union
4+
from typing import Any, List, Optional, Union
55

66
import dask.dataframe as dd
77
import matplotlib.pyplot as plt
@@ -280,10 +280,10 @@ def fingerprint(
280280
turns: int = 50,
281281
repetitions: int = 10,
282282
step: float = 0.01,
283-
processes: int = None,
284-
filename: str = None,
283+
processes: Optional[int] = None,
284+
filename: Optional[str] = None,
285285
progress_bar: bool = True,
286-
seed: int = None,
286+
seed: Optional[int] = None,
287287
) -> dict:
288288
"""Build and play the spatial tournament.
289289
@@ -358,7 +358,7 @@ def plot(
358358
self,
359359
cmap: str = "seismic",
360360
interpolation: str = "none",
361-
title: str = None,
361+
title: Optional[str] = None,
362362
colorbar: bool = True,
363363
labels: bool = True,
364364
) -> plt.Figure:
@@ -437,11 +437,11 @@ def fingerprint(
437437
self,
438438
turns: int = 50,
439439
repetitions: int = 1000,
440-
noise: float = None,
441-
processes: int = None,
442-
filename: str = None,
440+
noise: Optional[float] = None,
441+
processes: Optional[int] = None,
442+
filename: Optional[str] = None,
443443
progress_bar: bool = True,
444-
seed: int = None,
444+
seed: Optional[int] = None,
445445
) -> np.ndarray:
446446
"""Creates a spatial tournament to run the necessary matches to obtain
447447
fingerprint data.
@@ -556,7 +556,7 @@ def plot(
556556
self,
557557
cmap: str = "viridis",
558558
interpolation: str = "none",
559-
title: str = None,
559+
title: Optional[str] = None,
560560
colorbar: bool = True,
561561
labels: bool = True,
562562
display_names: bool = False,

axelrod/game.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from typing import Tuple, Union
33

44
import numpy as np
5+
import numpy.typing as npt
56

67
from axelrod import Action
78

@@ -20,7 +21,7 @@ class AsymmetricGame(object):
2021
"""
2122

2223
# pylint: disable=invalid-name
23-
def __init__(self, A: np.array, B: np.array) -> None:
24+
def __init__(self, A: npt.NDArray, B: npt.NDArray) -> None:
2425
"""
2526
Creates an asymmetric game from two matrices.
2627

axelrod/load_data_.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import pathlib
22
import pkgutil
3-
from typing import Dict, List, Text, Tuple
3+
from typing import Callable, Dict, List, Optional, Tuple
44

55

66
def axl_filename(path: pathlib.Path) -> pathlib.Path:
@@ -20,12 +20,18 @@ def axl_filename(path: pathlib.Path) -> pathlib.Path:
2020
return axl_path / path
2121

2222

23-
def load_file(filename: str, directory: str) -> List[List[str]]:
23+
def load_file(
24+
filename: str,
25+
directory: str,
26+
get_data: Callable[[str, str], Optional[bytes]] = pkgutil.get_data,
27+
) -> List[List[str]]:
2428
"""Loads a data file stored in the Axelrod library's data subdirectory,
2529
likely for parameters for a strategy."""
2630

2731
path = str(pathlib.Path(directory) / filename)
28-
data_bytes = pkgutil.get_data(__name__, path)
32+
data_bytes = get_data(__name__, path)
33+
if data_bytes is None:
34+
raise FileNotFoundError(f"Some loader issue for path {path}")
2935
data = data_bytes.decode("UTF-8", "replace")
3036

3137
rows = []

axelrod/mock_player.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from itertools import cycle
2-
from typing import List
2+
from typing import List, Optional
33

44
from axelrod.action import Action
55
from axelrod.player import Player
@@ -14,7 +14,7 @@ class MockPlayer(Player):
1414

1515
name = "Mock Player"
1616

17-
def __init__(self, actions: List[Action] = None) -> None:
17+
def __init__(self, actions: Optional[List[Action]] = None) -> None:
1818
super().__init__()
1919
if not actions:
2020
actions = []

axelrod/moran.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,15 @@ def __init__(
1818
self,
1919
players: List[Player],
2020
turns: int = DEFAULT_TURNS,
21-
prob_end: float = None,
21+
prob_end: Optional[float] = None,
2222
noise: float = 0,
2323
game: Game = None,
2424
deterministic_cache: DeterministicCache = None,
2525
mutation_rate: float = 0.0,
2626
mode: str = "bd",
2727
interaction_graph: Graph = None,
2828
reproduction_graph: Graph = None,
29-
fitness_transformation: Callable = None,
29+
fitness_transformation: Optional[Callable] = None,
3030
mutation_method="transition",
3131
stop_on_fixation=True,
3232
seed=None,
@@ -175,7 +175,7 @@ def set_players(self) -> None:
175175
self.populations = [self.population_distribution()]
176176

177177
def fitness_proportionate_selection(
178-
self, scores: List, fitness_transformation: Callable = None
178+
self, scores: List, fitness_transformation: Optional[Callable] = None
179179
) -> int:
180180
"""Randomly selects an individual proportionally to score.
181181
@@ -229,7 +229,7 @@ def mutate(self, index: int) -> Player:
229229
# Just clone the player
230230
return self.players[index].clone()
231231

232-
def death(self, index: int = None) -> int:
232+
def death(self, index: Optional[int] = None) -> int:
233233
"""
234234
Selects the player to be removed.
235235
@@ -258,7 +258,7 @@ def death(self, index: int = None) -> int:
258258
i = self.index[vertex]
259259
return i
260260

261-
def birth(self, index: int = None) -> int:
261+
def birth(self, index: Optional[int] = None) -> int:
262262
"""The birth event.
263263
264264
Parameters
@@ -349,7 +349,6 @@ def _matchup_indices(self) -> Set[Tuple[int, int]]:
349349
# The other calculations are unnecessary
350350
if self.mode == "db":
351351
source = self.index[self.dead]
352-
self.dead = None
353352
sources = sorted(self.interaction_graph.out_vertices(source))
354353
else:
355354
# birth-death is global

axelrod/plot.py

Lines changed: 52 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import pathlib
2-
from typing import List, Union
2+
from typing import Any, Callable, List, Optional, Union
33

44
import matplotlib
55
import matplotlib.pyplot as plt
@@ -10,7 +10,7 @@
1010
from .load_data_ import axl_filename
1111
from .result_set import ResultSet
1212

13-
titleType = List[str]
13+
titleType = str
1414
namesType = List[str]
1515
dataType = List[List[Union[int, float]]]
1616

@@ -25,8 +25,11 @@ def _violinplot(
2525
self,
2626
data: dataType,
2727
names: namesType,
28-
title: titleType = None,
29-
ax: matplotlib.axes.SubplotBase = None,
28+
title: Optional[titleType] = None,
29+
ax: Optional[matplotlib.axes.Axes] = None,
30+
get_figure: Callable[
31+
[matplotlib.axes.Axes], Union[matplotlib.figure.Figure, Any, None]
32+
] = lambda ax: ax.get_figure(),
3033
) -> matplotlib.figure.Figure:
3134
"""For making violinplots."""
3235

@@ -35,7 +38,11 @@ def _violinplot(
3538
else:
3639
ax = ax
3740

38-
figure = ax.get_figure()
41+
figure = get_figure(ax)
42+
if not isinstance(figure, matplotlib.figure.Figure):
43+
raise RuntimeError(
44+
"get_figure unexpectedly returned a non-figure object"
45+
)
3946
width = max(self.num_players / 3, 12)
4047
height = width / 2
4148
spacing = 4
@@ -50,7 +57,7 @@ def _violinplot(
5057
)
5158
ax.set_xticks(positions)
5259
ax.set_xticklabels(names, rotation=90)
53-
ax.set_xlim([0, spacing * (self.num_players + 1)])
60+
ax.set_xlim((0, spacing * (self.num_players + 1)))
5461
ax.tick_params(axis="both", which="both", labelsize=8)
5562
if title:
5663
ax.set_title(title)
@@ -76,7 +83,9 @@ def _boxplot_xticks_labels(self):
7683
return [str(n) for n in self.result_set.ranked_names]
7784

7885
def boxplot(
79-
self, title: titleType = None, ax: matplotlib.axes.SubplotBase = None
86+
self,
87+
title: Optional[titleType] = None,
88+
ax: Optional[matplotlib.axes.Axes] = None,
8089
) -> matplotlib.figure.Figure:
8190
"""For the specific mean score boxplot."""
8291
data = self._boxplot_dataset
@@ -98,7 +107,9 @@ def _winplot_dataset(self):
98107
return wins, ranked_names
99108

100109
def winplot(
101-
self, title: titleType = None, ax: matplotlib.axes.SubplotBase = None
110+
self,
111+
title: Optional[titleType] = None,
112+
ax: Optional[matplotlib.axes.Axes] = None,
102113
) -> matplotlib.figure.Figure:
103114
"""Plots the distributions for the number of wins for each strategy."""
104115

@@ -126,7 +137,9 @@ def _sdv_plot_dataset(self):
126137
return diffs, ranked_names
127138

128139
def sdvplot(
129-
self, title: titleType = None, ax: matplotlib.axes.SubplotBase = None
140+
self,
141+
title: Optional[titleType] = None,
142+
ax: Optional[matplotlib.axes.Axes] = None,
130143
) -> matplotlib.figure.Figure:
131144
"""Score difference violin plots to visualize the distributions of how
132145
players attain their payoffs."""
@@ -143,7 +156,9 @@ def _lengthplot_dataset(self):
143156
]
144157

145158
def lengthplot(
146-
self, title: titleType = None, ax: matplotlib.axes.SubplotBase = None
159+
self,
160+
title: Optional[titleType] = None,
161+
ax: Optional[matplotlib.axes.Axes] = None,
147162
) -> matplotlib.figure.Figure:
148163
"""For the specific match length boxplot."""
149164
data = self._lengthplot_dataset
@@ -174,9 +189,12 @@ def _payoff_heatmap(
174189
self,
175190
data: dataType,
176191
names: namesType,
177-
title: titleType = None,
178-
ax: matplotlib.axes.SubplotBase = None,
192+
title: Optional[titleType] = None,
193+
ax: Optional[matplotlib.axes.Axes] = None,
179194
cmap: str = "viridis",
195+
get_figure: Callable[
196+
[matplotlib.axes.Axes], Union[matplotlib.figure.Figure, Any, None]
197+
] = lambda ax: ax.get_figure(),
180198
) -> matplotlib.figure.Figure:
181199
"""Generic heatmap plot"""
182200

@@ -185,7 +203,11 @@ def _payoff_heatmap(
185203
else:
186204
ax = ax
187205

188-
figure = ax.get_figure()
206+
figure = get_figure(ax)
207+
if not isinstance(figure, matplotlib.figure.Figure):
208+
raise RuntimeError(
209+
"get_figure unexpectedly returned a non-figure object"
210+
)
189211
width = max(self.num_players / 4, 12)
190212
height = width
191213
figure.set_size_inches(width, height)
@@ -202,15 +224,19 @@ def _payoff_heatmap(
202224
return figure
203225

204226
def pdplot(
205-
self, title: titleType = None, ax: matplotlib.axes.SubplotBase = None
227+
self,
228+
title: Optional[titleType] = None,
229+
ax: Optional[matplotlib.axes.Axes] = None,
206230
) -> matplotlib.figure.Figure:
207231
"""Payoff difference heatmap to visualize the distributions of how
208232
players attain their payoffs."""
209233
matrix, names = self._pdplot_dataset
210234
return self._payoff_heatmap(matrix, names, title=title, ax=ax)
211235

212236
def payoff(
213-
self, title: titleType = None, ax: matplotlib.axes.SubplotBase = None
237+
self,
238+
title: Optional[titleType] = None,
239+
ax: Optional[matplotlib.axes.Axes] = None,
214240
) -> matplotlib.figure.Figure:
215241
"""Payoff heatmap to visualize the distributions of how
216242
players attain their payoffs."""
@@ -223,9 +249,12 @@ def payoff(
223249
def stackplot(
224250
self,
225251
eco,
226-
title: titleType = None,
252+
title: Optional[titleType] = None,
227253
logscale: bool = True,
228-
ax: matplotlib.axes.SubplotBase = None,
254+
ax: Optional[matplotlib.axes.Axes] = None,
255+
get_figure: Callable[
256+
[matplotlib.axes.Axes], Union[matplotlib.figure.Figure, Any, None]
257+
] = lambda ax: ax.get_figure(),
229258
) -> matplotlib.figure.Figure:
230259

231260
populations = eco.population_sizes
@@ -235,7 +264,11 @@ def stackplot(
235264
else:
236265
ax = ax
237266

238-
figure = ax.get_figure()
267+
figure = get_figure(ax)
268+
if not isinstance(figure, matplotlib.figure.Figure):
269+
raise RuntimeError(
270+
"get_figure unexpectedly returned a non-figure object"
271+
)
239272
turns = range(len(populations))
240273
pops = [
241274
[populations[iturn][ir] for iturn in turns]
@@ -247,7 +280,7 @@ def stackplot(
247280
ax.yaxis.set_label_position("right")
248281
ax.yaxis.labelpad = 25.0
249282

250-
ax.set_ylim([0.0, 1.0])
283+
ax.set_ylim((0.0, 1.0))
251284
ax.set_ylabel("Relative population size")
252285
ax.set_xlabel("Turn")
253286
if title is not None:

0 commit comments

Comments
 (0)