Skip to content

Commit e9f66df

Browse files
feat: add multi-hop generation
1 parent 2b80fd5 commit e9f66df

File tree

7 files changed

+185
-9
lines changed

7 files changed

+185
-9
lines changed

configs/graphgen_config.yaml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
qa_form: atomic
1+
qa_form: multi-hop
22
data_type: raw
33
input_file: resources/examples/raw_demo.jsonl
44
tokenizer: cl100k_base
@@ -10,9 +10,9 @@ traverse_strategy:
1010
- medium
1111
- medium
1212
edge_sampling: max_loss
13-
expand_method: max_tokens
13+
expand_method: max_width
1414
isolated_node_strategy: ignore
15-
max_depth: 5
16-
max_extra_edges: 5
15+
max_depth: 1
16+
max_extra_edges: 2
1717
max_tokens: 256
1818
web_search: false

graphgen/graphgen.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,14 @@
55
import time
66
from typing import List, cast, Union
77
from dataclasses import dataclass
8+
89
from tqdm.asyncio import tqdm as tqdm_async
910

1011
from models import Chunk, JsonKVStorage, OpenAIModel, NetworkXStorage, WikiSearch, Tokenizer, TraverseStrategy
1112
from models.storage.base_storage import StorageNameSpace
1213
from utils import create_event_loop, logger, compute_content_hash
1314
from .operators import (extract_kg, search_wikipedia, quiz, judge_statement, traverse_graph_by_edge,
14-
traverse_graph_atomically)
15+
traverse_graph_atomically, traverse_graph_for_multi_hop)
1516

1617

1718
sys_path = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
@@ -195,6 +196,12 @@ async def async_traverse(self):
195196
self.graph_storage,
196197
self.traverse_strategy,
197198
self.text_chunks_storage)
199+
elif self.traverse_strategy.qa_form == "multi_hop":
200+
results = await traverse_graph_for_multi_hop(self.synthesizer_llm_client,
201+
self.tokenizer_instance,
202+
self.graph_storage,
203+
self.traverse_strategy,
204+
self.text_chunks_storage)
198205
else:
199206
results = await traverse_graph_by_edge(self.synthesizer_llm_client, self.tokenizer_instance,
200207
self.graph_storage, self.traverse_strategy, self.text_chunks_storage)

graphgen/operators/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,14 @@
22
from .quiz import quiz
33
from .judge import judge_statement
44
from .search_wikipedia import search_wikipedia
5-
from .traverse_graph import traverse_graph_by_edge, traverse_graph_atomically
5+
from .traverse_graph import traverse_graph_by_edge, traverse_graph_atomically, traverse_graph_for_multi_hop
66

77
__all__ = [
88
"extract_kg",
99
"quiz",
1010
"judge_statement",
1111
"search_wikipedia",
1212
"traverse_graph_by_edge",
13-
"traverse_graph_atomically"
13+
"traverse_graph_atomically",
14+
"traverse_graph_for_multi_hop"
1415
]

graphgen/operators/traverse_graph.py

Lines changed: 108 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from tqdm.asyncio import tqdm as tqdm_async
44

55
from models import OpenAIModel, NetworkXStorage, TraverseStrategy, Tokenizer, JsonKVStorage
6-
from templates import ANSWER_REPHRASING_PROMPT, QUESTION_GENERATION_PROMPT
6+
from templates import ANSWER_REPHRASING_PROMPT, QUESTION_GENERATION_PROMPT, MULTI_HOP_GENERATION_PROMPT
77
from utils import detect_main_language, compute_content_hash, logger
88
from graphgen.operators.split_graph import get_batches_with_strategy
99

@@ -399,3 +399,110 @@ async def _generate_question(
399399
except Exception as e: # pylint: disable=broad-except
400400
logger.error("Error occurred while generating questions: %s", e)
401401
return results
402+
403+
async def traverse_graph_for_multi_hop(
404+
llm_client: OpenAIModel,
405+
tokenizer: Tokenizer,
406+
graph_storage: NetworkXStorage,
407+
traverse_strategy: TraverseStrategy,
408+
text_chunks_storage: JsonKVStorage,
409+
max_concurrent: int = 1000
410+
) -> dict:
411+
"""
412+
Traverse the graph for multi-hop
413+
414+
:param llm_client
415+
:param tokenizer
416+
:param graph_storage
417+
:param traverse_strategy
418+
:param text_chunks_storage
419+
:param max_concurrent
420+
:return: question and answer
421+
"""
422+
assert traverse_strategy.qa_form == "multi_hop"
423+
424+
semaphore = asyncio.Semaphore(max_concurrent)
425+
426+
results = {}
427+
edges = list(await graph_storage.get_all_edges())
428+
nodes = list(await graph_storage.get_all_nodes())
429+
430+
edges, nodes = await _pre_tokenize(graph_storage, tokenizer, edges, nodes)
431+
432+
processing_batches = await get_batches_with_strategy(
433+
nodes,
434+
edges,
435+
graph_storage,
436+
traverse_strategy
437+
)
438+
439+
processing_batches = assign_difficulty(processing_batches, traverse_strategy.difficulty_order)
440+
441+
async def _process_single_batch(
442+
_process_batch: tuple
443+
) -> dict:
444+
async with semaphore:
445+
try:
446+
language = "Chinese" if detect_main_language(_process_batch[0][0]['description']) == "zh" else "English"
447+
448+
_process_nodes = _process_batch[0]
449+
_process_edges = _process_batch[1]
450+
451+
entities = [
452+
f"{_process_node['node_id']}: {_process_node['description']}" for _process_node in _process_nodes
453+
]
454+
455+
relations = [
456+
f"{_process_edge[0]} -- {_process_edge[1]}: {_process_edge[2]['description']}"
457+
for _process_edge in _process_edges
458+
]
459+
460+
entities_str = "\n".join([f"{index + 1}. {entity}" for index, entity in enumerate(entities)])
461+
relations_str = "\n".join([f"{index + 1}. {relation}" for index, relation in enumerate(relations)])
462+
463+
prompt = MULTI_HOP_GENERATION_PROMPT[language].format(
464+
entities=entities_str,
465+
relationships=relations_str
466+
)
467+
468+
context = await llm_client.generate_answer(prompt)
469+
470+
# post-process the context
471+
if "Question:" in context and "Answer:" in context:
472+
question = context.split("Question:")[1].split("Answer:")[0].strip()
473+
answer = context.split("Answer:")[1].strip()
474+
elif "问题:" in context and "答案:" in context:
475+
question = context.split("问题:")[1].split("答案:")[0].strip()
476+
answer = context.split("答案:")[1].strip()
477+
else:
478+
return {}
479+
480+
question = question.strip("\"")
481+
answer = answer.strip("\"")
482+
483+
logger.info("Question: %s", question)
484+
logger.info("Answer: %s", answer)
485+
486+
return {
487+
compute_content_hash(question): {
488+
"question": question,
489+
"answer": answer,
490+
"loss": get_average_loss(_process_batch),
491+
"difficulty": _process_batch[2],
492+
}
493+
}
494+
495+
except Exception as e: # pylint: disable=broad-except
496+
logger.error("Error occurred while processing batch: %s", e)
497+
return {}
498+
499+
for result in tqdm_async(
500+
asyncio.as_completed([_process_single_batch(batch) for batch in processing_batches]),
501+
total=len(processing_batches),
502+
desc="Processing batches"
503+
):
504+
try:
505+
results.update(await result)
506+
except Exception as e: # pylint: disable=broad-except
507+
logger.error("Error occurred while processing batches: %s", e)
508+
return results

models/strategy/travserse_strategy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
@dataclass
77
class TraverseStrategy(BaseStrategy):
88
# 生成的QA形式:原子、多跳、开放性
9-
qa_form: str = "atomic"
9+
qa_form: str = "multi_hop" # "atomic" or "multi_hop" or "open"
1010
# 最大边数和最大token数方法中选择一个生效
1111
expand_method: str = "max_tokens" # "max_width" or "max_tokens"
1212
# 单向拓展还是双向拓展

templates/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,4 @@
55
from .statement_judgement import STATEMENT_JUDGEMENT_PROMPT
66
from .answer_rephrasing import ANSWER_REPHRASING_PROMPT
77
from .question_generation import QUESTION_GENERATION_PROMPT
8+
from .multi_hop_generation import MULTI_HOP_GENERATION_PROMPT

templates/multi_hop_generation.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
# pylint: disable=C0301
2+
3+
TEMPLATE_ZH: str = """请基于以下知识子图生成多跳推理问题和答案。你将获得一个知识子图,其中包含一系列实体、关系和事实。你的任务是提出一个问题,该问题需要经过多次推理才能回答。问题的答案应该是从给定的知识子图中推断出来的。确保问题的难度适中,需要多步推理才能回答。
4+
5+
例如:
6+
########
7+
--实体--
8+
1. 苹果
9+
2. 水果
10+
3. 维生素C
11+
########
12+
--关系--
13+
1. 苹果-水果:苹果是一种水果
14+
2. 水果-维生素C:水果中富含维生素C
15+
########
16+
问题:通过吃苹果补充的什么物质,有助于维持健康?
17+
答案:维生素C
18+
########
19+
20+
#########
21+
--实体--
22+
{entities}
23+
#########
24+
--关系--
25+
{relationships}
26+
#########
27+
直接输出生成的问题和答案,请不要直接复制示例问题和答案,不要输出无关内容。
28+
"""
29+
30+
TEMPLATE_EN: str = """Please generate a multi-hop reasoning question and answer based on the following knowledge subgraph. You will be provided with a knowledge subgraph that contains a series of entities, relations, and facts. Your task is to generate a question that requires multiple steps of reasoning to answer. The answer to the question should be inferred from the given knowledge subgraph. Ensure that the question is of moderate difficulty and requires multiple steps of reasoning to answer.
31+
32+
For example:
33+
########
34+
--Entities--
35+
1. Apple
36+
2. Fruit
37+
3. Vitamin C
38+
########
39+
--Relations--
40+
1. Apple-Fruit: Apple is a type of fruit
41+
2. Fruit-Vitamin C: Fruits are rich in Vitamin C
42+
########
43+
Question: What substance, obtained through eating apples, helps maintain health?
44+
Answer: Vitamin C
45+
########
46+
47+
########
48+
--Entities--
49+
{entities}
50+
########
51+
--Relations--
52+
{relationships}
53+
########
54+
Output the generated question and answer directly, please do not copy the example question and answer directly, and do not provide irrelevant information.
55+
"""
56+
57+
MULTI_HOP_GENERATION_PROMPT = {
58+
"English": TEMPLATE_EN,
59+
"Chinese": TEMPLATE_ZH
60+
}

0 commit comments

Comments
 (0)