Skip to content

Commit a769e3f

Browse files
authored
End the game early in case of a timeout (#259)
End the game early in case of timeout or non yes/no answer.
1 parent 72f839f commit a769e3f

File tree

1 file changed

+29
-19
lines changed

1 file changed

+29
-19
lines changed

kaggle_environments/envs/llm_20_questions/llm_20_questions.py

Lines changed: 29 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
DONE = "DONE"
2525
INACTIVE = "INACTIVE"
2626
ACTIVE = "ACTIVE"
27+
TIMEOUT = "TIMEOUT"
2728

2829
GUESS = "guess"
2930
ASK = "ask"
@@ -98,43 +99,40 @@ def guesser_action(active, inactive, step):
9899
if active.action and keyword_guessed(active.action):
99100
guessed = True
100101
score = 20 - int(step / 3)
101-
active.reward = score
102-
inactive.reward = score
103-
active.status = DONE
104-
inactive.status = DONE
105-
active.observation.keyword = keyword
106-
active.observation.category = category
102+
end_game(active, inactive, score, DONE, DONE)
103+
return guessed
104+
105+
def end_game(active, inactive, reward, status, inactive_status):
106+
active.observation.keyword = keyword
107+
active.observation.category = category
107108
inactive.observation.keyword = keyword
108109
inactive.observation.category = category
109-
return guessed
110+
active.reward = reward
111+
inactive.reward = reward
112+
active.status = status
113+
inactive.status = inactive_status
114+
110115

111116
def answerer_action(active, inactive):
112117
active.observation.keyword = keyword
113118
active.observation.category = category
114119
response = active.action
115120
if not response:
116121
response = "none"
117-
active.status = ERROR
122+
end_game(active, inactive, -1, ERROR, DONE)
118123
elif "yes" in response.lower():
119124
response = "yes"
120125
elif "no" in response.lower():
121126
response = "no"
122127
else:
123128
response = "maybe"
124-
active.status = ERROR
129+
end_game(active, inactive, -1, ERROR, DONE)
125130
active.observation.answers.append(response)
126131
inactive.observation.answers.append(response)
127132

128133
def increment_turn(active, inactive, step, guessed):
129134
if step == 59 and not guessed:
130-
active.observation.keyword = keyword
131-
active.observation.category = category
132-
inactive.observation.keyword = keyword
133-
inactive.observation.category = category
134-
active.reward = -1
135-
inactive.reward = -1
136-
active.status = DONE
137-
inactive.status = DONE
135+
end_game(active, inactive, -1, DONE, DONE)
138136
elif active.observation.turnType == "guess":
139137
active.observation.turnType = "ask"
140138
elif active.observation.turnType == "ask":
@@ -166,21 +164,33 @@ def interpreter(state, env):
166164

167165
step = state[0].observation.step
168166

167+
end_early = (active1 and active1.status) in (TIMEOUT, ERROR) or (active2 and active2.status in (TIMEOUT, ERROR))
168+
169169
if active1 is not None:
170170
guessed = False
171171
if active1.observation.role == GUESSER:
172172
guessed = guesser_action(active1, inactive1, step)
173173
else:
174174
answerer_action(active1, inactive1)
175-
increment_turn(active1, inactive1, step, guessed)
175+
if active1.status in (TIMEOUT, ERROR):
176+
end_game(active1, inactive1, -1, active1.status, DONE)
177+
elif end_early:
178+
end_game(active1, inactive1, 0, DONE, DONE)
179+
else:
180+
increment_turn(active1, inactive1, step, guessed)
176181

177182
if active2 is not None:
178183
guessed = False
179184
if active2.observation.role == GUESSER:
180185
guessed = guesser_action(active2, inactive2, step)
181186
else:
182187
answerer_action(active2, inactive2)
183-
increment_turn(active2, inactive2, step, guessed)
188+
if active2.status in (TIMEOUT, ERROR):
189+
end_game(active2, inactive2, -1, active2.status, DONE)
190+
elif end_early:
191+
end_game(active2, inactive2, 0, DONE, DONE)
192+
else:
193+
increment_turn(active2, inactive2, step, guessed)
184194

185195
return state
186196

0 commit comments

Comments
 (0)