Skip to content

Commit ca5821a

Browse files
VinciGit00DiTo97
andcommitted
feat: merge graphs node implementation
Co-Authored-By: Federico Minutoli <40361744+DiTo97@users.noreply.github.com>
1 parent 0e85e8f commit ca5821a

File tree

2 files changed

+44
-12
lines changed

2 files changed

+44
-12
lines changed

scrapegraphai/graphs/explore_graph.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,10 +89,10 @@ def _create_graph(self) -> BaseGraph:
8989
"schema": self.schema,
9090
}
9191
)
92-
92+
9393
search_link_node = SearchLinkNode(
9494
input="doc",
95-
output=[{"link": "description"}],
95+
output=[{"relevant_links"}],
9696
node_config={
9797
"llm_model": self.llm_model,
9898
}
@@ -115,7 +115,7 @@ def _create_graph(self) -> BaseGraph:
115115
entry_point=fetch_node
116116
)
117117

118-
def run(self) -> str:
118+
def run(self) -> tuple[str, dict]:
119119
"""
120120
Executes the scraping process and returns the answer to the prompt.
121121
@@ -125,4 +125,5 @@ def run(self) -> str:
125125
inputs = {"user_prompt": self.prompt, self.input_key: self.source}
126126
self.final_state, self.execution_info = self.graph.execute(inputs)
127127

128-
return self.final_state.get("answer", "No answer found.")
128+
return (self.final_state.get("answer", "No answer found."),
129+
self.final_state.get("relevant_links", dict()))

scrapegraphai/nodes/merge_explore_graphs_node.py

Lines changed: 39 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44

55
# Imports from standard library
66
from typing import List, Optional
7-
8-
# Imports from Langchain
7+
from functools import reduce
8+
import operator
99
from langchain.prompts import PromptTemplate
1010
from langchain_core.output_parsers import JsonOutputParser
1111
from tqdm import tqdm
@@ -68,20 +68,51 @@ def execute(self, state: dict) -> dict:
6868

6969
self.logger.info(f"--- Executing {self.node_name} Node ---")
7070

71-
template_answer = ""
71+
# merge the answers in one string
72+
73+
74+
# Initialize the output parser
75+
if self.node_config.get("schema", None) is not None:
76+
output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"])
77+
else:
78+
output_parser = JsonOutputParser()
79+
80+
format_instructions = output_parser.get_format_instructions()
81+
82+
template_answer = """
83+
You are a website scraper and you have just scraped some content from multiple websites.\n
84+
You are now asked to provide an answer to a USER PROMPT based on the content you have scraped.\n
85+
You need to merge the content from the different websites into a single answer without repetitions (if there are any). \n
86+
The scraped contents are in a JSON format and you need to merge them based on the context and providing a correct JSON structure.\n
87+
OUTPUT INSTRUCTIONS: {format_instructions}\n
88+
USER PROMPT: {user_prompt}\n
89+
WEBSITE CONTENT: {website_content}
90+
"""
91+
92+
input_keys = self.get_input_keys(state)
93+
# Fetching data from the state based on the input keys
94+
input_data = [state[key] for key in input_keys]
95+
96+
user_prompt = input_data[0]
97+
#answers is a list of strings
98+
answers, relevant_links = zip(*input_data[1])
99+
100+
answers_str = ""
101+
for i, answer in enumerate(answers):
102+
answers_str += f"CONTENT WEBSITE {i+1}: {answer}\n"
72103

73104
answers = str(state.get("answer"))
74105
relevant_links = str(state.get("relevant_links"))
75106
answer = {}
76107

77108
merge_prompt = PromptTemplate(
78109
template=template_answer,
79-
#input_variables=["context", "question"],
80-
#partial_variables={"format_instructions": format_instructions},
110+
input_variables=["context", "question"],
111+
partial_variables={"format_instructions": format_instructions},
81112
)
82113

83-
#answer = merge_prompt.invoke({"question": user_prompt})
114+
answer = merge_prompt.invoke({"question": user_prompt})
84115

85-
state.update({"relevant_links": "TODO"})
86-
state.update({"answer": "TODO"})
116+
state.update({"relevant_links": reduce(operator.ior, relevant_links, {})})
117+
state.update({"answer": answer})
87118
return state

0 commit comments

Comments
 (0)