|
1 | 1 | from SPARQLWrapper import SPARQLWrapper, JSON
|
| 2 | +from utils import * |
| 3 | + |
2 | 4 | SPARQLPATH = "http://192.168.80.12:8890/sparql" # depend on your own internal address and port, shown in Freebase folder's readme.md
|
3 | 5 |
|
4 | 6 | # pre-defined sparqls
|
@@ -42,3 +44,216 @@ def id2entity_name_or_type(entity_id):
|
42 | 44 | return "UnName_Entity"
|
43 | 45 | else:
|
44 | 46 | return results["results"]["bindings"][0]['tailEntity']['value']
|
| 47 | + |
| 48 | +from freebase_func import * |
| 49 | +from prompt_list import * |
| 50 | +import json |
| 51 | +import time |
| 52 | +import openai |
| 53 | +import re |
| 54 | +from prompt_list import * |
| 55 | +from rank_bm25 import BM25Okapi |
| 56 | +from sentence_transformers import util |
| 57 | +from sentence_transformers import SentenceTransformer |
| 58 | + |
| 59 | + |
| 60 | +def clean_relations(string, entity_id, head_relations): |
| 61 | + pattern = r"{\s*(?P<relation>[^()]+)\s+\(Score:\s+(?P<score>[0-9.]+)\)}" |
| 62 | + relations=[] |
| 63 | + for match in re.finditer(pattern, string): |
| 64 | + relation = match.group("relation").strip() |
| 65 | + if ';' in relation: |
| 66 | + continue |
| 67 | + score = match.group("score") |
| 68 | + if not relation or not score: |
| 69 | + return False, "output uncompleted.." |
| 70 | + try: |
| 71 | + score = float(score) |
| 72 | + except ValueError: |
| 73 | + return False, "Invalid score" |
| 74 | + if relation in head_relations: |
| 75 | + relations.append({"entity": entity_id, "relation": relation, "score": score, "head": True}) |
| 76 | + else: |
| 77 | + relations.append({"entity": entity_id, "relation": relation, "score": score, "head": False}) |
| 78 | + if not relations: |
| 79 | + return False, "No relations found" |
| 80 | + return True, relations |
| 81 | + |
| 82 | + |
| 83 | +def if_all_zero(topn_scores): |
| 84 | + return all(score == 0 for score in topn_scores) |
| 85 | + |
| 86 | + |
| 87 | +def clean_relations_bm25_sent(topn_relations, topn_scores, entity_id, head_relations): |
| 88 | + relations = [] |
| 89 | + if if_all_zero(topn_scores): |
| 90 | + topn_scores = [float(1/len(topn_scores))] * len(topn_scores) |
| 91 | + i=0 |
| 92 | + for relation in topn_relations: |
| 93 | + if relation in head_relations: |
| 94 | + relations.append({"entity": entity_id, "relation": relation, "score": topn_scores[i], "head": True}) |
| 95 | + else: |
| 96 | + relations.append({"entity": entity_id, "relation": relation, "score": topn_scores[i], "head": False}) |
| 97 | + i+=1 |
| 98 | + return True, relations |
| 99 | + |
| 100 | + |
| 101 | +def construct_relation_prune_prompt(question, entity_name, total_relations, args): |
| 102 | + return extract_relation_prompt % (args.width, args.width) + question + '\nTopic Entity: ' + entity_name + '\nRelations: '+ '; '.join(total_relations) + "\nA: " |
| 103 | + |
| 104 | + |
| 105 | +def construct_entity_score_prompt(question, relation, entity_candidates): |
| 106 | + return score_entity_candidates_prompt.format(question, relation) + "; ".join(entity_candidates) + '\nScore: ' |
| 107 | + |
| 108 | + |
| 109 | +def relation_search_prune(entity_id, entity_name, pre_relations, pre_head, question, args): |
| 110 | + sparql_relations_extract_head = sparql_head_relations % (entity_id) |
| 111 | + head_relations = execurte_sparql(sparql_relations_extract_head) |
| 112 | + head_relations = replace_relation_prefix(head_relations) |
| 113 | + |
| 114 | + sparql_relations_extract_tail= sparql_tail_relations % (entity_id) |
| 115 | + tail_relations = execurte_sparql(sparql_relations_extract_tail) |
| 116 | + tail_relations = replace_relation_prefix(tail_relations) |
| 117 | + |
| 118 | + if args.remove_unnecessary_rel: |
| 119 | + head_relations = [relation for relation in head_relations if not abandon_rels(relation)] |
| 120 | + tail_relations = [relation for relation in tail_relations if not abandon_rels(relation)] |
| 121 | + |
| 122 | + if pre_head: |
| 123 | + tail_relations = list(set(tail_relations) - set(pre_relations)) |
| 124 | + else: |
| 125 | + head_relations = list(set(head_relations) - set(pre_relations)) |
| 126 | + |
| 127 | + head_relations = list(set(head_relations)) |
| 128 | + tail_relations = list(set(tail_relations)) |
| 129 | + total_relations = head_relations+tail_relations |
| 130 | + total_relations.sort() # make sure the order in prompt is always equal |
| 131 | + |
| 132 | + if args.prune_tools == "llm": |
| 133 | + prompt = construct_relation_prune_prompt(question, entity_name, total_relations, args) |
| 134 | + |
| 135 | + result = run_llm(prompt, args.temperature_exploration, args.max_length, args.opeani_api_keys, args.LLM_type) |
| 136 | + flag, retrieve_relations_with_scores = clean_relations(result, entity_id, head_relations) |
| 137 | + |
| 138 | + elif args.prune_tools == "bm25": |
| 139 | + topn_relations, topn_scores = compute_bm25_similarity(question, total_relations, args.width) |
| 140 | + flag, retrieve_relations_with_scores = clean_relations_bm25_sent(topn_relations, topn_scores, entity_id, head_relations) |
| 141 | + else: |
| 142 | + model = SentenceTransformer('sentence-transformers/msmarco-distilbert-base-tas-b') |
| 143 | + topn_relations, topn_scores = retrieve_top_docs(question, total_relations, model, args.width) |
| 144 | + flag, retrieve_relations_with_scores = clean_relations_bm25_sent(topn_relations, topn_scores, entity_id, head_relations) |
| 145 | + |
| 146 | + if flag: |
| 147 | + return retrieve_relations_with_scores |
| 148 | + else: |
| 149 | + return [] # format error or too small max_length |
| 150 | + |
| 151 | + |
| 152 | +def entity_search(entity, relation, head=True): |
| 153 | + if head: |
| 154 | + tail_entities_extract = sparql_tail_entities_extract% (entity, relation) |
| 155 | + entities = execurte_sparql(tail_entities_extract) |
| 156 | + else: |
| 157 | + head_entities_extract = sparql_head_entities_extract% (entity, relation) |
| 158 | + entities = execurte_sparql(head_entities_extract) |
| 159 | + |
| 160 | + |
| 161 | + entity_ids = replace_entities_prefix(entities) |
| 162 | + new_entity = [entity for entity in entity_ids if entity.startswith("m.")] |
| 163 | + return new_entity |
| 164 | + |
| 165 | + |
| 166 | +def entity_score(question, entity_candidates_id, score, relation, args): |
| 167 | + entity_candidates = [id2entity_name_or_type(entity_id) for entity_id in entity_candidates_id] |
| 168 | + if all_unknown_entity(entity_candidates): |
| 169 | + return [1/len(entity_candidates) * score] * len(entity_candidates), entity_candidates, entity_candidates_id |
| 170 | + entity_candidates = del_unknown_entity(entity_candidates) |
| 171 | + if len(entity_candidates) == 1: |
| 172 | + return [score], entity_candidates, entity_candidates_id |
| 173 | + if len(entity_candidates) == 0: |
| 174 | + return [0.0], entity_candidates, entity_candidates_id |
| 175 | + |
| 176 | + # make sure the id and entity are in the same order |
| 177 | + zipped_lists = sorted(zip(entity_candidates, entity_candidates_id)) |
| 178 | + entity_candidates, entity_candidates_id = zip(*zipped_lists) |
| 179 | + entity_candidates = list(entity_candidates) |
| 180 | + entity_candidates_id = list(entity_candidates_id) |
| 181 | + if args.prune_tools == "llm": |
| 182 | + prompt = construct_entity_score_prompt(question, relation, entity_candidates) |
| 183 | + |
| 184 | + result = run_llm(prompt, args.temperature_exploration, args.max_length, args.opeani_api_keys, args.LLM_type) |
| 185 | + return [float(x) * score for x in clean_scores(result, entity_candidates)], entity_candidates, entity_candidates_id |
| 186 | + |
| 187 | + elif args.prune_tools == "bm25": |
| 188 | + topn_entities, topn_scores = compute_bm25_similarity(question, entity_candidates, args.width) |
| 189 | + else: |
| 190 | + model = SentenceTransformer('sentence-transformers/msmarco-distilbert-base-tas-b') |
| 191 | + topn_entities, topn_scores = retrieve_top_docs(question, entity_candidates, model, args.width) |
| 192 | + if if_all_zero(topn_scores): |
| 193 | + topn_scores = [float(1/len(topn_scores))] * len(topn_scores) |
| 194 | + return [float(x) * score for x in topn_scores], topn_entities, entity_candidates_id |
| 195 | + |
| 196 | + |
| 197 | +def update_history(entity_candidates, entity, scores, entity_candidates_id, total_candidates, total_scores, total_relations, total_entities_id, total_topic_entities, total_head): |
| 198 | + if len(entity_candidates) == 0: |
| 199 | + entity_candidates.append("[FINISH]") |
| 200 | + entity_candidates_id = ["[FINISH_ID]"] |
| 201 | + candidates_relation = [entity['relation']] * len(entity_candidates) |
| 202 | + topic_entities = [entity['entity']] * len(entity_candidates) |
| 203 | + head_num = [entity['head']] * len(entity_candidates) |
| 204 | + total_candidates.extend(entity_candidates) |
| 205 | + total_scores.extend(scores) |
| 206 | + total_relations.extend(candidates_relation) |
| 207 | + total_entities_id.extend(entity_candidates_id) |
| 208 | + total_topic_entities.extend(topic_entities) |
| 209 | + total_head.extend(head_num) |
| 210 | + return total_candidates, total_scores, total_relations, total_entities_id, total_topic_entities, total_head |
| 211 | + |
| 212 | + |
| 213 | +def half_stop(question, cluster_chain_of_entities, depth, args): |
| 214 | + print("No new knowledge added during search depth %d, stop searching." % depth) |
| 215 | + answer = generate_answer(question, cluster_chain_of_entities, args) |
| 216 | + save_2_jsonl(question, answer, cluster_chain_of_entities, file_name=args.dataset) |
| 217 | + |
| 218 | + |
| 219 | +def generate_answer(question, cluster_chain_of_entities, args): |
| 220 | + prompt = answer_prompt + question + '\n' |
| 221 | + chain_prompt = '\n'.join([', '.join([str(x) for x in chain]) for sublist in cluster_chain_of_entities for chain in sublist]) |
| 222 | + prompt += "\nKnowledge Triplets: " + chain_prompt + 'A: ' |
| 223 | + result = run_llm(prompt, args.temperature_reasoning, args.max_length, args.opeani_api_keys, args.LLM_type) |
| 224 | + return result |
| 225 | + |
| 226 | + |
| 227 | +def entity_prune(total_entities_id, total_relations, total_candidates, total_topic_entities, total_head, total_scores, args): |
| 228 | + zipped = list(zip(total_entities_id, total_relations, total_candidates, total_topic_entities, total_head, total_scores)) |
| 229 | + sorted_zipped = sorted(zipped, key=lambda x: x[5], reverse=True) |
| 230 | + sorted_entities_id, sorted_relations, sorted_candidates, sorted_topic_entities, sorted_head, sorted_scores = [x[0] for x in sorted_zipped], [x[1] for x in sorted_zipped], [x[2] for x in sorted_zipped], [x[3] for x in sorted_zipped], [x[4] for x in sorted_zipped], [x[5] for x in sorted_zipped] |
| 231 | + |
| 232 | + entities_id, relations, candidates, topics, heads, scores = sorted_entities_id[:args.width], sorted_relations[:args.width], sorted_candidates[:args.width], sorted_topic_entities[:args.width], sorted_head[:args.width], sorted_scores[:args.width] |
| 233 | + merged_list = list(zip(entities_id, relations, candidates, topics, heads, scores)) |
| 234 | + filtered_list = [(id, rel, ent, top, hea, score) for id, rel, ent, top, hea, score in merged_list if score != 0] |
| 235 | + if len(filtered_list) ==0: |
| 236 | + return False, [], [], [], [] |
| 237 | + entities_id, relations, candidates, tops, heads, scores = map(list, zip(*filtered_list)) |
| 238 | + |
| 239 | + tops = [id2entity_name_or_type(entity_id) for entity_id in tops] |
| 240 | + cluster_chain_of_entities = [[(tops[i], relations[i], candidates[i]) for i in range(len(candidates))]] |
| 241 | + return True, cluster_chain_of_entities, entities_id, relations, heads |
| 242 | + |
| 243 | + |
| 244 | +def reasoning(question, cluster_chain_of_entities, args): |
| 245 | + prompt = prompt_evaluate + question |
| 246 | + chain_prompt = '\n'.join([', '.join([str(x) for x in chain]) for sublist in cluster_chain_of_entities for chain in sublist]) |
| 247 | + prompt += "\nKnowledge Triplets: " + chain_prompt + 'A: ' |
| 248 | + |
| 249 | + response = run_llm(prompt, args.temperature_reasoning, args.max_length, args.opeani_api_keys, args.LLM_type) |
| 250 | + |
| 251 | + result = extract_answer(response) |
| 252 | + if if_true(result): |
| 253 | + return True, response |
| 254 | + else: |
| 255 | + return False, response |
| 256 | + |
| 257 | + |
| 258 | + |
| 259 | + |
0 commit comments