Skip to content

Commit e007fbd

Browse files
committed
fix pop bugs
1 parent 162d64f commit e007fbd

File tree

5 files changed

+267
-279
lines changed

5 files changed

+267
-279
lines changed

ToG/freebase_func.py

Lines changed: 215 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
from SPARQLWrapper import SPARQLWrapper, JSON
2+
from utils import *
3+
24
SPARQLPATH = "http://192.168.80.12:8890/sparql" # depend on your own internal address and port, shown in Freebase folder's readme.md
35

46
# pre-defined sparqls
@@ -42,3 +44,216 @@ def id2entity_name_or_type(entity_id):
4244
return "UnName_Entity"
4345
else:
4446
return results["results"]["bindings"][0]['tailEntity']['value']
47+
48+
from freebase_func import *
49+
from prompt_list import *
50+
import json
51+
import time
52+
import openai
53+
import re
54+
from prompt_list import *
55+
from rank_bm25 import BM25Okapi
56+
from sentence_transformers import util
57+
from sentence_transformers import SentenceTransformer
58+
59+
60+
def clean_relations(string, entity_id, head_relations):
61+
pattern = r"{\s*(?P<relation>[^()]+)\s+\(Score:\s+(?P<score>[0-9.]+)\)}"
62+
relations=[]
63+
for match in re.finditer(pattern, string):
64+
relation = match.group("relation").strip()
65+
if ';' in relation:
66+
continue
67+
score = match.group("score")
68+
if not relation or not score:
69+
return False, "output uncompleted.."
70+
try:
71+
score = float(score)
72+
except ValueError:
73+
return False, "Invalid score"
74+
if relation in head_relations:
75+
relations.append({"entity": entity_id, "relation": relation, "score": score, "head": True})
76+
else:
77+
relations.append({"entity": entity_id, "relation": relation, "score": score, "head": False})
78+
if not relations:
79+
return False, "No relations found"
80+
return True, relations
81+
82+
83+
def if_all_zero(topn_scores):
84+
return all(score == 0 for score in topn_scores)
85+
86+
87+
def clean_relations_bm25_sent(topn_relations, topn_scores, entity_id, head_relations):
88+
relations = []
89+
if if_all_zero(topn_scores):
90+
topn_scores = [float(1/len(topn_scores))] * len(topn_scores)
91+
i=0
92+
for relation in topn_relations:
93+
if relation in head_relations:
94+
relations.append({"entity": entity_id, "relation": relation, "score": topn_scores[i], "head": True})
95+
else:
96+
relations.append({"entity": entity_id, "relation": relation, "score": topn_scores[i], "head": False})
97+
i+=1
98+
return True, relations
99+
100+
101+
def construct_relation_prune_prompt(question, entity_name, total_relations, args):
102+
return extract_relation_prompt % (args.width, args.width) + question + '\nTopic Entity: ' + entity_name + '\nRelations: '+ '; '.join(total_relations) + "\nA: "
103+
104+
105+
def construct_entity_score_prompt(question, relation, entity_candidates):
106+
return score_entity_candidates_prompt.format(question, relation) + "; ".join(entity_candidates) + '\nScore: '
107+
108+
109+
def relation_search_prune(entity_id, entity_name, pre_relations, pre_head, question, args):
110+
sparql_relations_extract_head = sparql_head_relations % (entity_id)
111+
head_relations = execurte_sparql(sparql_relations_extract_head)
112+
head_relations = replace_relation_prefix(head_relations)
113+
114+
sparql_relations_extract_tail= sparql_tail_relations % (entity_id)
115+
tail_relations = execurte_sparql(sparql_relations_extract_tail)
116+
tail_relations = replace_relation_prefix(tail_relations)
117+
118+
if args.remove_unnecessary_rel:
119+
head_relations = [relation for relation in head_relations if not abandon_rels(relation)]
120+
tail_relations = [relation for relation in tail_relations if not abandon_rels(relation)]
121+
122+
if pre_head:
123+
tail_relations = list(set(tail_relations) - set(pre_relations))
124+
else:
125+
head_relations = list(set(head_relations) - set(pre_relations))
126+
127+
head_relations = list(set(head_relations))
128+
tail_relations = list(set(tail_relations))
129+
total_relations = head_relations+tail_relations
130+
total_relations.sort() # make sure the order in prompt is always equal
131+
132+
if args.prune_tools == "llm":
133+
prompt = construct_relation_prune_prompt(question, entity_name, total_relations, args)
134+
135+
result = run_llm(prompt, args.temperature_exploration, args.max_length, args.opeani_api_keys, args.LLM_type)
136+
flag, retrieve_relations_with_scores = clean_relations(result, entity_id, head_relations)
137+
138+
elif args.prune_tools == "bm25":
139+
topn_relations, topn_scores = compute_bm25_similarity(question, total_relations, args.width)
140+
flag, retrieve_relations_with_scores = clean_relations_bm25_sent(topn_relations, topn_scores, entity_id, head_relations)
141+
else:
142+
model = SentenceTransformer('sentence-transformers/msmarco-distilbert-base-tas-b')
143+
topn_relations, topn_scores = retrieve_top_docs(question, total_relations, model, args.width)
144+
flag, retrieve_relations_with_scores = clean_relations_bm25_sent(topn_relations, topn_scores, entity_id, head_relations)
145+
146+
if flag:
147+
return retrieve_relations_with_scores
148+
else:
149+
return [] # format error or too small max_length
150+
151+
152+
def entity_search(entity, relation, head=True):
153+
if head:
154+
tail_entities_extract = sparql_tail_entities_extract% (entity, relation)
155+
entities = execurte_sparql(tail_entities_extract)
156+
else:
157+
head_entities_extract = sparql_head_entities_extract% (entity, relation)
158+
entities = execurte_sparql(head_entities_extract)
159+
160+
161+
entity_ids = replace_entities_prefix(entities)
162+
new_entity = [entity for entity in entity_ids if entity.startswith("m.")]
163+
return new_entity
164+
165+
166+
def entity_score(question, entity_candidates_id, score, relation, args):
167+
entity_candidates = [id2entity_name_or_type(entity_id) for entity_id in entity_candidates_id]
168+
if all_unknown_entity(entity_candidates):
169+
return [1/len(entity_candidates) * score] * len(entity_candidates), entity_candidates, entity_candidates_id
170+
entity_candidates = del_unknown_entity(entity_candidates)
171+
if len(entity_candidates) == 1:
172+
return [score], entity_candidates, entity_candidates_id
173+
if len(entity_candidates) == 0:
174+
return [0.0], entity_candidates, entity_candidates_id
175+
176+
# make sure the id and entity are in the same order
177+
zipped_lists = sorted(zip(entity_candidates, entity_candidates_id))
178+
entity_candidates, entity_candidates_id = zip(*zipped_lists)
179+
entity_candidates = list(entity_candidates)
180+
entity_candidates_id = list(entity_candidates_id)
181+
if args.prune_tools == "llm":
182+
prompt = construct_entity_score_prompt(question, relation, entity_candidates)
183+
184+
result = run_llm(prompt, args.temperature_exploration, args.max_length, args.opeani_api_keys, args.LLM_type)
185+
return [float(x) * score for x in clean_scores(result, entity_candidates)], entity_candidates, entity_candidates_id
186+
187+
elif args.prune_tools == "bm25":
188+
topn_entities, topn_scores = compute_bm25_similarity(question, entity_candidates, args.width)
189+
else:
190+
model = SentenceTransformer('sentence-transformers/msmarco-distilbert-base-tas-b')
191+
topn_entities, topn_scores = retrieve_top_docs(question, entity_candidates, model, args.width)
192+
if if_all_zero(topn_scores):
193+
topn_scores = [float(1/len(topn_scores))] * len(topn_scores)
194+
return [float(x) * score for x in topn_scores], topn_entities, entity_candidates_id
195+
196+
197+
def update_history(entity_candidates, entity, scores, entity_candidates_id, total_candidates, total_scores, total_relations, total_entities_id, total_topic_entities, total_head):
198+
if len(entity_candidates) == 0:
199+
entity_candidates.append("[FINISH]")
200+
entity_candidates_id = ["[FINISH_ID]"]
201+
candidates_relation = [entity['relation']] * len(entity_candidates)
202+
topic_entities = [entity['entity']] * len(entity_candidates)
203+
head_num = [entity['head']] * len(entity_candidates)
204+
total_candidates.extend(entity_candidates)
205+
total_scores.extend(scores)
206+
total_relations.extend(candidates_relation)
207+
total_entities_id.extend(entity_candidates_id)
208+
total_topic_entities.extend(topic_entities)
209+
total_head.extend(head_num)
210+
return total_candidates, total_scores, total_relations, total_entities_id, total_topic_entities, total_head
211+
212+
213+
def half_stop(question, cluster_chain_of_entities, depth, args):
214+
print("No new knowledge added during search depth %d, stop searching." % depth)
215+
answer = generate_answer(question, cluster_chain_of_entities, args)
216+
save_2_jsonl(question, answer, cluster_chain_of_entities, file_name=args.dataset)
217+
218+
219+
def generate_answer(question, cluster_chain_of_entities, args):
220+
prompt = answer_prompt + question + '\n'
221+
chain_prompt = '\n'.join([', '.join([str(x) for x in chain]) for sublist in cluster_chain_of_entities for chain in sublist])
222+
prompt += "\nKnowledge Triplets: " + chain_prompt + 'A: '
223+
result = run_llm(prompt, args.temperature_reasoning, args.max_length, args.opeani_api_keys, args.LLM_type)
224+
return result
225+
226+
227+
def entity_prune(total_entities_id, total_relations, total_candidates, total_topic_entities, total_head, total_scores, args):
228+
zipped = list(zip(total_entities_id, total_relations, total_candidates, total_topic_entities, total_head, total_scores))
229+
sorted_zipped = sorted(zipped, key=lambda x: x[5], reverse=True)
230+
sorted_entities_id, sorted_relations, sorted_candidates, sorted_topic_entities, sorted_head, sorted_scores = [x[0] for x in sorted_zipped], [x[1] for x in sorted_zipped], [x[2] for x in sorted_zipped], [x[3] for x in sorted_zipped], [x[4] for x in sorted_zipped], [x[5] for x in sorted_zipped]
231+
232+
entities_id, relations, candidates, topics, heads, scores = sorted_entities_id[:args.width], sorted_relations[:args.width], sorted_candidates[:args.width], sorted_topic_entities[:args.width], sorted_head[:args.width], sorted_scores[:args.width]
233+
merged_list = list(zip(entities_id, relations, candidates, topics, heads, scores))
234+
filtered_list = [(id, rel, ent, top, hea, score) for id, rel, ent, top, hea, score in merged_list if score != 0]
235+
if len(filtered_list) ==0:
236+
return False, [], [], [], []
237+
entities_id, relations, candidates, tops, heads, scores = map(list, zip(*filtered_list))
238+
239+
tops = [id2entity_name_or_type(entity_id) for entity_id in tops]
240+
cluster_chain_of_entities = [[(tops[i], relations[i], candidates[i]) for i in range(len(candidates))]]
241+
return True, cluster_chain_of_entities, entities_id, relations, heads
242+
243+
244+
def reasoning(question, cluster_chain_of_entities, args):
245+
prompt = prompt_evaluate + question
246+
chain_prompt = '\n'.join([', '.join([str(x) for x in chain]) for sublist in cluster_chain_of_entities for chain in sublist])
247+
prompt += "\nKnowledge Triplets: " + chain_prompt + 'A: '
248+
249+
response = run_llm(prompt, args.temperature_reasoning, args.max_length, args.opeani_api_keys, args.LLM_type)
250+
251+
result = extract_answer(response)
252+
if if_true(result):
253+
return True, response
254+
else:
255+
return False, response
256+
257+
258+
259+

ToG/main_freebase.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from tqdm import tqdm
22
import argparse
33
from utils import *
4+
from freebase_func import *
45
import random
56
from client import *
67

@@ -37,6 +38,10 @@
3738
question = data[question_string]
3839
topic_entity = data['topic_entity']
3940
cluster_chain_of_entities = []
41+
if len(topic_entity) == 0:
42+
results = generate_without_explored_paths(question, args)
43+
save_2_jsonl(question, results, [], file_name=args.dataset)
44+
continue
4045
pre_relations = []
4146
pre_heads= [-1] * len(topic_entity)
4247
flag_printed = False
@@ -87,8 +92,13 @@
8792
break
8893
else:
8994
print("depth %d still not find the answer." % depth)
90-
topic_entity = {entity: id2entity_name_or_type(entity) for entity in entities_id}
91-
continue
95+
flag_finish, entities_id = if_finish_list(entities_id)
96+
if flag_finish:
97+
half_stop(question, cluster_chain_of_entities, depth, args)
98+
flag_printed = True
99+
else:
100+
topic_entity = {entity: id2entity_name_or_type(entity) for entity in entities_id}
101+
continue
92102
else:
93103
half_stop(question, cluster_chain_of_entities, depth, args)
94104
flag_printed = True

ToG/main_wiki.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import random
44
from wiki_func import *
55
from client import *
6+
from utils import *
67

78

89
if __name__ == '__main__':
@@ -39,6 +40,10 @@
3940
question = data[question_string]
4041
topic_entity = data['qid_topic_entity']
4142
cluster_chain_of_entities = []
43+
if len(topic_entity) == 0:
44+
results = generate_without_explored_paths(question, args)
45+
save_2_jsonl(question, results, [], file_name=args.dataset)
46+
continue
4247
pre_relations = []
4348
pre_heads= [-1] * len(topic_entity)
4449
flag_printed = False
@@ -105,8 +110,13 @@
105110
break
106111
else:
107112
print("depth %d still not find the answer." % depth)
108-
topic_entity = {qid: topic for qid, topic in zip(entities_id, [wiki_client.query_all("qid2label", entity).pop() for entity in entities_id])}
109-
continue
113+
flag_finish, entities_id = if_finish_list(entities_id)
114+
if flag_finish:
115+
half_stop(question, cluster_chain_of_entities, depth, args)
116+
flag_printed = True
117+
else:
118+
topic_entity = {qid: topic for qid, topic in zip(entities_id, [wiki_client.query_all("qid2label", entity).pop() for entity in entities_id])}
119+
continue
110120
else:
111121
half_stop(question, cluster_chain_of_entities, depth, args)
112122
flag_printed = True

0 commit comments

Comments
 (0)