diff --git a/scrapegraphai/graphs/base_graph.py b/scrapegraphai/graphs/base_graph.py index 867d774f..c0effda8 100644 --- a/scrapegraphai/graphs/base_graph.py +++ b/scrapegraphai/graphs/base_graph.py @@ -6,6 +6,7 @@ import warnings from langchain_community.callbacks import get_openai_callback from typing import Tuple +from collections import deque class BaseGraph: @@ -26,6 +27,8 @@ class BaseGraph: Raises: Warning: If the entry point node is not the first node in the list. + ValueError: If conditional_node does not have exactly two outgoing edges + Example: >>> BaseGraph( @@ -48,7 +51,7 @@ def __init__(self, nodes: list, edges: list, entry_point: str): self.nodes = nodes self.edges = self._create_edges({e for e in edges}) - self.entry_point = entry_point.node_name + self.entry_point = entry_point if nodes[0].node_name != entry_point.node_name: # raise a warning if the entry point is not the first node in the list @@ -68,13 +71,16 @@ def _create_edges(self, edges: list) -> dict: edge_dict = {} for from_node, to_node in edges: - edge_dict[from_node.node_name] = to_node.node_name + if from_node in edge_dict: + edge_dict[from_node].append(to_node) + else: + edge_dict[from_node] = [to_node] return edge_dict def execute(self, initial_state: dict) -> Tuple[dict, list]: """ - Executes the graph by traversing nodes starting from the entry point. The execution - follows the edges based on the result of each node's execution and continues until + Executes the graph by traversing nodes in breadth-first order starting from the entry point. + The execution follows the edges based on the result of each node's execution and continues until it reaches a node with no outgoing edges. Args: @@ -84,7 +90,6 @@ def execute(self, initial_state: dict) -> Tuple[dict, list]: Tuple[dict, list]: A tuple containing the final state and a list of execution info. """ - current_node_name = self.nodes[0] state = initial_state # variables for tracking execution info @@ -98,23 +103,22 @@ def execute(self, initial_state: dict) -> Tuple[dict, list]: "total_cost_USD": 0.0, } - for index in self.nodes: - + queue = deque([self.entry_point]) + while queue: + current_node = queue.popleft() curr_time = time.time() - current_node = index - - with get_openai_callback() as cb: + with get_openai_callback() as callback: result = current_node.execute(state) node_exec_time = time.time() - curr_time total_exec_time += node_exec_time cb = { - "node_name": index.node_name, - "total_tokens": cb.total_tokens, - "prompt_tokens": cb.prompt_tokens, - "completion_tokens": cb.completion_tokens, - "successful_requests": cb.successful_requests, - "total_cost_USD": cb.total_cost, + "node_name": current_node.node_name, + "total_tokens": callback.total_tokens, + "prompt_tokens": callback.prompt_tokens, + "completion_tokens": callback.completion_tokens, + "successful_requests": callback.successful_requests, + "total_cost_USD": callback.total_cost, "exec_time": node_exec_time, } @@ -128,21 +132,30 @@ def execute(self, initial_state: dict) -> Tuple[dict, list]: cb_total["successful_requests"] += cb["successful_requests"] cb_total["total_cost_USD"] += cb["total_cost_USD"] - if current_node.node_type == "conditional_node": - current_node_name = result - elif current_node_name in self.edges: - current_node_name = self.edges[current_node_name] - else: - current_node_name = None - - exec_info.append({ - "node_name": "TOTAL RESULT", - "total_tokens": cb_total["total_tokens"], - "prompt_tokens": cb_total["prompt_tokens"], - "completion_tokens": cb_total["completion_tokens"], - "successful_requests": cb_total["successful_requests"], - "total_cost_USD": cb_total["total_cost_USD"], - "exec_time": total_exec_time, - }) + if current_node in self.edges: + current_node_connections = self.edges[current_node] + if current_node.node_type == 'conditional_node': + # Assert that there are exactly two out edges from the conditional node + if len(current_node_connections) != 2: + raise ValueError(f"Conditional node should have exactly two out connections {current_node_connections.node_name}") + if result["next_node"] == 0: + queue.append(current_node_connections[0]) + else: + queue.append(current_node_connections[1]) + # remove the conditional node result + del result["next_node"] + else: + queue.extend(node for node in current_node_connections) + + + exec_info.append({ + "node_name": "TOTAL RESULT", + "total_tokens": cb_total["total_tokens"], + "prompt_tokens": cb_total["prompt_tokens"], + "completion_tokens": cb_total["completion_tokens"], + "successful_requests": cb_total["successful_requests"], + "total_cost_USD": cb_total["total_cost_USD"], + "exec_time": total_exec_time, + }) return state, exec_info diff --git a/scrapegraphai/nodes/conditional_node.py b/scrapegraphai/nodes/conditional_node.py index 4ee2da85..33731a9d 100644 --- a/scrapegraphai/nodes/conditional_node.py +++ b/scrapegraphai/nodes/conditional_node.py @@ -13,46 +13,33 @@ class ConditionalNode(BaseNode): This node type is used to implement branching logic within the graph, allowing for dynamic paths based on the data available in the current state. + It is expected thar exactly two edges are created out of this node. + The first node is chosen for execution if the key exists and has a non-empty value, + and the second node is chosen if the key does not exist or is empty. + Attributes: key_name (str): The name of the key in the state to check for its presence. - next_nodes (list): A list of two node instances. The first node is chosen - for execution if the key exists and has a non-empty value, - and the second node is chosen if the key does not exist or - is empty. Args: key_name (str): The name of the key to check in the graph's state. This is used to determine the path the graph's execution should take. - next_nodes (list): A list containing exactly two node instances, specifying - the next nodes to execute based on the condition's outcome. node_name (str, optional): The unique identifier name for the node. Defaults to "ConditionalNode". - Raises: - ValueError: If next_nodes does not contain exactly two elements, indicating - a misconfiguration in specifying the conditional paths. """ - def __init__(self, key_name: str, next_nodes: list, node_name="ConditionalNode"): + def __init__(self, key_name: str, node_name="ConditionalNode"): """ Initializes the node with the key to check and the next node names based on the condition. Args: key_name (str): The name of the key to check in the state. - next_nodes (list): A list containing exactly two names of the next nodes. - The first is used if the key exists, the second if it does not. - - Raises: - ValueError: If next_nodes does not contain exactly two elements. """ super().__init__(node_name, "conditional_node") self.key_name = key_name - if len(next_nodes) != 2: - raise ValueError("next_nodes must contain exactly two elements.") - self.next_nodes = next_nodes - def execute(self, state: dict) -> str: + def execute(self, state: dict) -> dict: """ Checks if the specified key is present in the state and decides the next node accordingly. @@ -64,5 +51,7 @@ def execute(self, state: dict) -> str: """ if self.key_name in state and len(state[self.key_name]) > 0: - return self.next_nodes[0].node_name - return self.next_nodes[1].node_name + state["next_node"] = 0 + else: + state["next_node"] = 1 + return state