Skip to content

Commit 3840d68

Browse files
refactor: split quiz and judge
1 parent 7ae7955 commit 3840d68

File tree

5 files changed

+103
-34
lines changed

5 files changed

+103
-34
lines changed

generate.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,9 @@
7373

7474
graph_gen.insert(data, args.data_type)
7575

76-
graph_gen.judge(re_judge=True, max_samples=3)
76+
graph_gen.quiz(max_samples=3)
77+
78+
graph_gen.judge(re_judge=True)
7779

7880
graph_gen.traverse()
7981
with open(os.path.join(sys_path, "cache", "configs", f"graphgen_{unique_id}.yaml"), "w", encoding='utf-8') as f:

graphgen/graphgen.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,10 @@
77
from dataclasses import dataclass
88
from tqdm.asyncio import tqdm as tqdm_async
99

10-
from .operators import *
1110
from models import Chunk, JsonKVStorage, OpenAIModel, NetworkXStorage, WikiSearch, Tokenizer, TraverseStrategy
1211
from utils import create_event_loop, logger, compute_content_hash
1312
from models.storage.base_storage import StorageNameSpace
13+
from .operators import *
1414

1515

1616
sys_path = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
@@ -81,7 +81,8 @@ async def async_split_chunks(self, data: Union[List[list], List[dict]], data_typ
8181
compute_content_hash(dp["content"], prefix="chunk-"): {
8282
**dp,
8383
'full_doc_id': doc_key
84-
} for dp in self.tokenizer_instance.chunk_by_token_size(doc["content"], self.chunk_overlap_size, self.chunk_size)
84+
} for dp in self.tokenizer_instance.chunk_by_token_size(doc["content"],
85+
self.chunk_overlap_size, self.chunk_size)
8586
}
8687
inserting_chunks.update(chunks)
8788
_add_chunk_keys = await self.text_chunks_storage.filter_keys(list(inserting_chunks.keys()))
@@ -156,21 +157,29 @@ async def async_insert(self, data: Union[List[list], List[dict]], data_type: str
156157

157158
async def _insert_done(self):
158159
tasks = []
159-
for storage_instance in [self.full_docs_storage, self.text_chunks_storage, self.graph_storage, self.wiki_storage]:
160+
for storage_instance in [self.full_docs_storage, self.text_chunks_storage,
161+
self.graph_storage, self.wiki_storage]:
160162
if storage_instance is None:
161163
continue
162164
tasks.append(cast(StorageNameSpace, storage_instance).index_done_callback())
163165
await asyncio.gather(*tasks)
164166

165-
def judge(self, re_judge=False, max_samples=1):
167+
def quiz(self, max_samples=1):
166168
loop = create_event_loop()
167-
loop.run_until_complete(self.async_judge(re_judge, max_samples))
169+
loop.run_until_complete(self.async_quiz(max_samples))
170+
171+
async def async_quiz(self, max_samples=1):
172+
await quiz_relations(self.teacher_llm_client, self.graph_storage, self.rephrase_storage, max_samples)
173+
await self.rephrase_storage.index_done_callback()
168174

169-
async def async_judge(self, re_judge=False, max_samples=1):
175+
def judge(self, re_judge=False):
176+
loop = create_event_loop()
177+
loop.run_until_complete(self.async_judge(re_judge))
178+
179+
async def async_judge(self, re_judge=False):
170180
_update_relations = await judge_relations(self.teacher_llm_client, self.student_llm_client,
171-
self.graph_storage, self.rephrase_storage, re_judge, max_samples)
181+
self.graph_storage, self.rephrase_storage, re_judge)
172182
await _update_relations.index_done_callback()
173-
await self.rephrase_storage.index_done_callback()
174183

175184
def traverse(self):
176185
loop = create_event_loop()

graphgen/operators/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
from .extract_kg import extract_kg
2+
from .quiz_relations import quiz_relations
23
from .judge_relations import judge_relations
34
from .search_wikipedia import search_wikipedia
45
from .traverse_graph import traverse_graph_by_edge
56

67
__all__ = [
78
"extract_kg",
9+
"quiz_relations",
810
"judge_relations",
911
"search_wikipedia",
1012
"traverse_graph_by_edge"

graphgen/operators/judge_relations.py

Lines changed: 3 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ async def judge_relations(
1212
graph_storage: NetworkXStorage,
1313
rephrase_storage: JsonKVStorage,
1414
re_judge: bool = False,
15-
max_samples: int = 1,
1615
max_concurrent: int = 1000) -> NetworkXStorage:
1716
"""
1817
Get all edges and judge them
@@ -22,7 +21,6 @@ async def judge_relations(
2221
:param graph_storage: graph storage instance
2322
:param rephrase_storage: rephrase storage instance
2423
:param re_judge: re-judge the relations
25-
:param max_samples: max samples for each edge
2624
:param max_concurrent: max concurrent
2725
:return:
2826
"""
@@ -38,34 +36,14 @@ async def _judge_single_relation(
3836
edge_data = edge[2]
3937

4038
if (not re_judge) and "loss" in edge_data and edge_data["loss"] is not None:
41-
logger.info(f"Edge {source_id} -> {target_id} already judged, loss: {edge_data['loss']}, skip")
39+
logger.info("Edge %s -> %s already judged, loss: %s, skip", source_id, target_id, edge_data["loss"])
4240
return source_id, target_id, edge_data
4341

4442
description = edge_data["description"]
45-
language = "English" if detect_main_language(description) == "en" else "Chinese"
4643

4744
try:
48-
# 如果在rephrase_storage中已经存在,直接取出
4945
descriptions = await rephrase_storage.get_by_id(description)
50-
if not descriptions:
51-
# 多次采样,取平均
52-
descriptions = [(description, 'yes')]
53-
for i in range(max_samples):
54-
if i > 0:
55-
new_description = await teacher_llm_client.generate_answer(
56-
DESCRIPTION_REPHRASING_PROMPT[language]['TEMPLATE'].format(input_sentence=description),
57-
temperature=1
58-
)
59-
descriptions.append((new_description, 'yes'))
60-
new_anti_description = await teacher_llm_client.generate_answer(
61-
DESCRIPTION_REPHRASING_PROMPT[language]['ANTI_TEMPLATE'].format(input_sentence=description),
62-
temperature=1
63-
)
64-
descriptions.append((new_anti_description, 'no'))
65-
66-
descriptions = list(set(descriptions))
67-
68-
await rephrase_storage.upsert({description: descriptions})
46+
assert descriptions is not None
6947

7048
judgements = []
7149
gts = [gt for _, gt in descriptions]
@@ -81,7 +59,7 @@ async def _judge_single_relation(
8159

8260
edge_data["loss"] = loss
8361
except Exception as e: # pylint: disable=broad-except
84-
logger.error(f"Error in judging relation {source_id} -> {target_id}: {e}")
62+
logger.error("Error in judging relation %s -> %s: %s", source_id, target_id, e)
8563
logger.info("Use default loss 0.1")
8664
edge_data["loss"] = -math.log(0.1)
8765

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
import asyncio
2+
3+
from tqdm.asyncio import tqdm as tqdm_async
4+
from models import JsonKVStorage, OpenAIModel, NetworkXStorage
5+
from utils import logger, detect_main_language
6+
from templates import DESCRIPTION_REPHRASING_PROMPT
7+
8+
9+
async def quiz_relations(
10+
teacher_llm_client: OpenAIModel,
11+
graph_storage: NetworkXStorage,
12+
rephrase_storage: JsonKVStorage,
13+
max_samples: int = 1,
14+
max_concurrent: int = 1000) -> JsonKVStorage:
15+
"""
16+
Get all edges and quiz them
17+
18+
:param teacher_llm_client: generate statements
19+
:param graph_storage: graph storage instance
20+
:param rephrase_storage: rephrase storage instance
21+
:param max_samples: max samples for each edge
22+
:param max_concurrent: max concurrent
23+
:return:
24+
"""
25+
26+
semaphore = asyncio.Semaphore(max_concurrent)
27+
28+
async def _quiz_single_relation(
29+
edge: tuple,
30+
):
31+
async with semaphore:
32+
source_id = edge[0]
33+
target_id = edge[1]
34+
edge_data = edge[2]
35+
36+
description = edge_data["description"]
37+
language = "English" if detect_main_language(description) == "en" else "Chinese"
38+
39+
try:
40+
# 如果在rephrase_storage中已经存在,直接取出
41+
descriptions = await rephrase_storage.get_by_id(description)
42+
if not descriptions:
43+
# 多次采样,取平均
44+
descriptions = [(description, 'yes')]
45+
for i in range(max_samples):
46+
if i > 0:
47+
new_description = await teacher_llm_client.generate_answer(
48+
DESCRIPTION_REPHRASING_PROMPT[language]['TEMPLATE'].format(input_sentence=description),
49+
temperature=1
50+
)
51+
descriptions.append((new_description, 'yes'))
52+
new_anti_description = await teacher_llm_client.generate_answer(
53+
DESCRIPTION_REPHRASING_PROMPT[language]['ANTI_TEMPLATE'].format(input_sentence=description),
54+
temperature=1
55+
)
56+
descriptions.append((new_anti_description, 'no'))
57+
58+
descriptions = list(set(descriptions))
59+
except Exception as e: # pylint: disable=broad-except
60+
logger.error(f"Error when quizzing edge {source_id} -> {target_id}: {e}")
61+
descriptions = [(description, 'yes')]
62+
63+
await rephrase_storage.upsert({description: descriptions})
64+
65+
return {description: descriptions}
66+
67+
68+
edges = await graph_storage.get_all_edges()
69+
70+
results = []
71+
for result in tqdm_async(
72+
asyncio.as_completed([_quiz_single_relation(edge) for edge in edges]),
73+
total=len(edges),
74+
desc="Quizzing relations"
75+
):
76+
results.append(await result)
77+
78+
return rephrase_storage

0 commit comments

Comments
 (0)