Skip to content

fix: Add a new graph traversal that allows more than one edges out of a graph #250

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 45 additions & 32 deletions scrapegraphai/graphs/base_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import warnings
from langchain_community.callbacks import get_openai_callback
from typing import Tuple
from collections import deque


class BaseGraph:
Expand All @@ -26,6 +27,8 @@ class BaseGraph:

Raises:
Warning: If the entry point node is not the first node in the list.
ValueError: If conditional_node does not have exactly two outgoing edges


Example:
>>> BaseGraph(
Expand All @@ -48,7 +51,7 @@ def __init__(self, nodes: list, edges: list, entry_point: str):

self.nodes = nodes
self.edges = self._create_edges({e for e in edges})
self.entry_point = entry_point.node_name
self.entry_point = entry_point

if nodes[0].node_name != entry_point.node_name:
# raise a warning if the entry point is not the first node in the list
Expand All @@ -68,13 +71,16 @@ def _create_edges(self, edges: list) -> dict:

edge_dict = {}
for from_node, to_node in edges:
edge_dict[from_node.node_name] = to_node.node_name
if from_node in edge_dict:
edge_dict[from_node].append(to_node)
else:
edge_dict[from_node] = [to_node]
return edge_dict

def execute(self, initial_state: dict) -> Tuple[dict, list]:
"""
Executes the graph by traversing nodes starting from the entry point. The execution
follows the edges based on the result of each node's execution and continues until
Executes the graph by traversing nodes in breadth-first order starting from the entry point.
The execution follows the edges based on the result of each node's execution and continues until
it reaches a node with no outgoing edges.

Args:
Expand All @@ -84,7 +90,6 @@ def execute(self, initial_state: dict) -> Tuple[dict, list]:
Tuple[dict, list]: A tuple containing the final state and a list of execution info.
"""

current_node_name = self.nodes[0]
state = initial_state

# variables for tracking execution info
Expand All @@ -98,23 +103,22 @@ def execute(self, initial_state: dict) -> Tuple[dict, list]:
"total_cost_USD": 0.0,
}

for index in self.nodes:

queue = deque([self.entry_point])
while queue:
current_node = queue.popleft()
curr_time = time.time()
current_node = index

with get_openai_callback() as cb:
with get_openai_callback() as callback:
result = current_node.execute(state)
node_exec_time = time.time() - curr_time
total_exec_time += node_exec_time

cb = {
"node_name": index.node_name,
"total_tokens": cb.total_tokens,
"prompt_tokens": cb.prompt_tokens,
"completion_tokens": cb.completion_tokens,
"successful_requests": cb.successful_requests,
"total_cost_USD": cb.total_cost,
"node_name": current_node.node_name,
"total_tokens": callback.total_tokens,
"prompt_tokens": callback.prompt_tokens,
"completion_tokens": callback.completion_tokens,
"successful_requests": callback.successful_requests,
"total_cost_USD": callback.total_cost,
"exec_time": node_exec_time,
}

Expand All @@ -128,21 +132,30 @@ def execute(self, initial_state: dict) -> Tuple[dict, list]:
cb_total["successful_requests"] += cb["successful_requests"]
cb_total["total_cost_USD"] += cb["total_cost_USD"]

if current_node.node_type == "conditional_node":
current_node_name = result
elif current_node_name in self.edges:
current_node_name = self.edges[current_node_name]
else:
current_node_name = None

exec_info.append({
"node_name": "TOTAL RESULT",
"total_tokens": cb_total["total_tokens"],
"prompt_tokens": cb_total["prompt_tokens"],
"completion_tokens": cb_total["completion_tokens"],
"successful_requests": cb_total["successful_requests"],
"total_cost_USD": cb_total["total_cost_USD"],
"exec_time": total_exec_time,
})
if current_node in self.edges:
current_node_connections = self.edges[current_node]
if current_node.node_type == 'conditional_node':
# Assert that there are exactly two out edges from the conditional node
if len(current_node_connections) != 2:
raise ValueError(f"Conditional node should have exactly two out connections {current_node_connections.node_name}")
if result["next_node"] == 0:
queue.append(current_node_connections[0])
else:
queue.append(current_node_connections[1])
# remove the conditional node result
del result["next_node"]
else:
queue.extend(node for node in current_node_connections)


exec_info.append({
"node_name": "TOTAL RESULT",
"total_tokens": cb_total["total_tokens"],
"prompt_tokens": cb_total["prompt_tokens"],
"completion_tokens": cb_total["completion_tokens"],
"successful_requests": cb_total["successful_requests"],
"total_cost_USD": cb_total["total_cost_USD"],
"exec_time": total_exec_time,
})

return state, exec_info
31 changes: 10 additions & 21 deletions scrapegraphai/nodes/conditional_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,46 +13,33 @@ class ConditionalNode(BaseNode):
This node type is used to implement branching logic within the graph, allowing
for dynamic paths based on the data available in the current state.

It is expected thar exactly two edges are created out of this node.
The first node is chosen for execution if the key exists and has a non-empty value,
and the second node is chosen if the key does not exist or is empty.

Attributes:
key_name (str): The name of the key in the state to check for its presence.
next_nodes (list): A list of two node instances. The first node is chosen
for execution if the key exists and has a non-empty value,
and the second node is chosen if the key does not exist or
is empty.

Args:
key_name (str): The name of the key to check in the graph's state. This is
used to determine the path the graph's execution should take.
next_nodes (list): A list containing exactly two node instances, specifying
the next nodes to execute based on the condition's outcome.
node_name (str, optional): The unique identifier name for the node. Defaults
to "ConditionalNode".

Raises:
ValueError: If next_nodes does not contain exactly two elements, indicating
a misconfiguration in specifying the conditional paths.
"""

def __init__(self, key_name: str, next_nodes: list, node_name="ConditionalNode"):
def __init__(self, key_name: str, node_name="ConditionalNode"):
"""
Initializes the node with the key to check and the next node names based on the condition.

Args:
key_name (str): The name of the key to check in the state.
next_nodes (list): A list containing exactly two names of the next nodes.
The first is used if the key exists, the second if it does not.

Raises:
ValueError: If next_nodes does not contain exactly two elements.
"""

super().__init__(node_name, "conditional_node")
self.key_name = key_name
if len(next_nodes) != 2:
raise ValueError("next_nodes must contain exactly two elements.")
self.next_nodes = next_nodes

def execute(self, state: dict) -> str:
def execute(self, state: dict) -> dict:
"""
Checks if the specified key is present in the state and decides the next node accordingly.

Expand All @@ -64,5 +51,7 @@ def execute(self, state: dict) -> str:
"""

if self.key_name in state and len(state[self.key_name]) > 0:
return self.next_nodes[0].node_name
return self.next_nodes[1].node_name
state["next_node"] = 0
else:
state["next_node"] = 1
return state
Loading