Skip to content

Commit 6907b1d

Browse files
authored
Allows other values than NaN for unplayable levels in SMB (#196)
1 parent 8ac4ca6 commit 6907b1d

File tree

2 files changed

+10
-2
lines changed

2 files changed

+10
-2
lines changed

src/poli/objective_repository/super_mario_bros/isolated_function.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,12 +66,14 @@ def __init__(
6666
alphabet: List[str] = smb_info.alphabet,
6767
max_time: int = 30,
6868
visualize: bool = False,
69+
value_on_unplayable: float = np.NaN,
6970
):
7071
self.alphabet = alphabet
7172
self.alphabet_s_to_i = {s: i for i, s in enumerate(alphabet)}
7273
self.alphabet_i_to_s = {i: s for i, s in enumerate(alphabet)}
7374
self.max_time = max_time
7475
self.visualize = visualize
76+
self.value_on_unplayable = value_on_unplayable
7577

7678
def __call__(self, x: np.ndarray, context=None) -> np.ndarray:
7779
"""Computes number of jumps in a given latent code x."""
@@ -95,7 +97,7 @@ def __call__(self, x: np.ndarray, context=None) -> np.ndarray:
9597
if res["marioStatus"] == 1:
9698
jumps = res["jumpActionsPerformed"]
9799
else:
98-
jumps = np.nan
100+
jumps = self.value_on_unplayable
99101

100102
jumps_for_all_levels.append(jumps)
101103

src/poli/objective_repository/super_mario_bros/register.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ def __init__(
6464
self,
6565
max_time: int = 30,
6666
visualize: bool = False,
67+
value_on_unplayable: float = np.NaN,
6768
batch_size: int = None,
6869
parallelize: bool = False,
6970
num_workers: int = None,
@@ -91,8 +92,9 @@ def __init__(
9192
evaluation_budget=evaluation_budget,
9293
)
9394
self.force_isolation = force_isolation
94-
self.max_time = max_time
95+
self.max_time = int(max_time)
9596
self.visualize = visualize
97+
self.value_on_unplayable = value_on_unplayable
9698
_ = get_inner_function(
9799
isolated_function_name="super_mario_bros__isolated",
98100
class_name="SMBIsolatedLogic",
@@ -101,6 +103,7 @@ def __init__(
101103
alphabet=smb_info.alphabet,
102104
max_time=self.max_time,
103105
visualize=self.visualize,
106+
value_on_unplayable=self.value_on_unplayable,
104107
)
105108

106109
def _black_box(self, x: np.ndarray, context=None) -> np.ndarray:
@@ -114,6 +117,7 @@ def _black_box(self, x: np.ndarray, context=None) -> np.ndarray:
114117
alphabet=smb_info.alphabet,
115118
max_time=self.max_time,
116119
visualize=self.visualize,
120+
value_on_unplayable=self.value_on_unplayable,
117121
)
118122
return inner_function(x, context)
119123

@@ -142,6 +146,7 @@ def create(
142146
self,
143147
max_time: int = 30,
144148
visualize: bool = False,
149+
value_on_unplayable: float = np.NaN,
145150
seed: int = None,
146151
batch_size: int = None,
147152
parallelize: bool = False,
@@ -182,6 +187,7 @@ def create(
182187
f = SuperMarioBrosBlackBox(
183188
max_time=max_time,
184189
visualize=visualize,
190+
value_on_unplayable=value_on_unplayable,
185191
batch_size=batch_size,
186192
parallelize=parallelize,
187193
num_workers=num_workers,

0 commit comments

Comments
 (0)