Skip to content

Commit 72f839f

Browse files
authored
Clean up llm 20 qs interpreter (#258)
Clean up LLM 20 Qs Intepreter
1 parent edfa1be commit 72f839f

File tree

1 file changed

+88
-118
lines changed

1 file changed

+88
-118
lines changed

kaggle_environments/envs/llm_20_questions/llm_20_questions.py

Lines changed: 88 additions & 118 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,16 @@
2020
tokenizer = None
2121
model_initialized = False
2222

23+
ERROR = "ERROR"
24+
DONE = "DONE"
25+
INACTIVE = "INACTIVE"
26+
ACTIVE = "ACTIVE"
27+
28+
GUESS = "guess"
29+
ASK = "ask"
30+
GUESSER = "guesser"
31+
ANSWERER = "guesser"
32+
2333
keywords_list = json.loads(KEYWORDS_JSON)
2434
keyword_cat = random.choice(keywords_list)
2535
category = keyword_cat["category"]
@@ -42,12 +52,12 @@ def guesser_agent(obs):
4252
)
4353

4454
prompt = ""
45-
if obs.turnType == "ask":
55+
if obs.turnType == ASK:
4656
prompt = "{}{}".format(
4757
info_prompt.format(q_a_thread=q_a_thread),
4858
questions_prompt
4959
)
50-
elif obs.turnType == "guess":
60+
elif obs.turnType == GUESS:
5161
prompt = "{}{}".format(
5262
info_prompt.format(q_a_thread=q_a_thread),
5363
guess_prompt
@@ -73,22 +83,82 @@ def answerer_agent(obs):
7383
return ""
7484

7585

76-
agents = {"guesser": guesser_agent, "answerer": answerer_agent}
86+
agents = {GUESSER: guesser_agent, ANSWERER: answerer_agent}
87+
88+
def guesser_action(active, inactive, step):
89+
guessed = False
90+
if not active.action:
91+
active.status = ERROR
92+
elif active.observation.turnType == ASK:
93+
active.observation.questions.append(active.action)
94+
inactive.observation.questions.append(active.action)
95+
elif active.observation.turnType == GUESS:
96+
active.observation.guesses.append(active.action)
97+
inactive.observation.guesses.append(active.action)
98+
if active.action and keyword_guessed(active.action):
99+
guessed = True
100+
score = 20 - int(step / 3)
101+
active.reward = score
102+
inactive.reward = score
103+
active.status = DONE
104+
inactive.status = DONE
105+
active.observation.keyword = keyword
106+
active.observation.category = category
107+
inactive.observation.keyword = keyword
108+
inactive.observation.category = category
109+
return guessed
110+
111+
def answerer_action(active, inactive):
112+
active.observation.keyword = keyword
113+
active.observation.category = category
114+
response = active.action
115+
if not response:
116+
response = "none"
117+
active.status = ERROR
118+
elif "yes" in response.lower():
119+
response = "yes"
120+
elif "no" in response.lower():
121+
response = "no"
122+
else:
123+
response = "maybe"
124+
active.status = ERROR
125+
active.observation.answers.append(response)
126+
inactive.observation.answers.append(response)
127+
128+
def increment_turn(active, inactive, step, guessed):
129+
if step == 59 and not guessed:
130+
active.observation.keyword = keyword
131+
active.observation.category = category
132+
inactive.observation.keyword = keyword
133+
inactive.observation.category = category
134+
active.reward = -1
135+
inactive.reward = -1
136+
active.status = DONE
137+
inactive.status = DONE
138+
elif active.observation.turnType == "guess":
139+
active.observation.turnType = "ask"
140+
elif active.observation.turnType == "ask":
141+
active.observation.turnType = "guess"
142+
active.status = INACTIVE
143+
inactive.status = ACTIVE
144+
else:
145+
active.status = INACTIVE
146+
inactive.status = ACTIVE
77147

78148

79149
def interpreter(state, env):
80150
if env.done:
81151
return state
82152

83153
# Isolate the active and inactive agents.
84-
active1 = state[0] if state[0].status == "ACTIVE" else state[1]
85-
inactive1 = state[0] if state[0].status == "INACTIVE" else state[1]
86-
active2 = state[2] if state[2].status == "ACTIVE" else state[3]
87-
inactive2 = state[2] if state[2].status == "INACTIVE" else state[3]
88-
if active1.status == "DONE" and inactive1.status == "DONE":
154+
active1 = state[0] if state[0].status == ACTIVE else state[1]
155+
inactive1 = state[0] if state[0].status == INACTIVE else state[1]
156+
active2 = state[2] if state[2].status == ACTIVE else state[3]
157+
inactive2 = state[2] if state[2].status == INACTIVE else state[3]
158+
if active1.status == DONE and inactive1.status == DONE:
89159
active1 = None
90160
inactive1 = None
91-
if active2.status == "DONE" or inactive2.status == "DONE":
161+
if active2.status == DONE or inactive2.status == DONE:
92162
active2 = None
93163
inactive2 = None
94164
if active1 is None and inactive1 is None and active2 is None and inactive2 is None:
@@ -98,119 +168,19 @@ def interpreter(state, env):
98168

99169
if active1 is not None:
100170
guessed = False
101-
if active1.observation.role == "guesser":
102-
if not active1.action:
103-
active1.status = "ERROR"
104-
elif active1.observation.turnType == "ask":
105-
active1.observation.questions.append(active1.action)
106-
inactive1.observation.questions.append(active1.action)
107-
elif active1.observation.turnType == "guess":
108-
active1.observation.guesses.append(active1.action)
109-
inactive1.observation.guesses.append(active1.action)
110-
if active1.action and keyword_guessed(active1.action):
111-
guessed = True
112-
score = 20 - int(step / 3)
113-
active1.reward = score
114-
inactive1.reward = score
115-
active1.status = "DONE"
116-
inactive1.status = "DONE"
117-
active1.observation.keyword = keyword
118-
active1.observation.category = category
119-
inactive1.observation.keyword = keyword
120-
inactive1.observation.category = category
121-
else:
122-
active1.observation.keyword = keyword
123-
active1.observation.category = category
124-
response = active1.action
125-
if not response:
126-
response = "none"
127-
active1.status = "ERROR"
128-
elif response.lower().__contains__("yes"):
129-
response = "yes"
130-
elif response.lower().__contains__("no"):
131-
response = "no"
132-
else:
133-
response = "maybe"
134-
active1.status = "ERROR"
135-
active1.observation.answers.append(response)
136-
inactive1.observation.answers.append(response)
137-
138-
if step == 59 and not guessed:
139-
active1.observation.keyword = keyword
140-
active1.observation.category = category
141-
inactive1.observation.keyword = keyword
142-
inactive1.observation.category = category
143-
active1.reward = -1
144-
inactive1.reward = -1
145-
active1.status = "DONE"
146-
inactive1.status = "DONE"
147-
elif active1.observation.turnType == "guess":
148-
active1.observation.turnType = "ask"
149-
elif active1.observation.turnType == "ask":
150-
active1.observation.turnType = "guess"
151-
active1.status = "INACTIVE"
152-
inactive1.status = "ACTIVE"
171+
if active1.observation.role == GUESSER:
172+
guessed = guesser_action(active1, inactive1, step)
153173
else:
154-
active1.status = "INACTIVE"
155-
inactive1.status = "ACTIVE"
174+
answerer_action(active1, inactive1)
175+
increment_turn(active1, inactive1, step, guessed)
156176

157177
if active2 is not None:
158178
guessed = False
159-
if active2.observation.role == "guesser":
160-
if not active2.action:
161-
active2.status = "ERROR"
162-
elif active2.observation.turnType == "ask":
163-
active2.observation.questions.append(active2.action)
164-
inactive2.observation.questions.append(active2.action)
165-
elif active2.observation.turnType == "guess":
166-
active2.observation.guesses.append(active2.action)
167-
inactive2.observation.guesses.append(active2.action)
168-
if active2.action and keyword_guessed(active2.action):
169-
guessed = True
170-
score = 20 - int(step / 3)
171-
active2.reward = score
172-
inactive2.reward = score
173-
active2.status = "DONE"
174-
inactive2.status = "DONE"
175-
active2.observation.keyword = keyword
176-
active2.observation.category = category
177-
inactive2.observation.keyword = keyword
178-
inactive2.observation.category = category
179-
else:
180-
active2.observation.keyword = keyword
181-
active2.observation.category = category
182-
response = active2.action
183-
if not response:
184-
reponse = "none"
185-
active2.status = "ERROR"
186-
elif response.lower().__contains__("yes"):
187-
response = "yes"
188-
elif response.lower().__contains__("no"):
189-
response = "no"
190-
else:
191-
reponse = "maybe"
192-
active2.status = "ERROR"
193-
active2.observation.answers.append(response)
194-
inactive2.observation.answers.append(response)
195-
196-
if step == 59 and not guessed:
197-
active2.observation.keyword = keyword
198-
active2.observation.category = category
199-
inactive2.observation.keyword = keyword
200-
inactive2.observation.category = category
201-
active2.reward = -1
202-
inactive2.reward = -1
203-
active2.status = "DONE"
204-
inactive2.status = "DONE"
205-
elif active2.observation.turnType == "guess":
206-
active2.observation.turnType = "ask"
207-
elif active2.observation.turnType == "ask":
208-
active2.observation.turnType = "guess"
209-
active2.status = "INACTIVE"
210-
inactive2.status = "ACTIVE"
179+
if active2.observation.role == GUESSER:
180+
guessed = guesser_action(active2, inactive2, step)
211181
else:
212-
active2.status = "INACTIVE"
213-
inactive2.status = "ACTIVE"
182+
answerer_action(active2, inactive2)
183+
increment_turn(active2, inactive2, step, guessed)
214184

215185
return state
216186

@@ -219,7 +189,7 @@ def renderer(state, env):
219189

220190
for s in state:
221191
print("role: ", s.observation.role)
222-
if s.observation.role == "guesser":
192+
if s.observation.role == GUESSER:
223193
transcript = ""
224194
for i in range(0, len(s.observation.guesses)):
225195
transcript = "{}Q: {} A: {}\nG: {}\n".format(

0 commit comments

Comments
 (0)