Skip to content

Commit 8d0033e

Browse files
committed
Add tests for state_distribution and these tests currently pass
1 parent e1fc7e2 commit 8d0033e

File tree

2 files changed

+20
-0
lines changed

2 files changed

+20
-0
lines changed

axelrod/interaction_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,9 @@ def compute_normalised_state_distribution(interactions):
101101
"""
102102
Returns the normalized count of each state for a set of interactions.
103103
"""
104+
if len(interactions) == 0:
105+
return None
106+
104107
normalized_count = Counter(interactions)
105108
total = sum(normalized_count.values(), 0.0)
106109
# By starting the sum with 0.0 we make sure total is a floating point value,

axelrod/tests/unit/test_interaction_utils.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# -*- coding: utf-8 -*-
22
import unittest
33
import tempfile
4+
from collections import Counter
45
import axelrod
56
import axelrod.interaction_utils as iu
67

@@ -17,6 +18,14 @@ class TestMatch(unittest.TestCase):
1718
winners = [False, 0, 1, None]
1819
cooperations = [(1, 1), (0, 2), (2, 1), None]
1920
normalised_cooperations = [(.5, .5), (0, 1), (1, .5), None]
21+
state_distribution = [Counter({('C', 'D'): 1, ('D', 'C'): 1}),
22+
Counter({('D', 'C'): 2}),
23+
Counter({('C', 'C'): 1, ('C', 'D'): 1}),
24+
None]
25+
normalised_state_distribution = [Counter({('C', 'D'): 0.5, ('D', 'C'): 0.5}),
26+
Counter({('D', 'C'): 1.0}),
27+
Counter({('C', 'C'): 0.5, ('C', 'D'): 0.5}),
28+
None]
2029
sparklines = [ u'█ \n █', u' \n██', u'██\n█ ', None ]
2130

2231

@@ -46,6 +55,14 @@ def test_compute_normalised_cooperations(self):
4655
for inter, coop in zip(self.interactions, self.normalised_cooperations):
4756
self.assertEqual(coop, iu.compute_normalised_cooperation(inter))
4857

58+
def test_compute_state_distribution(self):
59+
for inter, dist in zip(self.interactions, self.state_distribution):
60+
self.assertEqual(dist, iu.compute_state_distribution(inter))
61+
62+
def test_compute_normalised_state_distribution(self):
63+
for inter, dist in zip(self.interactions, self.normalised_state_distribution):
64+
self.assertEqual(dist, iu.compute_normalised_state_distribution(inter))
65+
4966
def test_compute_sparklines(self):
5067
for inter, spark in zip(self.interactions, self.sparklines):
5168
self.assertEqual(spark, iu.compute_sparklines(inter))

0 commit comments

Comments
 (0)