Skip to content

Commit 7823e6e

Browse files
committed
Pipeline rerun + exception on cyclic graph (for now)
1 parent edb9af7 commit 7823e6e

File tree

4 files changed

+60
-2
lines changed

4 files changed

+60
-2
lines changed

src/neo4j_genai/components/rag.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ async def process(self, prompt: str) -> dict[str, Any]:
3636
pipe.add_component("generate", LLM())
3737
pipe.connect("retrieve", "augment", {"context": "retrieve.context"})
3838
pipe.connect("augment", "generate", {"prompt": "augment.prompt"})
39+
pipe.connect("generate", "retrieve", {"prompt": "augment.prompt"})
3940

4041
query = "my question"
4142
print(

src/neo4j_genai/core/graph.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,3 +70,41 @@ def previous_edges(self, node: Node) -> list[Edge]:
7070

7171
def __contains__(self, node: Node) -> bool:
7272
return node.name in self._nodes
73+
74+
def _check_stack(self, v, visited, stack):
75+
visited[v.name] = True
76+
stack[v.name] = True
77+
for linked_vertices in self.next_edges(v):
78+
next_node = linked_vertices.end
79+
if not visited[next_node.name]:
80+
if self._check_stack(next_node, visited, stack):
81+
return True
82+
elif stack[next_node.name]:
83+
return True
84+
stack[v.name] = False
85+
return False
86+
87+
def __is_cyclic(self) -> bool:
88+
visited = {node: False for node in self._nodes}
89+
stack = {node: False for node in self._nodes}
90+
for node_name, node in self._nodes.items():
91+
if not visited[node_name]:
92+
if self._check_stack(node, visited, stack):
93+
return True
94+
return False
95+
96+
def dfs(self, visited, node):
97+
if node in visited:
98+
return True
99+
else:
100+
for edge in self.next_edges(node):
101+
neighbour = edge.end
102+
if self.dfs(visited | {node}, neighbour):
103+
return True
104+
return False
105+
106+
def is_cyclic(self):
107+
for node in self._nodes.values():
108+
if self.dfs(set(), node):
109+
return True
110+
return False

src/neo4j_genai/core/pipeline.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,9 @@ def get_input_defs_from_parents(self) -> dict[str, str]:
8181
input_defs.update(**prev_edge_data)
8282
return input_defs
8383

84+
def reinitialize(self) -> None:
85+
self.status = RunStatus.SCHEDULED
86+
8487
async def run(
8588
self, data: dict[str, Any], callback: Callable[[Any, Any, Any], Awaitable[Any]]
8689
) -> None:
@@ -163,7 +166,6 @@ def __init__(self, store: Optional[Store] = None) -> None:
163166
super().__init__()
164167
self._store = store or InMemoryStore()
165168
self._final_results = InMemoryStore()
166-
self.orchestrator = Orchestrator(self)
167169

168170
def add_component(self, name: str, component: Component) -> None:
169171
task = TaskNode(name, component, self)
@@ -178,6 +180,8 @@ def connect( # type: ignore
178180
start_node = self.get_node_by_name(start_component_name, raise_exception=True)
179181
end_node = self.get_node_by_name(end_component_name, raise_exception=True)
180182
super().connect(start_node, end_node, data={"input_defs": input_defs})
183+
if self.is_cyclic():
184+
raise Exception("Cyclic graph")
181185

182186
def add_result_for_component(
183187
self, name: str, result: dict[str, Any], is_final: bool = False
@@ -208,6 +212,14 @@ def get_component_inputs(
208212
component_inputs[input_def] = value
209213
return component_inputs
210214

215+
def reinitialize(self):
216+
self._store.empty()
217+
self._final_results.empty()
218+
for task in self._nodes.values():
219+
task.reinitialize() # type: ignore
220+
211221
async def run(self, data: dict[str, Any]) -> dict[str, Any]:
212-
await self.orchestrator.run(data)
222+
self.reinitialize()
223+
orchestrator = Orchestrator(self)
224+
await orchestrator.run(data)
213225
return self._final_results.all()

src/neo4j_genai/core/stores.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,10 @@ def all(self) -> dict[str, Any]:
4242
"""
4343
raise NotImplementedError()
4444

45+
def empty(self) -> None:
46+
"""Remove everything from store"""
47+
raise NotImplementedError()
48+
4549

4650
class InMemoryStore(Store):
4751
"""Simple in-memory store.
@@ -67,3 +71,6 @@ def find_all(self, pattern: str) -> list[Any]:
6771

6872
def all(self) -> dict[str, Any]:
6973
return self._data
74+
75+
def empty(self) -> None:
76+
self._data = {}

0 commit comments

Comments
 (0)