Skip to content

Commit 28ce6fd

Browse files
authored
Merge pull request #139 from ImogenBits/bugfix
Fix default scoring function if generator's Solution score is 0
2 parents 8d21543 + e58c078 commit 28ce6fd

File tree

2 files changed

+48
-1
lines changed

2 files changed

+48
-1
lines changed

algobattle/problem.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,9 @@ def default_score(
209209
try:
210210
return max(0, min(1, sol_score / gen_score))
211211
except ZeroDivisionError:
212-
return float(sol_score < 0)
212+
# if generator scored 0 then the solver will have achieved an equal or better score
213+
# i.e. the Fight's score is simply 1 regardless of its solution score.
214+
return 1
213215
else:
214216
return max(0, min(1, solution.score(instance, Role.solver)))
215217

tests/test_util.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,23 @@
11
"""Tests for all util functions."""
2+
from math import inf
23
import unittest
34

45
from algobattle.battle import Battle, Iterated, Averaged
6+
from algobattle.problem import InstanceModel, SolutionModel, default_score
7+
from algobattle.util import Role
8+
9+
10+
class DummyInstance(InstanceModel): # noqa: D101
11+
@property
12+
def size(self) -> int:
13+
return 1
14+
15+
16+
class DummySolution(SolutionModel[DummyInstance]): # noqa: D101
17+
val: float
18+
19+
def score(self, instance: DummyInstance, role: Role) -> float:
20+
return self.val
521

622

723
class Utiltests(unittest.TestCase):
@@ -12,6 +28,35 @@ def test_default_battle_types(self):
1228
self.assertEqual(Battle.all()["Iterated"], Iterated)
1329
self.assertEqual(Battle.all()["Averaged"], Averaged)
1430

31+
def test_default_fight_score(self):
32+
"""Tests the default fight scoring function."""
33+
instance = DummyInstance()
34+
scores = [
35+
(0, 0, 1),
36+
(0, 2, 1),
37+
(0, 4, 1),
38+
(0, inf, 1),
39+
(2, 0, 0),
40+
(2, 2, 1),
41+
(2, 4, 1),
42+
(2, inf, 1),
43+
(4, 0, 0),
44+
(4, 2, 0.5),
45+
(4, 4, 1),
46+
(4, inf, 1),
47+
(inf, 0, 0),
48+
(inf, 2, 0),
49+
(inf, 4, 0),
50+
(inf, inf, 1),
51+
]
52+
for gen, sol, score in scores:
53+
self.assertEqual(
54+
default_score(
55+
instance, generator_solution=DummySolution(val=gen), solver_solution=DummySolution(val=sol)
56+
),
57+
score,
58+
)
59+
1560

1661
if __name__ == "__main__":
1762
unittest.main()

0 commit comments

Comments
 (0)