Skip to content

feat: add conditional node to the smart_scraper_graph #754

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Oct 16, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,5 @@ undetected-playwright>=0.3.0
semchunk>=1.0.1
langchain-ollama>=0.1.3
simpleeval>=0.9.13
googlesearch-python>=1.2.5
googlesearch-python>=1.2.5
async_timeout>=4.0.3
7 changes: 6 additions & 1 deletion scrapegraphai/graphs/base_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,10 @@ def _set_conditional_node_edges(self):
raise ValueError(f"ConditionalNode '{node.node_name}' must have exactly two outgoing edges.")
# Assign true_node_name and false_node_name
node.true_node_name = outgoing_edges[0][1].node_name
node.false_node_name = outgoing_edges[1][1].node_name
try:
node.false_node_name = outgoing_edges[1][1].node_name
except:
node.false_node_name = None

def _execute_standard(self, initial_state: dict) -> Tuple[dict, list]:
"""
Expand Down Expand Up @@ -221,6 +224,8 @@ def _execute_standard(self, initial_state: dict) -> Tuple[dict, list]:
node_names = {node.node_name for node in self.nodes}
if result in node_names:
current_node_name = result
elif result is None:
current_node_name = None
else:
raise ValueError(f"Conditional Node returned a node name '{result}' that does not exist in the graph")

Expand Down
138 changes: 83 additions & 55 deletions scrapegraphai/graphs/smart_scraper_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,17 @@
SmartScraperGraph Module
"""
from typing import Optional
import logging
from pydantic import BaseModel
from .base_graph import BaseGraph
from .abstract_graph import AbstractGraph
from ..nodes import (
FetchNode,
ParseNode,
ReasoningNode,
GenerateAnswerNode
GenerateAnswerNode,
ConditionalNode
)
from ..prompts import REGEN_ADDITIONAL_INFO

class SmartScraperGraph(AbstractGraph):
"""
Expand Down Expand Up @@ -89,6 +90,28 @@ def _create_graph(self) -> BaseGraph:
}
)

cond_node = None
regen_node = None
if self.config.get("reattempt") is True:
cond_node = ConditionalNode(
input="answer",
output=["answer"],
node_name="ConditionalNode",
node_config={
"key_name": "answer",
"condition": 'not answer or answer=="NA"',
}
)
regen_node = GenerateAnswerNode(
input="user_prompt & answer",
output=["answer"],
node_config={
"llm_model": self.llm_model,
"additional_info": REGEN_ADDITIONAL_INFO,
"schema": self.schema,
}
)

if self.config.get("html_mode") is False:
parse_node = ParseNode(
input="doc",
Expand All @@ -99,6 +122,7 @@ def _create_graph(self) -> BaseGraph:
}
)

reasoning_node = None
if self.config.get("reasoning"):
reasoning_node = ReasoningNode(
input="user_prompt & (relevant_chunks | parsed_doc | doc)",
Expand All @@ -109,68 +133,72 @@ def _create_graph(self) -> BaseGraph:
"schema": self.schema,
}
)

# Define the graph variation configurations
# (html_mode, reasoning, reattempt)
graph_variation_config = {
(False, True, False): {
"nodes": [fetch_node, parse_node, reasoning_node, generate_answer_node],
"edges": [(fetch_node, parse_node), (parse_node, reasoning_node), (reasoning_node, generate_answer_node)]
},
(True, True, False): {
"nodes": [fetch_node, reasoning_node, generate_answer_node],
"edges": [(fetch_node, reasoning_node), (reasoning_node, generate_answer_node)]
},
(True, False, False): {
"nodes": [fetch_node, generate_answer_node],
"edges": [(fetch_node, generate_answer_node)]
},
(False, False, False): {
"nodes": [fetch_node, parse_node, generate_answer_node],
"edges": [(fetch_node, parse_node), (parse_node, generate_answer_node)]
},
(False, True, True): {
"nodes": [fetch_node, parse_node, reasoning_node, generate_answer_node, cond_node, regen_node],
"edges": [(fetch_node, parse_node), (parse_node, reasoning_node), (reasoning_node, generate_answer_node),
(generate_answer_node, cond_node), (cond_node, regen_node), (cond_node, None)]
},
(True, True, True): {
"nodes": [fetch_node, reasoning_node, generate_answer_node, cond_node, regen_node],
"edges": [(fetch_node, reasoning_node), (reasoning_node, generate_answer_node),
(generate_answer_node, cond_node), (cond_node, regen_node), (cond_node, None)]
},
(True, False, True): {
"nodes": [fetch_node, generate_answer_node, cond_node, regen_node],
"edges": [(fetch_node, generate_answer_node), (generate_answer_node, cond_node),
(cond_node, regen_node), (cond_node, None)]
},
(False, False, True): {
"nodes": [fetch_node, parse_node, generate_answer_node, cond_node, regen_node],
"edges": [(fetch_node, parse_node), (parse_node, generate_answer_node),
(generate_answer_node, cond_node), (cond_node, regen_node), (cond_node, None)]
}
}

if self.config.get("html_mode") is False and self.config.get("reasoning") is True:

return BaseGraph(
nodes=[
fetch_node,
parse_node,
reasoning_node,
generate_answer_node,
],
edges=[
(fetch_node, parse_node),
(parse_node, reasoning_node),
(reasoning_node, generate_answer_node)
],
entry_point=fetch_node,
graph_name=self.__class__.__name__
)

elif self.config.get("html_mode") is True and self.config.get("reasoning") is True:
# Get the current conditions
html_mode = self.config.get("html_mode", False)
reasoning = self.config.get("reasoning", False)
reattempt = self.config.get("reattempt", False)

return BaseGraph(
nodes=[
fetch_node,
reasoning_node,
generate_answer_node,
],
edges=[
(fetch_node, reasoning_node),
(reasoning_node, generate_answer_node)
],
entry_point=fetch_node,
graph_name=self.__class__.__name__
)
# Retrieve the appropriate graph configuration
config = graph_variation_config.get((html_mode, reasoning, reattempt))

elif self.config.get("html_mode") is True and self.config.get("reasoning") is False:
if config:
return BaseGraph(
nodes=[
fetch_node,
generate_answer_node,
],
edges=[
(fetch_node, generate_answer_node)
],
nodes=config["nodes"],
edges=config["edges"],
entry_point=fetch_node,
graph_name=self.__class__.__name__
)

# Default return if no conditions match
return BaseGraph(
nodes=[
fetch_node,
parse_node,
generate_answer_node,
],
edges=[
(fetch_node, parse_node),
(parse_node, generate_answer_node)
],
entry_point=fetch_node,
graph_name=self.__class__.__name__
)

nodes=[fetch_node, parse_node, generate_answer_node],
edges=[(fetch_node, parse_node), (parse_node, generate_answer_node)],
entry_point=fetch_node,
graph_name=self.__class__.__name__
)

def run(self) -> str:
"""
Executes the scraping process and returns the answer to the prompt.
Expand Down
2 changes: 1 addition & 1 deletion scrapegraphai/nodes/conditional_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def execute(self, state: dict) -> dict:
str: The name of the next node to execute based on the presence of the key.
"""

if self.true_node_name is None or self.false_node_name is None:
if self.true_node_name is None:
raise ValueError("ConditionalNode's next nodes are not set properly.")

if self.condition:
Expand Down
2 changes: 1 addition & 1 deletion scrapegraphai/prompts/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from .generate_answer_node_prompts import (TEMPLATE_CHUNKS,
TEMPLATE_NO_CHUNKS,
TEMPLATE_MERGE, TEMPLATE_CHUNKS_MD,
TEMPLATE_NO_CHUNKS_MD, TEMPLATE_MERGE_MD)
TEMPLATE_NO_CHUNKS_MD, TEMPLATE_MERGE_MD, REGEN_ADDITIONAL_INFO)
from .generate_answer_node_csv_prompts import (TEMPLATE_CHUKS_CSV,
TEMPLATE_NO_CHUKS_CSV,
TEMPLATE_MERGE_CSV)
Expand Down
4 changes: 4 additions & 0 deletions scrapegraphai/prompts/generate_answer_node_prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,7 @@
USER QUESTION: {question}\n
WEBSITE CONTENT: {context}\n
"""

REGEN_ADDITIONAL_INFO = """
You are a scraper and you have just failed to scrape the requested information from a website. \n
I want you to try again and provide the missing informations. \n"""
Loading