diff --git a/examples/groq/smart_scraper_multi_cond_groq.py b/examples/groq/smart_scraper_multi_cond_groq.py new file mode 100644 index 00000000..7e81cfd2 --- /dev/null +++ b/examples/groq/smart_scraper_multi_cond_groq.py @@ -0,0 +1,42 @@ +""" +Basic example of scraping pipeline using SmartScraperMultiConcatGraph with Groq +""" + +import os +import json +from dotenv import load_dotenv +from scrapegraphai.graphs import SmartScraperMultiCondGraph + +load_dotenv() + +# ************************************************ +# Define the configuration for the graph +# ************************************************ + +groq_key = os.getenv("GROQ_APIKEY") + +graph_config = { + "llm": { + "model": "groq/gemma-7b-it", + "api_key": groq_key, + "temperature": 0 + }, + "headless": False +} + +# ******************************************************* +# Create the SmartScraperMultiCondGraph instance and run it +# ******************************************************* + +multiple_search_graph = SmartScraperMultiCondGraph( + prompt="Who is Marco Perini?", + source=[ + "https://perinim.github.io/", + "https://perinim.github.io/cv/" + ], + schema=None, + config=graph_config +) + +result = multiple_search_graph.run() +print(json.dumps(result, indent=4)) diff --git a/requirements.txt b/requirements.txt index 8a29f1c8..e6e5d4d7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -18,3 +18,4 @@ undetected-playwright>=0.3.0 google>=3.0.0 semchunk>=1.0.1 langchain-ollama>=0.1.3 +simpleeval>=0.9.13 \ No newline at end of file diff --git a/scrapegraphai/graphs/__init__.py b/scrapegraphai/graphs/__init__.py index b5ffcc47..3415af3e 100644 --- a/scrapegraphai/graphs/__init__.py +++ b/scrapegraphai/graphs/__init__.py @@ -26,4 +26,5 @@ from .screenshot_scraper_graph import ScreenshotScraperGraph from .smart_scraper_multi_concat_graph import SmartScraperMultiConcatGraph from .code_generator_graph import CodeGeneratorGraph +from .smart_scraper_multi_cond_graph import SmartScraperMultiCondGraph from .depth_search_graph import DepthSearchGraph diff --git a/scrapegraphai/graphs/base_graph.py b/scrapegraphai/graphs/base_graph.py index 05f9773c..5fa9ff34 100644 --- a/scrapegraphai/graphs/base_graph.py +++ b/scrapegraphai/graphs/base_graph.py @@ -59,6 +59,8 @@ def __init__(self, nodes: list, edges: list, entry_point: str, # raise a warning if the entry point is not the first node in the list warnings.warn( "Careful! The entry point node is different from the first node in the graph.") + + self._set_conditional_node_edges() # Burr configuration self.use_burr = use_burr @@ -77,9 +79,24 @@ def _create_edges(self, edges: list) -> dict: edge_dict = {} for from_node, to_node in edges: - edge_dict[from_node.node_name] = to_node.node_name + if from_node.node_type != 'conditional_node': + edge_dict[from_node.node_name] = to_node.node_name return edge_dict + def _set_conditional_node_edges(self): + """ + Sets the true_node_name and false_node_name for each ConditionalNode. + """ + for node in self.nodes: + if node.node_type == 'conditional_node': + # Find outgoing edges from this ConditionalNode + outgoing_edges = [(from_node, to_node) for from_node, to_node in self.raw_edges if from_node.node_name == node.node_name] + if len(outgoing_edges) != 2: + raise ValueError(f"ConditionalNode '{node.node_name}' must have exactly two outgoing edges.") + # Assign true_node_name and false_node_name + node.true_node_name = outgoing_edges[0][1].node_name + node.false_node_name = outgoing_edges[1][1].node_name + def _execute_standard(self, initial_state: dict) -> Tuple[dict, list]: """ Executes the graph by traversing nodes starting from the @@ -201,7 +218,12 @@ def _execute_standard(self, initial_state: dict) -> Tuple[dict, list]: cb_total["total_cost_USD"] += cb_data["total_cost_USD"] if current_node.node_type == "conditional_node": - current_node_name = result + node_names = {node.node_name for node in self.nodes} + if result in node_names: + current_node_name = result + else: + raise ValueError(f"Conditional Node returned a node name '{result}' that does not exist in the graph") + elif current_node_name in self.edges: current_node_name = self.edges[current_node_name] else: diff --git a/scrapegraphai/graphs/smart_scraper_multi_cond_graph.py b/scrapegraphai/graphs/smart_scraper_multi_cond_graph.py new file mode 100644 index 00000000..278e3905 --- /dev/null +++ b/scrapegraphai/graphs/smart_scraper_multi_cond_graph.py @@ -0,0 +1,130 @@ +""" +SmartScraperMultiCondGraph Module with ConditionalNode +""" +from copy import deepcopy +from typing import List, Optional +from pydantic import BaseModel +from .base_graph import BaseGraph +from .abstract_graph import AbstractGraph +from .smart_scraper_graph import SmartScraperGraph +from ..nodes import ( + GraphIteratorNode, + MergeAnswersNode, + ConcatAnswersNode, + ConditionalNode +) +from ..utils.copy import safe_deepcopy + +class SmartScraperMultiCondGraph(AbstractGraph): + """ + SmartScraperMultiConditionalGraph is a scraping pipeline that scrapes a + list of URLs and generates answers to a given prompt. + + Attributes: + prompt (str): The user prompt to search the internet. + llm_model (dict): The configuration for the language model. + embedder_model (dict): The configuration for the embedder model. + headless (bool): A flag to run the browser in headless mode. + verbose (bool): A flag to display the execution information. + model_token (int): The token limit for the language model. + + Args: + prompt (str): The user prompt to search the internet. + source (List[str]): The source of the graph. + config (dict): Configuration parameters for the graph. + schema (Optional[BaseModel]): The schema for the graph output. + + Example: + >>> search_graph = MultipleSearchGraph( + ... "What is Chioggia famous for?", + ... {"llm": {"model": "openai/gpt-3.5-turbo"}} + ... ) + >>> result = search_graph.run() + """ + + def __init__(self, prompt: str, source: List[str], + config: dict, schema: Optional[BaseModel] = None): + + self.max_results = config.get("max_results", 3) + self.copy_config = safe_deepcopy(config) + self.copy_schema = deepcopy(schema) + + super().__init__(prompt, config, source, schema) + + def _create_graph(self) -> BaseGraph: + """ + Creates the graph of nodes representing the workflow for web scraping and searching, + including a ConditionalNode to decide between merging or concatenating the results. + + Returns: + BaseGraph: A graph instance representing the web scraping and searching workflow. + """ + + # Node that iterates over the URLs and collects results + graph_iterator_node = GraphIteratorNode( + input="user_prompt & urls", + output=["results"], + node_config={ + "graph_instance": SmartScraperGraph, + "scraper_config": self.copy_config, + }, + schema=self.copy_schema, + node_name="GraphIteratorNode" + ) + + # ConditionalNode to check if len(results) > 2 + conditional_node = ConditionalNode( + input="results", + output=["results"], + node_name="ConditionalNode", + node_config={ + 'key_name': 'results', + 'condition': 'len(results) > 2' + } + ) + + merge_answers_node = MergeAnswersNode( + input="user_prompt & results", + output=["answer"], + node_config={ + "llm_model": self.llm_model, + "schema": self.copy_schema + }, + node_name="MergeAnswersNode" + ) + + concat_node = ConcatAnswersNode( + input="results", + output=["answer"], + node_config={}, + node_name="ConcatNode" + ) + + # Build the graph + return BaseGraph( + nodes=[ + graph_iterator_node, + conditional_node, + merge_answers_node, + concat_node, + ], + edges=[ + (graph_iterator_node, conditional_node), + (conditional_node, merge_answers_node), # True node (len(results) > 2) + (conditional_node, concat_node), # False node (len(results) <= 2) + ], + entry_point=graph_iterator_node, + graph_name=self.__class__.__name__ + ) + + def run(self) -> str: + """ + Executes the web scraping and searching process. + + Returns: + str: The answer to the prompt. + """ + inputs = {"user_prompt": self.prompt, "urls": self.source} + self.final_state, self.execution_info = self.graph.execute(inputs) + + return self.final_state.get("answer", "No answer found.") diff --git a/scrapegraphai/nodes/__init__.py b/scrapegraphai/nodes/__init__.py index edb195a5..72457b4f 100644 --- a/scrapegraphai/nodes/__init__.py +++ b/scrapegraphai/nodes/__init__.py @@ -27,6 +27,7 @@ from .html_analyzer_node import HtmlAnalyzerNode from .generate_code_node import GenerateCodeNode from .search_node_with_context import SearchLinksWithContext +from .conditional_node import ConditionalNode from .reasoning_node import ReasoningNode from .fetch_node_level_k import FetchNodeLevelK from .generate_answer_node_k_level import GenerateAnswerNodeKLevel diff --git a/scrapegraphai/nodes/conditional_node.py b/scrapegraphai/nodes/conditional_node.py index 4aabce5d..238d2919 100644 --- a/scrapegraphai/nodes/conditional_node.py +++ b/scrapegraphai/nodes/conditional_node.py @@ -3,6 +3,7 @@ """ from typing import Optional, List from .base_node import BaseNode +from simpleeval import simple_eval, EvalWithCompoundTypes class ConditionalNode(BaseNode): """ @@ -28,13 +29,28 @@ class ConditionalNode(BaseNode): """ - def __init__(self): + def __init__(self, + input: str, + output: List[str], + node_config: Optional[dict] = None, + node_name: str = "Cond",): """ Initializes an empty ConditionalNode. """ - #super().__init__(node_name, "node", input, output, 2, node_config) - pass + super().__init__(node_name, "conditional_node", input, output, 2, node_config) + + try: + self.key_name = self.node_config["key_name"] + except: + raise NotImplementedError("You need to provide key_name inside the node config") + + self.true_node_name = None + self.false_node_name = None + self.condition = self.node_config.get("condition", None) + + self.eval_instance = EvalWithCompoundTypes() + self.eval_instance.functions = {'len': len} def execute(self, state: dict) -> dict: """ @@ -47,4 +63,45 @@ def execute(self, state: dict) -> dict: str: The name of the next node to execute based on the presence of the key. """ - pass + if self.true_node_name is None or self.false_node_name is None: + raise ValueError("ConditionalNode's next nodes are not set properly.") + + # Evaluate the condition + if self.condition: + condition_result = self._evaluate_condition(state, self.condition) + else: + # Default behavior: check existence and non-emptiness of key_name + value = state.get(self.key_name) + condition_result = value is not None and value != '' + + # Return the appropriate next node name + if condition_result: + return self.true_node_name + else: + return self.false_node_name + + def _evaluate_condition(self, state: dict, condition: str) -> bool: + """ + Parses and evaluates the condition expression against the state. + + Args: + state (dict): The current state of the graph. + condition (str): The condition expression to evaluate. + + Returns: + bool: The result of the condition evaluation. + """ + # Combine state and allowed functions for evaluation context + eval_globals = self.eval_instance.functions.copy() + eval_globals.update(state) + + try: + result = simple_eval( + condition, + names=eval_globals, + functions=self.eval_instance.functions, + operators=self.eval_instance.operators + ) + return bool(result) + except Exception as e: + raise ValueError(f"Error evaluating condition '{condition}' in {self.node_name}: {e}") \ No newline at end of file